seawolf2357 commited on
Commit
de99383
·
verified ·
1 Parent(s): 238a77b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -0
app.py CHANGED
@@ -149,6 +149,17 @@ class MultiScaleRetention(nn.Module):
149
  if past_key_values is not None:
150
  past_key_value = past_key_values
151
 
 
 
 
 
 
 
 
 
 
 
 
152
  # Q, K, V projections
153
  query_states = self.q_proj(hidden_states)
154
  key_states = self.k_proj(hidden_states)
@@ -561,6 +572,17 @@ class MultiScaleRetention(nn.Module):
561
  if past_key_values is not None:
562
  past_key_value = past_key_values
563
 
 
 
 
 
 
 
 
 
 
 
 
564
  query_states = self.q_proj(hidden_states)
565
  key_states = self.k_proj(hidden_states)
566
  value_states = self.v_proj(hidden_states)
 
149
  if past_key_values is not None:
150
  past_key_value = past_key_values
151
 
152
+ # ✅ FIX: Ensure all projection layers match input dtype/device
153
+ target_device = hidden_states.device
154
+ target_dtype = hidden_states.dtype
155
+
156
+ if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype:
157
+ self.q_proj = self.q_proj.to(device=target_device, dtype=target_dtype)
158
+ self.k_proj = self.k_proj.to(device=target_device, dtype=target_dtype)
159
+ self.v_proj = self.v_proj.to(device=target_device, dtype=target_dtype)
160
+ self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype)
161
+ self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype)
162
+
163
  # Q, K, V projections
164
  query_states = self.q_proj(hidden_states)
165
  key_states = self.k_proj(hidden_states)
 
572
  if past_key_values is not None:
573
  past_key_value = past_key_values
574
 
575
+ # ✅ FIX: Ensure all projection layers match input dtype/device
576
+ target_device = hidden_states.device
577
+ target_dtype = hidden_states.dtype
578
+
579
+ if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype:
580
+ self.q_proj = self.q_proj.to(device=target_device, dtype=target_dtype)
581
+ self.k_proj = self.k_proj.to(device=target_device, dtype=target_dtype)
582
+ self.v_proj = self.v_proj.to(device=target_device, dtype=target_dtype)
583
+ self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype)
584
+ self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype)
585
+
586
  query_states = self.q_proj(hidden_states)
587
  key_states = self.k_proj(hidden_states)
588
  value_states = self.v_proj(hidden_states)