Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import os | |
| import io | |
| import base64 | |
| import threading | |
| import traceback | |
| import gc | |
| from typing import Optional | |
| from flask import Flask, request, jsonify, send_from_directory | |
| from PIL import Image | |
| import numpy as np | |
| import requests | |
| import torch | |
| # Set environment variables for CPU-only operation | |
| os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib") | |
| os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig") | |
| os.environ.setdefault("FONTCONFIG_FILE", "/etc/fonts/fonts.conf") | |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") | |
| os.environ.setdefault("OMP_NUM_THREADS", "4") | |
| os.environ.setdefault("MKL_NUM_THREADS", "4") | |
| os.environ.setdefault("OPENBLAS_NUM_THREADS", "4") | |
| # Create writable fontconfig cache | |
| os.makedirs("/tmp/.fontconfig", exist_ok=True) | |
| os.makedirs("/tmp/.matplotlib", exist_ok=True) | |
| # Limit torch threads | |
| try: | |
| torch.set_num_threads(4) | |
| except Exception: | |
| pass | |
| import supervision as sv | |
| from rfdetr import RFDETRSegPreview | |
| app = Flask(__name__, static_folder="static", static_url_path="/") | |
| # Checkpoint URL & local path | |
| CHECKPOINT_URL = "https://huggingface.co/Subh775/Seg-Basil-rfdetr/resolve/main/checkpoint_best_total.pth" | |
| CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth") | |
| MODEL_LOCK = threading.Lock() | |
| MODEL = None | |
| def download_file(url: str, dst: str, chunk_size: int = 8192): | |
| """Download file if not exists""" | |
| if os.path.exists(dst) and os.path.getsize(dst) > 0: | |
| print(f"[INFO] Checkpoint already exists at {dst}") | |
| return dst | |
| print(f"[INFO] Downloading weights from {url} -> {dst}") | |
| try: | |
| r = requests.get(url, stream=True, timeout=180) | |
| r.raise_for_status() | |
| with open(dst, "wb") as fh: | |
| for chunk in r.iter_content(chunk_size=chunk_size): | |
| if chunk: | |
| fh.write(chunk) | |
| print("[INFO] Download complete.") | |
| return dst | |
| except Exception as e: | |
| print(f"[ERROR] Download failed: {e}") | |
| raise | |
| def init_model(): | |
| """Lazily initialize the RF-DETR model and cache it in global MODEL.""" | |
| global MODEL | |
| with MODEL_LOCK: | |
| if MODEL is not None: | |
| print("[INFO] Model already loaded, returning cached instance") | |
| return MODEL | |
| try: | |
| # Ensure checkpoint present | |
| if not os.path.exists(CHECKPOINT_PATH): | |
| print("[INFO] Checkpoint not found, downloading...") | |
| download_file(CHECKPOINT_URL, CHECKPOINT_PATH) | |
| else: | |
| print(f"[INFO] Using existing checkpoint at {CHECKPOINT_PATH}") | |
| print("[INFO] Loading RF-DETR model (CPU mode)...") | |
| MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH) | |
| # Try to optimize for inference | |
| try: | |
| print("[INFO] Optimizing model for inference...") | |
| MODEL.optimize_for_inference() | |
| print("[INFO] Model optimization complete") | |
| except Exception as e: | |
| print(f"[WARN] optimize_for_inference() skipped/failed: {e}") | |
| print("[INFO] Model ready for inference") | |
| return MODEL | |
| except Exception as e: | |
| print(f"[ERROR] Model initialization failed: {e}") | |
| traceback.print_exc() | |
| raise | |
| def decode_data_url(data_url: str) -> Image.Image: | |
| """Decode data URL to PIL Image""" | |
| if data_url.startswith("data:"): | |
| _, b64 = data_url.split(",", 1) | |
| data = base64.b64decode(b64) | |
| else: | |
| try: | |
| data = base64.b64decode(data_url) | |
| except Exception: | |
| raise ValueError("Invalid image data") | |
| return Image.open(io.BytesIO(data)).convert("RGB") | |
| def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str: | |
| """Encode PIL Image to data URL""" | |
| buf = io.BytesIO() | |
| pil_img.save(buf, format=fmt, optimize=False) | |
| buf.seek(0) | |
| return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii") | |
| def annotate_segmentation(image: Image.Image, detections: sv.Detections, | |
| show_labels: bool = True, show_confidence: bool = True) -> Image.Image: | |
| """ | |
| Annotate image with segmentation masks using supervision library. | |
| This matches the visualization from rfdetr_seg_infer.py script. | |
| Args: | |
| image: Input PIL Image | |
| detections: Supervision Detections object | |
| show_labels: Whether to show "Tulsi" label text | |
| show_confidence: Whether to show confidence scores | |
| """ | |
| try: | |
| # Define color palette | |
| palette = sv.ColorPalette.from_hex([ | |
| "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff", | |
| "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00", | |
| ]) | |
| # Calculate optimal text scale based on image resolution | |
| text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size) | |
| print(f"[INFO] Creating annotators with text_scale={text_scale}") | |
| # Create annotators | |
| mask_annotator = sv.MaskAnnotator(color=palette) | |
| polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE) | |
| # Apply base annotations (masks and polygons always shown) | |
| out = image.copy() | |
| print("[INFO] Applying mask annotation...") | |
| out = mask_annotator.annotate(out, detections) | |
| print("[INFO] Applying polygon annotation...") | |
| out = polygon_annotator.annotate(out, detections) | |
| # Only add labels if at least one option is enabled | |
| if show_labels or show_confidence: | |
| label_annotator = sv.LabelAnnotator( | |
| color=palette, | |
| text_color=sv.Color.BLACK, | |
| text_scale=text_scale, | |
| text_position=sv.Position.CENTER_OF_MASS | |
| ) | |
| # Create labels based on options | |
| labels = [] | |
| for conf in detections.confidence: | |
| label_parts = [] | |
| if show_labels: | |
| label_parts.append("Tulsi") | |
| if show_confidence: | |
| label_parts.append(f"{float(conf):.2f}") | |
| labels.append(" ".join(label_parts)) | |
| print(f"[INFO] Applying label annotation with {len(labels)} labels...") | |
| out = label_annotator.annotate(out, detections, labels) | |
| else: | |
| print("[INFO] Skipping label annotation (both labels and confidence disabled)") | |
| print("[INFO] Annotation complete") | |
| return out | |
| except Exception as e: | |
| print(f"[ERROR] Annotation failed: {e}") | |
| traceback.print_exc() | |
| # Return original image if annotation fails | |
| return image | |
| def index(): | |
| """Serve the static UI""" | |
| index_path = os.path.join(app.static_folder or "static", "index.html") | |
| if os.path.exists(index_path): | |
| return send_from_directory(app.static_folder, "index.html") | |
| return jsonify({"message": "RF-DETR Segmentation API is running.", "status": "ready"}) | |
| def health(): | |
| """Health check endpoint""" | |
| model_loaded = MODEL is not None | |
| return jsonify({ | |
| "status": "healthy", | |
| "model_loaded": model_loaded, | |
| "checkpoint_exists": os.path.exists(CHECKPOINT_PATH) | |
| }) | |
| def predict(): | |
| """ | |
| Accepts: | |
| - multipart/form-data with file field "file" | |
| - or JSON {"image": "<data:url...>", "conf": 0.05, "show_labels": true, "show_confidence": true} | |
| Returns JSON: | |
| {"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N} | |
| """ | |
| print("\n[INFO] ========== New prediction request ==========") | |
| try: | |
| print("[INFO] Initializing model...") | |
| model = init_model() | |
| print("[INFO] Model ready") | |
| except Exception as e: | |
| error_msg = f"Model initialization failed: {e}" | |
| print(f"[ERROR] {error_msg}") | |
| return jsonify({"error": error_msg}), 500 | |
| # Parse input | |
| img: Optional[Image.Image] = None | |
| conf_threshold = 0.05 | |
| show_labels = True | |
| show_confidence = True | |
| # Check if file uploaded | |
| if "file" in request.files: | |
| file = request.files["file"] | |
| print(f"[INFO] Processing uploaded file: {file.filename}") | |
| try: | |
| img = Image.open(file.stream).convert("RGB") | |
| except Exception as e: | |
| error_msg = f"Invalid uploaded image: {e}" | |
| print(f"[ERROR] {error_msg}") | |
| return jsonify({"error": error_msg}), 400 | |
| conf_threshold = float(request.form.get("conf", conf_threshold)) | |
| show_labels = request.form.get("show_labels", "true").lower() == "true" | |
| show_confidence = request.form.get("show_confidence", "true").lower() == "true" | |
| else: | |
| # Try JSON payload | |
| payload = request.get_json(silent=True) | |
| if not payload or "image" not in payload: | |
| return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400 | |
| try: | |
| print("[INFO] Decoding image from data URL...") | |
| img = decode_data_url(payload["image"]) | |
| except Exception as e: | |
| error_msg = f"Invalid image data: {e}" | |
| print(f"[ERROR] {error_msg}") | |
| return jsonify({"error": error_msg}), 400 | |
| conf_threshold = float(payload.get("conf", conf_threshold)) | |
| show_labels = payload.get("show_labels", True) | |
| show_confidence = payload.get("show_confidence", True) | |
| print(f"[INFO] Image size: {img.size}, Confidence threshold: {conf_threshold}") | |
| print(f"[INFO] Display options - Labels: {show_labels}, Confidence: {show_confidence}") | |
| # Optionally downscale large images to reduce memory usage | |
| MAX_SIZE = 1024 | |
| if max(img.size) > MAX_SIZE: | |
| w, h = img.size | |
| scale = MAX_SIZE / float(max(w, h)) | |
| new_w, new_h = int(round(w * scale)), int(round(h * scale)) | |
| print(f"[INFO] Resizing image from {w}x{h} to {new_w}x{new_h}") | |
| img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
| # Run inference with no_grad for memory efficiency | |
| try: | |
| print("[INFO] Running inference...") | |
| with torch.no_grad(): | |
| detections = model.predict(img, threshold=conf_threshold) | |
| print(f"[INFO] Raw detections: {len(detections)} objects") | |
| # Check if detections exist | |
| if len(detections) == 0 or not hasattr(detections, 'confidence') or len(detections.confidence) == 0: | |
| print("[INFO] No detections above threshold") | |
| # Return original image | |
| data_url = encode_pil_to_dataurl(img, fmt="PNG") | |
| return jsonify({ | |
| "annotated": data_url, | |
| "confidences": [], | |
| "count": 0 | |
| }) | |
| print(f"[INFO] Detections have {len(detections.confidence)} confidence scores") | |
| print(f"[INFO] Confidence range: {min(detections.confidence):.3f} - {max(detections.confidence):.3f}") | |
| # Check if masks exist | |
| if hasattr(detections, 'masks') and detections.masks is not None: | |
| print(f"[INFO] Masks present: shape={np.array(detections.masks).shape if hasattr(detections.masks, '__len__') else 'unknown'}") | |
| else: | |
| print("[WARN] No masks found in detections!") | |
| # Annotate image using supervision library | |
| print("[INFO] Starting annotation...") | |
| annotated_pil = annotate_segmentation(img, detections, show_labels, show_confidence) | |
| # Extract confidence scores | |
| confidences = [float(conf) for conf in detections.confidence] | |
| print(f"[INFO] Final confidences: {confidences}") | |
| # Encode to data URL | |
| print("[INFO] Encoding annotated image...") | |
| data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG") | |
| # Clean up | |
| del detections | |
| gc.collect() | |
| print(f"[INFO] ========== Prediction complete: {len(confidences)} leaves detected ==========\n") | |
| return jsonify({ | |
| "annotated": data_url, | |
| "confidences": confidences, | |
| "count": len(confidences) | |
| }) | |
| except Exception as e: | |
| error_msg = f"Inference failed: {e}" | |
| print(f"[ERROR] {error_msg}") | |
| traceback.print_exc() | |
| return jsonify({"error": error_msg}), 500 | |
| if __name__ == "__main__": | |
| print("\n" + "="*60) | |
| print("Starting Tulsi Leaf Segmentation Server") | |
| print("="*60 + "\n") | |
| # Warm model in background thread | |
| def warm(): | |
| try: | |
| print("[INFO] Starting model warmup in background...") | |
| init_model() | |
| print("[INFO] β Model warmup complete - ready for predictions") | |
| except Exception as e: | |
| print(f"[ERROR] β Model warmup failed: {e}") | |
| traceback.print_exc() | |
| threading.Thread(target=warm, daemon=True).start() | |
| # Run Flask app | |
| app.run( | |
| host="0.0.0.0", | |
| port=int(os.environ.get("PORT", 7860)), | |
| debug=False | |
| ) |