Spaces:
Running
Running
| # import os | |
| # import io | |
| # import base64 | |
| # import tempfile | |
| # import threading | |
| # from PIL import Image, ImageDraw, ImageFont | |
| # import numpy as np | |
| # from flask import Flask, request, jsonify, send_from_directory | |
| # import requests | |
| # # Force CPU-only (prevents accidental GPU usage); works by hiding CUDA devices | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| # # --- model import (ensure rfdetr package is available in requirements) --- | |
| # try: | |
| # from rfdetr import RFDETRSegPreview | |
| # except Exception as e: | |
| # raise RuntimeError("rfdetr package import failed. Make sure `rfdetr` is in requirements.") from e | |
| # app = Flask(__name__, static_folder="static", static_url_path="/") | |
| # # HF checkpoint raw resolve URL (use the 'resolve/main' raw link) | |
| # CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth" | |
| # CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth") | |
| # MODEL_LOCK = threading.Lock() | |
| # MODEL = None | |
| # def download_file(url: str, dst: str): | |
| # if os.path.exists(dst): | |
| # return dst | |
| # print(f"[INFO] Downloading weights from {url} ...") | |
| # r = requests.get(url, stream=True, timeout=60) | |
| # r.raise_for_status() | |
| # with open(dst, "wb") as fh: | |
| # for chunk in r.iter_content(chunk_size=8192): | |
| # if chunk: | |
| # fh.write(chunk) | |
| # print("[INFO] Download complete.") | |
| # return dst | |
| # def init_model(): | |
| # global MODEL | |
| # with MODEL_LOCK: | |
| # if MODEL is None: | |
| # # Ensure model checkpoint | |
| # try: | |
| # download_file(CHECKPOINT_URL, CHECKPOINT_PATH) | |
| # except Exception as e: | |
| # print(f"[WARN] Failed to download checkpoint: {e}. Attempting to init model without weights.") | |
| # # continue; model may fallback to default weights | |
| # print("[INFO] Loading RF-DETR model (CPU mode)...") | |
| # MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None) | |
| # try: | |
| # MODEL.optimize_for_inference() | |
| # except Exception: | |
| # # optimization may fail on CPU or if not implemented; ignore | |
| # pass | |
| # print("[INFO] Model ready.") | |
| # return MODEL | |
| # @app.route("/") | |
| # def index(): | |
| # return send_from_directory("static", "index.html") | |
| # def decode_data_url(data_url: str) -> Image.Image: | |
| # if data_url.startswith("data:"): | |
| # header, b64 = data_url.split(",", 1) | |
| # data = base64.b64decode(b64) | |
| # return Image.open(io.BytesIO(data)).convert("RGB") | |
| # else: | |
| # # assume plain base64 or path | |
| # data = base64.b64decode(data_url) | |
| # return Image.open(io.BytesIO(data)).convert("RGB") | |
| # def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG"): | |
| # buf = io.BytesIO() | |
| # pil_img.save(buf, format=fmt) | |
| # b = base64.b64encode(buf.getvalue()).decode("ascii") | |
| # return f"data:image/{fmt.lower()};base64,{b}" | |
| # def overlay_mask_on_image(pil_img: Image.Image, masks, confidences, threshold=0.01, mask_color=(255,77,166), alpha=0.45): | |
| # """ | |
| # masks: either list of HxW bool arrays or numpy array (N,H,W) | |
| # confidences: list of floats | |
| # Returns annotated PIL image and list of kept confidences and count. | |
| # """ | |
| # base = pil_img.convert("RGBA") | |
| # W, H = base.size | |
| # # Normalize masks to N,H,W | |
| # if masks is None: | |
| # return base, [] | |
| # if isinstance(masks, list): | |
| # masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0) | |
| # else: | |
| # masks_arr = np.asarray(masks) | |
| # # masks might be (H,W,N) -> transpose | |
| # if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W: | |
| # masks_arr = masks_arr.transpose(2, 0, 1) | |
| # # create overlay | |
| # overlay = Image.new("RGBA", (W, H), (0,0,0,0)) | |
| # draw = ImageDraw.Draw(overlay) | |
| # kept_confidences = [] | |
| # for i in range(masks_arr.shape[0]): | |
| # conf = float(confidences[i]) if confidences is not None and i < len(confidences) else 1.0 | |
| # if conf < threshold: | |
| # continue | |
| # mask = masks_arr[i].astype(np.uint8) * 255 | |
| # mask_img = Image.fromarray(mask).convert("L").resize((W, H), resample=Image.NEAREST) | |
| # # create colored mask image | |
| # color_layer = Image.new("RGBA", (W,H), mask_color + (0,)) | |
| # # put alpha using mask | |
| # color_layer.putalpha(mask_img.point(lambda p: int(p * alpha))) | |
| # overlay = Image.alpha_composite(overlay, color_layer) | |
| # kept_confidences.append(conf) | |
| # # composite | |
| # annotated = Image.alpha_composite(base, overlay) | |
| # # add confidence text (show highest kept confidence) | |
| # if len(kept_confidences) > 0: | |
| # best = max(kept_confidences) | |
| # draw = ImageDraw.Draw(annotated) | |
| # try: | |
| # # Try to use a builtin font | |
| # font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(16, W//30)) | |
| # except Exception: | |
| # font = ImageFont.load_default() | |
| # text = f"Confidence: {best:.2f}" | |
| # # draw background box for text | |
| # tw, th = draw.textsize(text, font=font) | |
| # pad = 8 | |
| # draw.rectangle([6,6, 6+tw+pad, 6+th+pad], fill=(0,0,0,180)) | |
| # draw.text((6+4,6+2), text, font=font, fill=(255,255,255,255)) | |
| # return annotated.convert("RGB"), kept_confidences | |
| # @app.route("/predict", methods=["POST"]) | |
| # def predict(): | |
| # payload = request.get_json(force=True) | |
| # if not payload or "image" not in payload: | |
| # return jsonify({"error": "Missing image"}), 400 | |
| # conf = float(payload.get("conf", 0.25)) | |
| # # ensure model ready | |
| # model = init_model() | |
| # # decode image | |
| # try: | |
| # pil = decode_data_url(payload["image"]) | |
| # except Exception as e: | |
| # return jsonify({"error": f"Invalid image: {e}"}), 400 | |
| # # perform prediction (model.predict expects PIL image) | |
| # try: | |
| # detections = model.predict(pil, threshold=0.0) # we filter using conf manually | |
| # except Exception as e: | |
| # return jsonify({"error": f"Inference failure: {e}"}), 500 | |
| # # extract masks and confidences | |
| # masks = getattr(detections, "masks", None) | |
| # confidences = [] | |
| # # attempt to read per-instance confidence | |
| # try: | |
| # confidences = [float(x) for x in getattr(detections, "confidence", [])] | |
| # except Exception: | |
| # # fallback: attempt attribute 'scores' or 'scores_' or generate ones | |
| # confidences = [] | |
| # try: | |
| # confidences = [float(x) for x in getattr(detections, "scores", [])] | |
| # except Exception: | |
| # confidences = [1.0] * (masks.shape[0] if masks is not None and hasattr(masks, "shape") and masks.shape[0] else 0) | |
| # # overlay mask with pink-red color | |
| # mask_color = (255, 77, 166) # pinkish | |
| # annotated_pil, kept_conf = overlay_mask_on_image(pil, masks, confidences, threshold=conf, mask_color=mask_color, alpha=0.45) | |
| # data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG") | |
| # return jsonify({ | |
| # "annotated": data_url, | |
| # "confidences": kept_conf, | |
| # "count": len(kept_conf) | |
| # }) | |
| # if __name__ == "__main__": | |
| # # warm up model on startup (non-blocking) | |
| # try: | |
| # init_model() | |
| # except Exception as e: | |
| # print("Model init warning:", e) | |
| # app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False) | |
| import os | |
| import io | |
| import base64 | |
| import threading | |
| import tempfile | |
| import traceback | |
| from typing import Optional | |
| from flask import Flask, request, jsonify, send_from_directory, send_file | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| import requests | |
| # Set writable cache dirs to avoid matplotlib/fontconfig warnings in containers | |
| os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib") | |
| os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig") | |
| # Ensure CPU-only (do not accidentally use GPU) | |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") | |
| # --- Imports that may trigger the above warnings --- | |
| try: | |
| import supervision as sv | |
| from rfdetr import RFDETRSegPreview | |
| except Exception as e: | |
| # Provide a clearer error at startup if imports fail | |
| raise RuntimeError(f"Required library import failed: {e}") | |
| app = Flask(__name__, static_folder="static", static_url_path="/") | |
| # Checkpoint URL & local path | |
| CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth" | |
| CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth") | |
| MODEL_LOCK = threading.Lock() | |
| MODEL = None | |
| def download_file(url: str, dst: str, chunk_size: int = 8192): | |
| if os.path.exists(dst) and os.path.getsize(dst) > 0: | |
| return dst | |
| print(f"[INFO] Downloading weights from {url} -> {dst}") | |
| r = requests.get(url, stream=True, timeout=60) | |
| r.raise_for_status() | |
| with open(dst, "wb") as fh: | |
| for chunk in r.iter_content(chunk_size=chunk_size): | |
| if chunk: | |
| fh.write(chunk) | |
| print("[INFO] Download complete.") | |
| return dst | |
| def init_model(): | |
| """ | |
| Lazily initialize the RF-DETR model and cache it in global MODEL. | |
| Thread-safe. | |
| """ | |
| global MODEL | |
| with MODEL_LOCK: | |
| if MODEL is not None: | |
| return MODEL | |
| try: | |
| # ensure checkpoint present (best-effort) | |
| try: | |
| download_file(CHECKPOINT_URL, CHECKPOINT_PATH) | |
| except Exception as e: | |
| print("[WARN] Failed to download checkpoint:", e) | |
| print("[INFO] Loading RF-DETR model (CPU mode)...") | |
| MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None) | |
| try: | |
| MODEL.optimize_for_inference() | |
| except Exception as e: | |
| print("[WARN] optimize_for_inference() skipped/failed:", e) | |
| print("[INFO] Model ready.") | |
| return MODEL | |
| except Exception: | |
| traceback.print_exc() | |
| raise | |
| def decode_data_url(data_url: str) -> Image.Image: | |
| """ | |
| Accepts a data URL (data:image/png;base64,...) or raw base64 and returns PIL.Image (RGB) | |
| """ | |
| if data_url.startswith("data:"): | |
| _, b64 = data_url.split(",", 1) | |
| data = base64.b64decode(b64) | |
| else: | |
| # assume raw base64 or binary string | |
| try: | |
| data = base64.b64decode(data_url) | |
| except Exception: | |
| raise ValueError("Invalid image data") | |
| return Image.open(io.BytesIO(data)).convert("RGB") | |
| def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str: | |
| buf = io.BytesIO() | |
| pil_img.save(buf, format=fmt) | |
| return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii") | |
| def overlay_mask_on_image(pil_img: Image.Image, detections, threshold: float = 0.25, | |
| mask_color=(255, 77, 166), alpha=0.45): | |
| """ | |
| Create annotated PIL image by overlaying per-instance masks (pink) and polygon borders, | |
| and add confidence text (best confidence) on the image. | |
| Uses supervision-like masks if available, otherwise attempts to use detections.masks. | |
| Returns (annotated_pil_rgb, kept_confidences_list) | |
| """ | |
| base = pil_img.convert("RGBA") | |
| W, H = base.size | |
| masks = getattr(detections, "masks", None) | |
| confidences = [] | |
| try: | |
| confidences = [float(x) for x in getattr(detections, "confidence", [])] | |
| except Exception: | |
| # fallback to 'scores' or empty | |
| try: | |
| confidences = [float(x) for x in getattr(detections, "scores", [])] | |
| except Exception: | |
| confidences = [] | |
| if masks is None: | |
| # no masks -> return original image and empty list | |
| return pil_img.convert("RGB"), [] | |
| # Normalize mask array to (N, H, W) | |
| if isinstance(masks, list): | |
| masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0) | |
| else: | |
| masks_arr = np.asarray(masks) | |
| # some outputs might be (H, W, N) | |
| if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W: | |
| masks_arr = masks_arr.transpose(2, 0, 1) | |
| # overlay image we will composite | |
| overlay = Image.new("RGBA", (W, H), (0, 0, 0, 0)) | |
| kept_confidences = [] | |
| for i in range(masks_arr.shape[0]): | |
| conf = confidences[i] if i < len(confidences) else 1.0 | |
| if conf < threshold: | |
| continue | |
| mask = masks_arr[i].astype(np.uint8) * 255 | |
| mask_img = Image.fromarray(mask).convert("L") | |
| # if mask size doesn't match, resize | |
| if mask_img.size != (W, H): | |
| mask_img = mask_img.resize((W, H), resample=Image.NEAREST) | |
| # color layer with alpha | |
| color_layer = Image.new("RGBA", (W, H), mask_color + (0,)) | |
| # compute per-pixel alpha from mask (0..255) scaled by alpha | |
| alpha_mask = mask_img.point(lambda p: int(p * alpha)) | |
| color_layer.putalpha(alpha_mask) | |
| overlay = Image.alpha_composite(overlay, color_layer) | |
| kept_confidences.append(float(conf)) | |
| # draw polygon outlines for visual crispness using supervision polygonifier if available | |
| try: | |
| # try to use supervision polygonizer if detections contains polygons | |
| # fallback: create thin white outline by expanding mask boundaries | |
| from skimage import measure | |
| draw = ImageDraw.Draw(overlay) | |
| for i in range(masks_arr.shape[0]): | |
| conf = confidences[i] if i < len(confidences) else 1.0 | |
| if conf < threshold: | |
| continue | |
| mask = masks_arr[i].astype(np.uint8) | |
| # resize mask for contour if needed | |
| if mask.shape[1] != W or mask.shape[0] != H: | |
| mask_pil = Image.fromarray((mask * 255).astype(np.uint8)).resize((W, H), resample=Image.NEAREST) | |
| mask = np.asarray(mask_pil).astype(np.uint8) // 255 | |
| contours = measure.find_contours(mask, 0.5) | |
| for contour in contours: | |
| # contour is list of (row, col) -> convert to (x, y) | |
| pts = [(float(c[1]), float(c[0])) for c in contour] | |
| if len(pts) >= 3: | |
| # draw white outline | |
| draw.line(pts + [pts[0]], fill=(255, 255, 255, 255), width=2) | |
| except Exception: | |
| # ignore if skimage not available; outlines are optional | |
| pass | |
| annotated = Image.alpha_composite(base, overlay).convert("RGBA") | |
| # annotate best confidence text (top-left) | |
| if kept_confidences: | |
| best = max(kept_confidences) | |
| draw = ImageDraw.Draw(annotated) | |
| try: | |
| font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(14, W // 32)) | |
| except Exception: | |
| font = ImageFont.load_default() | |
| text = f"Confidence: {best:.2f}" | |
| tw, th = draw.textsize(text, font=font) | |
| pad = 6 | |
| rect = [6, 6, 6 + tw + pad, 6 + th + pad] | |
| draw.rectangle(rect, fill=(0, 0, 0, 180)) | |
| draw.text((6 + pad // 2, 6 + pad // 2), text, font=font, fill=(255, 255, 255, 255)) | |
| return annotated.convert("RGB"), kept_confidences | |
| def index(): | |
| # serve the static UI file if present | |
| index_path = os.path.join(app.static_folder or "static", "index.html") | |
| if os.path.exists(index_path): | |
| return send_from_directory(app.static_folder, "index.html") | |
| return jsonify({"message": "RF-DETR Segmentation API is running."}) | |
| def predict(): | |
| """ | |
| Accepts: | |
| - multipart/form-data with file field "file" | |
| - or JSON {"image": "<data:url...>", "conf": 0.25} | |
| Returns JSON: | |
| {"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N} | |
| """ | |
| try: | |
| model = init_model() | |
| except Exception as e: | |
| return jsonify({"error": f"Model initialization failed: {e}"}), 500 | |
| # parse input | |
| img: Optional[Image.Image] = None | |
| conf_threshold = 0.25 | |
| # If form file uploaded | |
| if "file" in request.files: | |
| file = request.files["file"] | |
| try: | |
| img = Image.open(file.stream).convert("RGB") | |
| except Exception as e: | |
| return jsonify({"error": f"Invalid uploaded image: {e}"}), 400 | |
| conf_threshold = float(request.form.get("conf", conf_threshold)) | |
| else: | |
| # try JSON payload | |
| payload = request.get_json(silent=True) | |
| if not payload or "image" not in payload: | |
| return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400 | |
| try: | |
| img = decode_data_url(payload["image"]) | |
| except Exception as e: | |
| return jsonify({"error": f"Invalid image data: {e}"}), 400 | |
| conf_threshold = float(payload.get("conf", conf_threshold)) | |
| # run inference | |
| try: | |
| # set threshold=0.0 in model predict since we'll manually filter by conf_threshold | |
| detections = model.predict(img, threshold=0.0) | |
| except Exception as e: | |
| traceback.print_exc() | |
| return jsonify({"error": f"Inference failed: {e}"}), 500 | |
| # overlay masks and extract confidences > threshold | |
| annotated_pil, kept_conf = overlay_mask_on_image(img, detections, threshold=conf_threshold, | |
| mask_color=(255, 77, 166), alpha=0.45) | |
| data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG") | |
| return jsonify({"annotated": data_url, "confidences": kept_conf, "count": len(kept_conf)}) | |
| if __name__ == "__main__": | |
| # Warm model in a background thread to avoid blocking the container start logs too long | |
| def warm(): | |
| try: | |
| init_model() | |
| except Exception as e: | |
| print("Model warmup failed:", e) | |
| threading.Thread(target=warm, daemon=True).start() | |
| app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False) |