# app.py import os import tempfile from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import JSONResponse import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Global predictor variable predictor = None @app.on_event("startup") async def startup_event(): """Initialize predictor on startup""" global predictor try: from predictor import SimilarityPredictor MODEL_PATH = os.getenv("MODEL_PATH", "best_model.pth") THRESHOLD = float(os.getenv("THRESHOLD", "0.5")) logger.info(f"Loading model from: {MODEL_PATH}") logger.info(f"Using threshold: {THRESHOLD}") predictor = SimilarityPredictor(MODEL_PATH, threshold=THRESHOLD) logger.info("✅ Model loaded successfully!") except Exception as e: logger.error(f"❌ Failed to load model: {e}") # Don't raise here - let the app start but handle in endpoints @app.get("/") async def root(): """Root endpoint""" return { "message": "Image-Text Similarity API is running", "endpoints": { "GET /": "This endpoint", "GET /health": "Health check", "POST /predict": "Predict similarity between image and text" } } @app.get("/health") async def health_check(): """Health check endpoint""" global predictor return { "status": "healthy" if predictor is not None else "model_not_loaded", "model_loaded": predictor is not None } @app.post("/predict") async def predict_similarity( file: UploadFile = File(...), text: str = Form(...) ): """Predict similarity between uploaded image and text""" global predictor if predictor is None: raise HTTPException( status_code=503, detail="Model not loaded. Check logs for initialization errors." ) tmp_path = None try: # Validate file type if not file.content_type.startswith('image/'): raise HTTPException( status_code=400, detail=f"Invalid file type: {file.content_type}. Please upload an image." ) # Create temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: tmp_path = tmp.name content = await file.read() tmp.write(content) logger.info(f"Processing image: {file.filename}, text: {text[:50]}...") # Run prediction result = predictor.predict_similarity(tmp_path, text, verbose=False) if result is None: raise HTTPException(status_code=500, detail="Prediction failed") logger.info(f"Prediction completed: {result['prediction']}") return result except HTTPException: raise except Exception as e: logger.error(f"Prediction error: {e}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") finally: # Cleanup temp file if tmp_path and os.path.exists(tmp_path): try: os.remove(tmp_path) except Exception as e: logger.warning(f"Failed to cleanup temp file: {e}") # Add middleware for better error handling @app.middleware("http") async def log_requests(request, call_next): logger.info(f"{request.method} {request.url}") response = await call_next(request) logger.info(f"Response status: {response.status_code}") return response