File size: 6,881 Bytes
9bbe2d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python3
"""
ASI V2.5 - HuggingFace Spaces Compatible Version
Optimized for CPU environment with 16GB RAM limitation

Fixed all dimension errors and optimized for Spaces hardware
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional

class HFCompatibleASIAttention(nn.Module):
    """
    ASI V2.5 Compatible with HuggingFace Spaces
    
    Key fixes:
    - Proper dimension handling for CPU environment
    - Memory optimized for 16GB RAM limit
    - No GPU dependencies
    - Fixed matrix multiplication errors
    """
    
    def __init__(self, hidden_size=768, num_heads=12, threshold=8, feature_dim=4):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.threshold = threshold
        self.feature_dim = feature_dim
        
        # Validation
        assert hidden_size % num_heads == 0, f"hidden_size {hidden_size} not divisible by num_heads {num_heads}"
        
        # Standard attention projections
        self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
        self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
        self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
        self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
        
        # ASI feature mapping - FIXED dimensions
        # Map from head_dim to feature_dim for each head
        self.feature_map = nn.Linear(self.head_dim, feature_dim, bias=False)
        
        self.scale = (self.head_dim ** -0.5)
        
    def forward(self, hidden_states, attention_mask=None, **kwargs):
        """
        Fixed forward pass with proper dimension handling
        """
        batch_size, seq_len, _ = hidden_states.shape
        
        # Project to Q, K, V
        q = self.q_proj(hidden_states)  # [B, L, H]
        k = self.k_proj(hidden_states)  # [B, L, H]
        v = self.v_proj(hidden_states)  # [B, L, H]
        
        # Reshape for multi-head attention
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, L, D]
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, L, D]
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, L, D]
        
        # ASI adaptive attention
        if seq_len <= self.threshold:
            # Exact attention for short sequences
            attn_output = self._exact_attention(q, k, v, attention_mask)
        else:
            # Linear attention for long sequences - FIXED VERSION
            attn_output = self._linear_attention_fixed(q, k, v, attention_mask)
        
        # Reshape back and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.hidden_size
        )
        attn_output = self.o_proj(attn_output)
        
        return attn_output, None, None  # Match expected HF signature
    
    def _exact_attention(self, q, k, v, attention_mask=None):
        """Standard O(L²) attention"""
        # q, k, v: [B, H, L, D]
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # [B, H, L, L]
        
        if attention_mask is not None:
            # Apply mask
            mask = attention_mask.unsqueeze(1).unsqueeze(1)  # [B, 1, 1, L]
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = torch.softmax(scores, dim=-1)  # [B, H, L, L]
        attn_output = torch.matmul(attn_weights, v)  # [B, H, L, D]
        
        return attn_output
    
    def _linear_attention_fixed(self, q, k, v, attention_mask=None):
        """
        FIXED Linear attention for O(L) complexity
        Properly handles dimensions for HuggingFace Spaces
        """
        # q, k, v: [B, H, L, D] where D = head_dim
        batch_size, num_heads, seq_len, head_dim = q.shape
        
        # Apply feature mapping to reduce dimension
        # Reshape for feature mapping: [B*H*L, D] -> [B*H*L, F]
        q_reshaped = q.reshape(-1, head_dim)  # [B*H*L, D]
        k_reshaped = k.reshape(-1, head_dim)  # [B*H*L, D]
        
        q_feat = self.feature_map(q_reshaped)  # [B*H*L, F]
        k_feat = self.feature_map(k_reshaped)  # [B*H*L, F]
        
        # Reshape back: [B*H*L, F] -> [B, H, L, F]
        q_feat = q_feat.view(batch_size, num_heads, seq_len, self.feature_dim)
        k_feat = k_feat.view(batch_size, num_heads, seq_len, self.feature_dim)
        
        # Apply attention mask to keys if provided
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(1).unsqueeze(-1)  # [B, 1, L, 1]
            k_feat = k_feat * mask.float()
        
        # Linear attention computation - FIXED DIMENSIONS
        # Step 1: K^T @ V  
        # k_feat: [B, H, L, F], v: [B, H, L, D] -> kv: [B, H, F, D]
        kv = torch.matmul(k_feat.transpose(-2, -1), v)  # [B, H, F, D]
        
        # Step 2: Q @ (K^T @ V)
        # q_feat: [B, H, L, F], kv: [B, H, F, D] -> attn_output: [B, H, L, D]
        attn_output = torch.matmul(q_feat, kv)  # [B, H, L, D]
        
        # Step 3: Normalization - FIXED
        # k_feat: [B, H, L, F] -> k_sum: [B, H, 1, F]
        k_sum = k_feat.sum(dim=-2, keepdim=True)  # [B, H, 1, F]
        
        # q_feat: [B, H, L, F], k_sum: [B, H, 1, F] -> normalization: [B, H, L, 1]
        # Use einsum for clearer dimension handling
        normalization = torch.einsum('bhlf,bhf->bhl', q_feat, k_sum.squeeze(-2))  # [B, H, L]
        normalization = normalization.unsqueeze(-1)  # [B, H, L, 1]
        
        # Prevent division by zero and normalize
        attn_output = attn_output / (normalization + 1e-8)
        
        return attn_output

def create_hf_asi_attention(dim=768, num_heads=12, threshold=8, feature_dim=4):
    """Factory function for HF Spaces compatible ASI"""
    return HFCompatibleASIAttention(
        hidden_size=dim,
        num_heads=num_heads,
        threshold=threshold,
        feature_dim=feature_dim
    )

# Test function
def test_hf_asi():
    """Test the HF compatible ASI implementation"""
    batch_size, seq_len, hidden_size = 1, 512, 768
    device = "cpu"  # HF Spaces is CPU-only
    
    # Create test data
    hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device)
    
    # Create ASI attention
    asi_attention = create_hf_asi_attention(dim=hidden_size, threshold=8, feature_dim=4)
    asi_attention.to(device)
    
    # Test forward pass
    with torch.no_grad():
        output, _, _ = asi_attention(hidden_states)
    
    print(f"✅ Input shape: {hidden_states.shape}")
    print(f"✅ Output shape: {output.shape}")
    print(f"✅ ASI test passed!")
    
    return True

if __name__ == "__main__":
    test_hf_asi()