Update app.py
Browse files
app.py
CHANGED
|
@@ -301,18 +301,16 @@ class HierarchicalRetention(nn.Module):
|
|
| 301 |
target_device = hidden_states.device
|
| 302 |
target_dtype = hidden_states.dtype
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
self.
|
| 310 |
-
|
| 311 |
-
self.
|
| 312 |
-
self.
|
| 313 |
-
self.
|
| 314 |
-
self.fusion = self.fusion.to(dtype=target_dtype)
|
| 315 |
-
self.norm = self.norm.to(dtype=target_dtype)
|
| 316 |
|
| 317 |
base_result = self.base_retention(
|
| 318 |
hidden_states, attention_mask, position_ids,
|
|
@@ -322,9 +320,9 @@ class HierarchicalRetention(nn.Module):
|
|
| 322 |
retention_output = base_result[0]
|
| 323 |
|
| 324 |
# Hierarchical states
|
| 325 |
-
short_state = torch.zeros(batch_size, self.d_state, dtype=
|
| 326 |
-
medium_state = torch.zeros(batch_size, self.d_state, dtype=
|
| 327 |
-
long_state = torch.zeros(batch_size, self.d_state * 2, dtype=
|
| 328 |
|
| 329 |
hierarchical_outputs = []
|
| 330 |
|
|
@@ -686,12 +684,16 @@ class HierarchicalRetention(nn.Module):
|
|
| 686 |
target_device = hidden_states.device
|
| 687 |
target_dtype = hidden_states.dtype
|
| 688 |
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
|
| 696 |
base_result = self.base_retention(
|
| 697 |
hidden_states, attention_mask, position_ids,
|
|
@@ -871,7 +873,8 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
|
|
| 871 |
|
| 872 |
# Auto-registration
|
| 873 |
AutoConfig.register("phoenix", PhoenixConfig)
|
| 874 |
-
'''
|
|
|
|
| 875 |
return modeling_code
|
| 876 |
|
| 877 |
|
|
|
|
| 301 |
target_device = hidden_states.device
|
| 302 |
target_dtype = hidden_states.dtype
|
| 303 |
|
| 304 |
+
# ✅ 개선된 dtype/device 체크
|
| 305 |
+
current_device = next(self.short_proj.parameters()).device
|
| 306 |
+
current_dtype = next(self.short_proj.parameters()).dtype
|
| 307 |
+
|
| 308 |
+
if current_device != target_device or current_dtype != target_dtype:
|
| 309 |
+
self.short_proj = self.short_proj.to(device=target_device, dtype=target_dtype)
|
| 310 |
+
self.medium_proj = self.medium_proj.to(device=target_device, dtype=target_dtype)
|
| 311 |
+
self.long_proj = self.long_proj.to(device=target_device, dtype=target_dtype)
|
| 312 |
+
self.fusion = self.fusion.to(device=target_device, dtype=target_dtype)
|
| 313 |
+
self.norm = self.norm.to(device=target_device, dtype=target_dtype)
|
|
|
|
|
|
|
| 314 |
|
| 315 |
base_result = self.base_retention(
|
| 316 |
hidden_states, attention_mask, position_ids,
|
|
|
|
| 320 |
retention_output = base_result[0]
|
| 321 |
|
| 322 |
# Hierarchical states
|
| 323 |
+
short_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device)
|
| 324 |
+
medium_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device)
|
| 325 |
+
long_state = torch.zeros(batch_size, self.d_state * 2, dtype=target_dtype, device=target_device)
|
| 326 |
|
| 327 |
hierarchical_outputs = []
|
| 328 |
|
|
|
|
| 684 |
target_device = hidden_states.device
|
| 685 |
target_dtype = hidden_states.dtype
|
| 686 |
|
| 687 |
+
# ✅ 개선된 dtype/device 체크
|
| 688 |
+
current_device = next(self.short_proj.parameters()).device
|
| 689 |
+
current_dtype = next(self.short_proj.parameters()).dtype
|
| 690 |
+
|
| 691 |
+
if current_device != target_device or current_dtype != target_dtype:
|
| 692 |
+
self.short_proj = self.short_proj.to(device=target_device, dtype=target_dtype)
|
| 693 |
+
self.medium_proj = self.medium_proj.to(device=target_device, dtype=target_dtype)
|
| 694 |
+
self.long_proj = self.long_proj.to(device=target_device, dtype=target_dtype)
|
| 695 |
+
self.fusion = self.fusion.to(device=target_device, dtype=target_dtype)
|
| 696 |
+
self.norm = self.norm.to(device=target_device, dtype=target_dtype)
|
| 697 |
|
| 698 |
base_result = self.base_retention(
|
| 699 |
hidden_states, attention_mask, position_ids,
|
|
|
|
| 873 |
|
| 874 |
# Auto-registration
|
| 875 |
AutoConfig.register("phoenix", PhoenixConfig)
|
| 876 |
+
'''
|
| 877 |
+
|
| 878 |
return modeling_code
|
| 879 |
|
| 880 |
|