File size: 3,627 Bytes
97a3613
551788c
 
97a3613
551788c
97a3613
551788c
97a3613
 
 
551788c
 
 
97a3613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551788c
97a3613
551788c
 
 
97a3613
 
 
 
 
 
 
 
 
 
551788c
97a3613
 
 
 
 
 
 
 
551788c
 
 
 
97a3613
 
 
551788c
 
97a3613
551788c
97a3613
 
 
 
551788c
97a3613
 
551788c
97a3613
 
 
 
 
 
 
 
 
 
551788c
97a3613
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# app.py
import os
import tempfile
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Global predictor variable
predictor = None

@app.on_event("startup")
async def startup_event():
    """Initialize predictor on startup"""
    global predictor
    try:
        from predictor import SimilarityPredictor
        MODEL_PATH = os.getenv("MODEL_PATH", "best_model.pth")
        THRESHOLD = float(os.getenv("THRESHOLD", "0.5"))
        
        logger.info(f"Loading model from: {MODEL_PATH}")
        logger.info(f"Using threshold: {THRESHOLD}")
        
        predictor = SimilarityPredictor(MODEL_PATH, threshold=THRESHOLD)
        logger.info("✅ Model loaded successfully!")
        
    except Exception as e:
        logger.error(f"❌ Failed to load model: {e}")
        # Don't raise here - let the app start but handle in endpoints

@app.get("/")
async def root():
    """Root endpoint"""
    return {
        "message": "Image-Text Similarity API is running",
        "endpoints": {
            "GET /": "This endpoint",
            "GET /health": "Health check",
            "POST /predict": "Predict similarity between image and text"
        }
    }

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    global predictor
    return {
        "status": "healthy" if predictor is not None else "model_not_loaded",
        "model_loaded": predictor is not None
    }

@app.post("/predict")
async def predict_similarity(
    file: UploadFile = File(...),
    text: str = Form(...)
):
    """Predict similarity between uploaded image and text"""
    global predictor
    
    if predictor is None:
        raise HTTPException(
            status_code=503, 
            detail="Model not loaded. Check logs for initialization errors."
        )
    
    tmp_path = None
    try:
        # Validate file type
        if not file.content_type.startswith('image/'):
            raise HTTPException(
                status_code=400, 
                detail=f"Invalid file type: {file.content_type}. Please upload an image."
            )
        
        # Create temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
            tmp_path = tmp.name
            content = await file.read()
            tmp.write(content)
        
        logger.info(f"Processing image: {file.filename}, text: {text[:50]}...")
        
        # Run prediction
        result = predictor.predict_similarity(tmp_path, text, verbose=False)
        
        if result is None:
            raise HTTPException(status_code=500, detail="Prediction failed")
        
        logger.info(f"Prediction completed: {result['prediction']}")
        return result

    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Prediction error: {e}")
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
    
    finally:
        # Cleanup temp file
        if tmp_path and os.path.exists(tmp_path):
            try:
                os.remove(tmp_path)
            except Exception as e:
                logger.warning(f"Failed to cleanup temp file: {e}")

# Add middleware for better error handling
@app.middleware("http")
async def log_requests(request, call_next):
    logger.info(f"{request.method} {request.url}")
    response = await call_next(request)
    logger.info(f"Response status: {response.status_code}")
    return response