Spaces:
Running
Running
File size: 6,881 Bytes
9bbe2d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
#!/usr/bin/env python3
"""
ASI V2.5 - HuggingFace Spaces Compatible Version
Optimized for CPU environment with 16GB RAM limitation
Fixed all dimension errors and optimized for Spaces hardware
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class HFCompatibleASIAttention(nn.Module):
"""
ASI V2.5 Compatible with HuggingFace Spaces
Key fixes:
- Proper dimension handling for CPU environment
- Memory optimized for 16GB RAM limit
- No GPU dependencies
- Fixed matrix multiplication errors
"""
def __init__(self, hidden_size=768, num_heads=12, threshold=8, feature_dim=4):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.threshold = threshold
self.feature_dim = feature_dim
# Validation
assert hidden_size % num_heads == 0, f"hidden_size {hidden_size} not divisible by num_heads {num_heads}"
# Standard attention projections
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
# ASI feature mapping - FIXED dimensions
# Map from head_dim to feature_dim for each head
self.feature_map = nn.Linear(self.head_dim, feature_dim, bias=False)
self.scale = (self.head_dim ** -0.5)
def forward(self, hidden_states, attention_mask=None, **kwargs):
"""
Fixed forward pass with proper dimension handling
"""
batch_size, seq_len, _ = hidden_states.shape
# Project to Q, K, V
q = self.q_proj(hidden_states) # [B, L, H]
k = self.k_proj(hidden_states) # [B, L, H]
v = self.v_proj(hidden_states) # [B, L, H]
# Reshape for multi-head attention
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L, D]
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L, D]
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L, D]
# ASI adaptive attention
if seq_len <= self.threshold:
# Exact attention for short sequences
attn_output = self._exact_attention(q, k, v, attention_mask)
else:
# Linear attention for long sequences - FIXED VERSION
attn_output = self._linear_attention_fixed(q, k, v, attention_mask)
# Reshape back and project
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.hidden_size
)
attn_output = self.o_proj(attn_output)
return attn_output, None, None # Match expected HF signature
def _exact_attention(self, q, k, v, attention_mask=None):
"""Standard O(L²) attention"""
# q, k, v: [B, H, L, D]
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale # [B, H, L, L]
if attention_mask is not None:
# Apply mask
mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L]
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1) # [B, H, L, L]
attn_output = torch.matmul(attn_weights, v) # [B, H, L, D]
return attn_output
def _linear_attention_fixed(self, q, k, v, attention_mask=None):
"""
FIXED Linear attention for O(L) complexity
Properly handles dimensions for HuggingFace Spaces
"""
# q, k, v: [B, H, L, D] where D = head_dim
batch_size, num_heads, seq_len, head_dim = q.shape
# Apply feature mapping to reduce dimension
# Reshape for feature mapping: [B*H*L, D] -> [B*H*L, F]
q_reshaped = q.reshape(-1, head_dim) # [B*H*L, D]
k_reshaped = k.reshape(-1, head_dim) # [B*H*L, D]
q_feat = self.feature_map(q_reshaped) # [B*H*L, F]
k_feat = self.feature_map(k_reshaped) # [B*H*L, F]
# Reshape back: [B*H*L, F] -> [B, H, L, F]
q_feat = q_feat.view(batch_size, num_heads, seq_len, self.feature_dim)
k_feat = k_feat.view(batch_size, num_heads, seq_len, self.feature_dim)
# Apply attention mask to keys if provided
if attention_mask is not None:
mask = attention_mask.unsqueeze(1).unsqueeze(-1) # [B, 1, L, 1]
k_feat = k_feat * mask.float()
# Linear attention computation - FIXED DIMENSIONS
# Step 1: K^T @ V
# k_feat: [B, H, L, F], v: [B, H, L, D] -> kv: [B, H, F, D]
kv = torch.matmul(k_feat.transpose(-2, -1), v) # [B, H, F, D]
# Step 2: Q @ (K^T @ V)
# q_feat: [B, H, L, F], kv: [B, H, F, D] -> attn_output: [B, H, L, D]
attn_output = torch.matmul(q_feat, kv) # [B, H, L, D]
# Step 3: Normalization - FIXED
# k_feat: [B, H, L, F] -> k_sum: [B, H, 1, F]
k_sum = k_feat.sum(dim=-2, keepdim=True) # [B, H, 1, F]
# q_feat: [B, H, L, F], k_sum: [B, H, 1, F] -> normalization: [B, H, L, 1]
# Use einsum for clearer dimension handling
normalization = torch.einsum('bhlf,bhf->bhl', q_feat, k_sum.squeeze(-2)) # [B, H, L]
normalization = normalization.unsqueeze(-1) # [B, H, L, 1]
# Prevent division by zero and normalize
attn_output = attn_output / (normalization + 1e-8)
return attn_output
def create_hf_asi_attention(dim=768, num_heads=12, threshold=8, feature_dim=4):
"""Factory function for HF Spaces compatible ASI"""
return HFCompatibleASIAttention(
hidden_size=dim,
num_heads=num_heads,
threshold=threshold,
feature_dim=feature_dim
)
# Test function
def test_hf_asi():
"""Test the HF compatible ASI implementation"""
batch_size, seq_len, hidden_size = 1, 512, 768
device = "cpu" # HF Spaces is CPU-only
# Create test data
hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device)
# Create ASI attention
asi_attention = create_hf_asi_attention(dim=hidden_size, threshold=8, feature_dim=4)
asi_attention.to(device)
# Test forward pass
with torch.no_grad():
output, _, _ = asi_attention(hidden_states)
print(f"✅ Input shape: {hidden_states.shape}")
print(f"✅ Output shape: {output.shape}")
print(f"✅ ASI test passed!")
return True
if __name__ == "__main__":
test_hf_asi()
|