Yassine854 commited on
Commit
97a3613
·
1 Parent(s): 551788c

update 1.0

Browse files
Files changed (2) hide show
  1. Dockerfile +23 -9
  2. app.py +95 -20
Dockerfile CHANGED
@@ -1,6 +1,11 @@
1
  # Use official Python slim image
2
  FROM python:3.11-slim
3
 
 
 
 
 
 
4
  # Create a non-root user
5
  RUN useradd -m -u 1000 user
6
  USER user
@@ -9,16 +14,25 @@ ENV PATH="/home/user/.local/bin:$PATH"
9
  # Set working directory
10
  WORKDIR /app
11
 
12
- # Copy and install dependencies
13
- COPY --chown=user ./requirements.txt requirements.txt
14
- RUN pip install --no-cache-dir --upgrade pip \
15
- && pip install --no-cache-dir -r requirements.txt
 
 
16
 
17
- # Copy app files
18
- COPY --chown=user . /app
19
 
20
- # Expose default HF Spaces port
 
 
 
21
  EXPOSE 7860
22
 
23
- # Run FastAPI app
24
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
 
 
 
 
 
1
  # Use official Python slim image
2
  FROM python:3.11-slim
3
 
4
+ # Install system dependencies
5
+ RUN apt-get update && apt-get install -y \
6
+ curl \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
  # Create a non-root user
10
  RUN useradd -m -u 1000 user
11
  USER user
 
14
  # Set working directory
15
  WORKDIR /app
16
 
17
+ # Copy requirements first (for better caching)
18
+ COPY --chown=user requirements.txt .
19
+
20
+ # Install Python dependencies
21
+ RUN pip install --no-cache-dir --upgrade pip && \
22
+ pip install --no-cache-dir -r requirements.txt
23
 
24
+ # Copy application files
25
+ COPY --chown=user . .
26
 
27
+ # Make sure the model file exists (you'll need to add this)
28
+ # COPY --chown=user best_model.pth .
29
+
30
+ # Expose port (Hugging Face Spaces typically uses 7860)
31
  EXPOSE 7860
32
 
33
+ # Health check
34
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
35
+ CMD curl -f http://localhost:7860/health || exit 1
36
+
37
+ # Run the application
38
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1", "--log-level", "info"]
app.py CHANGED
@@ -1,42 +1,117 @@
 
1
  import os
2
  import tempfile
3
- from fastapi import FastAPI, File, Form, UploadFile
4
  from fastapi.responses import JSONResponse
5
- from predictor import SimilarityPredictor # <-- move your model code to predictor.py
6
 
7
- # Load model once at startup
8
- MODEL_PATH = os.getenv("MODEL_PATH", "best_model.pth")
9
- THRESHOLD = float(os.getenv("THRESHOLD", 0.5))
10
- predictor = SimilarityPredictor(MODEL_PATH, threshold=THRESHOLD)
11
 
12
  app = FastAPI()
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  @app.post("/predict")
15
- async def predict(
16
  file: UploadFile = File(...),
17
  text: str = Form(...)
18
  ):
 
 
 
 
 
 
 
 
 
 
19
  try:
20
- # Save uploaded image temporarily
 
 
 
 
 
 
 
21
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
22
  tmp_path = tmp.name
23
  content = await file.read()
24
  tmp.write(content)
25
-
 
 
26
  # Run prediction
27
  result = predictor.predict_similarity(tmp_path, text, verbose=False)
28
-
29
- # Cleanup temp file
30
- os.remove(tmp_path)
31
-
32
  if result is None:
33
- return JSONResponse({"error": "Prediction failed"}, status_code=500)
34
-
35
- return JSONResponse(result)
 
36
 
 
 
37
  except Exception as e:
38
- return JSONResponse({"error": str(e)}, status_code=500)
 
 
 
 
 
 
 
 
 
39
 
40
- @app.get("/")
41
- def home():
42
- return {"message": "Image-Text Similarity API is running"}
 
 
 
 
 
