LhatMjnk's picture
Update app.py
664f6dd verified
from PIL import Image
import cv2
import numpy as np
import gradio as gr
from inference import CoralSegModel, id2label, label2color, create_segmentation_overlay
model = CoralSegModel()
# ---- helpers ----
def _safe_read(cap):
ok, frame = cap.read()
return frame if ok and frame is not None else None
def build_annotations(pred_map: np.ndarray, selected: list[str]) -> list[tuple[np.ndarray, str]]:
"""Return [(mask,label), ...] where mask is 0/1 float HxW for AnnotatedImage."""
if pred_map is None or not selected:
return []
# Create reverse mapping: label_name -> class_id
label2id = {label: int(id_str) for id_str, label in id2label.items()}
anns = []
for label_name in selected:
if label_name not in label2id:
continue # Skip unknown labels
class_id = label2id[label_name] # Convert label name to class ID
mask = (pred_map == class_id).astype(np.float32)
if mask.sum() > 0:
anns.append((mask, label_name)) # Use the label name for display
return anns
# ==============================
# STREAMING EVENT FUNCTIONS
# ==============================
# IMPORTANT: make the event functions themselves generators.
# Also: include the States as outputs so we can update them every frame.
def remote_start(url: str, n: int, pred_state, base_state):
if not url:
return
cap = cv2.VideoCapture(url)
if not cap.isOpened():
return
idx = 0
try:
while True:
frame = _safe_read(cap)
if frame is None:
break
if n > 1 and (idx % n) != 0:
idx += 1
continue
pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame)
# yield live image + updated States' *values*
yield overlay_rgb, pred_map, base_rgb
idx += 1
finally:
cap.release()
def upload_start(video_file: str, n: int):
if not video_file:
return
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
return
idx = 0
try:
while True:
ok, frame = cap.read()
if not ok or frame is None:
break
if n > 1 and (idx % n) != 0:
idx += 1
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame)
yield overlay_rgb, pred_map, base_rgb
idx += 1
finally:
cap.release()
# ==============================
# SNAPSHOT / TOGGLES (non-streaming)
# ==============================
# NOTE: When you pass gr.State as an input, you receive the *value*, not the wrapper.
def make_snapshot(selected_labels, pred_map, base_rgb, alpha=0.25):
if pred_map is None or base_rgb is None:
return gr.update()
# rebuild overlay to match the live look
overlay = create_segmentation_overlay(pred_map, id2label, label2color, Image.fromarray(base_rgb), alpha=alpha)
ann = build_annotations(pred_map, selected_labels or [])
return (overlay, ann) # (base_image, [(mask,label), ...])
# ==============================
# UI
# ==============================
with gr.Blocks(title="CoralScapes Streaming Segmentation") as demo:
gr.Markdown("# CoralScapes Streaming Segmentation")
gr.Markdown(
"Left: **live stream** (fast). Right: **snapshot** with **hover labels** and **per-class toggles**."
)
with gr.Tab("Remote Stream (RTSP/HTTP)"):
with gr.Row():
with gr.Column(scale=2):
# States start as None. We'll UPDATE them on every frame by returning them as outputs.
pred_state_remote = gr.State(None) # holds last pred_map (HxW np.uint8)
base_state_remote = gr.State(None) # holds last base_rgb (HxWx3 uint8)
live_remote = gr.Image(label="Live segmented stream")
start_btn = gr.Button("Start")
snap_btn_remote = gr.Button("📸 Snapshot (hover-able)")
hover_remote = gr.AnnotatedImage(label="Snapshot (hover to see label)")
with gr.Column(scale=1):
url = gr.Textbox(label="Stream URL", placeholder="rtsp://user:pass@ip:port/…")
skip = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame")
toggles_remote = gr.CheckboxGroup(
choices=list(id2label.values()), value=list(id2label.values()),
label="Toggle classes in snapshot",
)
start_btn.click(
remote_start,
inputs=[url, skip, pred_state_remote, base_state_remote],
outputs=[live_remote, pred_state_remote, base_state_remote],
queue=True, # be explicit; required for generator streaming
)
snap_btn_remote.click(
make_snapshot,
inputs=[toggles_remote, pred_state_remote, base_state_remote],
outputs=[hover_remote],
)
toggles_remote.change(
make_snapshot,
inputs=[toggles_remote, pred_state_remote, base_state_remote],
outputs=[hover_remote],
)
with gr.Tab("Upload Video"):
with gr.Row():
# Left column (now contains toggles, snapshot button, and live output)
with gr.Column(scale=2):
# States remain in the same column as live_upload
pred_state_upload = gr.State(None)
base_state_upload = gr.State(None)
live_upload = gr.Image(label="Live segmented output")
start_btn2 = gr.Button("Process")
snap_btn_upload = gr.Button("📸 Snapshot (hover-able)")
hover_upload = gr.AnnotatedImage(label="Snapshot (hover to see label)")
# Right column (now contains video input and slider)
with gr.Column(scale=1):
vid_in = gr.Video(sources=["upload"], format="mp4", label="Input Video")
skip2 = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame")
toggles_upload = gr.CheckboxGroup(
choices=list(id2label.values()), value=list(id2label.values()),
label="Toggle classes in snapshot",
)
# Event handlers remain the same
start_btn2.click(
upload_start,
inputs=[vid_in, skip2],
outputs=[live_upload, pred_state_upload, base_state_upload],
queue=True,
)
snap_btn_upload.click(
make_snapshot,
inputs=[toggles_upload, pred_state_upload, base_state_upload],
outputs=[hover_upload],
)
toggles_upload.change(
make_snapshot,
inputs=[toggles_upload, pred_state_upload, base_state_upload],
outputs=[hover_upload],
)
if __name__ == "__main__":
demo.queue().launch(share=True)