Update app.py
Browse files
app.py
CHANGED
|
@@ -171,9 +171,11 @@ class MultiScaleRetention(nn.Module):
|
|
| 171 |
batch_size, seq_len, self.hidden_size
|
| 172 |
)
|
| 173 |
|
| 174 |
-
# ✅ Group norm - ensure it's on the correct device
|
| 175 |
if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
|
| 176 |
-
self.group_norm = self.group_norm.to(retention_states.device)
|
|
|
|
|
|
|
| 177 |
|
| 178 |
retention_states = self.group_norm(
|
| 179 |
retention_states.transpose(1, 2)
|
|
@@ -277,14 +279,22 @@ class HierarchicalRetention(nn.Module):
|
|
| 277 |
if past_key_values is not None:
|
| 278 |
past_key_value = past_key_values
|
| 279 |
|
| 280 |
-
# ✅ Ensure all submodules are on correct device
|
| 281 |
target_device = hidden_states.device
|
|
|
|
|
|
|
| 282 |
if not next(self.short_proj.parameters()).is_cuda and hidden_states.is_cuda:
|
| 283 |
-
self.short_proj = self.short_proj.to(target_device)
|
| 284 |
-
self.medium_proj = self.medium_proj.to(target_device)
|
| 285 |
-
self.long_proj = self.long_proj.to(target_device)
|
| 286 |
-
self.fusion = self.fusion.to(target_device)
|
| 287 |
-
self.norm = self.norm.to(target_device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
# Base Retention
|
| 290 |
retention_output, attn_weights, past_kv = self.base_retention(
|
|
|
|
| 171 |
batch_size, seq_len, self.hidden_size
|
| 172 |
)
|
| 173 |
|
| 174 |
+
# ✅ Group norm - ensure it's on the correct device AND dtype
|
| 175 |
if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
|
| 176 |
+
self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype)
|
| 177 |
+
elif next(self.group_norm.parameters()).dtype != retention_states.dtype:
|
| 178 |
+
self.group_norm = self.group_norm.to(dtype=retention_states.dtype)
|
| 179 |
|
| 180 |
retention_states = self.group_norm(
|
| 181 |
retention_states.transpose(1, 2)
|
|
|
|
| 279 |
if past_key_values is not None:
|
| 280 |
past_key_value = past_key_values
|
| 281 |
|
| 282 |
+
# ✅ Ensure all submodules are on correct device AND dtype
|
| 283 |
target_device = hidden_states.device
|
| 284 |
+
target_dtype = hidden_states.dtype
|
| 285 |
+
|
| 286 |
if not next(self.short_proj.parameters()).is_cuda and hidden_states.is_cuda:
|
| 287 |
+
self.short_proj = self.short_proj.to(target_device, dtype=target_dtype)
|
| 288 |
+
self.medium_proj = self.medium_proj.to(target_device, dtype=target_dtype)
|
| 289 |
+
self.long_proj = self.long_proj.to(target_device, dtype=target_dtype)
|
| 290 |
+
self.fusion = self.fusion.to(target_device, dtype=target_dtype)
|
| 291 |
+
self.norm = self.norm.to(target_device, dtype=target_dtype)
|
| 292 |
+
elif next(self.short_proj.parameters()).dtype != target_dtype:
|
| 293 |
+
self.short_proj = self.short_proj.to(dtype=target_dtype)
|
| 294 |
+
self.medium_proj = self.medium_proj.to(dtype=target_dtype)
|
| 295 |
+
self.long_proj = self.long_proj.to(dtype=target_dtype)
|
| 296 |
+
self.fusion = self.fusion.to(dtype=target_dtype)
|
| 297 |
+
self.norm = self.norm.to(dtype=target_dtype)
|
| 298 |
|
| 299 |
# Base Retention
|
| 300 |
retention_output, attn_weights, past_kv = self.base_retention(
|