Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| ASI V2.5 - HuggingFace Spaces Compatible Version | |
| Optimized for CPU environment with 16GB RAM limitation | |
| Fixed all dimension errors and optimized for Spaces hardware | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Tuple, Optional | |
| class HFCompatibleASIAttention(nn.Module): | |
| """ | |
| ASI V2.5 Compatible with HuggingFace Spaces | |
| Key fixes: | |
| - Proper dimension handling for CPU environment | |
| - Memory optimized for 16GB RAM limit | |
| - No GPU dependencies | |
| - Fixed matrix multiplication errors | |
| """ | |
| def __init__(self, hidden_size=768, num_heads=12, threshold=8, feature_dim=4): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.head_dim = hidden_size // num_heads | |
| self.threshold = threshold | |
| self.feature_dim = feature_dim | |
| # Validation | |
| assert hidden_size % num_heads == 0, f"hidden_size {hidden_size} not divisible by num_heads {num_heads}" | |
| # Standard attention projections | |
| self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| # ASI feature mapping - FIXED dimensions | |
| # Map from head_dim to feature_dim for each head | |
| self.feature_map = nn.Linear(self.head_dim, feature_dim, bias=False) | |
| self.scale = (self.head_dim ** -0.5) | |
| def forward(self, hidden_states, attention_mask=None, **kwargs): | |
| """ | |
| Fixed forward pass with proper dimension handling | |
| """ | |
| batch_size, seq_len, _ = hidden_states.shape | |
| # Project to Q, K, V | |
| q = self.q_proj(hidden_states) # [B, L, H] | |
| k = self.k_proj(hidden_states) # [B, L, H] | |
| v = self.v_proj(hidden_states) # [B, L, H] | |
| # Reshape for multi-head attention | |
| q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L, D] | |
| k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L, D] | |
| v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L, D] | |
| # ASI adaptive attention | |
| if seq_len <= self.threshold: | |
| # Exact attention for short sequences | |
| attn_output = self._exact_attention(q, k, v, attention_mask) | |
| else: | |
| # Linear attention for long sequences - FIXED VERSION | |
| attn_output = self._linear_attention_fixed(q, k, v, attention_mask) | |
| # Reshape back and project | |
| attn_output = attn_output.transpose(1, 2).contiguous().view( | |
| batch_size, seq_len, self.hidden_size | |
| ) | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output, None, None # Match expected HF signature | |
| def _exact_attention(self, q, k, v, attention_mask=None): | |
| """Standard O(L²) attention""" | |
| # q, k, v: [B, H, L, D] | |
| scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale # [B, H, L, L] | |
| if attention_mask is not None: | |
| # Apply mask | |
| mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L] | |
| scores = scores.masked_fill(mask == 0, -1e9) | |
| attn_weights = torch.softmax(scores, dim=-1) # [B, H, L, L] | |
| attn_output = torch.matmul(attn_weights, v) # [B, H, L, D] | |
| return attn_output | |
| def _linear_attention_fixed(self, q, k, v, attention_mask=None): | |
| """ | |
| FIXED Linear attention for O(L) complexity | |
| Properly handles dimensions for HuggingFace Spaces | |
| """ | |
| # q, k, v: [B, H, L, D] where D = head_dim | |
| batch_size, num_heads, seq_len, head_dim = q.shape | |
| # Apply feature mapping to reduce dimension | |
| # Reshape for feature mapping: [B*H*L, D] -> [B*H*L, F] | |
| q_reshaped = q.reshape(-1, head_dim) # [B*H*L, D] | |
| k_reshaped = k.reshape(-1, head_dim) # [B*H*L, D] | |
| q_feat = self.feature_map(q_reshaped) # [B*H*L, F] | |
| k_feat = self.feature_map(k_reshaped) # [B*H*L, F] | |
| # Reshape back: [B*H*L, F] -> [B, H, L, F] | |
| q_feat = q_feat.view(batch_size, num_heads, seq_len, self.feature_dim) | |
| k_feat = k_feat.view(batch_size, num_heads, seq_len, self.feature_dim) | |
| # Apply attention mask to keys if provided | |
| if attention_mask is not None: | |
| mask = attention_mask.unsqueeze(1).unsqueeze(-1) # [B, 1, L, 1] | |
| k_feat = k_feat * mask.float() | |
| # Linear attention computation - FIXED DIMENSIONS | |
| # Step 1: K^T @ V | |
| # k_feat: [B, H, L, F], v: [B, H, L, D] -> kv: [B, H, F, D] | |
| kv = torch.matmul(k_feat.transpose(-2, -1), v) # [B, H, F, D] | |
| # Step 2: Q @ (K^T @ V) | |
| # q_feat: [B, H, L, F], kv: [B, H, F, D] -> attn_output: [B, H, L, D] | |
| attn_output = torch.matmul(q_feat, kv) # [B, H, L, D] | |
| # Step 3: Normalization - FIXED | |
| # k_feat: [B, H, L, F] -> k_sum: [B, H, 1, F] | |
| k_sum = k_feat.sum(dim=-2, keepdim=True) # [B, H, 1, F] | |
| # q_feat: [B, H, L, F], k_sum: [B, H, 1, F] -> normalization: [B, H, L, 1] | |
| # Use einsum for clearer dimension handling | |
| normalization = torch.einsum('bhlf,bhf->bhl', q_feat, k_sum.squeeze(-2)) # [B, H, L] | |
| normalization = normalization.unsqueeze(-1) # [B, H, L, 1] | |
| # Prevent division by zero and normalize | |
| attn_output = attn_output / (normalization + 1e-8) | |
| return attn_output | |
| def create_hf_asi_attention(dim=768, num_heads=12, threshold=8, feature_dim=4): | |
| """Factory function for HF Spaces compatible ASI""" | |
| return HFCompatibleASIAttention( | |
| hidden_size=dim, | |
| num_heads=num_heads, | |
| threshold=threshold, | |
| feature_dim=feature_dim | |
| ) | |
| # Test function | |
| def test_hf_asi(): | |
| """Test the HF compatible ASI implementation""" | |
| batch_size, seq_len, hidden_size = 1, 512, 768 | |
| device = "cpu" # HF Spaces is CPU-only | |
| # Create test data | |
| hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device) | |
| # Create ASI attention | |
| asi_attention = create_hf_asi_attention(dim=hidden_size, threshold=8, feature_dim=4) | |
| asi_attention.to(device) | |
| # Test forward pass | |
| with torch.no_grad(): | |
| output, _, _ = asi_attention(hidden_states) | |
| print(f"✅ Input shape: {hidden_states.shape}") | |
| print(f"✅ Output shape: {output.shape}") | |
| print(f"✅ ASI test passed!") | |
| return True | |
| if __name__ == "__main__": | |
| test_hf_asi() | |