Subh775 commited on
Commit
2313bae
·
verified ·
1 Parent(s): dc804ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +348 -3
app.py CHANGED
@@ -1,3 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import io
3
  import base64
@@ -37,7 +376,7 @@ from rfdetr import RFDETRSegPreview
37
  app = Flask(__name__, static_folder="static", static_url_path="/")
38
 
39
  # Checkpoint URL & local path
40
- CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs/resolve/main/checkpoint_best_total.pth"
41
  CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
42
 
43
  MODEL_LOCK = threading.Lock()
@@ -198,7 +537,7 @@ def predict():
198
  """
199
  Accepts:
200
  - multipart/form-data with file field "file"
201
- - or JSON {"image": "<data:url...>", "conf": 0.05}
202
  Returns JSON:
203
  {"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N}
204
  """
@@ -215,7 +554,9 @@ def predict():
215
 
216
  # Parse input
217
  img: Optional[Image.Image] = None
218
- conf_threshold = 0.05
 
 
219
 
220
  # Check if file uploaded
221
  if "file" in request.files:
@@ -228,6 +569,8 @@ def predict():
228
  print(f"[ERROR] {error_msg}")
229
  return jsonify({"error": error_msg}), 400
230
  conf_threshold = float(request.form.get("conf", conf_threshold))
 
 
231
  else:
232
  # Try JSON payload
233
  payload = request.get_json(silent=True)
@@ -241,6 +584,8 @@ def predict():
241
  print(f"[ERROR] {error_msg}")
242
  return jsonify({"error": error_msg}), 400
243
  conf_threshold = float(payload.get("conf", conf_threshold))
 
 
244
 
245
  print(f"[INFO] Image size: {img.size}, Confidence threshold: {conf_threshold}")
