Update utils.py
Browse files
utils.py
CHANGED
|
@@ -45,16 +45,16 @@ def display_image(img, rows,cols):
|
|
| 45 |
|
| 46 |
|
| 47 |
def get_patch_embeddings(img, ps=16, device="cuda"):
|
| 48 |
-
inputs = processor(images=img, return_tensors="pt").to(device, torch.float16)
|
| 49 |
B, C, H, W = inputs["pixel_values"].shape
|
| 50 |
-
rows, cols = H // ps, W // ps
|
| 51 |
|
| 52 |
with torch.no_grad():
|
| 53 |
out = model(**inputs)
|
| 54 |
|
| 55 |
hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy()
|
| 56 |
|
| 57 |
-
# remove CLS +
|
| 58 |
n_patches = rows * cols
|
| 59 |
patch_embs = hs[-n_patches:, :].reshape(rows, cols, -1)
|
| 60 |
|
|
@@ -62,22 +62,18 @@ def get_patch_embeddings(img, ps=16, device="cuda"):
|
|
| 62 |
X = patch_embs.reshape(-1, patch_embs.shape[-1])
|
| 63 |
Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)
|
| 64 |
|
| 65 |
-
return patch_embs, Xn, rows, cols
|
| 66 |
|
| 67 |
def compute_patch_similarity(patch_embs, patch_embs_norm, row, col):
|
| 68 |
rows, cols, dim = patch_embs.shape
|
| 69 |
patch_idx = row * cols + col # flatten index
|
| 70 |
|
| 71 |
-
# cosine similarity via dot product
|
| 72 |
-
sim = patch_embs_norm @ patch_embs_norm[patch_idx]
|
| 73 |
sim_map = sim.reshape(rows, cols)
|
| 74 |
sim_map = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-8)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
return sim_map
|
| 78 |
|
| 79 |
def overlay_similarity(img, sim_map, alpha=0.5, cmap="hot"):
|
| 80 |
-
"""Draw heatmap overlay with grid and return as PIL image (for Gradio)."""
|
| 81 |
W, H = img.size
|
| 82 |
|
| 83 |
# Expand sim_map (14x14) to full resolution via Kronecker upsampling
|
|
@@ -87,7 +83,6 @@ def overlay_similarity(img, sim_map, alpha=0.5, cmap="hot"):
|
|
| 87 |
ax.imshow(img)
|
| 88 |
ax.imshow(sim_map_resized, cmap=cmap, alpha=alpha)
|
| 89 |
|
| 90 |
-
# Draw patch grid
|
| 91 |
patch_w = W / sim_map.shape[1]
|
| 92 |
patch_h = H / sim_map.shape[0]
|
| 93 |
for i in range(1, sim_map.shape[1]):
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
def get_patch_embeddings(img, ps=16, device="cuda"):
|
| 48 |
+
inputs = processor(images=img, return_tensors="pt").to(device, torch.float16) # preprocessing for image include scaling, normalization etc
|
| 49 |
B, C, H, W = inputs["pixel_values"].shape
|
| 50 |
+
rows, cols = H // ps, W // ps # image of size 224x224, patch size = 16x16, hence image has 14x14 patches
|
| 51 |
|
| 52 |
with torch.no_grad():
|
| 53 |
out = model(**inputs)
|
| 54 |
|
| 55 |
hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy()
|
| 56 |
|
| 57 |
+
# remove CLS + any non-patch token
|
| 58 |
n_patches = rows * cols
|
| 59 |
patch_embs = hs[-n_patches:, :].reshape(rows, cols, -1)
|
| 60 |
|
|
|
|
| 62 |
X = patch_embs.reshape(-1, patch_embs.shape[-1])
|
| 63 |
Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)
|
| 64 |
|
| 65 |
+
return patch_embs, Xn, rows, cols # list of normalized patch vectors
|
| 66 |
|
| 67 |
def compute_patch_similarity(patch_embs, patch_embs_norm, row, col):
|
| 68 |
rows, cols, dim = patch_embs.shape
|
| 69 |
patch_idx = row * cols + col # flatten index
|
| 70 |
|
| 71 |
+
sim = patch_embs_norm @ patch_embs_norm[patch_idx] # cosine similarity via dot product
|
|
|
|
| 72 |
sim_map = sim.reshape(rows, cols)
|
| 73 |
sim_map = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-8)
|
|
|
|
|
|
|
| 74 |
return sim_map
|
| 75 |
|
| 76 |
def overlay_similarity(img, sim_map, alpha=0.5, cmap="hot"):
|
|
|
|
| 77 |
W, H = img.size
|
| 78 |
|
| 79 |
# Expand sim_map (14x14) to full resolution via Kronecker upsampling
|
|
|
|
| 83 |
ax.imshow(img)
|
| 84 |
ax.imshow(sim_map_resized, cmap=cmap, alpha=alpha)
|
| 85 |
|
|
|
|
| 86 |
patch_w = W / sim_map.shape[1]
|
| 87 |
patch_h = H / sim_map.shape[0]
|
| 88 |
for i in range(1, sim_map.shape[1]):
|