Subh775 commited on
Commit
c8d1052
·
verified ·
1 Parent(s): ecee7e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -84
app.py CHANGED
@@ -193,110 +193,279 @@
193
  # print("Model init warning:", e)
194
  # app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
195
 
196
-
197
  import os
198
  import io
 
 
 
 
 
 
 
 
199
  import numpy as np
200
- from PIL import Image
201
  import requests
202
- import supervision as sv
203
- from flask import Flask, request, jsonify, send_file
204
- from rfdetr import RFDETRSegPreview
205
 
206
- app = Flask(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- # ---- CONFIG ----
209
- WEIGHTS_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
210
- WEIGHTS_PATH = "/tmp/checkpoint_best_total.pth"
211
 
212
- # ---- HELPERS ----
213
- def download_file(url: str, dst: str):
214
- """Download model weights if not already cached."""
215
- if os.path.exists(dst):
216
- print(f"[INFO] Weights already exist at {dst}")
217
  return dst
218
- print(f"[INFO] Downloading weights from {url} ...")
219
- r = requests.get(url, stream=True)
220
  r.raise_for_status()
221
- with open(dst, "wb") as f:
222
- for chunk in r.iter_content(chunk_size=8192):
223
- f.write(chunk)
 
224
  print("[INFO] Download complete.")
225
  return dst
226
 
227
 
228
- def annotate_segmentation(image: Image.Image, detections: sv.Detections):
229
- """Overlay colored masks and confidence scores."""
230
- palette = sv.ColorPalette.from_hex([
231
- "#ff9b00", "#ff8080", "#ff66b2", "#b266ff",
232
- "#9999ff", "#3399ff", "#33ff99", "#99ff00"
233
- ])
234
- text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
235
-
236
- mask_annotator = sv.MaskAnnotator(color=palette)
237
- polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE)
238
- label_annotator = sv.LabelAnnotator(
239
- color=palette,
240
- text_color=sv.Color.BLACK,
241
- text_scale=text_scale,
242
- text_position=sv.Position.CENTER_OF_MASS
243
- )
244
-
245
- # Only show confidence (no class id)
246
- labels = [f"{conf:.2f}" for conf in detections.confidence]
247
-
248
- annotated = image.copy()
249
- annotated = mask_annotator.annotate(annotated, detections)
250
- annotated = polygon_annotator.annotate(annotated, detections)
251
- annotated = label_annotator.annotate(annotated, detections, labels)
252
- return annotated
253
-
254
-
255
- # ---- MODEL INITIALIZATION ----
256
- print("[INFO] Loading RF-DETR model (CPU mode)...")
257
- download_file(WEIGHTS_URL, WEIGHTS_PATH)
258
- model = RFDETRSegPreview(pretrain_weights=WEIGHTS_PATH)
259
- try:
260
- model.optimize_for_inference()
261
- except Exception as e:
262
- print(f"[WARN] optimize_for_inference() skipped: {e}")
263
- print("[INFO] Model ready.")
264
-
265
-
266
- # ---- ROUTES ----
267
- @app.route("/")
268
- def home():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  return jsonify({"message": "RF-DETR Segmentation API is running."})
270
 
271
 
272
  @app.route("/predict", methods=["POST"])
273
  def predict():
274
- """Accepts an image file and returns annotated segmentation overlay."""
275
- if "file" not in request.files:
276
- return jsonify({"error": "No file uploaded"}), 400
277
-
278
- file = request.files["file"]
279
- image = Image.open(file.stream).convert("RGB")
280
- print(f"[INFO] Image received for inference: {file.filename}")
281
-
282
- detections = model.predict(image, threshold=0.3)
283
- print(f"[INFO] Detections found: {len(getattr(detections, 'boxes', []))}")
284
-
285
- annotated = annotate_segmentation(image, detections)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
- buf = io.BytesIO()
288
- annotated.save(buf, format="PNG")
289
- buf.seek(0)
290
- return send_file(buf, mimetype="image/png")
291
 
 
 
