NeerjaK commited on
Commit
3789875
·
verified ·
1 Parent(s): 5492757

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +5 -10
utils.py CHANGED
@@ -45,16 +45,16 @@ def display_image(img, rows,cols):
45
 
46
 
47
  def get_patch_embeddings(img, ps=16, device="cuda"):
48
- inputs = processor(images=img, return_tensors="pt").to(device, torch.float16)
49
  B, C, H, W = inputs["pixel_values"].shape
50
- rows, cols = H // ps, W // ps
51
 
52
  with torch.no_grad():
53
  out = model(**inputs)
54
 
55
  hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy()
56
 
57
- # remove CLS + register tokens
58
  n_patches = rows * cols
59
  patch_embs = hs[-n_patches:, :].reshape(rows, cols, -1)
60
 
@@ -62,22 +62,18 @@ def get_patch_embeddings(img, ps=16, device="cuda"):
62
  X = patch_embs.reshape(-1, patch_embs.shape[-1])
63
  Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)
64
 
65
- return patch_embs, Xn, rows, cols
66
 
67
  def compute_patch_similarity(patch_embs, patch_embs_norm, row, col):
68
  rows, cols, dim = patch_embs.shape
69
  patch_idx = row * cols + col # flatten index
70
 
71
- # cosine similarity via dot product
72
- sim = patch_embs_norm @ patch_embs_norm[patch_idx]
73
  sim_map = sim.reshape(rows, cols)
74
  sim_map = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-8)
75
-
76
-
77
  return sim_map
78
 
79
  def overlay_similarity(img, sim_map, alpha=0.5, cmap="hot"):
80
- """Draw heatmap overlay with grid and return as PIL image (for Gradio)."""
81
  W, H = img.size
82
 
83
  # Expand sim_map (14x14) to full resolution via Kronecker upsampling
@@ -87,7 +83,6 @@ def overlay_similarity(img, sim_map, alpha=0.5, cmap="hot"):
87
  ax.imshow(img)
88
  ax.imshow(sim_map_resized, cmap=cmap, alpha=alpha)
89
 
90
- # Draw patch grid
91
  patch_w = W / sim_map.shape[1]
92
  patch_h = H / sim_map.shape[0]
93
  for i in range(1, sim_map.shape[1]):
 
45
 
46
 
47
  def get_patch_embeddings(img, ps=16, device="cuda"):
48
+ inputs = processor(images=img, return_tensors="pt").to(device, torch.float16) # preprocessing for image include scaling, normalization etc
49
  B, C, H, W = inputs["pixel_values"].shape
50
+ rows, cols = H // ps, W // ps # image of size 224x224, patch size = 16x16, hence image has 14x14 patches
51
 
52
  with torch.no_grad():
53
  out = model(**inputs)
54
 
55
  hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy()
56
 
57
+ # remove CLS + any non-patch token
58
  n_patches = rows * cols
59
  patch_embs = hs[-n_patches:, :].reshape(rows, cols, -1)
60
 
 
62
  X = patch_embs.reshape(-1, patch_embs.shape[-1])
63
  Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)
64
 
65
+ return patch_embs, Xn, rows, cols # list of normalized patch vectors
66
 
67
  def compute_patch_similarity(patch_embs, patch_embs_norm, row, col):
68
  rows, cols, dim = patch_embs.shape
69
  patch_idx = row * cols + col # flatten index
70
 
71
+ sim = patch_embs_norm @ patch_embs_norm[patch_idx] # cosine similarity via dot product
 
72
  sim_map = sim.reshape(rows, cols)
73
  sim_map = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-8)
 
 
74
  return sim_map
75
 
76
  def overlay_similarity(img, sim_map, alpha=0.5, cmap="hot"):
 
77
  W, H = img.size
78
 
79
  # Expand sim_map (14x14) to full resolution via Kronecker upsampling
 
83
  ax.imshow(img)
84
  ax.imshow(sim_map_resized, cmap=cmap, alpha=alpha)
85
 
 
86
  patch_w = W / sim_map.shape[1]
87
  patch_h = H / sim_map.shape[0]
88
  for i in range(1, sim_map.shape[1]):