246
 
 
1
+ # import os
2
+ # import io
3
+ # import base64
4
+ # import threading
5
+ # import traceback
6
+ # import gc
7
+ # from typing import Optional
8
+
9
+ # from flask import Flask, request, jsonify, send_from_directory
10
+ # from PIL import Image
11
+ # import numpy as np
12
+ # import requests
13
+ # import torch
14
+
15
+ # # Set environment variables for CPU-only operation
16
+ # os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib")
17
+ # os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig")
18
+ # os.environ.setdefault("FONTCONFIG_FILE", "/etc/fonts/fonts.conf")
19
+ # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
20
+ # os.environ.setdefault("OMP_NUM_THREADS", "4")
21
+ # os.environ.setdefault("MKL_NUM_THREADS", "4")
22
+ # os.environ.setdefault("OPENBLAS_NUM_THREADS", "4")
23
+
24
+ # # Create writable fontconfig cache
25
+ # os.makedirs("/tmp/.fontconfig", exist_ok=True)
26
+ # os.makedirs("/tmp/.matplotlib", exist_ok=True)
27
+
28
+ # # Limit torch threads
29
+ # try:
30
+ # torch.set_num_threads(4)
31
+ # except Exception:
32
+ # pass
33
+
34
+ # import supervision as sv
35
+ # from rfdetr import RFDETRSegPreview
36
+
37
+ # app = Flask(__name__, static_folder="static", static_url_path="/")
38
+
39
+ # # Checkpoint URL & local path
40
+ # CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs/resolve/main/checkpoint_best_total.pth"
41
+ # CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
42
+
43
+ # MODEL_LOCK = threading.Lock()
44
+ # MODEL = None
45
+
46
+
47
+ # def download_file(url: str, dst: str, chunk_size: int = 8192):
48
+ # """Download file if not exists"""
49
+ # if os.path.exists(dst) and os.path.getsize(dst) > 0:
50
+ # print(f"[INFO] Checkpoint already exists at {dst}")
51
+ # return dst
52
+ # print(f"[INFO] Downloading weights from {url} -> {dst}")
53
+ # try:
54
+ # r = requests.get(url, stream=True, timeout=180)
55
+ # r.raise_for_status()
56
+ # with open(dst, "wb") as fh:
57
+ # for chunk in r.iter_content(chunk_size=chunk_size):
58
+ # if chunk:
59
+ # fh.write(chunk)
60
+ # print("[INFO] Download complete.")
61
+ # return dst
62
+ # except Exception as e:
63
+ # print(f"[ERROR] Download failed: {e}")
64
+ # raise
65
+
66
+
67
+ # def init_model():
68
+ # """Lazily initialize the RF-DETR model and cache it in global MODEL."""
69
+ # global MODEL
70
+ # with MODEL_LOCK:
71
+ # if MODEL is not None:
72
+ # print("[INFO] Model already loaded, returning cached instance")
73
+ # return MODEL
74
+ # try:
75
+ # # Ensure checkpoint present
76
+ # if not os.path.exists(CHECKPOINT_PATH):
77
+ # print("[INFO] Checkpoint not found, downloading...")
78
+ # download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
79
+ # else:
80
+ # print(f"[INFO] Using existing checkpoint at {CHECKPOINT_PATH}")
81
+
82
+ # print("[INFO] Loading RF-DETR model (CPU mode)...")
83
+ # MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH)
84
+
85
+ # # Try to optimize for inference
86
+ # try:
87
+ # print("[INFO] Optimizing model for inference...")
88
+ # MODEL.optimize_for_inference()
89
+ # print("[INFO] Model optimization complete")
90
+ # except Exception as e:
91
+ # print(f"[WARN] optimize_for_inference() skipped/failed: {e}")
92
+
93
+ # print("[INFO] Model ready for inference")
94
+ # return MODEL
95
+ # except Exception as e:
96
+ # print(f"[ERROR] Model initialization failed: {e}")
97
+ # traceback.print_exc()
98
+ # raise
99
+
100
+
101
+ # def decode_data_url(data_url: str) -> Image.Image:
102
+ # """Decode data URL to PIL Image"""
103
+ # if data_url.startswith("data:"):
104
+ # _, b64 = data_url.split(",", 1)
105
+ # data = base64.b64decode(b64)
106
+ # else:
107
+ # try:
108
+ # data = base64.b64decode(data_url)
109
+ # except Exception:
110
+ # raise ValueError("Invalid image data")
111
+ # return Image.open(io.BytesIO(data)).convert("RGB")
112
+
113
+
114
+ # def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str:
115
+ # """Encode PIL Image to data URL"""
116
+ # buf = io.BytesIO()
117
+ # pil_img.save(buf, format=fmt, optimize=False)
118
+ # buf.seek(0)
119
+ # return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii")
120
+
121
+
122
+ # def annotate_segmentation(image: Image.Image, detections: sv.Detections) -> Image.Image:
123
+ # """
124
+ # Annotate image with segmentation masks using supervision library.
125
+ # This matches the visualization from rfdetr_seg_infer.py script.
126
+ # """
127
+ # try:
128
+ # # Define color palette
129
+ # palette = sv.ColorPalette.from_hex([
130
+ # "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
131
+ # "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00",
132
+ # ])
133
+
134
+ # # Calculate optimal text scale based on image resolution
135
+ # text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
136
+
137
+ # print(f"[INFO] Creating annotators with text_scale={text_scale}")
138
+
139
+ # # Create annotators
140
+ # mask_annotator = sv.MaskAnnotator(color=palette)
141
+ # polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE)
142
+ # label_annotator = sv.LabelAnnotator(
143
+ # color=palette,
144
+ # text_color=sv.Color.BLACK,
145
+ # text_scale=text_scale,
146
+ # text_position=sv.Position.CENTER_OF_MASS
147
+ # )
148
+
149
+ # # Create labels with confidence scores
150
+ # labels = [
151
+ # f"Tulsi {float(conf):.2f}"
152
+ # for conf in detections.confidence
153
+ # ]
154
+
155
+ # print(f"[INFO] Annotating {len(labels)} detections")
156
+
157
+ # # Apply annotations step by step
158
+ # out = image.copy()
159
+ # print("[INFO] Applying mask annotation...")
160
+ # out = mask_annotator.annotate(out, detections)
161
+ # print("[INFO] Applying polygon annotation...")
162
+ # out = polygon_annotator.annotate(out, detections)
163
+ # print("[INFO] Applying label annotation...")
164
+ # out = label_annotator.annotate(out, detections, labels)
165
+
166
+ # print("[INFO] Annotation complete")
167
+ # return out
168
+
169
+ # except Exception as e:
170
+ # print(f"[ERROR] Annotation failed: {e}")
171
+ # traceback.print_exc()
172
+ # # Return original image if annotation fails
173
+ # return image
174
+
175
+
176
+ # @app.route("/", methods=["GET"])
177
+ # def index():
178
+ # """Serve the static UI"""
179
+ # index_path = os.path.join(app.static_folder or "static", "index.html")
180
+ # if os.path.exists(index_path):
181
+ # return send_from_directory(app.static_folder, "index.html")
182
+ # return jsonify({"message": "RF-DETR Segmentation API is running.", "status": "ready"})
183
+
184
+
185
+ # @app.route("/health", methods=["GET"])
186
+ # def health():
187
+ # """Health check endpoint"""
188
+ # model_loaded = MODEL is not None
189
+ # return jsonify({
190
+ # "status": "healthy",
191
+ # "model_loaded": model_loaded,
192
+ # "checkpoint_exists": os.path.exists(CHECKPOINT_PATH)
193
+ # })
194
+
195
+
196
+ # @app.route("/predict", methods=["POST"])
197
+ # def predict():
198
+ # """
199
+ # Accepts:
200
+ # - multipart/form-data with file field "file"
201
+ # - or JSON {"image": "<data:url...>", "conf": 0.05}
202
+ # Returns JSON:
203
+ # {"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N}
204
+ # """
205
+ # print("\n[INFO] ========== New prediction request ==========")
206
+
207
+ # try:
208
+ # print("[INFO] Initializing model...")
209
+ # model = init_model()
210
+ # print("[INFO] Model ready")
211
+ # except Exception as e:
212
+ # error_msg = f"Model initialization failed: {e}"
213
+ # print(f"[ERROR] {error_msg}")
214
+ # return jsonify({"error": error_msg}), 500
215
+
216
+ # # Parse input
217
+ # img: Optional[Image.Image] = None
218
+ # conf_threshold = 0.05
219
+
220
+ # # Check if file uploaded
221
+ # if "file" in request.files:
222
+ # file = request.files["file"]
223
+ # print(f"[INFO] Processing uploaded file: {file.filename}")
224
+ # try:
225
+ # img = Image.open(file.stream).convert("RGB")
226
+ # except Exception as e:
227
+ # error_msg = f"Invalid uploaded image: {e}"
228
+ # print(f"[ERROR] {error_msg}")
229
+ # return jsonify({"error": error_msg}), 400
230
+ # conf_threshold = float(request.form.get("conf", conf_threshold))
231
+ # else:
232
+ # # Try JSON payload
233
+ # payload = request.get_json(silent=True)
234
+ # if not payload or "image" not in payload:
235
+ # return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400
236
+ # try:
237
+ # print("[INFO] Decoding image from data URL...")
238
+ # img = decode_data_url(payload["image"])
239
+ # except Exception as e:
240
+ # error_msg = f"Invalid image data: {e}"
241
+ # print(f"[ERROR] {error_msg}")
242
+ # return jsonify({"error": error_msg}), 400
243
+ # conf_threshold = float(payload.get("conf", conf_threshold))
244
+
245
+ # print(f"[INFO] Image size: {img.size}, Confidence threshold: {conf_threshold}")
246
+
247
+ # # Optionally downscale large images to reduce memory usage
248
+ # MAX_SIZE = 1024
249
+ # if max(img.size) > MAX_SIZE:
250
+ # w, h = img.size
251
+ # scale = MAX_SIZE / float(max(w, h))
252
+ # new_w, new_h = int(round(w * scale)), int(round(h * scale))
253
+ # print(f"[INFO] Resizing image from {w}x{h} to {new_w}x{new_h}")
254
+ # img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
255
+
256
+ # # Run inference with no_grad for memory efficiency
257
+ # try:
258
+ # print("[INFO] Running inference...")
259
+ # with torch.no_grad():
260
+ # detections = model.predict(img, threshold=conf_threshold)
261
+
262
+ # print(f"[INFO] Raw detections: {len(detections)} objects")
263
+
264
+ # # Check if detections exist
265
+ # if len(detections) == 0 or not hasattr(detections, 'confidence') or len(detections.confidence) == 0:
266
+ # print("[INFO] No detections above threshold")
267
+ # # Return original image
268
+ # data_url = encode_pil_to_dataurl(img, fmt="PNG")
269
+ # return jsonify({
270
+ # "annotated": data_url,
271
+ # "confidences": [],
272
+ # "count": 0
273
+ # })
274
+
275
+ # print(f"[INFO] Detections have {len(detections.confidence)} confidence scores")
276
+ # print(f"[INFO] Confidence range: {min(detections.confidence):.3f} - {max(detections.confidence):.3f}")
277
+
278
+ # # Check if masks exist
279
+ # if hasattr(detections, 'masks') and detections.masks is not None:
280
+ # print(f"[INFO] Masks present: shape={np.array(detections.masks).shape if hasattr(detections.masks, '__len__') else 'unknown'}")
281
+ # else:
282
+ # print("[WARN] No masks found in detections!")
283
+
284
+ # # Annotate image using supervision library
285
+ # print("[INFO] Starting annotation...")
286
+ # annotated_pil = annotate_segmentation(img, detections)
287
+
288
+ # # Extract confidence scores
289
+ # confidences = [float(conf) for conf in detections.confidence]
290
+ # print(f"[INFO] Final confidences: {confidences}")
291
+
292
+ # # Encode to data URL
293
+ # print("[INFO] Encoding annotated image...")
294
+ # data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
295
+
296
+ # # Clean up
297
+ # del detections
298
+ # gc.collect()
299
+
300
+ # print(f"[INFO] ========== Prediction complete: {len(confidences)} leaves detected ==========\n")
301
+
302
+ # return jsonify({
303
+ # "annotated": data_url,
304
+ # "confidences": confidences,
305
+ # "count": len(confidences)
306
+ # })
307
+
308
+ # except Exception as e:
309
+ # error_msg = f"Inference failed: {e}"
310
+ # print(f"[ERROR] {error_msg}")
311
+ # traceback.print_exc()
312
+ # return jsonify({"error": error_msg}), 500
313
+
314
+
315
+ # if __name__ == "__main__":
316
+ # print("\n" + "="*60)
317
+ # print("Starting Tulsi Leaf Segmentation Server")
318
+ # print("="*60 + "\n")
319
+
320
+ # # Warm model in background thread
321
+ # def warm():
322
+ # try:
323
+ # print("[INFO] Starting model warmup in background...")
324
+ # init_model()
325
+ # print("[INFO] ✓ Model warmup complete - ready for predictions")
326
+ # except Exception as e:
327
+ # print(f"[ERROR] ✗ Model warmup failed: {e}")
328
+ # traceback.print_exc()
329
+
330
+ # threading.Thread(target=warm, daemon=True).start()
331
+
332
+ # # Run Flask app
333
+ # app.run(
334
+ # host="0.0.0.0",
335
+ # port=int(os.environ.get("PORT", 7860)),
336
+ # debug=False
337
+ # )
338
+
339
+
340
  import os
