ML / app.py
Yassine854's picture
update 1.0
97a3613
# app.py
import os
import tempfile
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Global predictor variable
predictor = None
@app.on_event("startup")
async def startup_event():
"""Initialize predictor on startup"""
global predictor
try:
from predictor import SimilarityPredictor
MODEL_PATH = os.getenv("MODEL_PATH", "best_model.pth")
THRESHOLD = float(os.getenv("THRESHOLD", "0.5"))
logger.info(f"Loading model from: {MODEL_PATH}")
logger.info(f"Using threshold: {THRESHOLD}")
predictor = SimilarityPredictor(MODEL_PATH, threshold=THRESHOLD)
logger.info("βœ… Model loaded successfully!")
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
# Don't raise here - let the app start but handle in endpoints
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "Image-Text Similarity API is running",
"endpoints": {
"GET /": "This endpoint",
"GET /health": "Health check",
"POST /predict": "Predict similarity between image and text"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
global predictor
return {
"status": "healthy" if predictor is not None else "model_not_loaded",
"model_loaded": predictor is not None
}
@app.post("/predict")
async def predict_similarity(
file: UploadFile = File(...),
text: str = Form(...)
):
"""Predict similarity between uploaded image and text"""
global predictor
if predictor is None:
raise HTTPException(
status_code=503,
detail="Model not loaded. Check logs for initialization errors."
)
tmp_path = None
try:
# Validate file type
if not file.content_type.startswith('image/'):
raise HTTPException(
status_code=400,
detail=f"Invalid file type: {file.content_type}. Please upload an image."
)
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
tmp_path = tmp.name
content = await file.read()
tmp.write(content)
logger.info(f"Processing image: {file.filename}, text: {text[:50]}...")
# Run prediction
result = predictor.predict_similarity(tmp_path, text, verbose=False)
if result is None:
raise HTTPException(status_code=500, detail="Prediction failed")
logger.info(f"Prediction completed: {result['prediction']}")
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
finally:
# Cleanup temp file
if tmp_path and os.path.exists(tmp_path):
try:
os.remove(tmp_path)
except Exception as e:
logger.warning(f"Failed to cleanup temp file: {e}")
# Add middleware for better error handling
@app.middleware("http")
async def log_requests(request, call_next):
logger.info(f"{request.method} {request.url}")
response = await call_next(request)
logger.info(f"Response status: {response.status_code}")
return response