File size: 3,552 Bytes
745d6a0 433e02f 745d6a0 433e02f 745d6a0 cc0ae2a 433e02f 745d6a0 3789875 745d6a0 3789875 745d6a0 3789875 745d6a0 3789875 745d6a0 3789875 745d6a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from transformers.image_utils import load_image
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
from transformers import AutoImageProcessor, AutoModel
from huggingface_hub import login
import os
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vits16-pretrain-lvd1689m",token=os.environ.get("HF_TOKEN"))
model = AutoModel.from_pretrained(
"facebook/dinov3-vits16-pretrain-lvd1689m",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa",
token=os.environ.get("HF_TOKEN")
)
model.eval()
def display_image(img, rows,cols):
W, H = img.size
patch_w = W / rows
patch_h = H / cols
plt.figure(figsize=(8,8))
plt.imshow(img)
# Draw vertical lines
for i in range(1, rows):
plt.axvline(i * patch_w, color='white', linestyle='--', linewidth=0.8)
# Draw horizontal lines
for i in range(1, cols):
plt.axhline(i * patch_h, color='white', linestyle='--', linewidth=0.8)
plt.axis('off')
plt.show()
def get_patch_embeddings(img, ps=16, device="cuda"):
inputs = processor(images=img, return_tensors="pt").to(device, torch.float16) # preprocessing for image include scaling, normalization etc
B, C, H, W = inputs["pixel_values"].shape
rows, cols = H // ps, W // ps # image of size 224x224, patch size = 16x16, hence image has 14x14 patches
with torch.no_grad():
out = model(**inputs)
hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy()
# remove CLS + any non-patch token
n_patches = rows * cols
patch_embs = hs[-n_patches:, :].reshape(rows, cols, -1)
# flatten and normalize
X = patch_embs.reshape(-1, patch_embs.shape[-1])
Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)
return patch_embs, Xn, rows, cols # list of normalized patch vectors
def compute_patch_similarity(patch_embs, patch_embs_norm, row, col):
rows, cols, dim = patch_embs.shape
patch_idx = row * cols + col # flatten index
sim = patch_embs_norm @ patch_embs_norm[patch_idx] # cosine similarity via dot product
sim_map = sim.reshape(rows, cols)
sim_map = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-8)
return sim_map
def overlay_similarity(img, sim_map, alpha=0.5, cmap="hot"):
W, H = img.size
# Expand sim_map (14x14) to full resolution via Kronecker upsampling
sim_map_resized = np.kron(sim_map, np.ones((H // sim_map.shape[0], W // sim_map.shape[1])))
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(img)
ax.imshow(sim_map_resized, cmap=cmap, alpha=alpha)
patch_w = W / sim_map.shape[1]
patch_h = H / sim_map.shape[0]
for i in range(1, sim_map.shape[1]):
ax.axvline(i * patch_w, color='white', linestyle='--', linewidth=0.8)
for i in range(1, sim_map.shape[0]):
ax.axhline(i * patch_h, color='white', linestyle='--', linewidth=0.8)
ax.axis('off')
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
overlay_img = Image.open(buf)
return overlay_img
# img = Image.open("two-cats.jpg")
# patch_embs,patch_embs_norm,rows,cols= get_patch_embeddings(img,ps=16, device=device)
# display_image(img,rows,cols)
# sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, 7, 7)
# result_img = overlay_similarity(img,sim_map)
# plt.imshow(result_img)
# plt.savefig("overlay_result.png")
# plt.show() |