seawolf2357 commited on
Commit
c9f844d
·
verified ·
1 Parent(s): ad81847

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -12
app.py CHANGED
@@ -80,6 +80,10 @@ class MultiScaleRetention(nn.Module):
80
  self.kv_head_dim = self.head_dim # Same as Q head_dim
81
  self.kv_dim = self.num_key_value_heads * self.kv_head_dim
82
 
 
 
 
 
83
  print(f" 📐 Layer {layer_idx} Retention (GQA) initialized:")
84
  print(f" - hidden_size: {self.hidden_size}")
85
  print(f" - num_heads (Q): {self.num_heads}")
@@ -117,6 +121,11 @@ class MultiScaleRetention(nn.Module):
117
  batch, num_key_value_heads, n_rep, slen, head_dim
118
  )
119
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 
 
 
 
 
120
 
121
  def forward(
122
  self,
@@ -163,11 +172,17 @@ class MultiScaleRetention(nn.Module):
163
 
164
  # Now all have shape [B, num_heads, L, head_dim]
165
 
166
- # Retention computation
 
167
  retention_states, new_state = self._compute_retention(
168
- query_states, key_states, value_states, past_key_value
169
  )
170
 
 
 
 
 
 
171
  # Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden_size]
172
  retention_states = retention_states.transpose(1, 2).contiguous()
173
  retention_states = retention_states.reshape(
@@ -187,11 +202,11 @@ class MultiScaleRetention(nn.Module):
187
  # Output projection
188
  attn_output = self.o_proj(retention_states)
189
 
190
- # ✅ Return output and state for KV cache - Always return 3 values
191
- if use_cache:
192
- return (attn_output, None, new_state)
193
- else:
194
- return (attn_output, None, None)
195
 
196
  def _compute_retention(
197
  self,
@@ -354,11 +369,9 @@ class HierarchicalRetention(nn.Module):
354
  output = torch.stack(hierarchical_outputs, dim=1)
355
  output = self.norm(output)
356
 
357
- # ✅ Return with state for KV cache - Always return 3 values
358
- if use_cache and new_state is not None:
359
- return (output, None, new_state)
360
- else:
361
- return (output, None, None)
362
 
363
 
364
  # =====================================================
@@ -736,6 +749,15 @@ def generate_text_phoenix(
736
 
737
  print(f"✅ Converted {converted}/{total} layers")
738
 
 
 
 
 
 
 
 
 
 
739
  # 3. Tokenizer 로드
740
  try:
741
  tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
 
80
  self.kv_head_dim = self.head_dim # Same as Q head_dim
81
  self.kv_dim = self.num_key_value_heads * self.kv_head_dim
82
 
83
+ # ✅ Internal state storage for KV cache simulation
84
+ self.register_buffer('_internal_state', None, persistent=False)
85
+ self.register_buffer('_state_initialized', torch.tensor(False), persistent=False)
86
+
87
  print(f" 📐 Layer {layer_idx} Retention (GQA) initialized:")
88
  print(f" - hidden_size: {self.hidden_size}")
89
  print(f" - num_heads (Q): {self.num_heads}")
 
121
  batch, num_key_value_heads, n_rep, slen, head_dim
122
  )
123
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
124
+
125
+ def reset_state(self):
126
+ """Reset internal state (call at start of new sequence)"""
127
+ self._internal_state = None
128
+ self._state_initialized = torch.tensor(False)
129
 
130
  def forward(
131
  self,
 
172
 
173
  # Now all have shape [B, num_heads, L, head_dim]
174
 
175
+ # Retention computation with internal state
176
+ past_state = self._internal_state if (use_cache and self._state_initialized) else None
177
  retention_states, new_state = self._compute_retention(
178
+ query_states, key_states, value_states, past_state
179
  )
180
 
181
+ # ✅ Store state internally for next iteration
182
+ if use_cache:
183
+ self._internal_state = new_state.detach()
184
+ self._state_initialized = torch.tensor(True)
185
+
186
  # Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden_size]
187
  retention_states = retention_states.transpose(1, 2).contiguous()
188
  retention_states = retention_states.reshape(
 
202
  # Output projection
203
  attn_output = self.o_proj(retention_states)
204
 
205
+ # ✅ Return format for compatibility
206
+ # Granite expects: (hidden_states, attn_weights)
207
+ # We return: (output, None) - no past_key_values in return signature
208
+ # State is stored internally but not returned
209
+ return (attn_output, None)
210
 
211
  def _compute_retention(
212
  self,
 
369
  output = torch.stack(hierarchical_outputs, dim=1)
370
  output = self.norm(output)
371
 
372
+ # ✅ Return format for compatibility with Granite
373
+ # Granite expects: (hidden_states, attn_weights)
374
+ return (output, None)
 
 
375
 
376
 
377
  # =====================================================
 
749
 
750
  print(f"✅ Converted {converted}/{total} layers")
751
 
752
+ # ✅ Reset all retention states before generation
753
+ print(f"🔄 Resetting retention states...")
754
+ for layer in model.model.layers:
755
+ if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'reset_state'):
756
+ layer.self_attn.reset_state()
757
+ elif hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'base_retention'):
758
+ if hasattr(layer.self_attn.base_retention, 'reset_state'):
759
+ layer.self_attn.base_retention.reset_state()
760
+
761
  # 3. Tokenizer 로드
762
  try:
763
  tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)