NeerjaK commited on
Commit
745d6a0
·
1 Parent(s): b3cb8ec

first commit

Browse files
Files changed (3) hide show
  1. app.py +83 -0
  2. requirements.txt +8 -0
  3. utils.py +111 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils import get_patch_embeddings, compute_patch_similarity, overlay_similarity, device
3
+
4
+ selected_patch = {"row": 0, "col": 0}
5
+
6
+ def init_states(img):
7
+ if img is None:
8
+ return gr.update(value=None), None
9
+ patch_embs, patch_embs_norm, rows, cols = get_patch_embeddings(img, ps=16, device=device)
10
+
11
+ sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, 0, 0)
12
+ result_img = overlay_similarity(img, sim_map, alpha=0.6, cmap="hot")
13
+
14
+ state = {
15
+ "img": img,
16
+ "patch_embs": patch_embs,
17
+ "patch_embs_norm": patch_embs_norm,
18
+ "grid_size": rows,
19
+ "alpha": 0.6,
20
+ "overlay_img":result_img,
21
+ }
22
+
23
+ return state, result_img
24
+
25
+ def store_patch(evt, state):
26
+ if state is None or evt is None:
27
+ return state
28
+
29
+ rows = state["grid_size"] # e.g., (14, 14)
30
+ cols = rows
31
+ overlay_img = state["overlay_img"]
32
+ overlay_W, overlay_H = overlay_img.size
33
+ x_click, y_click = evt.index # coordinates from click event
34
+
35
+ # Map click coordinates to original patch grid
36
+ col = int(x_click / overlay_W * cols)
37
+ row = int(y_click / overlay_H * rows)
38
+
39
+ # Clamp to valid range
40
+ col = min(max(col, 0), cols - 1)
41
+ row = min(max(row, 0), rows - 1)
42
+
43
+ # Store in global or state dictionary
44
+ selected_patch["row"] = row
45
+ selected_patch["col"] = col
46
+
47
+
48
+ return state
49
+
50
+
51
+ def reload_overlay(evt: gr.SelectData,state):
52
+ if state is None:
53
+ return None
54
+ store_patch(evt, state)
55
+ row, col = selected_patch["row"], selected_patch["col"]
56
+ img = state["img"]
57
+ patch_embs = state["patch_embs"]
58
+ patch_embs_norm = state["patch_embs_norm"]
59
+ alpha = state["alpha"]
60
+ sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, row, col)
61
+ result_img = overlay_similarity(img, sim_map, alpha=alpha, cmap="hot")
62
+ return result_img
63
+
64
+ with gr.Blocks() as demo:
65
+ state_store = gr.State()
66
+ gr.Markdown("""
67
+ <h1 style="font-size:36px; font-weight:bold;">Patch Similarity Visualizer</h1>
68
+ <ul style="font-size:18px;">
69
+ <li>Upload an image in the <strong>left box</strong>.</li>
70
+ <li>Click anywhere in the <strong>right box</strong> to select a patch.</li>
71
+ <li>View the similarity of the selected patch with all other patches in the image.</li>
72
+ </ul>
73
+ """)
74
+
75
+ with gr.Row():
76
+ img_input = gr.Image(type="pil", label="Upload image")
77
+ output_img = gr.Image(type="pil", label="Similarity overlay",interactive=True)
78
+
79
+ img_input.change(fn=init_states, inputs=[img_input], outputs=[state_store, output_img])
80
+
81
+ output_img.select(fn=reload_overlay, inputs=[state_store], outputs=[output_img])
82
+
83
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ timm
5
+ gradio==5.49.1
6
+ numpy
7
+ Pillow
8
+ matplotlib
utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.image_utils import load_image
2
+ import numpy as np
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ import io
6
+ from transformers import AutoImageProcessor, AutoModel
7
+ import torch
8
+
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+
11
+ processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vits16-pretrain-lvd1689m")
12
+ model = AutoModel.from_pretrained(
13
+ "facebook/dinov3-vits16-pretrain-lvd1689m",
14
+ torch_dtype=torch.float16,
15
+ device_map="auto",
16
+ attn_implementation="sdpa"
17
+ )
18
+ model.eval()
19
+
20
+
21
+ def display_image(img, rows,cols):
22
+ W, H = img.size
23
+ patch_w = W / rows
24
+ patch_h = H / cols
25
+
26
+ plt.figure(figsize=(8,8))
27
+ plt.imshow(img)
28
+
29
+ # Draw vertical lines
30
+ for i in range(1, rows):
31
+ plt.axvline(i * patch_w, color='white', linestyle='--', linewidth=0.8)
32
+
33
+ # Draw horizontal lines
34
+ for i in range(1, cols):
35
+ plt.axhline(i * patch_h, color='white', linestyle='--', linewidth=0.8)
36
+
37
+ plt.axis('off')
38
+ plt.show()
39
+
40
+
41
+ def get_patch_embeddings(img, ps=16, device="cuda"):
42
+ inputs = processor(images=img, return_tensors="pt").to(device, torch.float16)
43
+ B, C, H, W = inputs["pixel_values"].shape
44
+ rows, cols = H // ps, W // ps
45
+
46
+ with torch.no_grad():
47
+ out = model(**inputs)
48
+
49
+ hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy()
50
+
51
+ # remove CLS + register tokens
52
+ n_patches = rows * cols
53
+ patch_embs = hs[-n_patches:, :].reshape(rows, cols, -1)
54
+
55
+ # flatten and normalize
56
+ X = patch_embs.reshape(-1, patch_embs.shape[-1])
57
+ Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)
58
+
59
+ return patch_embs, Xn, rows, cols
60
+
61
+ def compute_patch_similarity(patch_embs, patch_embs_norm, row, col):
62
+ rows, cols, dim = patch_embs.shape
63
+ patch_idx = row * cols + col # flatten index
64
+
65
+ # cosine similarity via dot product
66
+ sim = patch_embs_norm @ patch_embs_norm[patch_idx]
67
+ sim_map = sim.reshape(rows, cols)
68
+ sim_map = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-8)
69
+
70
+
71
+ return sim_map
72
+
73
+ def overlay_similarity(img, sim_map, alpha=0.5, cmap="hot"):
74
+ """Draw heatmap overlay with grid and return as PIL image (for Gradio)."""
75
+ W, H = img.size
76
+
77
+ # Expand sim_map (14x14) to full resolution via Kronecker upsampling
78
+ sim_map_resized = np.kron(sim_map, np.ones((H // sim_map.shape[0], W // sim_map.shape[1])))
79
+
80
+ # Plot to figure (no plt.show())
81
+ fig, ax = plt.subplots(figsize=(8, 8))
82
+ ax.imshow(img)
83
+ ax.imshow(sim_map_resized, cmap=cmap, alpha=alpha)
84
+
85
+ # Draw patch grid
86
+ patch_w = W / sim_map.shape[1]
87
+ patch_h = H / sim_map.shape[0]
88
+ for i in range(1, sim_map.shape[1]):
89
+ ax.axvline(i * patch_w, color='white', linestyle='--', linewidth=0.8)
90
+ for i in range(1, sim_map.shape[0]):
91
+ ax.axhline(i * patch_h, color='white', linestyle='--', linewidth=0.8)
92
+
93
+ ax.axis('off')
94
+
95
+ # Convert figure to PIL image (so Gradio can show it)
96
+ buf = io.BytesIO()
97
+ fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
98
+ plt.close(fig)
99
+ buf.seek(0)
100
+ overlay_img = Image.open(buf)
101
+
102
+ return overlay_img
103
+
104
+ # img = Image.open("two-cats.jpg")
105
+ # patch_embs,patch_embs_norm,rows,cols= get_patch_embeddings(img,ps=16, device=device)
106
+ # display_image(img,rows,cols)
107
+ # sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, 7, 7)
108
+ # result_img = overlay_similarity(img,sim_map)
109
+ # plt.imshow(result_img)
110
+ # plt.savefig("overlay_result.png")
111
+ # plt.show()