Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| import gradio as gr | |
| from inference import CoralSegModel, id2label, label2color, create_segmentation_overlay | |
| model = CoralSegModel() | |
| # ---- helpers ---- | |
| def _safe_read(cap): | |
| ok, frame = cap.read() | |
| return frame if ok and frame is not None else None | |
| def build_annotations(pred_map: np.ndarray, selected: list[str]) -> list[tuple[np.ndarray, str]]: | |
| """Return [(mask,label), ...] where mask is 0/1 float HxW for AnnotatedImage.""" | |
| if pred_map is None or not selected: | |
| return [] | |
| # Create reverse mapping: label_name -> class_id | |
| label2id = {label: int(id_str) for id_str, label in id2label.items()} | |
| anns = [] | |
| for label_name in selected: | |
| if label_name not in label2id: | |
| continue # Skip unknown labels | |
| class_id = label2id[label_name] # Convert label name to class ID | |
| mask = (pred_map == class_id).astype(np.float32) | |
| if mask.sum() > 0: | |
| anns.append((mask, label_name)) # Use the label name for display | |
| return anns | |
| # ============================== | |
| # STREAMING EVENT FUNCTIONS | |
| # ============================== | |
| # IMPORTANT: make the event functions themselves generators. | |
| # Also: include the States as outputs so we can update them every frame. | |
| def remote_start(url: str, n: int, pred_state, base_state): | |
| if not url: | |
| return | |
| cap = cv2.VideoCapture(url) | |
| if not cap.isOpened(): | |
| return | |
| idx = 0 | |
| try: | |
| while True: | |
| frame = _safe_read(cap) | |
| if frame is None: | |
| break | |
| if n > 1 and (idx % n) != 0: | |
| idx += 1 | |
| continue | |
| pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame) | |
| # yield live image + updated States' *values* | |
| yield overlay_rgb, pred_map, base_rgb | |
| idx += 1 | |
| finally: | |
| cap.release() | |
| def upload_start(video_file: str, n: int): | |
| if not video_file: | |
| return | |
| cap = cv2.VideoCapture(video_file) | |
| if not cap.isOpened(): | |
| return | |
| idx = 0 | |
| try: | |
| while True: | |
| ok, frame = cap.read() | |
| if not ok or frame is None: | |
| break | |
| if n > 1 and (idx % n) != 0: | |
| idx += 1 | |
| continue | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame) | |
| yield overlay_rgb, pred_map, base_rgb | |
| idx += 1 | |
| finally: | |
| cap.release() | |
| # ============================== | |
| # SNAPSHOT / TOGGLES (non-streaming) | |
| # ============================== | |
| # NOTE: When you pass gr.State as an input, you receive the *value*, not the wrapper. | |
| def make_snapshot(selected_labels, pred_map, base_rgb, alpha=0.25): | |
| if pred_map is None or base_rgb is None: | |
| return gr.update() | |
| # rebuild overlay to match the live look | |
| overlay = create_segmentation_overlay(pred_map, id2label, label2color, Image.fromarray(base_rgb), alpha=alpha) | |
| ann = build_annotations(pred_map, selected_labels or []) | |
| return (overlay, ann) # (base_image, [(mask,label), ...]) | |
| # ============================== | |
| # UI | |
| # ============================== | |
| with gr.Blocks(title="CoralScapes Streaming Segmentation") as demo: | |
| gr.Markdown("# CoralScapes Streaming Segmentation") | |
| gr.Markdown( | |
| "Left: **live stream** (fast). Right: **snapshot** with **hover labels** and **per-class toggles**." | |
| ) | |
| with gr.Tab("Remote Stream (RTSP/HTTP)"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # States start as None. We'll UPDATE them on every frame by returning them as outputs. | |
| pred_state_remote = gr.State(None) # holds last pred_map (HxW np.uint8) | |
| base_state_remote = gr.State(None) # holds last base_rgb (HxWx3 uint8) | |
| live_remote = gr.Image(label="Live segmented stream") | |
| start_btn = gr.Button("Start") | |
| snap_btn_remote = gr.Button("📸 Snapshot (hover-able)") | |
| hover_remote = gr.AnnotatedImage(label="Snapshot (hover to see label)") | |
| with gr.Column(scale=1): | |
| url = gr.Textbox(label="Stream URL", placeholder="rtsp://user:pass@ip:port/…") | |
| skip = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame") | |
| toggles_remote = gr.CheckboxGroup( | |
| choices=list(id2label.values()), value=list(id2label.values()), | |
| label="Toggle classes in snapshot", | |
| ) | |
| start_btn.click( | |
| remote_start, | |
| inputs=[url, skip, pred_state_remote, base_state_remote], | |
| outputs=[live_remote, pred_state_remote, base_state_remote], | |
| queue=True, # be explicit; required for generator streaming | |
| ) | |
| snap_btn_remote.click( | |
| make_snapshot, | |
| inputs=[toggles_remote, pred_state_remote, base_state_remote], | |
| outputs=[hover_remote], | |
| ) | |
| toggles_remote.change( | |
| make_snapshot, | |
| inputs=[toggles_remote, pred_state_remote, base_state_remote], | |
| outputs=[hover_remote], | |
| ) | |
| with gr.Tab("Upload Video"): | |
| with gr.Row(): | |
| # Left column (now contains toggles, snapshot button, and live output) | |
| with gr.Column(scale=2): | |
| # States remain in the same column as live_upload | |
| pred_state_upload = gr.State(None) | |
| base_state_upload = gr.State(None) | |
| live_upload = gr.Image(label="Live segmented output") | |
| start_btn2 = gr.Button("Process") | |
| snap_btn_upload = gr.Button("📸 Snapshot (hover-able)") | |
| hover_upload = gr.AnnotatedImage(label="Snapshot (hover to see label)") | |
| # Right column (now contains video input and slider) | |
| with gr.Column(scale=1): | |
| vid_in = gr.Video(sources=["upload"], format="mp4", label="Input Video") | |
| skip2 = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame") | |
| toggles_upload = gr.CheckboxGroup( | |
| choices=list(id2label.values()), value=list(id2label.values()), | |
| label="Toggle classes in snapshot", | |
| ) | |
| # Event handlers remain the same | |
| start_btn2.click( | |
| upload_start, | |
| inputs=[vid_in, skip2], | |
| outputs=[live_upload, pred_state_upload, base_state_upload], | |
| queue=True, | |
| ) | |
| snap_btn_upload.click( | |
| make_snapshot, | |
| inputs=[toggles_upload, pred_state_upload, base_state_upload], | |
| outputs=[hover_upload], | |
| ) | |
| toggles_upload.change( | |
| make_snapshot, | |
| inputs=[toggles_upload, pred_state_upload, base_state_upload], | |
| outputs=[hover_upload], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(share=True) | |