292
 
293
- # if __name__ == "__main__":
294
- # app.run(host="0.0.0.0", port=7860)
295
 
296
  if __name__ == "__main__":
297
- # warm up model on startup (non-blocking)
298
- try:
299
- init_model()
300
- except Exception as e:
301
- print("Model init warning:", e)
302
- app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
 
 
 
193
  # print("Model init warning:", e)
194
  # app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
195
 
 
196
  import os
197
  import io
198
+ import base64
199
+ import threading
200
+ import tempfile
201
+ import traceback
202
+ from typing import Optional
203
+
204
+ from flask import Flask, request, jsonify, send_from_directory, send_file
205
+ from PIL import Image, ImageDraw, ImageFont
206
  import numpy as np
 
207
  import requests
 
 
 
208
 
209
+ # Set writable cache dirs to avoid matplotlib/fontconfig warnings in containers
210
+ os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib")
211
+ os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig")
212
+ # Ensure CPU-only (do not accidentally use GPU)
213
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
214
+
215
+ # --- Imports that may trigger the above warnings ---
216
+ try:
217
+ import supervision as sv
218
+ from rfdetr import RFDETRSegPreview
219
+ except Exception as e:
220
+ # Provide a clearer error at startup if imports fail
221
+ raise RuntimeError(f"Required library import failed: {e}")
222
+
223
+ app = Flask(__name__, static_folder="static", static_url_path="/")
224
+
225
+ # Checkpoint URL & local path
226
+ CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
227
+ CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
228
+
229
+ MODEL_LOCK = threading.Lock()
230
+ MODEL = None
231
 
 
 
 
232
 
233
+ def download_file(url: str, dst: str, chunk_size: int = 8192):
234
+ if os.path.exists(dst) and os.path.getsize(dst) > 0:
 
 
 
235
  return dst
236
+ print(f"[INFO] Downloading weights from {url} -> {dst}")
237
+ r = requests.get(url, stream=True, timeout=60)
238
  r.raise_for_status()
239
+ with open(dst, "wb") as fh:
240
+ for chunk in r.iter_content(chunk_size=chunk_size):
241
+ if chunk:
242
+ fh.write(chunk)
243
  print("[INFO] Download complete.")
244
  return dst
245
 
246
 
