import numpy as np import torch import matplotlib matplotlib.use('Agg') # Use non-interactive backend for server environments import matplotlib.pyplot as plt from PIL import Image from typing import Tuple, Dict, Any import io import base64 import math class PatchAttentionAnalyzer: """Utility class for computing and visualizing patch-level attention between images.""" def __init__(self, embedding_model): self.embedding_model = embedding_model self.supports_native_attention = hasattr(embedding_model, 'supports_native_attention') and embedding_model.supports_native_attention() def compute_patch_similarities(self, query_image: Image.Image, candidate_image: Image.Image) -> Dict[str, Any]: """ Compute patch-level similarities between query and candidate images. Automatically uses native attention if model supports it. Returns: Dictionary containing attention matrix, top correspondences, and metadata """ # Use native attention if available if self.supports_native_attention: return self.compute_native_attention_similarities(query_image, candidate_image) # Fallback to cosine similarity approach try: # Get patch features for both images query_patches = self.embedding_model.encode_image_patches(query_image) candidate_patches = self.embedding_model.encode_image_patches(candidate_image) # Compute attention matrix attention_matrix = self.embedding_model.compute_patch_attention(query_patches, candidate_patches) # Get grid dimensions (assuming square patches for ViT models) query_grid_size = int(math.sqrt(query_patches.shape[0])) candidate_grid_size = int(math.sqrt(candidate_patches.shape[0])) # Find top correspondences for each query patch top_correspondences = [] for i in range(attention_matrix.shape[0]): patch_similarities = attention_matrix[i] top_indices = torch.topk(patch_similarities, k=min(5, patch_similarities.shape[0])) top_correspondences.append({ 'query_patch_idx': i, 'query_patch_coord': self._patch_idx_to_coord(i, query_grid_size), 'top_candidate_indices': top_indices.indices.tolist(), 'top_candidate_coords': [self._patch_idx_to_coord(idx.item(), candidate_grid_size) for idx in top_indices.indices], 'similarity_scores': top_indices.values.tolist() }) return { 'attention_matrix': attention_matrix.cpu().numpy(), 'query_grid_size': query_grid_size, 'candidate_grid_size': candidate_grid_size, 'top_correspondences': top_correspondences, 'query_patches_shape': query_patches.shape, 'candidate_patches_shape': candidate_patches.shape, 'overall_similarity': torch.mean(attention_matrix).item() } except NotImplementedError: raise ValueError(f"Patch-level encoding not supported for {self.embedding_model.get_model_name()}") except Exception as e: raise RuntimeError(f"Error computing patch similarities: {e}") def _patch_idx_to_coord(self, patch_idx: int, grid_size: int) -> Tuple[int, int]: """Convert flat patch index to (row, col) coordinate.""" row = patch_idx // grid_size col = patch_idx % grid_size return (row, col) def visualize_attention_heatmap(self, query_image: Image.Image, candidate_image: Image.Image, similarity_data: Dict[str, Any], figsize: Tuple[int, int] = (15, 10)) -> str: """ Create a visualization showing attention heatmap between patches. Returns base64 encoded PNG image. """ attention_matrix = similarity_data['attention_matrix'] query_grid_size = similarity_data['query_grid_size'] candidate_grid_size = similarity_data['candidate_grid_size'] fig, axes = plt.subplots(2, 2, figsize=figsize) fig.suptitle(f'Patch Attention Analysis - Overall Similarity: {similarity_data["overall_similarity"]:.3f}', fontsize=14, fontweight='bold') # Plot original images axes[0, 0].imshow(query_image) axes[0, 0].set_title('Query Image') axes[0, 0].axis('off') self._overlay_patch_grid(axes[0, 0], query_image.size, query_grid_size) axes[0, 1].imshow(candidate_image) axes[0, 1].set_title('Candidate Image') axes[0, 1].axis('off') self._overlay_patch_grid(axes[0, 1], candidate_image.size, candidate_grid_size) # Plot attention matrix im = axes[1, 0].imshow(attention_matrix, cmap='viridis', aspect='auto') axes[1, 0].set_title('Attention Matrix') axes[1, 0].set_xlabel('Candidate Patches') axes[1, 0].set_ylabel('Query Patches') plt.colorbar(im, ax=axes[1, 0], fraction=0.046, pad=0.04) # Plot attention summary (max attention per query patch) max_attention_per_query = np.max(attention_matrix, axis=1) attention_grid = max_attention_per_query.reshape(query_grid_size, query_grid_size) im2 = axes[1, 1].imshow(attention_grid, cmap='hot', interpolation='nearest') axes[1, 1].set_title('Max Attention per Query Patch') axes[1, 1].set_xlabel('Patch Column') axes[1, 1].set_ylabel('Patch Row') plt.colorbar(im2, ax=axes[1, 1], fraction=0.046, pad=0.04) plt.tight_layout() # Convert to base64 buffer = io.BytesIO() plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight') buffer.seek(0) plot_data = buffer.getvalue() buffer.close() plt.close() return base64.b64encode(plot_data).decode() def visualize_top_correspondences(self, query_image: Image.Image, candidate_image: Image.Image, similarity_data: Dict[str, Any], num_top_patches: int = 6) -> str: """ Visualize the top corresponding patches between query and candidate images. Returns base64 encoded PNG image. """ top_correspondences = similarity_data['top_correspondences'] query_grid_size = similarity_data['query_grid_size'] candidate_grid_size = similarity_data['candidate_grid_size'] # Sort by best similarity score sorted_correspondences = sorted( top_correspondences, key=lambda x: max(x['similarity_scores']), reverse=True )[:num_top_patches] fig, axes = plt.subplots(2, num_top_patches, figsize=(3*num_top_patches, 6)) fig.suptitle('Top Patch Correspondences', fontsize=14, fontweight='bold') for i, correspondence in enumerate(sorted_correspondences): query_coord = correspondence['query_patch_coord'] best_candidate_coord = correspondence['top_candidate_coords'][0] best_score = correspondence['similarity_scores'][0] # Extract and show query patch query_patch = self._extract_patch_from_image(query_image, query_coord, query_grid_size) axes[0, i].imshow(query_patch) axes[0, i].set_title(f'Q-Patch {query_coord}\nScore: {best_score:.3f}') axes[0, i].axis('off') # Extract and show best matching candidate patch candidate_patch = self._extract_patch_from_image(candidate_image, best_candidate_coord, candidate_grid_size) axes[1, i].imshow(candidate_patch) axes[1, i].set_title(f'C-Patch {best_candidate_coord}') axes[1, i].axis('off') plt.tight_layout() # Convert to base64 buffer = io.BytesIO() plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight') buffer.seek(0) plot_data = buffer.getvalue() buffer.close() plt.close() return base64.b64encode(plot_data).decode() def _overlay_patch_grid(self, ax, image_size: Tuple[int, int], grid_size: int): """Overlay patch grid lines on image.""" width, height = image_size patch_width = width / grid_size patch_height = height / grid_size # Draw vertical lines for i in range(1, grid_size): x = i * patch_width ax.axvline(x=x, color='white', alpha=0.5, linewidth=1) # Draw horizontal lines for i in range(1, grid_size): y = i * patch_height ax.axhline(y=y, color='white', alpha=0.5, linewidth=1) def _extract_patch_from_image(self, image: Image.Image, patch_coord: Tuple[int, int], grid_size: int) -> Image.Image: """Extract a specific patch from an image based on grid coordinates.""" row, col = patch_coord width, height = image.size patch_width = width // grid_size patch_height = height // grid_size left = col * patch_width top = row * patch_height right = min((col + 1) * patch_width, width) bottom = min((row + 1) * patch_height, height) return image.crop((left, top, right, bottom)) def compute_native_attention_similarities(self, query_image: Image.Image, candidate_image: Image.Image) -> Dict[str, Any]: """ Compute patch-level similarities using native attention mechanism. Only available for models with native attention support (e.g., DINOv2 with registers). Returns: Dictionary containing attention matrix, top correspondences, and metadata """ try: # Use model's cross-attention computation attention_matrix = self.embedding_model.compute_cross_attention(query_image, candidate_image) attention_matrix_np = attention_matrix.cpu().numpy() # Get patch counts (attention_matrix is already query_patches x candidate_patches) num_query_patches = attention_matrix.shape[0] num_candidate_patches = attention_matrix.shape[1] # Get grid dimensions (assuming square patches) query_grid_size = int(math.sqrt(num_query_patches)) candidate_grid_size = int(math.sqrt(num_candidate_patches)) # Find top correspondences for each query patch top_correspondences = [] for i in range(num_query_patches): patch_similarities = attention_matrix[i] top_indices = torch.topk(patch_similarities, k=min(5, num_candidate_patches)) top_correspondences.append({ 'query_patch_idx': i, 'query_patch_coord': self._patch_idx_to_coord(i, query_grid_size), 'top_candidate_indices': top_indices.indices.tolist(), 'top_candidate_coords': [self._patch_idx_to_coord(idx.item(), candidate_grid_size) for idx in top_indices.indices], 'similarity_scores': top_indices.values.tolist() }) return { 'attention_matrix': attention_matrix_np, 'query_grid_size': query_grid_size, 'candidate_grid_size': candidate_grid_size, 'top_correspondences': top_correspondences, 'query_patches_shape': (num_query_patches, attention_matrix.shape[-1]), 'candidate_patches_shape': (num_candidate_patches, attention_matrix.shape[-1]), 'overall_similarity': torch.mean(attention_matrix).item(), 'use_native_attention': True } except Exception as e: raise RuntimeError(f"Error computing native attention similarities: {e}") def get_similarity_summary(self, similarity_data: Dict[str, Any]) -> Dict[str, Any]: """Get a summary of similarity statistics.""" attention_matrix = similarity_data['attention_matrix'] summary = { 'overall_similarity': similarity_data['overall_similarity'], 'max_similarity': float(np.max(attention_matrix)), 'min_similarity': float(np.min(attention_matrix)), 'std_similarity': float(np.std(attention_matrix)), 'query_patches_count': similarity_data['query_patches_shape'][0], 'candidate_patches_count': similarity_data['candidate_patches_shape'][0], 'high_attention_patches': int(np.sum(attention_matrix > (np.mean(attention_matrix) + np.std(attention_matrix)))), 'model_name': self.embedding_model.get_model_name() } # Add native attention flag if present if 'use_native_attention' in similarity_data: summary['use_native_attention'] = similarity_data['use_native_attention'] return summary def visualize_multihead_attention(self, image: Image.Image, layer_idx: int = -1, figsize: Tuple[int, int] = (20, 12)) -> str: """ Visualize attention from multiple heads for a single image. Only available for models with native attention support. Args: image: Input image to visualize attention for layer_idx: Which transformer layer to visualize (-1 for last layer) figsize: Figure size for the plot Returns: Base64 encoded PNG image showing multi-head attention patterns """ if not self.supports_native_attention: raise ValueError("Multi-head attention visualization requires native attention support") try: # Get attention maps from the model attention_maps = self.embedding_model.get_attention_maps(image) # Shape: (num_layers, num_heads, num_tokens, num_tokens) # Select the specified layer layer_attention = attention_maps[layer_idx] # (num_heads, num_tokens, num_tokens) num_heads = layer_attention.shape[0] # Extract patch-to-patch attention (exclude CLS token and register tokens) # Token sequence structure varies by model: # DINOv2 with registers: [CLS] + 4 register tokens + 256 spatial patches = 261 total # DINOv3: [CLS] + 4 register tokens + 196 spatial patches (16x16 patches) = 201 total model_name = self.embedding_model.get_model_name().lower() if 'dinov3' in model_name: num_register_tokens = 4 expected_patches = 196 # For 224x224 image with patch size 16 (14*14=196) else: num_register_tokens = 4 expected_patches = 256 # For 224x224 image with patch size 14 # Skip CLS token (position 0) and register tokens (positions 1-4) start_idx = 1 + num_register_tokens # Position 5 end_idx = start_idx + expected_patches # Position 261 patch_attention = layer_attention[:, start_idx:end_idx, start_idx:end_idx] # Convert to numpy patch_attention_np = patch_attention.cpu().numpy() # Get grid size num_patches = patch_attention.shape[1] grid_size = int(math.sqrt(num_patches)) # Create subplot grid num_cols = 4 num_rows = (num_heads + num_cols - 1) // num_cols # Ceiling division fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize) axes = axes.flatten() if num_heads > 1 else [axes] layer_name = f"Layer {layer_idx}" if layer_idx >= 0 else f"Last Layer ({len(attention_maps)})" fig.suptitle(f'Multi-Head Attention Patterns - {layer_name}', fontsize=16, fontweight='bold') # Plot each head's average attention for head_idx in range(num_heads): # Average attention from all query patches to all key patches head_attn = patch_attention_np[head_idx] avg_attention = np.mean(head_attn, axis=0).reshape(grid_size, grid_size) im = axes[head_idx].imshow(avg_attention, cmap='viridis', interpolation='nearest') axes[head_idx].set_title(f'Head {head_idx + 1}') axes[head_idx].axis('off') plt.colorbar(im, ax=axes[head_idx], fraction=0.046, pad=0.04) # Hide unused subplots for idx in range(num_heads, len(axes)): axes[idx].axis('off') plt.tight_layout() # Convert to base64 buffer = io.BytesIO() plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight') buffer.seek(0) plot_data = buffer.getvalue() buffer.close() plt.close() return base64.b64encode(plot_data).decode() except Exception as e: raise RuntimeError(f"Error visualizing multi-head attention: {e}") def visualize_attention_comparison(self, query_image: Image.Image, candidate_image: Image.Image, figsize: Tuple[int, int] = (20, 10)) -> str: """ Compare native attention vs computed cosine similarity side-by-side. Only available for models with native attention support. Args: query_image: Query image candidate_image: Candidate image figsize: Figure size for the plot Returns: Base64 encoded PNG showing both attention methods """ if not self.supports_native_attention: raise ValueError("Attention comparison requires native attention support") try: # Compute native attention native_data = self.compute_native_attention_similarities(query_image, candidate_image) # Compute cosine similarity for comparison query_patches = self.embedding_model.encode_image_patches(query_image) candidate_patches = self.embedding_model.encode_image_patches(candidate_image) cosine_attention = self.embedding_model.compute_patch_attention(query_patches, candidate_patches) cosine_attention_np = cosine_attention.cpu().numpy() # Create comparison visualization fig, axes = plt.subplots(2, 3, figsize=figsize) fig.suptitle('Native Attention vs Cosine Similarity Comparison', fontsize=16, fontweight='bold') # Row 1: Native attention axes[0, 0].imshow(query_image) axes[0, 0].set_title('Query Image') axes[0, 0].axis('off') im1 = axes[0, 1].imshow(native_data['attention_matrix'], cmap='viridis', aspect='auto') axes[0, 1].set_title(f'Native Attention\n(Avg: {native_data["overall_similarity"]:.3f})') axes[0, 1].set_xlabel('Candidate Patches') axes[0, 1].set_ylabel('Query Patches') plt.colorbar(im1, ax=axes[0, 1], fraction=0.046, pad=0.04) # Max attention heatmap for native max_native = np.max(native_data['attention_matrix'], axis=1) native_grid = max_native.reshape(native_data['query_grid_size'], native_data['query_grid_size']) im2 = axes[0, 2].imshow(native_grid, cmap='hot', interpolation='nearest') axes[0, 2].set_title('Max Native Attention per Patch') plt.colorbar(im2, ax=axes[0, 2], fraction=0.046, pad=0.04) # Row 2: Cosine similarity axes[1, 0].imshow(candidate_image) axes[1, 0].set_title('Candidate Image') axes[1, 0].axis('off') cosine_mean = float(np.mean(cosine_attention_np)) im3 = axes[1, 1].imshow(cosine_attention_np, cmap='viridis', aspect='auto') axes[1, 1].set_title(f'Cosine Similarity\n(Avg: {cosine_mean:.3f})') axes[1, 1].set_xlabel('Candidate Patches') axes[1, 1].set_ylabel('Query Patches') plt.colorbar(im3, ax=axes[1, 1], fraction=0.046, pad=0.04) # Max attention heatmap for cosine max_cosine = np.max(cosine_attention_np, axis=1) query_grid_size = int(math.sqrt(query_patches.shape[0])) cosine_grid = max_cosine.reshape(query_grid_size, query_grid_size) im4 = axes[1, 2].imshow(cosine_grid, cmap='hot', interpolation='nearest') axes[1, 2].set_title('Max Cosine Similarity per Patch') plt.colorbar(im4, ax=axes[1, 2], fraction=0.046, pad=0.04) plt.tight_layout() # Convert to base64 buffer = io.BytesIO() plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight') buffer.seek(0) plot_data = buffer.getvalue() buffer.close() plt.close() return base64.b64encode(plot_data).decode() except Exception as e: raise RuntimeError(f"Error comparing attention methods: {e}")