1
+ # app.py
2
  import os
3
  import tempfile
4
+ from fastapi import FastAPI, File, Form, UploadFile, HTTPException
5
  from fastapi.responses import JSONResponse
6
+ import logging
7
 
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
 
11
 
12
  app = FastAPI()
13
 
14
+ # Global predictor variable
15
+ predictor = None
16
+
17
+ @app.on_event("startup")
18
+ async def startup_event():
19
+ """Initialize predictor on startup"""
20
+ global predictor
21
+ try:
22
+ from predictor import SimilarityPredictor
23
+ MODEL_PATH = os.getenv("MODEL_PATH", "best_model.pth")
24
+ THRESHOLD = float(os.getenv("THRESHOLD", "0.5"))
25
+
26
+ logger.info(f"Loading model from: {MODEL_PATH}")
27
+ logger.info(f"Using threshold: {THRESHOLD}")
28
+
29
+ predictor = SimilarityPredictor(MODEL_PATH, threshold=THRESHOLD)
30
+ logger.info("✅ Model loaded successfully!")
31
+
32
+ except Exception as e:
33
+ logger.error(f"❌ Failed to load model: {e}")
34
+ # Don't raise here - let the app start but handle in endpoints
35
+
36
+ @app.get("/")
37
+ async def root():
38
+ """Root endpoint"""
39
+ return {
40
+ "message": "Image-Text Similarity API is running",
41
+ "endpoints": {
42
+ "GET /": "This endpoint",
43
+ "GET /health": "Health check",
44
+ "POST /predict": "Predict similarity between image and text"
45
+ }
46
+ }
47
+
48
+ @app.get("/health")
49
+ async def health_check():
50
+ """Health check endpoint"""
51
+ global predictor
52
+ return {
53
+ "status": "healthy" if predictor is not None else "model_not_loaded",
54
+ "model_loaded": predictor is not None
55
+ }
56
+
57
  @app.post("/predict")
58
+ async def predict_similarity(
59
  file: UploadFile = File(...),
60
  text: str = Form(...)
61
  ):
62
+ """Predict similarity between uploaded image and text"""
63
+ global predictor
64
+
65
+ if predictor is None:
66
+ raise HTTPException(
67
+ status_code=503,
68
+ detail="Model not loaded. Check logs for initialization errors."
69
+ )
70
+
71
+ tmp_path = None
72
  try:
73
+ # Validate file type
74
+ if not file.content_type.startswith('image/'):
75
+ raise HTTPException(
76
+ status_code=400,
77
+ detail=f"Invalid file type: {file.content_type}. Please upload an image."
78
+ )
79
+
80
+ # Create temporary file
81
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
82
  tmp_path = tmp.name
83
  content = await file.read()
84
  tmp.write(content)
85
+
86
+ logger.info(f"Processing image: {file.filename}, text: {text[:50]}...")
87
+
88
  # Run prediction
89
  result = predictor.predict_similarity(tmp_path, text, verbose=False)
90
+
 
 
 
91
  if result is None:
92
+ raise HTTPException(status_code=500, detail="Prediction failed")
93
+
94
+ logger.info(f"Prediction completed: {result['prediction']}")
95
+ return result
96
 
97
+ except HTTPException:
98
+ raise
99
  except Exception as e:
100
+ logger.error(f"Prediction error: {e}")
101
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
102
+
103
+ finally:
104
+ # Cleanup temp file
105
+ if tmp_path and os.path.exists(tmp_path):
106
+ try:
107
+ os.remove(tmp_path)
108
+ except Exception as e:
109
+ logger.warning(f"Failed to cleanup temp file: {e}")
110
 
111
+ # Add middleware for better error handling
112
+ @app.middleware("http")
113
+ async def log_requests(request, call_next):
114
+ logger.info(f"{request.method} {request.url}")
115
+ response = await call_next(request)
116
+ logger.info(f"Response status: {response.status_code}")
117
+ return response