Subh775's picture
model weights changed to: Subh775/Seg-Basil-rfdetr
1009e84 verified
import os
import io
import base64
import threading
import traceback
import gc
from typing import Optional
from flask import Flask, request, jsonify, send_from_directory
from PIL import Image
import numpy as np
import requests
import torch
# Set environment variables for CPU-only operation
os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib")
os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig")
os.environ.setdefault("FONTCONFIG_FILE", "/etc/fonts/fonts.conf")
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
os.environ.setdefault("OMP_NUM_THREADS", "4")
os.environ.setdefault("MKL_NUM_THREADS", "4")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "4")
# Create writable fontconfig cache
os.makedirs("/tmp/.fontconfig", exist_ok=True)
os.makedirs("/tmp/.matplotlib", exist_ok=True)
# Limit torch threads
try:
torch.set_num_threads(4)
except Exception:
pass
import supervision as sv
from rfdetr import RFDETRSegPreview
app = Flask(__name__, static_folder="static", static_url_path="/")
# Checkpoint URL & local path
CHECKPOINT_URL = "https://huggingface.co/Subh775/Seg-Basil-rfdetr/resolve/main/checkpoint_best_total.pth"
CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
MODEL_LOCK = threading.Lock()
MODEL = None
def download_file(url: str, dst: str, chunk_size: int = 8192):
"""Download file if not exists"""
if os.path.exists(dst) and os.path.getsize(dst) > 0:
print(f"[INFO] Checkpoint already exists at {dst}")
return dst
print(f"[INFO] Downloading weights from {url} -> {dst}")
try:
r = requests.get(url, stream=True, timeout=180)
r.raise_for_status()
with open(dst, "wb") as fh:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
fh.write(chunk)
print("[INFO] Download complete.")
return dst
except Exception as e:
print(f"[ERROR] Download failed: {e}")
raise
def init_model():
"""Lazily initialize the RF-DETR model and cache it in global MODEL."""
global MODEL
with MODEL_LOCK:
if MODEL is not None:
print("[INFO] Model already loaded, returning cached instance")
return MODEL
try:
# Ensure checkpoint present
if not os.path.exists(CHECKPOINT_PATH):
print("[INFO] Checkpoint not found, downloading...")
download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
else:
print(f"[INFO] Using existing checkpoint at {CHECKPOINT_PATH}")
print("[INFO] Loading RF-DETR model (CPU mode)...")
MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH)
# Try to optimize for inference
try:
print("[INFO] Optimizing model for inference...")
MODEL.optimize_for_inference()
print("[INFO] Model optimization complete")
except Exception as e:
print(f"[WARN] optimize_for_inference() skipped/failed: {e}")
print("[INFO] Model ready for inference")
return MODEL
except Exception as e:
print(f"[ERROR] Model initialization failed: {e}")
traceback.print_exc()
raise
def decode_data_url(data_url: str) -> Image.Image:
"""Decode data URL to PIL Image"""
if data_url.startswith("data:"):
_, b64 = data_url.split(",", 1)
data = base64.b64decode(b64)
else:
try:
data = base64.b64decode(data_url)
except Exception:
raise ValueError("Invalid image data")
return Image.open(io.BytesIO(data)).convert("RGB")
def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str:
"""Encode PIL Image to data URL"""
buf = io.BytesIO()
pil_img.save(buf, format=fmt, optimize=False)
buf.seek(0)
return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii")
def annotate_segmentation(image: Image.Image, detections: sv.Detections,
show_labels: bool = True, show_confidence: bool = True) -> Image.Image:
"""
Annotate image with segmentation masks using supervision library.
This matches the visualization from rfdetr_seg_infer.py script.
Args:
image: Input PIL Image
detections: Supervision Detections object
show_labels: Whether to show "Tulsi" label text
show_confidence: Whether to show confidence scores
"""
try:
# Define color palette
palette = sv.ColorPalette.from_hex([
"#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
"#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00",
])
# Calculate optimal text scale based on image resolution
text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
print(f"[INFO] Creating annotators with text_scale={text_scale}")
# Create annotators
mask_annotator = sv.MaskAnnotator(color=palette)
polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE)
# Apply base annotations (masks and polygons always shown)
out = image.copy()
print("[INFO] Applying mask annotation...")
out = mask_annotator.annotate(out, detections)
print("[INFO] Applying polygon annotation...")
out = polygon_annotator.annotate(out, detections)
# Only add labels if at least one option is enabled
if show_labels or show_confidence:
label_annotator = sv.LabelAnnotator(
color=palette,
text_color=sv.Color.BLACK,
text_scale=text_scale,
text_position=sv.Position.CENTER_OF_MASS
)
# Create labels based on options
labels = []
for conf in detections.confidence:
label_parts = []
if show_labels:
label_parts.append("Tulsi")
if show_confidence:
label_parts.append(f"{float(conf):.2f}")
labels.append(" ".join(label_parts))
print(f"[INFO] Applying label annotation with {len(labels)} labels...")
out = label_annotator.annotate(out, detections, labels)
else:
print("[INFO] Skipping label annotation (both labels and confidence disabled)")
print("[INFO] Annotation complete")
return out
except Exception as e:
print(f"[ERROR] Annotation failed: {e}")
traceback.print_exc()
# Return original image if annotation fails
return image
@app.route("/", methods=["GET"])
def index():
"""Serve the static UI"""
index_path = os.path.join(app.static_folder or "static", "index.html")
if os.path.exists(index_path):
return send_from_directory(app.static_folder, "index.html")
return jsonify({"message": "RF-DETR Segmentation API is running.", "status": "ready"})
@app.route("/health", methods=["GET"])
def health():
"""Health check endpoint"""
model_loaded = MODEL is not None
return jsonify({
"status": "healthy",
"model_loaded": model_loaded,
"checkpoint_exists": os.path.exists(CHECKPOINT_PATH)
})
@app.route("/predict", methods=["POST"])
def predict():
"""
Accepts:
- multipart/form-data with file field "file"
- or JSON {"image": "<data:url...>", "conf": 0.05, "show_labels": true, "show_confidence": true}
Returns JSON:
{"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N}
"""
print("\n[INFO] ========== New prediction request ==========")
try:
print("[INFO] Initializing model...")
model = init_model()
print("[INFO] Model ready")
except Exception as e:
error_msg = f"Model initialization failed: {e}"
print(f"[ERROR] {error_msg}")
return jsonify({"error": error_msg}), 500
# Parse input
img: Optional[Image.Image] = None
conf_threshold = 0.05
show_labels = True
show_confidence = True
# Check if file uploaded
if "file" in request.files:
file = request.files["file"]
print(f"[INFO] Processing uploaded file: {file.filename}")
try:
img = Image.open(file.stream).convert("RGB")
except Exception as e:
error_msg = f"Invalid uploaded image: {e}"
print(f"[ERROR] {error_msg}")
return jsonify({"error": error_msg}), 400
conf_threshold = float(request.form.get("conf", conf_threshold))
show_labels = request.form.get("show_labels", "true").lower() == "true"
show_confidence = request.form.get("show_confidence", "true").lower() == "true"
else:
# Try JSON payload
payload = request.get_json(silent=True)
if not payload or "image" not in payload:
return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400
try:
print("[INFO] Decoding image from data URL...")
img = decode_data_url(payload["image"])
except Exception as e:
error_msg = f"Invalid image data: {e}"
print(f"[ERROR] {error_msg}")
return jsonify({"error": error_msg}), 400
conf_threshold = float(payload.get("conf", conf_threshold))
show_labels = payload.get("show_labels", True)
show_confidence = payload.get("show_confidence", True)
print(f"[INFO] Image size: {img.size}, Confidence threshold: {conf_threshold}")
print(f"[INFO] Display options - Labels: {show_labels}, Confidence: {show_confidence}")
# Optionally downscale large images to reduce memory usage
MAX_SIZE = 1024
if max(img.size) > MAX_SIZE:
w, h = img.size
scale = MAX_SIZE / float(max(w, h))
new_w, new_h = int(round(w * scale)), int(round(h * scale))
print(f"[INFO] Resizing image from {w}x{h} to {new_w}x{new_h}")
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
# Run inference with no_grad for memory efficiency
try:
print("[INFO] Running inference...")
with torch.no_grad():
detections = model.predict(img, threshold=conf_threshold)
print(f"[INFO] Raw detections: {len(detections)} objects")
# Check if detections exist
if len(detections) == 0 or not hasattr(detections, 'confidence') or len(detections.confidence) == 0:
print("[INFO] No detections above threshold")
# Return original image
data_url = encode_pil_to_dataurl(img, fmt="PNG")
return jsonify({
"annotated": data_url,
"confidences": [],
"count": 0
})
print(f"[INFO] Detections have {len(detections.confidence)} confidence scores")
print(f"[INFO] Confidence range: {min(detections.confidence):.3f} - {max(detections.confidence):.3f}")
# Check if masks exist
if hasattr(detections, 'masks') and detections.masks is not None:
print(f"[INFO] Masks present: shape={np.array(detections.masks).shape if hasattr(detections.masks, '__len__') else 'unknown'}")
else:
print("[WARN] No masks found in detections!")
# Annotate image using supervision library
print("[INFO] Starting annotation...")
annotated_pil = annotate_segmentation(img, detections, show_labels, show_confidence)
# Extract confidence scores
confidences = [float(conf) for conf in detections.confidence]
print(f"[INFO] Final confidences: {confidences}")
# Encode to data URL
print("[INFO] Encoding annotated image...")
data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
# Clean up
del detections
gc.collect()
print(f"[INFO] ========== Prediction complete: {len(confidences)} leaves detected ==========\n")
return jsonify({
"annotated": data_url,
"confidences": confidences,
"count": len(confidences)
})
except Exception as e:
error_msg = f"Inference failed: {e}"
print(f"[ERROR] {error_msg}")
traceback.print_exc()
return jsonify({"error": error_msg}), 500
if __name__ == "__main__":
print("\n" + "="*60)
print("Starting Tulsi Leaf Segmentation Server")
print("="*60 + "\n")
# Warm model in background thread
def warm():
try:
print("[INFO] Starting model warmup in background...")
init_model()
print("[INFO] βœ“ Model warmup complete - ready for predictions")
except Exception as e:
print(f"[ERROR] βœ— Model warmup failed: {e}")
traceback.print_exc()
threading.Thread(target=warm, daemon=True).start()
# Run Flask app
app.run(
host="0.0.0.0",
port=int(os.environ.get("PORT", 7860)),
debug=False
)