seawolf2357 commited on
Commit
238a77b
·
verified ·
1 Parent(s): caf3990

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -22
app.py CHANGED
@@ -301,18 +301,16 @@ class HierarchicalRetention(nn.Module):
301
  target_device = hidden_states.device
302
  target_dtype = hidden_states.dtype
303
 
304
- if not next(self.short_proj.parameters()).is_cuda and hidden_states.is_cuda:
305
- self.short_proj = self.short_proj.to(target_device, dtype=target_dtype)
306
- self.medium_proj = self.medium_proj.to(target_device, dtype=target_dtype)
307
- self.long_proj = self.long_proj.to(target_device, dtype=target_dtype)
308
- self.fusion = self.fusion.to(target_device, dtype=target_dtype)
309
- self.norm = self.norm.to(target_device, dtype=target_dtype)
310
- elif next(self.short_proj.parameters()).dtype != target_dtype:
311
- self.short_proj = self.short_proj.to(dtype=target_dtype)
312
- self.medium_proj = self.medium_proj.to(dtype=target_dtype)
313
- self.long_proj = self.long_proj.to(dtype=target_dtype)
314
- self.fusion = self.fusion.to(dtype=target_dtype)
315
- self.norm = self.norm.to(dtype=target_dtype)
316
 
317
  base_result = self.base_retention(
318
  hidden_states, attention_mask, position_ids,
@@ -322,9 +320,9 @@ class HierarchicalRetention(nn.Module):
322
  retention_output = base_result[0]
323
 
324
  # Hierarchical states
325
- short_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
326
- medium_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
327
- long_state = torch.zeros(batch_size, self.d_state * 2, dtype=hidden_states.dtype, device=target_device)
328
 
329
  hierarchical_outputs = []
330
 
@@ -686,12 +684,16 @@ class HierarchicalRetention(nn.Module):
686
  target_device = hidden_states.device
687
  target_dtype = hidden_states.dtype
688
 
689
- if not next(self.short_proj.parameters()).is_cuda and hidden_states.is_cuda:
690
- self.short_proj = self.short_proj.to(target_device, dtype=target_dtype)
691
- self.medium_proj = self.medium_proj.to(target_device, dtype=target_dtype)
692
- self.long_proj = self.long_proj.to(target_device, dtype=target_dtype)
693
- self.fusion = self.fusion.to(target_device, dtype=target_dtype)
694
- self.norm = self.norm.to(target_device, dtype=target_dtype)
 
 
 
 
695
 
696
  base_result = self.base_retention(
697
  hidden_states, attention_mask, position_ids,
@@ -871,7 +873,8 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
871
 
872
  # Auto-registration
873
  AutoConfig.register("phoenix", PhoenixConfig)
874
- '''
 
875
  return modeling_code
876
 
877
 
 
301
  target_device = hidden_states.device
302
  target_dtype = hidden_states.dtype
303
 
304
+ # 개선된 dtype/device 체크
305
+ current_device = next(self.short_proj.parameters()).device
306
+ current_dtype = next(self.short_proj.parameters()).dtype
307
+
308
+ if current_device != target_device or current_dtype != target_dtype:
309
+ self.short_proj = self.short_proj.to(device=target_device, dtype=target_dtype)
310
+ self.medium_proj = self.medium_proj.to(device=target_device, dtype=target_dtype)
311
+ self.long_proj = self.long_proj.to(device=target_device, dtype=target_dtype)
312
+ self.fusion = self.fusion.to(device=target_device, dtype=target_dtype)
313
+ self.norm = self.norm.to(device=target_device, dtype=target_dtype)
 
 
314
 
315
  base_result = self.base_retention(
316
  hidden_states, attention_mask, position_ids,
 
320
  retention_output = base_result[0]
321
 
322
  # Hierarchical states
323
+ short_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device)
324
+ medium_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device)
325
+ long_state = torch.zeros(batch_size, self.d_state * 2, dtype=target_dtype, device=target_device)
326
 
327
  hierarchical_outputs = []
328
 
 
684
  target_device = hidden_states.device
685
  target_dtype = hidden_states.dtype
686
 
687
+ # 개선된 dtype/device 체크
688
+ current_device = next(self.short_proj.parameters()).device
689
+ current_dtype = next(self.short_proj.parameters()).dtype
690
+
691
+ if current_device != target_device or current_dtype != target_dtype:
692
+ self.short_proj = self.short_proj.to(device=target_device, dtype=target_dtype)
693
+ self.medium_proj = self.medium_proj.to(device=target_device, dtype=target_dtype)
694
+ self.long_proj = self.long_proj.to(device=target_device, dtype=target_dtype)
695
+ self.fusion = self.fusion.to(device=target_device, dtype=target_dtype)
696
+ self.norm = self.norm.to(device=target_device, dtype=target_dtype)
697
 
698
  base_result = self.base_retention(
699
  hidden_states, attention_mask, position_ids,
 
873
 
874
  # Auto-registration
875
  AutoConfig.register("phoenix", PhoenixConfig)
876
+ '''
877
+
878
  return modeling_code
879
 
880