247
+ def init_model():
248
+ """
249
+ Lazily initialize the RF-DETR model and cache it in global MODEL.
250
+ Thread-safe.
251
+ """
252
+ global MODEL
253
+ with MODEL_LOCK:
254
+ if MODEL is not None:
255
+ return MODEL
256
+ try:
257
+ # ensure checkpoint present (best-effort)
258
+ try:
259
+ download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
260
+ except Exception as e:
261
+ print("[WARN] Failed to download checkpoint:", e)
262
+
263
+ print("[INFO] Loading RF-DETR model (CPU mode)...")
264
+ MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None)
265
+ try:
266
+ MODEL.optimize_for_inference()
267
+ except Exception as e:
268
+ print("[WARN] optimize_for_inference() skipped/failed:", e)
269
+ print("[INFO] Model ready.")
270
+ return MODEL
271
+ except Exception:
272
+ traceback.print_exc()
273
+ raise
274
+
275
+
276
+ def decode_data_url(data_url: str) -> Image.Image:
277
+ """
278
+ Accepts a data URL (data:image/png;base64,...) or raw base64 and returns PIL.Image (RGB)
279
+ """
280
+ if data_url.startswith("data:"):
281
+ _, b64 = data_url.split(",", 1)
282
+ data = base64.b64decode(b64)
283
+ else:
284
+ # assume raw base64 or binary string
285
+ try:
286
+ data = base64.b64decode(data_url)
287
+ except Exception:
288
+ raise ValueError("Invalid image data")
289
+ return Image.open(io.BytesIO(data)).convert("RGB")
290
+
291
+
292
+ def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str:
293
+ buf = io.BytesIO()
294
+ pil_img.save(buf, format=fmt)
295
+ return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii")
296
+
297
+
298
+ def overlay_mask_on_image(pil_img: Image.Image, detections, threshold: float = 0.25,
299
+ mask_color=(255, 77, 166), alpha=0.45):
300
+ """
301
+ Create annotated PIL image by overlaying per-instance masks (pink) and polygon borders,
302
+ and add confidence text (best confidence) on the image.
303
+ Uses supervision-like masks if available, otherwise attempts to use detections.masks.
304
+ Returns (annotated_pil_rgb, kept_confidences_list)
305
+ """
306
+ base = pil_img.convert("RGBA")
307
+ W, H = base.size
308
+
309
+ masks = getattr(detections, "masks", None)
310
+ confidences = []
311
+ try:
312
+ confidences = [float(x) for x in getattr(detections, "confidence", [])]
313
+ except Exception:
314
+ # fallback to 'scores' or empty
315
+ try:
316
+ confidences = [float(x) for x in getattr(detections, "scores", [])]
317
+ except Exception:
318
+ confidences = []
319
+
320
+ if masks is None:
321
+ # no masks -> return original image and empty list
322
+ return pil_img.convert("RGB"), []
323
+
324
+ # Normalize mask array to (N, H, W)
325
+ if isinstance(masks, list):
326
+ masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0)
327
+ else:
328
+ masks_arr = np.asarray(masks)
329
+ # some outputs might be (H, W, N)
330
+ if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W:
331
+ masks_arr = masks_arr.transpose(2, 0, 1)
332
+
333
+ # overlay image we will composite
334
+ overlay = Image.new("RGBA", (W, H), (0, 0, 0, 0))
335
+ kept_confidences = []
336
+
337
+ for i in range(masks_arr.shape[0]):
338
+ conf = confidences[i] if i < len(confidences) else 1.0
339
+ if conf < threshold:
340
+ continue
341
+ mask = masks_arr[i].astype(np.uint8) * 255
342
+ mask_img = Image.fromarray(mask).convert("L")
343
+ # if mask size doesn't match, resize
344
+ if mask_img.size != (W, H):
345
+ mask_img = mask_img.resize((W, H), resample=Image.NEAREST)
346
+
347
+ # color layer with alpha
348
+ color_layer = Image.new("RGBA", (W, H), mask_color + (0,))
349
+ # compute per-pixel alpha from mask (0..255) scaled by alpha
350
+ alpha_mask = mask_img.point(lambda p: int(p * alpha))
351
+ color_layer.putalpha(alpha_mask)
352
+ overlay = Image.alpha_composite(overlay, color_layer)
353
+ kept_confidences.append(float(conf))
354
+
355
+ # draw polygon outlines for visual crispness using supervision polygonifier if available
356
+ try:
357
+ # try to use supervision polygonizer if detections contains polygons
358
+ # fallback: create thin white outline by expanding mask boundaries
359
+ from skimage import measure
360
+ draw = ImageDraw.Draw(overlay)
361
+ for i in range(masks_arr.shape[0]):
362
+ conf = confidences[i] if i < len(confidences) else 1.0
363
+ if conf < threshold:
364
+ continue
365
+ mask = masks_arr[i].astype(np.uint8)
366
+ # resize mask for contour if needed
367
+ if mask.shape[1] != W or mask.shape[0] != H:
368
+ mask_pil = Image.fromarray((mask * 255).astype(np.uint8)).resize((W, H), resample=Image.NEAREST)
369
+ mask = np.asarray(mask_pil).astype(np.uint8) // 255
370
+ contours = measure.find_contours(mask, 0.5)
371
+ for contour in contours:
372
+ # contour is list of (row, col) -> convert to (x, y)
373
+ pts = [(float(c[1]), float(c[0])) for c in contour]
374
+ if len(pts) >= 3:
375
+ # draw white outline
376
+ draw.line(pts + [pts[0]], fill=(255, 255, 255, 255), width=2)
377
+ except Exception:
378
+ # ignore if skimage not available; outlines are optional
379
+ pass
380
+
381
+ annotated = Image.alpha_composite(base, overlay).convert("RGBA")
382
+
383
+ # annotate best confidence text (top-left)
384
+ if kept_confidences:
385
+ best = max(kept_confidences)
386
+ draw = ImageDraw.Draw(annotated)
387
+ try:
388
+ font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(14, W // 32))
389
+ except Exception:
390
+ font = ImageFont.load_default()
391
+ text = f"Confidence: {best:.2f}"
392
+ tw, th = draw.textsize(text, font=font)
393
+ pad = 6
394
+ rect = [6, 6, 6 + tw + pad, 6 + th + pad]
395
+ draw.rectangle(rect, fill=(0, 0, 0, 180))
396
+ draw.text((6 + pad // 2, 6 + pad // 2), text, font=font, fill=(255, 255, 255, 255))
397
+
398
+ return annotated.convert("RGB"), kept_confidences
399
+
400
+
401
+ @app.route("/", methods=["GET"])
402
+ def index():
403
+ # serve the static UI file if present
404
+ index_path = os.path.join(app.static_folder or "static", "index.html")
405
+ if os.path.exists(index_path):
406
+ return send_from_directory(app.static_folder, "index.html")
407
  return jsonify({"message": "RF-DETR Segmentation API is running."})
408
 
409
 
410
  @app.route("/predict", methods=["POST"])
411
  def predict():
412
+ """
413
+ Accepts:
414
+ - multipart/form-data with file field "file"
415
+ - or JSON {"image": "<data:url...>", "conf": 0.25}
416
+ Returns JSON:
417
+ {"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N}
418
+ """
419
+ try:
420
+ model = init_model()
421
+ except Exception as e:
422
+ return jsonify({"error": f"Model initialization failed: {e}"}), 500
423
+
424
+ # parse input
425
+ img: Optional[Image.Image] = None
426
+ conf_threshold = 0.25
427
+
428
+ # If form file uploaded
429
+ if "file" in request.files:
430
+ file = request.files["file"]
431
+ try:
432
+ img = Image.open(file.stream).convert("RGB")
433
+ except Exception as e:
434
+ return jsonify({"error": f"Invalid uploaded image: {e}"}), 400
435
+ conf_threshold = float(request.form.get("conf", conf_threshold))
436
+ else:
437
+ # try JSON payload
438
+ payload = request.get_json(silent=True)
439
+ if not payload or "image" not in payload:
440
+ return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400
441
+ try:
442
+ img = decode_data_url(payload["image"])
443
+ except Exception as e:
444
+ return jsonify({"error": f"Invalid image data: {e}"}), 400
445
+ conf_threshold = float(payload.get("conf", conf_threshold))
446
+
447
+ # run inference
448
+ try:
449
+ # set threshold=0.0 in model predict since we'll manually filter by conf_threshold
450
+ detections = model.predict(img, threshold=0.0)
451
+ except Exception as e:
452
+ traceback.print_exc()
453
+ return jsonify({"error": f"Inference failed: {e}"}), 500
454
 
455
+ # overlay masks and extract confidences > threshold
456
+ annotated_pil, kept_conf = overlay_mask_on_image(img, detections, threshold=conf_threshold,
457
+ mask_color=(255, 77, 166), alpha=0.45)
 
458
 
459
+ data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
460
+ return jsonify({"annotated": data_url, "confidences": kept_conf, "count": len(kept_conf)})
461
 
 
 
462
 
463
  if __name__ == "__main__":
464
+ # Warm model in a background thread to avoid blocking the container start logs too long
465
+ def warm():
466
+ try:
467
+ init_model()
468
+ except Exception as e:
469
+ print("Model warmup failed:", e)
470
+ threading.Thread(target=warm, daemon=True).start()
471
+ app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)