Spaces:
Running
Running
File size: 3,627 Bytes
97a3613 551788c 97a3613 551788c 97a3613 551788c 97a3613 551788c 97a3613 551788c 97a3613 551788c 97a3613 551788c 97a3613 551788c 97a3613 551788c 97a3613 551788c 97a3613 551788c 97a3613 551788c 97a3613 551788c 97a3613 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
# app.py
import os
import tempfile
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Global predictor variable
predictor = None
@app.on_event("startup")
async def startup_event():
"""Initialize predictor on startup"""
global predictor
try:
from predictor import SimilarityPredictor
MODEL_PATH = os.getenv("MODEL_PATH", "best_model.pth")
THRESHOLD = float(os.getenv("THRESHOLD", "0.5"))
logger.info(f"Loading model from: {MODEL_PATH}")
logger.info(f"Using threshold: {THRESHOLD}")
predictor = SimilarityPredictor(MODEL_PATH, threshold=THRESHOLD)
logger.info("✅ Model loaded successfully!")
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
# Don't raise here - let the app start but handle in endpoints
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "Image-Text Similarity API is running",
"endpoints": {
"GET /": "This endpoint",
"GET /health": "Health check",
"POST /predict": "Predict similarity between image and text"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
global predictor
return {
"status": "healthy" if predictor is not None else "model_not_loaded",
"model_loaded": predictor is not None
}
@app.post("/predict")
async def predict_similarity(
file: UploadFile = File(...),
text: str = Form(...)
):
"""Predict similarity between uploaded image and text"""
global predictor
if predictor is None:
raise HTTPException(
status_code=503,
detail="Model not loaded. Check logs for initialization errors."
)
tmp_path = None
try:
# Validate file type
if not file.content_type.startswith('image/'):
raise HTTPException(
status_code=400,
detail=f"Invalid file type: {file.content_type}. Please upload an image."
)
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
tmp_path = tmp.name
content = await file.read()
tmp.write(content)
logger.info(f"Processing image: {file.filename}, text: {text[:50]}...")
# Run prediction
result = predictor.predict_similarity(tmp_path, text, verbose=False)
if result is None:
raise HTTPException(status_code=500, detail="Prediction failed")
logger.info(f"Prediction completed: {result['prediction']}")
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
finally:
# Cleanup temp file
if tmp_path and os.path.exists(tmp_path):
try:
os.remove(tmp_path)
except Exception as e:
logger.warning(f"Failed to cleanup temp file: {e}")
# Add middleware for better error handling
@app.middleware("http")
async def log_requests(request, call_next):
logger.info(f"{request.method} {request.url}")
response = await call_next(request)
logger.info(f"Response status: {response.status_code}")
return response |