Spaces:
Running
Running
| # app.py | |
| import os | |
| import tempfile | |
| from fastapi import FastAPI, File, Form, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # Global predictor variable | |
| predictor = None | |
| async def startup_event(): | |
| """Initialize predictor on startup""" | |
| global predictor | |
| try: | |
| from predictor import SimilarityPredictor | |
| MODEL_PATH = os.getenv("MODEL_PATH", "best_model.pth") | |
| THRESHOLD = float(os.getenv("THRESHOLD", "0.5")) | |
| logger.info(f"Loading model from: {MODEL_PATH}") | |
| logger.info(f"Using threshold: {THRESHOLD}") | |
| predictor = SimilarityPredictor(MODEL_PATH, threshold=THRESHOLD) | |
| logger.info("β Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"β Failed to load model: {e}") | |
| # Don't raise here - let the app start but handle in endpoints | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "Image-Text Similarity API is running", | |
| "endpoints": { | |
| "GET /": "This endpoint", | |
| "GET /health": "Health check", | |
| "POST /predict": "Predict similarity between image and text" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| global predictor | |
| return { | |
| "status": "healthy" if predictor is not None else "model_not_loaded", | |
| "model_loaded": predictor is not None | |
| } | |
| async def predict_similarity( | |
| file: UploadFile = File(...), | |
| text: str = Form(...) | |
| ): | |
| """Predict similarity between uploaded image and text""" | |
| global predictor | |
| if predictor is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model not loaded. Check logs for initialization errors." | |
| ) | |
| tmp_path = None | |
| try: | |
| # Validate file type | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid file type: {file.content_type}. Please upload an image." | |
| ) | |
| # Create temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| tmp_path = tmp.name | |
| content = await file.read() | |
| tmp.write(content) | |
| logger.info(f"Processing image: {file.filename}, text: {text[:50]}...") | |
| # Run prediction | |
| result = predictor.predict_similarity(tmp_path, text, verbose=False) | |
| if result is None: | |
| raise HTTPException(status_code=500, detail="Prediction failed") | |
| logger.info(f"Prediction completed: {result['prediction']}") | |
| return result | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| finally: | |
| # Cleanup temp file | |
| if tmp_path and os.path.exists(tmp_path): | |
| try: | |
| os.remove(tmp_path) | |
| except Exception as e: | |
| logger.warning(f"Failed to cleanup temp file: {e}") | |
| # Add middleware for better error handling | |
| async def log_requests(request, call_next): | |
| logger.info(f"{request.method} {request.url}") | |
| response = await call_next(request) | |
| logger.info(f"Response status: {response.status_code}") | |
| return response |