Spaces:
Sleeping
Sleeping
| import joblib | |
| import uvicorn | |
| import xgboost as xgb | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, UploadFile, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.responses import JSONResponse | |
| import asyncio | |
| import json | |
| import pickle | |
| import warnings | |
| import os | |
| import io | |
| import timeit | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| # Add this to your existing imports if not already present | |
| from fastapi.openapi.docs import get_swagger_ui_html | |
| from fastapi.openapi.utils import get_openapi | |
| from models.detr_model import DETR | |
| from models.glpn_model import GLPDepth | |
| from models.lstm_model import LSTM_Model | |
| from models.predict_z_location_single_row_lstm import predict_z_location_single_row_lstm | |
| from utils.processing import PROCESSING | |
| from config import CONFIG | |
| warnings.filterwarnings("ignore") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Real-Time WebSocket Image Processing API", | |
| description="API for object detection and depth estimation using WebSocket for real-time image processing.", | |
| ) | |
| try: | |
| # Load models and utilities | |
| device = CONFIG['device'] | |
| print("Loading models...") | |
| detr = DETR() # Object detection model (DETR) | |
| print("DETR model loaded.") | |
| glpn = GLPDepth() # Depth estimation model (GLPN) | |
| print("GLPDepth model loaded.") | |
| zlocE_LSTM = LSTM_Model() # LSTM model for prediction (e.g., localization) | |
| print("LSTM model loaded.") | |
| lstm_scaler = pickle.load(open(CONFIG['lstm_scaler_path'], 'rb')) # Load pre-trained scaler for LSTM | |
| print("LSTM Scaler loaded.") | |
| processing = PROCESSING() # Utility class for post-processing | |
| print("Processing utilities loaded.") | |
| except Exception as e: | |
| print(f"An unexpected error occurred. Details: {e}") | |
| # Serve HTML documentation for the API | |
| async def get_docs(): | |
| """ | |
| Serve HTML documentation for the WebSocket-based image processing API. | |
| The HTML file must be available in the same directory. | |
| Returns a 404 error if the documentation file is not found. | |
| """ | |
| html_path = os.path.join(os.path.dirname(__file__), "api_documentation.html") | |
| if not os.path.exists(html_path): | |
| return HTMLResponse(content="api_documentation.html file not found", status_code=404) | |
| with open(html_path, "r") as f: | |
| return HTMLResponse(f.read()) | |
| async def get_docs(): | |
| """ | |
| Serve HTML documentation for the WebSocket-based image processing API. | |
| The HTML file must be available in the same directory. | |
| Returns a 404 error if the documentation file is not found. | |
| """ | |
| html_path = os.path.join(os.path.dirname(__file__), "try_page.html") | |
| if not os.path.exists(html_path): | |
| return HTMLResponse(content="try_page.html file not found", status_code=404) | |
| with open(html_path, "r") as f: | |
| return HTMLResponse(f.read()) | |
| # Function to decode the image received via WebSocket | |
| async def decode_image(image_bytes): | |
| """ | |
| Decodes image bytes into a PIL Image and returns the image along with its shape. | |
| Args: | |
| image_bytes (bytes): The image data received from the client. | |
| Returns: | |
| tuple: A tuple containing: | |
| - pil_image (PIL.Image): The decoded image. | |
| - img_shape (tuple): Shape of the image as (height, width). | |
| - decode_time (float): Time taken to decode the image in seconds. | |
| Raises: | |
| ValueError: If image decoding fails. | |
| """ | |
| start = timeit.default_timer() | |
| nparr = np.frombuffer(image_bytes, np.uint8) | |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if frame is None: | |
| raise ValueError("Failed to decode image") | |
| color_converted = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(color_converted) | |
| img_shape = color_converted.shape[0:2] # (height, width) | |
| end = timeit.default_timer() | |
| return pil_image, img_shape, end - start | |
| # Function to run the DETR model for object detection | |
| async def run_detr_model(pil_image): | |
| """ | |
| Runs the DETR (DEtection TRansformer) model to perform object detection on the input image. | |
| Args: | |
| pil_image (PIL.Image): The image to be processed by the DETR model. | |
| Returns: | |
| tuple: A tuple containing: | |
| - detr_result (tuple): The DETR model output consisting of detections' scores and boxes. | |
| - detr_time (float): Time taken to run the DETR model in seconds. | |
| """ | |
| start = timeit.default_timer() | |
| detr_result = await asyncio.to_thread(detr.detect, pil_image) | |
| end = timeit.default_timer() | |
| return detr_result, end - start | |
| # Function to run the GLPN model for depth estimation | |
| async def run_glpn_model(pil_image, img_shape): | |
| """ | |
| Runs the GLPN (Global Local Prediction Network) model to estimate the depth of the objects in the image. | |
| Args: | |
| pil_image (PIL.Image): The image to be processed by the GLPN model. | |
| img_shape (tuple): The shape of the image as (height, width). | |
| Returns: | |
| tuple: A tuple containing: | |
| - depth_map (numpy.ndarray): The depth map for the input image. | |
| - glpn_time (float): Time taken to run the GLPN model in seconds. | |
| """ | |
| start = timeit.default_timer() | |
| depth_map = await asyncio.to_thread(glpn.predict, pil_image, img_shape) | |
| end = timeit.default_timer() | |
| return depth_map, end - start | |
| # Function to process the detections with depth map | |
| async def process_detections(scores, boxes, depth_map): | |
| """ | |
| Processes the DETR model detections and integrates depth information from the GLPN model. | |
| Args: | |
| scores (numpy.ndarray): The detection scores for the detected objects. | |
| boxes (numpy.ndarray): The bounding boxes for the detected objects. | |
| depth_map (numpy.ndarray): The depth map generated by the GLPN model. | |
| Returns: | |
| tuple: A tuple containing: | |
| - pdata (dict): Processed detection data including depth and bounding box information. | |
| - process_time (float): Time taken for processing detections in seconds. | |
| """ | |
| start = timeit.default_timer() | |
| pdata = processing.process_detections(scores, boxes, depth_map, detr) | |
| end = timeit.default_timer() | |
| return pdata, end - start | |
| # Function to generate JSON output for LSTM predictions | |
| async def generate_json_output(data): | |
| """ | |
| Predict Z-location for each object in the data and prepare the JSON output. | |
| Parameters: | |
| - data: DataFrame with bounding box coordinates, depth information, and class type. | |
| - ZlocE: Pre-loaded LSTM model for Z-location prediction. | |
| - scaler: Scaler for normalizing input data. | |
| Returns: | |
| - JSON structure with object class, distance estimated, and relevant features. | |
| """ | |
| output_json = [] | |
| start = timeit.default_timer() | |
| # Iterate over each row in the data | |
| for i, row in data.iterrows(): | |
| # Predict distance for each object using the single-row prediction function | |
| distance = predict_z_location_single_row_lstm(row, zlocE_LSTM, lstm_scaler) | |
| # Create object info dictionary | |
| object_info = { | |
| "class": row["class"], # Object class (e.g., 'car', 'truck') | |
| "distance_estimated": float(distance), # Convert distance to float (if necessary) | |
| "features": { | |
| "xmin": float(row["xmin"]), # Bounding box xmin | |
| "ymin": float(row["ymin"]), # Bounding box ymin | |
| "xmax": float(row["xmax"]), # Bounding box xmax | |
| "ymax": float(row["ymax"]), # Bounding box ymax | |
| "mean_depth": float(row["depth_mean"]), # Depth mean | |
| "depth_mean_trim": float(row["depth_mean_trim"]), # Depth mean trim | |
| "depth_median": float(row["depth_median"]), # Depth median | |
| "width": float(row["width"]), # Object width | |
| "height": float(row["height"]) # Object height | |
| } | |
| } | |
| # Append each object info to the output JSON list | |
| output_json.append(object_info) | |
| end = timeit.default_timer() | |
| # Return the final JSON output structure, and time | |
| return {"objects": output_json}, end - start | |
| # Function to process a single frame (image) in the WebSocket stream | |
| async def process_frame(frame_id, image_bytes): | |
| """ | |
| Processes a single frame (image) from the WebSocket stream. The process includes: | |
| - Decoding the image. | |
| - Running the DETR and GLPN models concurrently. | |
| - Processing detections and generating the final output JSON. | |
| Args: | |
| frame_id (int): The identifier for the frame being processed. | |
| image_bytes (bytes): The image data received from the WebSocket. | |
| Returns: | |
| dict: A dictionary containing the output JSON and timing information for each processing step. | |
| """ | |
| timings = {} | |
| try: | |
| # Step 1: Decode the image | |
| pil_image, img_shape, decode_time = await decode_image(image_bytes) | |
| timings["decode_time"] = decode_time | |
| # Step 2: Run DETR and GLPN models in parallel | |
| (detr_result, detr_time), (depth_map, glpn_time) = await asyncio.gather( | |
| run_detr_model(pil_image), | |
| run_glpn_model(pil_image, img_shape) | |
| ) | |
| models_time = max(detr_time, glpn_time) # Take the longest time of the two models | |
| timings["models_time"] = models_time | |
| # Step 3: Process detections with depth map | |
| scores, boxes = detr_result | |
| pdata, process_time = await process_detections(scores, boxes, depth_map) | |
| timings["process_time"] = process_time | |
| # Step 4: Generate output JSON | |
| print("generate json") | |
| output_json, json_time = await generate_json_output(pdata) | |
| print(output_json) | |
| timings["json_time"] = json_time | |
| timings["total_time"] = decode_time + models_time + process_time + json_time | |
| # Add frame_id and timings to the JSON output | |
| output_json["frame_id"] = frame_id | |
| output_json["timings"] = timings | |
| return output_json | |
| except Exception as e: | |
| return { | |
| "error": str(e), | |
| "frame_id": frame_id, | |
| "timings": timings | |
| } | |
| async def process_image(file: UploadFile = File(...)): | |
| """ | |
| Process a single image for object detection and depth estimation. | |
| The endpoint performs: | |
| - Object detection using DETR model | |
| - Depth estimation using GLPN model | |
| - Z-location prediction using LSTM model | |
| Parameters: | |
| - file: Image file to process (JPEG, PNG) | |
| Returns: | |
| - JSON response with detected objects, estimated distances, and timing information | |
| """ | |
| try: | |
| # Read image content | |
| image_bytes = await file.read() | |
| if not image_bytes: | |
| raise HTTPException(status_code=400, detail="Empty file") | |
| # Use the same processing pipeline as the WebSocket endpoint | |
| result = await process_frame(0, image_bytes) | |
| # Check if there's an error | |
| if "error" in result: | |
| raise HTTPException(status_code=500, detail=result["error"]) | |
| return JSONResponse(content=result) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Add custom OpenAPI documentation | |
| async def custom_swagger_ui_html(): | |
| return get_swagger_ui_html( | |
| openapi_url="/api/openapi.json", | |
| title="Real-Time Image Processing API Documentation", | |
| swagger_js_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui-bundle.js", | |
| swagger_css_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui.css", | |
| ) | |
| async def get_open_api_endpoint(): | |
| return get_openapi( | |
| title="Real-Time Image Processing API", | |
| version="1.0.0", | |
| description="API for object detection, depth estimation, and z-location prediction using computer vision models", | |
| routes=app.routes, | |
| ) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """ | |
| WebSocket endpoint for real-time image processing. Clients can send image frames to be processed | |
| and receive JSON output containing object detection, depth estimation, and predictions in real-time. | |
| - Handles the reception of image data over the WebSocket. | |
| - Calls the image processing pipeline and returns the result. | |
| Args: | |
| websocket (WebSocket): The WebSocket connection to the client. | |
| """ | |
| await websocket.accept() | |
| frame_id = 0 | |
| try: | |
| while True: | |
| frame_id += 1 | |
| # Receive image bytes from the client | |
| image_bytes = await websocket.receive_bytes() | |
| # Process the frame asynchronously | |
| processing_task = asyncio.create_task(process_frame(frame_id, image_bytes)) | |
| result = await processing_task | |
| # Send the result back to the client | |
| await websocket.send_text(json.dumps(result)) | |
| except WebSocketDisconnect: | |
| print(f"Client disconnected after processing {frame_id} frames.") | |
| except Exception as e: | |
| print(f"Unexpected error: {e}") | |
| finally: | |
| await websocket.close() | |