seawolf2357 commited on
Commit
8c55f6d
·
verified ·
1 Parent(s): 41f8d59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -171,9 +171,11 @@ class MultiScaleRetention(nn.Module):
171
  batch_size, seq_len, self.hidden_size
172
  )
173
 
174
- # ✅ Group norm - ensure it's on the correct device
175
  if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
176
- self.group_norm = self.group_norm.to(retention_states.device)
 
 
177
 
178
  retention_states = self.group_norm(
179
  retention_states.transpose(1, 2)
@@ -277,14 +279,22 @@ class HierarchicalRetention(nn.Module):
277
  if past_key_values is not None:
278
  past_key_value = past_key_values
279
 
280
- # ✅ Ensure all submodules are on correct device
281
  target_device = hidden_states.device
 
 
282
  if not next(self.short_proj.parameters()).is_cuda and hidden_states.is_cuda:
283
- self.short_proj = self.short_proj.to(target_device)
284
- self.medium_proj = self.medium_proj.to(target_device)
285
- self.long_proj = self.long_proj.to(target_device)
286
- self.fusion = self.fusion.to(target_device)
287
- self.norm = self.norm.to(target_device)
 
 
 
 
 
 
288
 
289
  # Base Retention
290
  retention_output, attn_weights, past_kv = self.base_retention(
 
171
  batch_size, seq_len, self.hidden_size
172
  )
173
 
174
+ # ✅ Group norm - ensure it's on the correct device AND dtype
175
  if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
176
+ self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype)
177
+ elif next(self.group_norm.parameters()).dtype != retention_states.dtype:
178
+ self.group_norm = self.group_norm.to(dtype=retention_states.dtype)
179
 
180
  retention_states = self.group_norm(
181
  retention_states.transpose(1, 2)
 
279
  if past_key_values is not None:
280
  past_key_value = past_key_values
281
 
282
+ # ✅ Ensure all submodules are on correct device AND dtype
283
  target_device = hidden_states.device
284
+ target_dtype = hidden_states.dtype
285
+
286
  if not next(self.short_proj.parameters()).is_cuda and hidden_states.is_cuda:
287
+ self.short_proj = self.short_proj.to(target_device, dtype=target_dtype)
288
+ self.medium_proj = self.medium_proj.to(target_device, dtype=target_dtype)
289
+ self.long_proj = self.long_proj.to(target_device, dtype=target_dtype)
290
+ self.fusion = self.fusion.to(target_device, dtype=target_dtype)
291
+ self.norm = self.norm.to(target_device, dtype=target_dtype)
292
+ elif next(self.short_proj.parameters()).dtype != target_dtype:
293
+ self.short_proj = self.short_proj.to(dtype=target_dtype)
294
+ self.medium_proj = self.medium_proj.to(dtype=target_dtype)
295
+ self.long_proj = self.long_proj.to(dtype=target_dtype)
296
+ self.fusion = self.fusion.to(dtype=target_dtype)
297
+ self.norm = self.norm.to(dtype=target_dtype)
298
 
299
  # Base Retention
300
  retention_output, attn_weights, past_kv = self.base_retention(