File size: 3,446 Bytes
8af588d
745c9b6
89215aa
c70c9b1
745c9b6
 
b22dfe5
89215aa
b22dfe5
 
 
 
 
 
 
 
745c9b6
89215aa
b22dfe5
 
 
89215aa
745c9b6
4dcf554
89215aa
 
df71871
745c9b6
 
 
 
 
 
4dcf554
df71871
 
4dcf554
b22dfe5
 
 
c70c9b1
 
745c9b6
 
 
 
 
c70c9b1
745c9b6
c70c9b1
 
 
72165a4
745c9b6
a1d1400
c70c9b1
745c9b6
 
 
 
 
c70c9b1
745c9b6
c70c9b1
745c9b6
c70c9b1
 
 
b22dfe5
 
 
 
 
 
 
 
 
 
745c9b6
b22dfe5
 
 
 
 
 
 
c70c9b1
b22dfe5
745c9b6
 
 
 
 
d576613
b22dfe5
745c9b6
 
 
 
8af588d
745c9b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from ultralytics import YOLO
import numpy as np
import cv2
from PIL import Image
import random
from transformers import pipeline

# ---------------------------
# Load Models
# ---------------------------

# Text model (tiny LLM)
text_gen = pipeline("text-generation", model="tiny-random-gpt2")

# YOLOv8 segmentation (nano version for speed)
yolo_model = YOLO("yolov8n-seg.pt")  # change to yolov8s-seg.pt for more accuracy

# ---------------------------
# Image Segmentation
# ---------------------------
def segment_image(image: Image.Image):
    results = yolo_model.predict(np.array(image))[0]

    overlay = np.array(image).copy()
    annotations = []

    if results.masks is not None:
        for mask, cls in zip(results.masks.xy, results.boxes.cls):
            pts = np.array(mask, dtype=np.int32)
            color = [random.randint(0, 255) for _ in range(3)]
            cv2.fillPoly(overlay, [pts], color)
            annotations.append((mask.tolist(), yolo_model.names[int(cls)]))

    overlay_img = Image.fromarray(overlay)
    return (overlay_img, annotations)

# ---------------------------
# Video Segmentation
# ---------------------------
def segment_video(video):
    cap = cv2.VideoCapture(video)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out_path = "output.mp4"
    out = cv2.VideoWriter(out_path, fourcc, cap.get(cv2.CAP_PROP_FPS),
                          (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
                           int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        results = yolo_model.predict(frame)[0]
        overlay = frame.copy()

        if results.masks is not None:
            for mask, cls in zip(results.masks.xy, results.boxes.cls):
                pts = np.array(mask, dtype=np.int32)
                color = [random.randint(0, 255) for _ in range(3)]
                cv2.fillPoly(overlay, [pts], color)

        out.write(overlay)

    cap.release()
    out.release()
    return out_path

# ---------------------------
# Text Generation
# ---------------------------
def generate_text(prompt):
    result = text_gen(prompt, max_length=100, num_return_sequences=1)
    return result[0]["generated_text"]

# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks() as demo:
    gr.Markdown("# 🔥 Multi-Modal Playground\nTry out **Text + Image + Video Segmentation** in one app!")

    with gr.Tab("💬 Text Generation"):
        inp_text = gr.Textbox(label="Enter your prompt")
        out_text = gr.Textbox(label="Generated text")
        btn_text = gr.Button("Generate")
        btn_text.click(generate_text, inputs=inp_text, outputs=out_text)

    with gr.Tab("🖼️ Image Segmentation"):
        inp_img = gr.Image(type="pil", label="Upload Image")
        out_img = gr.Image(type="pil", label="Segmented Image")
        out_ann = gr.JSON(label="Annotations")
        btn_img = gr.Button("Run Segmentation")
        btn_img.click(segment_image, inputs=inp_img, outputs=[out_img, out_ann])

    with gr.Tab("🎥 Video Segmentation"):
        inp_vid = gr.Video(label="Upload Video")
        out_vid = gr.Video(label="Segmented Video")
        btn_vid = gr.Button("Run Segmentation")
        btn_vid.click(segment_video, inputs=inp_vid, outputs=out_vid)

demo.launch(server_name="0.0.0.0", server_port=7860, share=False, ssr_mode=False)