#!/usr/bin/env python3 """ ASI V2.5 - HuggingFace Spaces Compatible Version Optimized for CPU environment with 16GB RAM limitation Fixed all dimension errors and optimized for Spaces hardware """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional class HFCompatibleASIAttention(nn.Module): """ ASI V2.5 Compatible with HuggingFace Spaces Key fixes: - Proper dimension handling for CPU environment - Memory optimized for 16GB RAM limit - No GPU dependencies - Fixed matrix multiplication errors """ def __init__(self, hidden_size=768, num_heads=12, threshold=8, feature_dim=4): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.threshold = threshold self.feature_dim = feature_dim # Validation assert hidden_size % num_heads == 0, f"hidden_size {hidden_size} not divisible by num_heads {num_heads}" # Standard attention projections self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) # ASI feature mapping - FIXED dimensions # Map from head_dim to feature_dim for each head self.feature_map = nn.Linear(self.head_dim, feature_dim, bias=False) self.scale = (self.head_dim ** -0.5) def forward(self, hidden_states, attention_mask=None, **kwargs): """ Fixed forward pass with proper dimension handling """ batch_size, seq_len, _ = hidden_states.shape # Project to Q, K, V q = self.q_proj(hidden_states) # [B, L, H] k = self.k_proj(hidden_states) # [B, L, H] v = self.v_proj(hidden_states) # [B, L, H] # Reshape for multi-head attention q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L, D] k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L, D] v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L, D] # ASI adaptive attention if seq_len <= self.threshold: # Exact attention for short sequences attn_output = self._exact_attention(q, k, v, attention_mask) else: # Linear attention for long sequences - FIXED VERSION attn_output = self._linear_attention_fixed(q, k, v, attention_mask) # Reshape back and project attn_output = attn_output.transpose(1, 2).contiguous().view( batch_size, seq_len, self.hidden_size ) attn_output = self.o_proj(attn_output) return attn_output, None, None # Match expected HF signature def _exact_attention(self, q, k, v, attention_mask=None): """Standard O(L²) attention""" # q, k, v: [B, H, L, D] scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale # [B, H, L, L] if attention_mask is not None: # Apply mask mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L] scores = scores.masked_fill(mask == 0, -1e9) attn_weights = torch.softmax(scores, dim=-1) # [B, H, L, L] attn_output = torch.matmul(attn_weights, v) # [B, H, L, D] return attn_output def _linear_attention_fixed(self, q, k, v, attention_mask=None): """ FIXED Linear attention for O(L) complexity Properly handles dimensions for HuggingFace Spaces """ # q, k, v: [B, H, L, D] where D = head_dim batch_size, num_heads, seq_len, head_dim = q.shape # Apply feature mapping to reduce dimension # Reshape for feature mapping: [B*H*L, D] -> [B*H*L, F] q_reshaped = q.reshape(-1, head_dim) # [B*H*L, D] k_reshaped = k.reshape(-1, head_dim) # [B*H*L, D] q_feat = self.feature_map(q_reshaped) # [B*H*L, F] k_feat = self.feature_map(k_reshaped) # [B*H*L, F] # Reshape back: [B*H*L, F] -> [B, H, L, F] q_feat = q_feat.view(batch_size, num_heads, seq_len, self.feature_dim) k_feat = k_feat.view(batch_size, num_heads, seq_len, self.feature_dim) # Apply attention mask to keys if provided if attention_mask is not None: mask = attention_mask.unsqueeze(1).unsqueeze(-1) # [B, 1, L, 1] k_feat = k_feat * mask.float() # Linear attention computation - FIXED DIMENSIONS # Step 1: K^T @ V # k_feat: [B, H, L, F], v: [B, H, L, D] -> kv: [B, H, F, D] kv = torch.matmul(k_feat.transpose(-2, -1), v) # [B, H, F, D] # Step 2: Q @ (K^T @ V) # q_feat: [B, H, L, F], kv: [B, H, F, D] -> attn_output: [B, H, L, D] attn_output = torch.matmul(q_feat, kv) # [B, H, L, D] # Step 3: Normalization - FIXED # k_feat: [B, H, L, F] -> k_sum: [B, H, 1, F] k_sum = k_feat.sum(dim=-2, keepdim=True) # [B, H, 1, F] # q_feat: [B, H, L, F], k_sum: [B, H, 1, F] -> normalization: [B, H, L, 1] # Use einsum for clearer dimension handling normalization = torch.einsum('bhlf,bhf->bhl', q_feat, k_sum.squeeze(-2)) # [B, H, L] normalization = normalization.unsqueeze(-1) # [B, H, L, 1] # Prevent division by zero and normalize attn_output = attn_output / (normalization + 1e-8) return attn_output def create_hf_asi_attention(dim=768, num_heads=12, threshold=8, feature_dim=4): """Factory function for HF Spaces compatible ASI""" return HFCompatibleASIAttention( hidden_size=dim, num_heads=num_heads, threshold=threshold, feature_dim=feature_dim ) # Test function def test_hf_asi(): """Test the HF compatible ASI implementation""" batch_size, seq_len, hidden_size = 1, 512, 768 device = "cpu" # HF Spaces is CPU-only # Create test data hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device) # Create ASI attention asi_attention = create_hf_asi_attention(dim=hidden_size, threshold=8, feature_dim=4) asi_attention.to(device) # Test forward pass with torch.no_grad(): output, _, _ = asi_attention(hidden_states) print(f"āœ… Input shape: {hidden_states.shape}") print(f"āœ… Output shape: {output.shape}") print(f"āœ… ASI test passed!") return True if __name__ == "__main__": test_hf_asi()