File size: 4,227 Bytes
769e1e0
92dd353
769e1e0
 
 
 
 
 
 
92dd353
769e1e0
92dd353
 
 
769e1e0
92dd353
 
 
769e1e0
92dd353
 
769e1e0
 
 
92dd353
 
769e1e0
 
 
 
 
 
92dd353
769e1e0
92dd353
769e1e0
92dd353
769e1e0
 
 
 
 
 
 
92dd353
769e1e0
 
 
92dd353
769e1e0
92dd353
 
769e1e0
 
 
92dd353
769e1e0
 
92dd353
 
769e1e0
 
92dd353
769e1e0
 
 
92dd353
769e1e0
 
92dd353
769e1e0
 
 
92dd353
769e1e0
 
92dd353
769e1e0
 
 
 
 
 
 
92dd353
769e1e0
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"  # force PyTorch-only pipelines

import numpy as np
from PIL import Image, ImageFilter
import gradio as gr
import torch
from transformers import pipeline

# ---- Config ----
DEVICE = 0 if torch.cuda.is_available() else -1
SEG_MODEL = "nvidia/segformer-b0-finetuned-ade-512-512"
DEPTH_MODEL = "Intel/dpt-hybrid-midas"
SIZE = (512, 512)

# ---- Pipelines (loaded once) ----
seg_pipe = pipeline("image-segmentation", model=SEG_MODEL, device=DEVICE, framework="pt")
depth_pipe = pipeline("depth-estimation", model=DEPTH_MODEL, device=DEVICE, framework="pt")

# ---- Helpers ----
def resize_center_crop(img: Image.Image, size=SIZE) -> Image.Image:
    img = img.convert("RGB")
    w, h = img.size
    tw, th = size
    s = max(tw / w, th / h)
    nw, nh = int(round(w * s)), int(round(h * s))
    img = img.resize((nw, nh), Image.BICUBIC)
    left, top = (nw - tw) // 2, (nh - th) // 2
    return img.crop((left, top, left + tw, top + th))

def person_mask(img_512: Image.Image) -> Image.Image:
    results = seg_pipe(img_512)
    person = next((r for r in results if r.get("label", "").lower() == "person"), None)
    if person is None:
        person = next((r for r in results if "person" in r.get("label", "").lower()), None)
    if person is None:
        return Image.new("L", img_512.size, 0)  # no person detected
    m = person["mask"].convert("L")
    m = (np.array(m) > 127).astype(np.uint8) * 255
    return Image.fromarray(m, mode="L")

def gaussian_bg_blur(img_512: Image.Image, sigma: int = 15) -> Image.Image:
    m = person_mask(img_512)
    blurred = img_512.filter(ImageFilter.GaussianBlur(radius=int(sigma)))
    return Image.composite(img_512, blurred, m)  # white=person -> keep sharp

def depth_lens_blur(img_512: Image.Image, max_radius: int = 15, keep_subject: bool = True) -> Image.Image:
    out = depth_pipe(img_512)
    d = out["depth"].resize(SIZE, Image.BICUBIC)
    dnp = np.array(d).astype(np.float32)
    d01 = (dnp - dnp.min()) / (dnp.max() - dnp.min() + 1e-8)  # 0..1
    far = 1.0 - d01  # larger=farther -> more blur
    if keep_subject:
        m = person_mask(img_512)
        m01 = (np.array(m) > 127).astype(np.float32)
        far = far * (1.0 - 0.85 * m01)  # suppress blur on detected subject
    max_radius = int(max(0, min(30, max_radius)))
    idx = np.clip(np.rint(far * max_radius).astype(np.int32), 0, max_radius)
    stack = [img_512 if r == 0 else img_512.filter(ImageFilter.GaussianBlur(radius=r))
             for r in range(max_radius + 1)]
    stack_np = np.stack([np.array(im) for im in stack], axis=0)  # [R+1,H,W,3]
    H, W = idx.shape
    h = np.arange(H)[:, None]; w = np.arange(W)[None, :]
    out_np = stack_np[idx, h, w]
    return Image.fromarray(out_np.astype(np.uint8))

def run(image, effect, sigma, max_radius, keep_subject):
    if image is None:
        return None, None
    img_512 = resize_center_crop(image, SIZE)
    if effect == "Gaussian Background Blur (subject sharp)":
        out = gaussian_bg_blur(img_512, sigma=int(sigma))
    else:
        out = depth_lens_blur(img_512, max_radius=int(max_radius), keep_subject=bool(keep_subject))
    return img_512, out

# ---- UI ----
with gr.Blocks(title="Gaussian & Lens Blur Lab") as demo:
    gr.Markdown("# Gaussian & Lens Blur Lab\nUpload an image and compare effects.")
    with gr.Row():
        with gr.Column():
            in_img = gr.Image(type="pil", label="Upload image")
            effect = gr.Radio(
                ["Gaussian Background Blur (subject sharp)", "Depth-based Lens Blur"],
                value="Gaussian Background Blur (subject sharp)", label="Effect"
            )
            sigma = gr.Slider(1, 40, value=15, step=1, label="Gaussian sigma")
            max_r = gr.Slider(4, 30, value=15, step=1, label="Max blur radius (lens blur)")
            keep = gr.Checkbox(True, label="Keep detected subject sharper (lens blur)")
            btn = gr.Button("Run")
        with gr.Column():
            out_a = gr.Image(label="Preprocessed 512×512")
            out_b = gr.Image(label="Result")
    btn.click(run, inputs=[in_img, effect, sigma, max_r, keep], outputs=[out_a, out_b])

if __name__ == "__main__":
    demo.launch()