Spaces:
Paused
Paused
| """ | |
| FastAPI server for OpenAI Realtime API integration with RAG system. | |
| Provides endpoints for session management and RAG tool calls. | |
| Directory structure: | |
| /data/ # Original PDFs, HTML | |
| /embeddings/ # FAISS, Chroma, DPR vector stores | |
| /graph/ # Graph database files | |
| /metadata/ # Image metadata (SQLite or MongoDB) | |
| """ | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from typing import Dict, Any, Optional | |
| from fastapi import FastAPI, HTTPException, Request, Response, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from fastapi.exceptions import RequestValidationError | |
| from starlette.exceptions import HTTPException as StarletteHTTPException | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from openai import OpenAI | |
| # Import all query modules | |
| from query_graph import query as graph_query | |
| from query_vanilla import query as vanilla_query | |
| from query_dpr import query as dpr_query | |
| from query_bm25 import query as bm25_query | |
| from query_context import query as context_query | |
| from query_vision import query as vision_query | |
| from config import OPENAI_API_KEY, OPENAI_CHAT_MODEL, OPENAI_REALTIME_MODEL, REALTIME_VOICE, REALTIME_INSTRUCTIONS, DEFAULT_METHOD | |
| from analytics_db import log_query | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="SIGHT Realtime API Server", version="1.0.0") | |
| # CORS middleware for frontend integration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, restrict to your domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def log_requests(request: Request, call_next): | |
| """Log all incoming requests for debugging.""" | |
| logger.info(f"Incoming request: {request.method} {request.url}") | |
| try: | |
| response = await call_next(request) | |
| logger.info(f"Response status: {response.status_code}") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Request processing error: {e}") | |
| return JSONResponse( | |
| content={"error": "Internal server error"}, | |
| status_code=500 | |
| ) | |
| # Exception handlers | |
| async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
| logger.warning(f"Validation error for {request.url}: {exc}") | |
| return JSONResponse( | |
| status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
| content={"error": "Invalid request format", "details": str(exc)} | |
| ) | |
| async def http_exception_handler(request: Request, exc: StarletteHTTPException): | |
| logger.warning(f"HTTP error for {request.url}: {exc.status_code} - {exc.detail}") | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"error": exc.detail} | |
| ) | |
| async def general_exception_handler(request: Request, exc: Exception): | |
| logger.error(f"Unhandled error for {request.url}: {exc}") | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={"error": "Internal server error"} | |
| ) | |
| # Initialize OpenAI client | |
| client = OpenAI(api_key=OPENAI_API_KEY) | |
| # Query method dispatch | |
| QUERY_DISPATCH = { | |
| 'graph': graph_query, | |
| 'vanilla': vanilla_query, | |
| 'dpr': dpr_query, | |
| 'bm25': bm25_query, | |
| 'context': context_query, | |
| 'vision': vision_query | |
| } | |
| # Use configuration from config.py with environment variable overrides | |
| REALTIME_MODEL = os.getenv("REALTIME_MODEL", OPENAI_REALTIME_MODEL) | |
| VOICE = os.getenv("REALTIME_VOICE", REALTIME_VOICE) | |
| INSTRUCTIONS = os.getenv("REALTIME_INSTRUCTIONS", REALTIME_INSTRUCTIONS) | |
| # Pydantic models for request/response | |
| class SessionRequest(BaseModel): | |
| """Request model for creating ephemeral sessions.""" | |
| model: Optional[str] = "gpt-4o-realtime-preview" | |
| instructions: Optional[str] = None | |
| voice: Optional[str] = None | |
| class RAGRequest(BaseModel): | |
| """Request model for RAG queries.""" | |
| query: str | |
| method: str = "graph" | |
| top_k: int = 5 | |
| image_path: Optional[str] = None | |
| class RAGResponse(BaseModel): | |
| """Response model for RAG queries.""" | |
| answer: str | |
| citations: list | |
| method: str | |
| citations_html: Optional[str] = None | |
| async def create_ephemeral_session(request: SessionRequest) -> JSONResponse: | |
| """ | |
| Create an ephemeral session token for OpenAI Realtime API. | |
| This token will be used by the frontend WebRTC client. | |
| """ | |
| try: | |
| logger.info(f"Creating ephemeral session with model: {request.model or REALTIME_MODEL}") | |
| # Create ephemeral token using direct HTTP call to OpenAI API | |
| # Since the Python SDK doesn't support realtime sessions yet | |
| import requests | |
| session_data = { | |
| "model": request.model or REALTIME_MODEL, | |
| "voice": request.voice or VOICE, | |
| "modalities": ["audio", "text"], | |
| "instructions": request.instructions or INSTRUCTIONS, | |
| } | |
| headers = { | |
| "Authorization": f"Bearer {OPENAI_API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| # Make direct HTTP request to OpenAI's realtime sessions endpoint | |
| response = requests.post( | |
| "https://api.openai.com/v1/realtime/sessions", | |
| json=session_data, | |
| headers=headers, | |
| timeout=30 | |
| ) | |
| if response.status_code == 200: | |
| session_result = response.json() | |
| response_data = { | |
| "client_secret": session_result.get("client_secret", {}).get("value") or session_result.get("client_secret"), | |
| "model": request.model or REALTIME_MODEL, | |
| "session_id": session_result.get("id") | |
| } | |
| logger.info("Ephemeral session created successfully") | |
| return JSONResponse(content=response_data, status_code=200) | |
| else: | |
| logger.error(f"OpenAI API error: {response.status_code} - {response.text}") | |
| return JSONResponse( | |
| content={"error": f"OpenAI API error: {response.status_code} - {response.text}"}, | |
| status_code=response.status_code | |
| ) | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Network error creating ephemeral session: {e}") | |
| return JSONResponse( | |
| content={"error": f"Network error: {str(e)}"}, | |
| status_code=500 | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error creating ephemeral session: {e}") | |
| return JSONResponse( | |
| content={"error": f"Session creation failed: {str(e)}"}, | |
| status_code=500 | |
| ) | |
| async def rag_query(request: RAGRequest) -> RAGResponse: | |
| """ | |
| Handle RAG queries from the realtime interface. | |
| This endpoint is called by the JavaScript frontend when the model | |
| requests the ask_rag function. | |
| """ | |
| try: | |
| logger.info(f"RAG query: {request.query} using method: {request.method}") | |
| # Validate and default method if needed | |
| method = request.method | |
| if method not in QUERY_DISPATCH: | |
| logger.warning(f"Invalid method '{method}', using default '{DEFAULT_METHOD}'") | |
| method = DEFAULT_METHOD | |
| # Get the appropriate query function | |
| query_func = QUERY_DISPATCH[method] | |
| # Execute the query | |
| start_time = time.time() | |
| answer, citations = query_func( | |
| question=request.query, | |
| image_path=request.image_path, | |
| top_k=request.top_k | |
| ) | |
| response_time = (time.time() - start_time) * 1000 # Convert to ms | |
| # Format citations for HTML display (optional) | |
| citations_html = format_citations_html(citations, method) | |
| # Log to analytics database (mark as voice interaction) | |
| try: | |
| # Generate unique session ID for each voice interaction | |
| import uuid | |
| voice_session_id = f"voice_{uuid.uuid4().hex[:8]}" | |
| log_query( | |
| user_query=request.query, | |
| method=method, | |
| answer=answer, | |
| citations=citations, | |
| response_time=response_time, | |
| image_path=request.image_path, | |
| top_k=request.top_k, | |
| session_id=voice_session_id, | |
| additional_settings={'voice_interaction': True, 'interaction_type': 'speech_to_speech'} | |
| ) | |
| logger.info(f"Voice interaction logged: {request.query[:50]}...") | |
| except Exception as log_error: | |
| logger.error(f"Failed to log voice query: {log_error}") | |
| logger.info(f"RAG query completed: {len(answer)} chars, {len(citations)} citations") | |
| return RAGResponse( | |
| answer=answer, | |
| citations=citations, | |
| method=method, | |
| citations_html=citations_html | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing RAG query: {e}") | |
| raise HTTPException(status_code=500, detail=f"RAG query failed: {str(e)}") | |
| def format_citations_html(citations: list, method: str) -> str: | |
| """Format citations as HTML for display.""" | |
| if not citations: | |
| return "<p><em>No citations available</em></p>" | |
| html_parts = ["<div style='margin-top: 1em;'><strong>Sources:</strong><ul>"] | |
| for citation in citations: | |
| if isinstance(citation, dict) and 'source' in citation: | |
| source = citation['source'] | |
| cite_type = citation.get('type', 'unknown') | |
| # Build citation text based on type | |
| if cite_type == 'pdf': | |
| cite_text = f"π {source} (PDF)" | |
| elif cite_type == 'html': | |
| url = citation.get('url', '') | |
| if url: | |
| cite_text = f"π <a href='{url}' target='_blank'>{source}</a> (Web)" | |
| else: | |
| cite_text = f"π {source} (Web)" | |
| elif cite_type == 'image': | |
| page = citation.get('page', 'N/A') | |
| cite_text = f"πΌοΈ {source} (Image, page {page})" | |
| else: | |
| cite_text = f"π {source}" | |
| # Add scores if available | |
| scores = [] | |
| if 'relevance_score' in citation: | |
| scores.append(f"relevance: {citation['relevance_score']:.3f}") | |
| if 'score' in citation: | |
| scores.append(f"score: {citation['score']:.3f}") | |
| if scores: | |
| cite_text += f" <small>({', '.join(scores)})</small>" | |
| html_parts.append(f"<li>{cite_text}</li>") | |
| elif isinstance(citation, (list, tuple)) and len(citation) >= 4: | |
| # Handle legacy citation format (header, score, text, source) | |
| header, score, text, source = citation[:4] | |
| cite_text = f"π {source} <small>(score: {score:.3f})</small>" | |
| html_parts.append(f"<li>{cite_text}</li>") | |
| html_parts.append("</ul></div>") | |
| return "".join(html_parts) | |
| async def root(): | |
| """Root endpoint to prevent invalid HTTP request warnings.""" | |
| return { | |
| "service": "SIGHT Realtime API Server", | |
| "version": "1.0.0", | |
| "status": "running", | |
| "endpoints": { | |
| "session": "POST /session - Create realtime session", | |
| "rag": "POST /rag - Query RAG system", | |
| "health": "GET /health - Health check", | |
| "methods": "GET /methods - List available RAG methods" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy", "service": "SIGHT Realtime API Server"} | |
| async def list_methods(): | |
| """List available RAG methods.""" | |
| return { | |
| "methods": list(QUERY_DISPATCH.keys()), | |
| "descriptions": { | |
| 'graph': "Graph-based RAG using NetworkX with relationship-aware retrieval", | |
| 'vanilla': "Standard vector search with FAISS and OpenAI embeddings", | |
| 'dpr': "Dense Passage Retrieval with bi-encoder and cross-encoder re-ranking", | |
| 'bm25': "BM25 keyword search with neural re-ranking for exact term matching", | |
| 'context': "Context stuffing with full document loading and heuristic selection", | |
| 'vision': "Vision-based search using GPT-5 Vision for image analysis" | |
| } | |
| } | |
| async def options_handler(request: Request, response: Response): | |
| """Handle CORS preflight requests.""" | |
| response.headers["Access-Control-Allow-Origin"] = "*" | |
| response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" | |
| response.headers["Access-Control-Allow-Headers"] = "*" | |
| return response | |
| if __name__ == "__main__": | |
| import argparse | |
| # Parse command line arguments | |
| parser = argparse.ArgumentParser(description="SIGHT Realtime API Server") | |
| parser.add_argument("--https", action="store_true", help="Enable HTTPS with self-signed certificate") | |
| parser.add_argument("--port", type=int, default=5050, help="Port to run the server on") | |
| parser.add_argument("--host", default="0.0.0.0", help="Host to bind the server to") | |
| args = parser.parse_args() | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| # Suppress uvicorn access logs for cleaner output | |
| uvicorn_logger = logging.getLogger("uvicorn.access") | |
| uvicorn_logger.setLevel(logging.WARNING) | |
| # Prepare uvicorn configuration | |
| uvicorn_config = { | |
| "app": "realtime_server:app", | |
| "host": args.host, | |
| "port": args.port, | |
| "reload": True, | |
| "log_level": "warning", | |
| "access_log": False | |
| } | |
| # Add SSL configuration if HTTPS is requested | |
| if args.https: | |
| logger.info("Starting server with HTTPS (self-signed certificate)") | |
| logger.warning("β οΈ Self-signed certificate will show security warnings in browser") | |
| logger.info("For production, use a proper SSL certificate from a CA") | |
| # Note: You would need to generate SSL certificates | |
| # For development, you can create self-signed certificates: | |
| # openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes | |
| uvicorn_config.update({ | |
| "ssl_keyfile": "key.pem", | |
| "ssl_certfile": "cert.pem" | |
| }) | |
| print(f"π Starting HTTPS server on https://{args.host}:{args.port}") | |
| print("π To generate self-signed certificates, run:") | |
| print(" openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes") | |
| else: | |
| print(f"π Starting HTTP server on http://{args.host}:{args.port}") | |
| print("β οΈ HTTP only works for localhost. Use --https for production deployment.") | |
| # Run the server | |
| uvicorn.run(**uvicorn_config) | |