import os import io import base64 import threading import traceback import gc from typing import Optional from flask import Flask, request, jsonify, send_from_directory from PIL import Image import numpy as np import requests import torch # Set environment variables for CPU-only operation os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib") os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig") os.environ.setdefault("FONTCONFIG_FILE", "/etc/fonts/fonts.conf") os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") os.environ.setdefault("OMP_NUM_THREADS", "4") os.environ.setdefault("MKL_NUM_THREADS", "4") os.environ.setdefault("OPENBLAS_NUM_THREADS", "4") # Create writable fontconfig cache os.makedirs("/tmp/.fontconfig", exist_ok=True) os.makedirs("/tmp/.matplotlib", exist_ok=True) # Limit torch threads try: torch.set_num_threads(4) except Exception: pass import supervision as sv from rfdetr import RFDETRSegPreview app = Flask(__name__, static_folder="static", static_url_path="/") # Checkpoint URL & local path CHECKPOINT_URL = "https://huggingface.co/Subh775/Seg-Basil-rfdetr/resolve/main/checkpoint_best_total.pth" CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth") MODEL_LOCK = threading.Lock() MODEL = None def download_file(url: str, dst: str, chunk_size: int = 8192): """Download file if not exists""" if os.path.exists(dst) and os.path.getsize(dst) > 0: print(f"[INFO] Checkpoint already exists at {dst}") return dst print(f"[INFO] Downloading weights from {url} -> {dst}") try: r = requests.get(url, stream=True, timeout=180) r.raise_for_status() with open(dst, "wb") as fh: for chunk in r.iter_content(chunk_size=chunk_size): if chunk: fh.write(chunk) print("[INFO] Download complete.") return dst except Exception as e: print(f"[ERROR] Download failed: {e}") raise def init_model(): """Lazily initialize the RF-DETR model and cache it in global MODEL.""" global MODEL with MODEL_LOCK: if MODEL is not None: print("[INFO] Model already loaded, returning cached instance") return MODEL try: # Ensure checkpoint present if not os.path.exists(CHECKPOINT_PATH): print("[INFO] Checkpoint not found, downloading...") download_file(CHECKPOINT_URL, CHECKPOINT_PATH) else: print(f"[INFO] Using existing checkpoint at {CHECKPOINT_PATH}") print("[INFO] Loading RF-DETR model (CPU mode)...") MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH) # Try to optimize for inference try: print("[INFO] Optimizing model for inference...") MODEL.optimize_for_inference() print("[INFO] Model optimization complete") except Exception as e: print(f"[WARN] optimize_for_inference() skipped/failed: {e}") print("[INFO] Model ready for inference") return MODEL except Exception as e: print(f"[ERROR] Model initialization failed: {e}") traceback.print_exc() raise def decode_data_url(data_url: str) -> Image.Image: """Decode data URL to PIL Image""" if data_url.startswith("data:"): _, b64 = data_url.split(",", 1) data = base64.b64decode(b64) else: try: data = base64.b64decode(data_url) except Exception: raise ValueError("Invalid image data") return Image.open(io.BytesIO(data)).convert("RGB") def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str: """Encode PIL Image to data URL""" buf = io.BytesIO() pil_img.save(buf, format=fmt, optimize=False) buf.seek(0) return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii") def annotate_segmentation(image: Image.Image, detections: sv.Detections, show_labels: bool = True, show_confidence: bool = True) -> Image.Image: """ Annotate image with segmentation masks using supervision library. This matches the visualization from rfdetr_seg_infer.py script. Args: image: Input PIL Image detections: Supervision Detections object show_labels: Whether to show "Tulsi" label text show_confidence: Whether to show confidence scores """ try: # Define color palette palette = sv.ColorPalette.from_hex([ "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff", "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00", ]) # Calculate optimal text scale based on image resolution text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size) print(f"[INFO] Creating annotators with text_scale={text_scale}") # Create annotators mask_annotator = sv.MaskAnnotator(color=palette) polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE) # Apply base annotations (masks and polygons always shown) out = image.copy() print("[INFO] Applying mask annotation...") out = mask_annotator.annotate(out, detections) print("[INFO] Applying polygon annotation...") out = polygon_annotator.annotate(out, detections) # Only add labels if at least one option is enabled if show_labels or show_confidence: label_annotator = sv.LabelAnnotator( color=palette, text_color=sv.Color.BLACK, text_scale=text_scale, text_position=sv.Position.CENTER_OF_MASS ) # Create labels based on options labels = [] for conf in detections.confidence: label_parts = [] if show_labels: label_parts.append("Tulsi") if show_confidence: label_parts.append(f"{float(conf):.2f}") labels.append(" ".join(label_parts)) print(f"[INFO] Applying label annotation with {len(labels)} labels...") out = label_annotator.annotate(out, detections, labels) else: print("[INFO] Skipping label annotation (both labels and confidence disabled)") print("[INFO] Annotation complete") return out except Exception as e: print(f"[ERROR] Annotation failed: {e}") traceback.print_exc() # Return original image if annotation fails return image @app.route("/", methods=["GET"]) def index(): """Serve the static UI""" index_path = os.path.join(app.static_folder or "static", "index.html") if os.path.exists(index_path): return send_from_directory(app.static_folder, "index.html") return jsonify({"message": "RF-DETR Segmentation API is running.", "status": "ready"}) @app.route("/health", methods=["GET"]) def health(): """Health check endpoint""" model_loaded = MODEL is not None return jsonify({ "status": "healthy", "model_loaded": model_loaded, "checkpoint_exists": os.path.exists(CHECKPOINT_PATH) }) @app.route("/predict", methods=["POST"]) def predict(): """ Accepts: - multipart/form-data with file field "file" - or JSON {"image": "", "conf": 0.05, "show_labels": true, "show_confidence": true} Returns JSON: {"annotated": "", "confidences": [..], "count": N} """ print("\n[INFO] ========== New prediction request ==========") try: print("[INFO] Initializing model...") model = init_model() print("[INFO] Model ready") except Exception as e: error_msg = f"Model initialization failed: {e}" print(f"[ERROR] {error_msg}") return jsonify({"error": error_msg}), 500 # Parse input img: Optional[Image.Image] = None conf_threshold = 0.05 show_labels = True show_confidence = True # Check if file uploaded if "file" in request.files: file = request.files["file"] print(f"[INFO] Processing uploaded file: {file.filename}") try: img = Image.open(file.stream).convert("RGB") except Exception as e: error_msg = f"Invalid uploaded image: {e}" print(f"[ERROR] {error_msg}") return jsonify({"error": error_msg}), 400 conf_threshold = float(request.form.get("conf", conf_threshold)) show_labels = request.form.get("show_labels", "true").lower() == "true" show_confidence = request.form.get("show_confidence", "true").lower() == "true" else: # Try JSON payload payload = request.get_json(silent=True) if not payload or "image" not in payload: return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400 try: print("[INFO] Decoding image from data URL...") img = decode_data_url(payload["image"]) except Exception as e: error_msg = f"Invalid image data: {e}" print(f"[ERROR] {error_msg}") return jsonify({"error": error_msg}), 400 conf_threshold = float(payload.get("conf", conf_threshold)) show_labels = payload.get("show_labels", True) show_confidence = payload.get("show_confidence", True) print(f"[INFO] Image size: {img.size}, Confidence threshold: {conf_threshold}") print(f"[INFO] Display options - Labels: {show_labels}, Confidence: {show_confidence}") # Optionally downscale large images to reduce memory usage MAX_SIZE = 1024 if max(img.size) > MAX_SIZE: w, h = img.size scale = MAX_SIZE / float(max(w, h)) new_w, new_h = int(round(w * scale)), int(round(h * scale)) print(f"[INFO] Resizing image from {w}x{h} to {new_w}x{new_h}") img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) # Run inference with no_grad for memory efficiency try: print("[INFO] Running inference...") with torch.no_grad(): detections = model.predict(img, threshold=conf_threshold) print(f"[INFO] Raw detections: {len(detections)} objects") # Check if detections exist if len(detections) == 0 or not hasattr(detections, 'confidence') or len(detections.confidence) == 0: print("[INFO] No detections above threshold") # Return original image data_url = encode_pil_to_dataurl(img, fmt="PNG") return jsonify({ "annotated": data_url, "confidences": [], "count": 0 }) print(f"[INFO] Detections have {len(detections.confidence)} confidence scores") print(f"[INFO] Confidence range: {min(detections.confidence):.3f} - {max(detections.confidence):.3f}") # Check if masks exist if hasattr(detections, 'masks') and detections.masks is not None: print(f"[INFO] Masks present: shape={np.array(detections.masks).shape if hasattr(detections.masks, '__len__') else 'unknown'}") else: print("[WARN] No masks found in detections!") # Annotate image using supervision library print("[INFO] Starting annotation...") annotated_pil = annotate_segmentation(img, detections, show_labels, show_confidence) # Extract confidence scores confidences = [float(conf) for conf in detections.confidence] print(f"[INFO] Final confidences: {confidences}") # Encode to data URL print("[INFO] Encoding annotated image...") data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG") # Clean up del detections gc.collect() print(f"[INFO] ========== Prediction complete: {len(confidences)} leaves detected ==========\n") return jsonify({ "annotated": data_url, "confidences": confidences, "count": len(confidences) }) except Exception as e: error_msg = f"Inference failed: {e}" print(f"[ERROR] {error_msg}") traceback.print_exc() return jsonify({"error": error_msg}), 500 if __name__ == "__main__": print("\n" + "="*60) print("Starting Tulsi Leaf Segmentation Server") print("="*60 + "\n") # Warm model in background thread def warm(): try: print("[INFO] Starting model warmup in background...") init_model() print("[INFO] ✓ Model warmup complete - ready for predictions") except Exception as e: print(f"[ERROR] ✗ Model warmup failed: {e}") traceback.print_exc() threading.Thread(target=warm, daemon=True).start() # Run Flask app app.run( host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False )