Update app.py
Browse files
app.py
CHANGED
|
@@ -149,6 +149,17 @@ class MultiScaleRetention(nn.Module):
|
|
| 149 |
if past_key_values is not None:
|
| 150 |
past_key_value = past_key_values
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
# Q, K, V projections
|
| 153 |
query_states = self.q_proj(hidden_states)
|
| 154 |
key_states = self.k_proj(hidden_states)
|
|
@@ -561,6 +572,17 @@ class MultiScaleRetention(nn.Module):
|
|
| 561 |
if past_key_values is not None:
|
| 562 |
past_key_value = past_key_values
|
| 563 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
query_states = self.q_proj(hidden_states)
|
| 565 |
key_states = self.k_proj(hidden_states)
|
| 566 |
value_states = self.v_proj(hidden_states)
|
|
|
|
| 149 |
if past_key_values is not None:
|
| 150 |
past_key_value = past_key_values
|
| 151 |
|
| 152 |
+
# ✅ FIX: Ensure all projection layers match input dtype/device
|
| 153 |
+
target_device = hidden_states.device
|
| 154 |
+
target_dtype = hidden_states.dtype
|
| 155 |
+
|
| 156 |
+
if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype:
|
| 157 |
+
self.q_proj = self.q_proj.to(device=target_device, dtype=target_dtype)
|
| 158 |
+
self.k_proj = self.k_proj.to(device=target_device, dtype=target_dtype)
|
| 159 |
+
self.v_proj = self.v_proj.to(device=target_device, dtype=target_dtype)
|
| 160 |
+
self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype)
|
| 161 |
+
self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype)
|
| 162 |
+
|
| 163 |
# Q, K, V projections
|
| 164 |
query_states = self.q_proj(hidden_states)
|
| 165 |
key_states = self.k_proj(hidden_states)
|
|
|
|
| 572 |
if past_key_values is not None:
|
| 573 |
past_key_value = past_key_values
|
| 574 |
|
| 575 |
+
# ✅ FIX: Ensure all projection layers match input dtype/device
|
| 576 |
+
target_device = hidden_states.device
|
| 577 |
+
target_dtype = hidden_states.dtype
|
| 578 |
+
|
| 579 |
+
if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype:
|
| 580 |
+
self.q_proj = self.q_proj.to(device=target_device, dtype=target_dtype)
|
| 581 |
+
self.k_proj = self.k_proj.to(device=target_device, dtype=target_dtype)
|
| 582 |
+
self.v_proj = self.v_proj.to(device=target_device, dtype=target_dtype)
|
| 583 |
+
self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype)
|
| 584 |
+
self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype)
|
| 585 |
+
|
| 586 |
query_states = self.q_proj(hidden_states)
|
| 587 |
key_states = self.k_proj(hidden_states)
|
| 588 |
value_states = self.v_proj(hidden_states)
|