seawolf2357 commited on
Commit
41f8d59
·
verified ·
1 Parent(s): 57a0735

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -171,7 +171,10 @@ class MultiScaleRetention(nn.Module):
171
  batch_size, seq_len, self.hidden_size
172
  )
173
 
174
- # Group norm
 
 
 
175
  retention_states = self.group_norm(
176
  retention_states.transpose(1, 2)
177
  ).transpose(1, 2)
@@ -274,6 +277,15 @@ class HierarchicalRetention(nn.Module):
274
  if past_key_values is not None:
275
  past_key_value = past_key_values
276
 
 
 
 
 
 
 
 
 
 
277
  # Base Retention
278
  retention_output, attn_weights, past_kv = self.base_retention(
279
  hidden_states, attention_mask, position_ids,
@@ -281,9 +293,9 @@ class HierarchicalRetention(nn.Module):
281
  )
282
 
283
  # Hierarchical states
284
- short_state = torch.zeros(batch_size, self.d_state).to(hidden_states.device)
285
- medium_state = torch.zeros(batch_size, self.d_state).to(hidden_states.device)
286
- long_state = torch.zeros(batch_size, self.d_state * 2).to(hidden_states.device)
287
 
288
  hierarchical_outputs = []
289
 
 
171
  batch_size, seq_len, self.hidden_size
172
  )
173
 
174
+ # Group norm - ensure it's on the correct device
175
+ if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
176
+ self.group_norm = self.group_norm.to(retention_states.device)
177
+
178
  retention_states = self.group_norm(
179
  retention_states.transpose(1, 2)
180
  ).transpose(1, 2)
 
277
  if past_key_values is not None:
278
  past_key_value = past_key_values
279
 
280
+ # ✅ Ensure all submodules are on correct device
281
+ target_device = hidden_states.device
282
+ if not next(self.short_proj.parameters()).is_cuda and hidden_states.is_cuda:
283
+ self.short_proj = self.short_proj.to(target_device)
284
+ self.medium_proj = self.medium_proj.to(target_device)
285
+ self.long_proj = self.long_proj.to(target_device)
286
+ self.fusion = self.fusion.to(target_device)
287
+ self.norm = self.norm.to(target_device)
288
+
289
  # Base Retention
290
  retention_output, attn_weights, past_kv = self.base_retention(
291
  hidden_states, attention_mask, position_ids,
 
293
  )
294
 
295
  # Hierarchical states
296
+ short_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
297
+ medium_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
298
+ long_state = torch.zeros(batch_size, self.d_state * 2, dtype=hidden_states.dtype, device=target_device)
299
 
300
  hierarchical_outputs = []
301