Subh775's picture
Update app.py
c8d1052 verified
raw
history blame
18 kB
# import os
# import io
# import base64
# import tempfile
# import threading
# from PIL import Image, ImageDraw, ImageFont
# import numpy as np
# from flask import Flask, request, jsonify, send_from_directory
# import requests
# # Force CPU-only (prevents accidental GPU usage); works by hiding CUDA devices
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
# # --- model import (ensure rfdetr package is available in requirements) ---
# try:
# from rfdetr import RFDETRSegPreview
# except Exception as e:
# raise RuntimeError("rfdetr package import failed. Make sure `rfdetr` is in requirements.") from e
# app = Flask(__name__, static_folder="static", static_url_path="/")
# # HF checkpoint raw resolve URL (use the 'resolve/main' raw link)
# CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
# CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
# MODEL_LOCK = threading.Lock()
# MODEL = None
# def download_file(url: str, dst: str):
# if os.path.exists(dst):
# return dst
# print(f"[INFO] Downloading weights from {url} ...")
# r = requests.get(url, stream=True, timeout=60)
# r.raise_for_status()
# with open(dst, "wb") as fh:
# for chunk in r.iter_content(chunk_size=8192):
# if chunk:
# fh.write(chunk)
# print("[INFO] Download complete.")
# return dst
# def init_model():
# global MODEL
# with MODEL_LOCK:
# if MODEL is None:
# # Ensure model checkpoint
# try:
# download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
# except Exception as e:
# print(f"[WARN] Failed to download checkpoint: {e}. Attempting to init model without weights.")
# # continue; model may fallback to default weights
# print("[INFO] Loading RF-DETR model (CPU mode)...")
# MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None)
# try:
# MODEL.optimize_for_inference()
# except Exception:
# # optimization may fail on CPU or if not implemented; ignore
# pass
# print("[INFO] Model ready.")
# return MODEL
# @app.route("/")
# def index():
# return send_from_directory("static", "index.html")
# def decode_data_url(data_url: str) -> Image.Image:
# if data_url.startswith("data:"):
# header, b64 = data_url.split(",", 1)
# data = base64.b64decode(b64)
# return Image.open(io.BytesIO(data)).convert("RGB")
# else:
# # assume plain base64 or path
# data = base64.b64decode(data_url)
# return Image.open(io.BytesIO(data)).convert("RGB")
# def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG"):
# buf = io.BytesIO()
# pil_img.save(buf, format=fmt)
# b = base64.b64encode(buf.getvalue()).decode("ascii")
# return f"data:image/{fmt.lower()};base64,{b}"
# def overlay_mask_on_image(pil_img: Image.Image, masks, confidences, threshold=0.01, mask_color=(255,77,166), alpha=0.45):
# """
# masks: either list of HxW bool arrays or numpy array (N,H,W)
# confidences: list of floats
# Returns annotated PIL image and list of kept confidences and count.
# """
# base = pil_img.convert("RGBA")
# W, H = base.size
# # Normalize masks to N,H,W
# if masks is None:
# return base, []
# if isinstance(masks, list):
# masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0)
# else:
# masks_arr = np.asarray(masks)
# # masks might be (H,W,N) -> transpose
# if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W:
# masks_arr = masks_arr.transpose(2, 0, 1)
# # create overlay
# overlay = Image.new("RGBA", (W, H), (0,0,0,0))
# draw = ImageDraw.Draw(overlay)
# kept_confidences = []
# for i in range(masks_arr.shape[0]):
# conf = float(confidences[i]) if confidences is not None and i < len(confidences) else 1.0
# if conf < threshold:
# continue
# mask = masks_arr[i].astype(np.uint8) * 255
# mask_img = Image.fromarray(mask).convert("L").resize((W, H), resample=Image.NEAREST)
# # create colored mask image
# color_layer = Image.new("RGBA", (W,H), mask_color + (0,))
# # put alpha using mask
# color_layer.putalpha(mask_img.point(lambda p: int(p * alpha)))
# overlay = Image.alpha_composite(overlay, color_layer)
# kept_confidences.append(conf)
# # composite
# annotated = Image.alpha_composite(base, overlay)
# # add confidence text (show highest kept confidence)
# if len(kept_confidences) > 0:
# best = max(kept_confidences)
# draw = ImageDraw.Draw(annotated)
# try:
# # Try to use a builtin font
# font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(16, W//30))
# except Exception:
# font = ImageFont.load_default()
# text = f"Confidence: {best:.2f}"
# # draw background box for text
# tw, th = draw.textsize(text, font=font)
# pad = 8
# draw.rectangle([6,6, 6+tw+pad, 6+th+pad], fill=(0,0,0,180))
# draw.text((6+4,6+2), text, font=font, fill=(255,255,255,255))
# return annotated.convert("RGB"), kept_confidences
# @app.route("/predict", methods=["POST"])
# def predict():
# payload = request.get_json(force=True)
# if not payload or "image" not in payload:
# return jsonify({"error": "Missing image"}), 400
# conf = float(payload.get("conf", 0.25))
# # ensure model ready
# model = init_model()
# # decode image
# try:
# pil = decode_data_url(payload["image"])
# except Exception as e:
# return jsonify({"error": f"Invalid image: {e}"}), 400
# # perform prediction (model.predict expects PIL image)
# try:
# detections = model.predict(pil, threshold=0.0) # we filter using conf manually
# except Exception as e:
# return jsonify({"error": f"Inference failure: {e}"}), 500
# # extract masks and confidences
# masks = getattr(detections, "masks", None)
# confidences = []
# # attempt to read per-instance confidence
# try:
# confidences = [float(x) for x in getattr(detections, "confidence", [])]
# except Exception:
# # fallback: attempt attribute 'scores' or 'scores_' or generate ones
# confidences = []
# try:
# confidences = [float(x) for x in getattr(detections, "scores", [])]
# except Exception:
# confidences = [1.0] * (masks.shape[0] if masks is not None and hasattr(masks, "shape") and masks.shape[0] else 0)
# # overlay mask with pink-red color
# mask_color = (255, 77, 166) # pinkish
# annotated_pil, kept_conf = overlay_mask_on_image(pil, masks, confidences, threshold=conf, mask_color=mask_color, alpha=0.45)
# data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
# return jsonify({
# "annotated": data_url,
# "confidences": kept_conf,
# "count": len(kept_conf)
# })
# if __name__ == "__main__":
# # warm up model on startup (non-blocking)
# try:
# init_model()
# except Exception as e:
# print("Model init warning:", e)
# app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
import os
import io
import base64
import threading
import tempfile
import traceback
from typing import Optional
from flask import Flask, request, jsonify, send_from_directory, send_file
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import requests
# Set writable cache dirs to avoid matplotlib/fontconfig warnings in containers
os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib")
os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig")
# Ensure CPU-only (do not accidentally use GPU)
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
# --- Imports that may trigger the above warnings ---
try:
import supervision as sv
from rfdetr import RFDETRSegPreview
except Exception as e:
# Provide a clearer error at startup if imports fail
raise RuntimeError(f"Required library import failed: {e}")
app = Flask(__name__, static_folder="static", static_url_path="/")
# Checkpoint URL & local path
CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
MODEL_LOCK = threading.Lock()
MODEL = None
def download_file(url: str, dst: str, chunk_size: int = 8192):
if os.path.exists(dst) and os.path.getsize(dst) > 0:
return dst
print(f"[INFO] Downloading weights from {url} -> {dst}")
r = requests.get(url, stream=True, timeout=60)
r.raise_for_status()
with open(dst, "wb") as fh:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
fh.write(chunk)
print("[INFO] Download complete.")
return dst
def init_model():
"""
Lazily initialize the RF-DETR model and cache it in global MODEL.
Thread-safe.
"""
global MODEL
with MODEL_LOCK:
if MODEL is not None:
return MODEL
try:
# ensure checkpoint present (best-effort)
try:
download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
except Exception as e:
print("[WARN] Failed to download checkpoint:", e)
print("[INFO] Loading RF-DETR model (CPU mode)...")
MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None)
try:
MODEL.optimize_for_inference()
except Exception as e:
print("[WARN] optimize_for_inference() skipped/failed:", e)
print("[INFO] Model ready.")
return MODEL
except Exception:
traceback.print_exc()
raise
def decode_data_url(data_url: str) -> Image.Image:
"""
Accepts a data URL (data:image/png;base64,...) or raw base64 and returns PIL.Image (RGB)
"""
if data_url.startswith("data:"):
_, b64 = data_url.split(",", 1)
data = base64.b64decode(b64)
else:
# assume raw base64 or binary string
try:
data = base64.b64decode(data_url)
except Exception:
raise ValueError("Invalid image data")
return Image.open(io.BytesIO(data)).convert("RGB")
def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str:
buf = io.BytesIO()
pil_img.save(buf, format=fmt)
return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii")
def overlay_mask_on_image(pil_img: Image.Image, detections, threshold: float = 0.25,
mask_color=(255, 77, 166), alpha=0.45):
"""
Create annotated PIL image by overlaying per-instance masks (pink) and polygon borders,
and add confidence text (best confidence) on the image.
Uses supervision-like masks if available, otherwise attempts to use detections.masks.
Returns (annotated_pil_rgb, kept_confidences_list)
"""
base = pil_img.convert("RGBA")
W, H = base.size
masks = getattr(detections, "masks", None)
confidences = []
try:
confidences = [float(x) for x in getattr(detections, "confidence", [])]
except Exception:
# fallback to 'scores' or empty
try:
confidences = [float(x) for x in getattr(detections, "scores", [])]
except Exception:
confidences = []
if masks is None:
# no masks -> return original image and empty list
return pil_img.convert("RGB"), []
# Normalize mask array to (N, H, W)
if isinstance(masks, list):
masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0)
else:
masks_arr = np.asarray(masks)
# some outputs might be (H, W, N)
if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W:
masks_arr = masks_arr.transpose(2, 0, 1)
# overlay image we will composite
overlay = Image.new("RGBA", (W, H), (0, 0, 0, 0))
kept_confidences = []
for i in range(masks_arr.shape[0]):
conf = confidences[i] if i < len(confidences) else 1.0
if conf < threshold:
continue
mask = masks_arr[i].astype(np.uint8) * 255
mask_img = Image.fromarray(mask).convert("L")
# if mask size doesn't match, resize
if mask_img.size != (W, H):
mask_img = mask_img.resize((W, H), resample=Image.NEAREST)
# color layer with alpha
color_layer = Image.new("RGBA", (W, H), mask_color + (0,))
# compute per-pixel alpha from mask (0..255) scaled by alpha
alpha_mask = mask_img.point(lambda p: int(p * alpha))
color_layer.putalpha(alpha_mask)
overlay = Image.alpha_composite(overlay, color_layer)
kept_confidences.append(float(conf))
# draw polygon outlines for visual crispness using supervision polygonifier if available
try:
# try to use supervision polygonizer if detections contains polygons
# fallback: create thin white outline by expanding mask boundaries
from skimage import measure
draw = ImageDraw.Draw(overlay)
for i in range(masks_arr.shape[0]):
conf = confidences[i] if i < len(confidences) else 1.0
if conf < threshold:
continue
mask = masks_arr[i].astype(np.uint8)
# resize mask for contour if needed
if mask.shape[1] != W or mask.shape[0] != H:
mask_pil = Image.fromarray((mask * 255).astype(np.uint8)).resize((W, H), resample=Image.NEAREST)
mask = np.asarray(mask_pil).astype(np.uint8) // 255
contours = measure.find_contours(mask, 0.5)
for contour in contours:
# contour is list of (row, col) -> convert to (x, y)
pts = [(float(c[1]), float(c[0])) for c in contour]
if len(pts) >= 3:
# draw white outline
draw.line(pts + [pts[0]], fill=(255, 255, 255, 255), width=2)
except Exception:
# ignore if skimage not available; outlines are optional
pass
annotated = Image.alpha_composite(base, overlay).convert("RGBA")
# annotate best confidence text (top-left)
if kept_confidences:
best = max(kept_confidences)
draw = ImageDraw.Draw(annotated)
try:
font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(14, W // 32))
except Exception:
font = ImageFont.load_default()
text = f"Confidence: {best:.2f}"
tw, th = draw.textsize(text, font=font)
pad = 6
rect = [6, 6, 6 + tw + pad, 6 + th + pad]
draw.rectangle(rect, fill=(0, 0, 0, 180))
draw.text((6 + pad // 2, 6 + pad // 2), text, font=font, fill=(255, 255, 255, 255))
return annotated.convert("RGB"), kept_confidences
@app.route("/", methods=["GET"])
def index():
# serve the static UI file if present
index_path = os.path.join(app.static_folder or "static", "index.html")
if os.path.exists(index_path):
return send_from_directory(app.static_folder, "index.html")
return jsonify({"message": "RF-DETR Segmentation API is running."})
@app.route("/predict", methods=["POST"])
def predict():
"""
Accepts:
- multipart/form-data with file field "file"
- or JSON {"image": "<data:url...>", "conf": 0.25}
Returns JSON:
{"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N}
"""
try:
model = init_model()
except Exception as e:
return jsonify({"error": f"Model initialization failed: {e}"}), 500
# parse input
img: Optional[Image.Image] = None
conf_threshold = 0.25
# If form file uploaded
if "file" in request.files:
file = request.files["file"]
try:
img = Image.open(file.stream).convert("RGB")
except Exception as e:
return jsonify({"error": f"Invalid uploaded image: {e}"}), 400
conf_threshold = float(request.form.get("conf", conf_threshold))
else:
# try JSON payload
payload = request.get_json(silent=True)
if not payload or "image" not in payload:
return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400
try:
img = decode_data_url(payload["image"])
except Exception as e:
return jsonify({"error": f"Invalid image data: {e}"}), 400
conf_threshold = float(payload.get("conf", conf_threshold))
# run inference
try:
# set threshold=0.0 in model predict since we'll manually filter by conf_threshold
detections = model.predict(img, threshold=0.0)
except Exception as e:
traceback.print_exc()
return jsonify({"error": f"Inference failed: {e}"}), 500
# overlay masks and extract confidences > threshold
annotated_pil, kept_conf = overlay_mask_on_image(img, detections, threshold=conf_threshold,
mask_color=(255, 77, 166), alpha=0.45)
data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
return jsonify({"annotated": data_url, "confidences": kept_conf, "count": len(kept_conf)})
if __name__ == "__main__":
# Warm model in a background thread to avoid blocking the container start logs too long
def warm():
try:
init_model()
except Exception as e:
print("Model warmup failed:", e)
threading.Thread(target=warm, daemon=True).start()
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)