seawolf2357 commited on
Commit
1fa5f7c
·
verified ·
1 Parent(s): 3198863

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -3
app.py CHANGED
@@ -687,6 +687,7 @@ def generate_modeling_phoenix_code():
687
  return '''"""
688
  PHOENIX Retention Model v1.4.3
689
  ✅ v1.4.3 CRITICAL FIX: forward() 시그니처 Transformers 호환
 
690
  ✅ PhoenixPreTrainedModel 베이스 클래스 포함
691
  ✅ 모든 Retention 클래스 완전 구현
692
  """
@@ -748,7 +749,8 @@ class MultiScaleRetention(nn.Module):
748
  b, s, _ = hidden_states.shape
749
  device, dtype = hidden_states.device, hidden_states.dtype
750
 
751
- if self.q_proj.weight.device != device:
 
752
  self.to(device=device, dtype=dtype)
753
 
754
  q = self.q_proj(hidden_states).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
@@ -801,7 +803,9 @@ class HierarchicalRetention(nn.Module):
801
  ):
802
  b, s, h = hidden_states.shape
803
  device, dtype = hidden_states.device, hidden_states.dtype
804
- if next(self.short_proj.parameters()).device != device:
 
 
805
  self.to(device=device, dtype=dtype)
806
 
807
  ret_out = self.base_retention(hidden_states)[0]
@@ -824,10 +828,23 @@ def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
824
  layers = getattr(layers, 'layers', getattr(layers, 'h', getattr(layers, 'layers', None)))
825
  if layers is None: return model, 0, 0
826
 
 
 
 
 
 
 
827
  cnt = 0
828
  for i, layer in enumerate(layers):
829
  if hasattr(layer, 'self_attn'):
830
- layer.self_attn = HierarchicalRetention(model.config, i) if use_hierarchical else MultiScaleRetention(model.config, i)
 
 
 
 
 
 
 
831
  cnt += 1
832
  return model, cnt, len(layers)
833
 
@@ -1871,6 +1888,7 @@ with gr.Blocks(
1871
  **Complete Integrated Version with All Fixes**
1872
 
1873
  ✅ **NEW v1.4.3!** forward() 시그니처 Transformers 호환 - 완벽 수정!
 
1874
  ✅ Embedding Tying 저장 시점 처리
1875
  ✅ State Dict 직접 로드로 Retention 보존
1876
  ✅ Model Structure Pre-Analysis
@@ -2003,6 +2021,7 @@ with gr.Blocks(
2003
 
2004
  ### What's New in v1.4.3 (Complete Integrated Version)
2005
  - ✅ **CRITICAL FIX: forward() Signature** - Transformers 호환성 완벽 수정
 
2006
  - ✅ **Embedding Tying** - 저장 시점에 자동 처리
2007
  - ✅ **Qwen3-0.6B Generation Fixed** - 정상적인 텍스트 생성
2008
  - ✅ **완전 통합** - 모든 수정사항 포함
 
687
  return '''"""
688
  PHOENIX Retention Model v1.4.3
689
  ✅ v1.4.3 CRITICAL FIX: forward() 시그니처 Transformers 호환
690
+ ✅ v1.4.3 HOTFIX: dtype 불일치 수정 (bfloat16 지원)
691
  ✅ PhoenixPreTrainedModel 베이스 클래스 포함
692
  ✅ 모든 Retention 클래스 완전 구현
693
  """
 
749
  b, s, _ = hidden_states.shape
750
  device, dtype = hidden_states.device, hidden_states.dtype
751
 
752
+ # ✅ FIX: dtype과 device 모두 일치시킴
753
+ if self.q_proj.weight.device != device or self.q_proj.weight.dtype != dtype:
754
  self.to(device=device, dtype=dtype)
755
 
756
  q = self.q_proj(hidden_states).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
 
803
  ):
804
  b, s, h = hidden_states.shape
805
  device, dtype = hidden_states.device, hidden_states.dtype
806
+
807
+ # ✅ FIX: dtype과 device 모두 일치시킴
808
+ if next(self.short_proj.parameters()).device != device or next(self.short_proj.parameters()).dtype != dtype:
809
  self.to(device=device, dtype=dtype)
810
 
811
  ret_out = self.base_retention(hidden_states)[0]
 
828
  layers = getattr(layers, 'layers', getattr(layers, 'h', getattr(layers, 'layers', None)))
829
  if layers is None: return model, 0, 0
830
 
831
+ # ✅ FIX: 원본 모델의 dtype 감지
832
+ original_dtype = None
833
+ for param in model.parameters():
834
+ original_dtype = param.dtype
835
+ break
836
+
837
  cnt = 0
838
  for i, layer in enumerate(layers):
839
  if hasattr(layer, 'self_attn'):
840
+ # Retention 생성
841
+ new_retention = HierarchicalRetention(model.config, i) if use_hierarchical else MultiScaleRetention(model.config, i)
842
+
843
+ # ✅ FIX: 원본 dtype으로 변환
844
+ if original_dtype is not None:
845
+ new_retention = new_retention.to(dtype=original_dtype)
846
+
847
+ layer.self_attn = new_retention
848
  cnt += 1
849
  return model, cnt, len(layers)
850
 
 
1888
  **Complete Integrated Version with All Fixes**
1889
 
1890
  ✅ **NEW v1.4.3!** forward() 시그니처 Transformers 호환 - 완벽 수정!
1891
+ ✅ **NEW v1.4.3!** dtype 불일치 수정 - bfloat16 완벽 지원!
1892
  ✅ Embedding Tying 저장 시점 처리
1893
  ✅ State Dict 직접 로드로 Retention 보존
1894
  ✅ Model Structure Pre-Analysis
 
2021
 
2022
  ### What's New in v1.4.3 (Complete Integrated Version)
2023
  - ✅ **CRITICAL FIX: forward() Signature** - Transformers 호환성 완벽 수정
2024
+ - ✅ **HOTFIX: dtype 불일치** - bfloat16 완벽 지원
2025
  - ✅ **Embedding Tying** - 저장 시점에 자동 처리
2026
  - ✅ **Qwen3-0.6B Generation Fixed** - 정상적인 텍스트 생성
2027
  - ✅ **완전 통합** - 모든 수정사항 포함