341
  import io
342
  import base64
 
376
  app = Flask(__name__, static_folder="static", static_url_path="/")
377
 
378
  # Checkpoint URL & local path
379
+ CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
380
  CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
381
 
382
  MODEL_LOCK = threading.Lock()
 
537
  """
538
  Accepts:
539
  - multipart/form-data with file field "file"
540
+ - or JSON {"image": "<data:url...>", "conf": 0.25, "show_labels": true, "show_confidence": true}
541
  Returns JSON:
542
  {"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N}
543
  """
 
554
 
555
  # Parse input
556
  img: Optional[Image.Image] = None
557
+ conf_threshold = 0.25
558
+ show_labels = True
559
+ show_confidence = True
560
 
561
  # Check if file uploaded
562
  if "file" in request.files:
 
569
  print(f"[ERROR] {error_msg}")
570
  return jsonify({"error": error_msg}), 400
571
  conf_threshold = float(request.form.get("conf", conf_threshold))
572
+ show_labels = request.form.get("show_labels", "true").lower() == "true"
573
+ show_confidence = request.form.get("show_confidence", "true").lower() == "true"
574
  else:
575
  # Try JSON payload
576
  payload = request.get_json(silent=True)
 
584
  print(f"[ERROR] {error_msg}")
585
  return jsonify({"error": error_msg}), 400
586
  conf_threshold = float(payload.get("conf", conf_threshold))
587
+ show_labels = payload.get("show_labels", True)
588
+ show_confidence = payload.get("show_confidence", True)
589
 
590
  print(f"[INFO] Image size: {img.size}, Confidence threshold: {conf_threshold}")
591