Update app.py
Browse files
app.py
CHANGED
|
@@ -171,7 +171,10 @@ class MultiScaleRetention(nn.Module):
|
|
| 171 |
batch_size, seq_len, self.hidden_size
|
| 172 |
)
|
| 173 |
|
| 174 |
-
# Group norm
|
|
|
|
|
|
|
|
|
|
| 175 |
retention_states = self.group_norm(
|
| 176 |
retention_states.transpose(1, 2)
|
| 177 |
).transpose(1, 2)
|
|
@@ -274,6 +277,15 @@ class HierarchicalRetention(nn.Module):
|
|
| 274 |
if past_key_values is not None:
|
| 275 |
past_key_value = past_key_values
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
# Base Retention
|
| 278 |
retention_output, attn_weights, past_kv = self.base_retention(
|
| 279 |
hidden_states, attention_mask, position_ids,
|
|
@@ -281,9 +293,9 @@ class HierarchicalRetention(nn.Module):
|
|
| 281 |
)
|
| 282 |
|
| 283 |
# Hierarchical states
|
| 284 |
-
short_state = torch.zeros(batch_size, self.d_state
|
| 285 |
-
medium_state = torch.zeros(batch_size, self.d_state
|
| 286 |
-
long_state = torch.zeros(batch_size, self.d_state * 2
|
| 287 |
|
| 288 |
hierarchical_outputs = []
|
| 289 |
|
|
|
|
| 171 |
batch_size, seq_len, self.hidden_size
|
| 172 |
)
|
| 173 |
|
| 174 |
+
# ✅ Group norm - ensure it's on the correct device
|
| 175 |
+
if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
|
| 176 |
+
self.group_norm = self.group_norm.to(retention_states.device)
|
| 177 |
+
|
| 178 |
retention_states = self.group_norm(
|
| 179 |
retention_states.transpose(1, 2)
|
| 180 |
).transpose(1, 2)
|
|
|
|
| 277 |
if past_key_values is not None:
|
| 278 |
past_key_value = past_key_values
|
| 279 |
|
| 280 |
+
# ✅ Ensure all submodules are on correct device
|
| 281 |
+
target_device = hidden_states.device
|
| 282 |
+
if not next(self.short_proj.parameters()).is_cuda and hidden_states.is_cuda:
|
| 283 |
+
self.short_proj = self.short_proj.to(target_device)
|
| 284 |
+
self.medium_proj = self.medium_proj.to(target_device)
|
| 285 |
+
self.long_proj = self.long_proj.to(target_device)
|
| 286 |
+
self.fusion = self.fusion.to(target_device)
|
| 287 |
+
self.norm = self.norm.to(target_device)
|
| 288 |
+
|
| 289 |
# Base Retention
|
| 290 |
retention_output, attn_weights, past_kv = self.base_retention(
|
| 291 |
hidden_states, attention_mask, position_ids,
|
|
|
|
| 293 |
)
|
| 294 |
|
| 295 |
# Hierarchical states
|
| 296 |
+
short_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
|
| 297 |
+
medium_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
|
| 298 |
+
long_state = torch.zeros(batch_size, self.d_state * 2, dtype=hidden_states.dtype, device=target_device)
|
| 299 |
|
| 300 |
hierarchical_outputs = []
|
| 301 |
|