Update app.py
Browse files
app.py
CHANGED
|
@@ -80,6 +80,10 @@ class MultiScaleRetention(nn.Module):
|
|
| 80 |
self.kv_head_dim = self.head_dim # Same as Q head_dim
|
| 81 |
self.kv_dim = self.num_key_value_heads * self.kv_head_dim
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
print(f" 📐 Layer {layer_idx} Retention (GQA) initialized:")
|
| 84 |
print(f" - hidden_size: {self.hidden_size}")
|
| 85 |
print(f" - num_heads (Q): {self.num_heads}")
|
|
@@ -117,6 +121,11 @@ class MultiScaleRetention(nn.Module):
|
|
| 117 |
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 118 |
)
|
| 119 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
def forward(
|
| 122 |
self,
|
|
@@ -163,11 +172,17 @@ class MultiScaleRetention(nn.Module):
|
|
| 163 |
|
| 164 |
# Now all have shape [B, num_heads, L, head_dim]
|
| 165 |
|
| 166 |
-
# Retention computation
|
|
|
|
| 167 |
retention_states, new_state = self._compute_retention(
|
| 168 |
-
query_states, key_states, value_states,
|
| 169 |
)
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
# Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden_size]
|
| 172 |
retention_states = retention_states.transpose(1, 2).contiguous()
|
| 173 |
retention_states = retention_states.reshape(
|
|
@@ -187,11 +202,11 @@ class MultiScaleRetention(nn.Module):
|
|
| 187 |
# Output projection
|
| 188 |
attn_output = self.o_proj(retention_states)
|
| 189 |
|
| 190 |
-
# ✅ Return
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
|
| 196 |
def _compute_retention(
|
| 197 |
self,
|
|
@@ -354,11 +369,9 @@ class HierarchicalRetention(nn.Module):
|
|
| 354 |
output = torch.stack(hierarchical_outputs, dim=1)
|
| 355 |
output = self.norm(output)
|
| 356 |
|
| 357 |
-
# ✅ Return
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
else:
|
| 361 |
-
return (output, None, None)
|
| 362 |
|
| 363 |
|
| 364 |
# =====================================================
|
|
@@ -736,6 +749,15 @@ def generate_text_phoenix(
|
|
| 736 |
|
| 737 |
print(f"✅ Converted {converted}/{total} layers")
|
| 738 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 739 |
# 3. Tokenizer 로드
|
| 740 |
try:
|
| 741 |
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
|
|
|
|
| 80 |
self.kv_head_dim = self.head_dim # Same as Q head_dim
|
| 81 |
self.kv_dim = self.num_key_value_heads * self.kv_head_dim
|
| 82 |
|
| 83 |
+
# ✅ Internal state storage for KV cache simulation
|
| 84 |
+
self.register_buffer('_internal_state', None, persistent=False)
|
| 85 |
+
self.register_buffer('_state_initialized', torch.tensor(False), persistent=False)
|
| 86 |
+
|
| 87 |
print(f" 📐 Layer {layer_idx} Retention (GQA) initialized:")
|
| 88 |
print(f" - hidden_size: {self.hidden_size}")
|
| 89 |
print(f" - num_heads (Q): {self.num_heads}")
|
|
|
|
| 121 |
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 122 |
)
|
| 123 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 124 |
+
|
| 125 |
+
def reset_state(self):
|
| 126 |
+
"""Reset internal state (call at start of new sequence)"""
|
| 127 |
+
self._internal_state = None
|
| 128 |
+
self._state_initialized = torch.tensor(False)
|
| 129 |
|
| 130 |
def forward(
|
| 131 |
self,
|
|
|
|
| 172 |
|
| 173 |
# Now all have shape [B, num_heads, L, head_dim]
|
| 174 |
|
| 175 |
+
# Retention computation with internal state
|
| 176 |
+
past_state = self._internal_state if (use_cache and self._state_initialized) else None
|
| 177 |
retention_states, new_state = self._compute_retention(
|
| 178 |
+
query_states, key_states, value_states, past_state
|
| 179 |
)
|
| 180 |
|
| 181 |
+
# ✅ Store state internally for next iteration
|
| 182 |
+
if use_cache:
|
| 183 |
+
self._internal_state = new_state.detach()
|
| 184 |
+
self._state_initialized = torch.tensor(True)
|
| 185 |
+
|
| 186 |
# Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden_size]
|
| 187 |
retention_states = retention_states.transpose(1, 2).contiguous()
|
| 188 |
retention_states = retention_states.reshape(
|
|
|
|
| 202 |
# Output projection
|
| 203 |
attn_output = self.o_proj(retention_states)
|
| 204 |
|
| 205 |
+
# ✅ Return format for compatibility
|
| 206 |
+
# Granite expects: (hidden_states, attn_weights)
|
| 207 |
+
# We return: (output, None) - no past_key_values in return signature
|
| 208 |
+
# State is stored internally but not returned
|
| 209 |
+
return (attn_output, None)
|
| 210 |
|
| 211 |
def _compute_retention(
|
| 212 |
self,
|
|
|
|
| 369 |
output = torch.stack(hierarchical_outputs, dim=1)
|
| 370 |
output = self.norm(output)
|
| 371 |
|
| 372 |
+
# ✅ Return format for compatibility with Granite
|
| 373 |
+
# Granite expects: (hidden_states, attn_weights)
|
| 374 |
+
return (output, None)
|
|
|
|
|
|
|
| 375 |
|
| 376 |
|
| 377 |
# =====================================================
|
|
|
|
| 749 |
|
| 750 |
print(f"✅ Converted {converted}/{total} layers")
|
| 751 |
|
| 752 |
+
# ✅ Reset all retention states before generation
|
| 753 |
+
print(f"🔄 Resetting retention states...")
|
| 754 |
+
for layer in model.model.layers:
|
| 755 |
+
if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'reset_state'):
|
| 756 |
+
layer.self_attn.reset_state()
|
| 757 |
+
elif hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'base_retention'):
|
| 758 |
+
if hasattr(layer.self_attn.base_retention, 'reset_state'):
|
| 759 |
+
layer.self_attn.base_retention.reset_state()
|
| 760 |
+
|
| 761 |
# 3. Tokenizer 로드
|
| 762 |
try:
|
| 763 |
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
|