Update app.py
Browse files
app.py
CHANGED
|
@@ -687,6 +687,7 @@ def generate_modeling_phoenix_code():
|
|
| 687 |
return '''"""
|
| 688 |
PHOENIX Retention Model v1.4.3
|
| 689 |
✅ v1.4.3 CRITICAL FIX: forward() 시그니처 Transformers 호환
|
|
|
|
| 690 |
✅ PhoenixPreTrainedModel 베이스 클래스 포함
|
| 691 |
✅ 모든 Retention 클래스 완전 구현
|
| 692 |
"""
|
|
@@ -748,7 +749,8 @@ class MultiScaleRetention(nn.Module):
|
|
| 748 |
b, s, _ = hidden_states.shape
|
| 749 |
device, dtype = hidden_states.device, hidden_states.dtype
|
| 750 |
|
| 751 |
-
|
|
|
|
| 752 |
self.to(device=device, dtype=dtype)
|
| 753 |
|
| 754 |
q = self.q_proj(hidden_states).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
|
|
@@ -801,7 +803,9 @@ class HierarchicalRetention(nn.Module):
|
|
| 801 |
):
|
| 802 |
b, s, h = hidden_states.shape
|
| 803 |
device, dtype = hidden_states.device, hidden_states.dtype
|
| 804 |
-
|
|
|
|
|
|
|
| 805 |
self.to(device=device, dtype=dtype)
|
| 806 |
|
| 807 |
ret_out = self.base_retention(hidden_states)[0]
|
|
@@ -824,10 +828,23 @@ def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
|
|
| 824 |
layers = getattr(layers, 'layers', getattr(layers, 'h', getattr(layers, 'layers', None)))
|
| 825 |
if layers is None: return model, 0, 0
|
| 826 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 827 |
cnt = 0
|
| 828 |
for i, layer in enumerate(layers):
|
| 829 |
if hasattr(layer, 'self_attn'):
|
| 830 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 831 |
cnt += 1
|
| 832 |
return model, cnt, len(layers)
|
| 833 |
|
|
@@ -1871,6 +1888,7 @@ with gr.Blocks(
|
|
| 1871 |
**Complete Integrated Version with All Fixes**
|
| 1872 |
|
| 1873 |
✅ **NEW v1.4.3!** forward() 시그니처 Transformers 호환 - 완벽 수정!
|
|
|
|
| 1874 |
✅ Embedding Tying 저장 시점 처리
|
| 1875 |
✅ State Dict 직접 로드로 Retention 보존
|
| 1876 |
✅ Model Structure Pre-Analysis
|
|
@@ -2003,6 +2021,7 @@ with gr.Blocks(
|
|
| 2003 |
|
| 2004 |
### What's New in v1.4.3 (Complete Integrated Version)
|
| 2005 |
- ✅ **CRITICAL FIX: forward() Signature** - Transformers 호환성 완벽 수정
|
|
|
|
| 2006 |
- ✅ **Embedding Tying** - 저장 시점에 자동 처리
|
| 2007 |
- ✅ **Qwen3-0.6B Generation Fixed** - 정상적인 텍스트 생성
|
| 2008 |
- ✅ **완전 통합** - 모든 수정사항 포함
|
|
|
|
| 687 |
return '''"""
|
| 688 |
PHOENIX Retention Model v1.4.3
|
| 689 |
✅ v1.4.3 CRITICAL FIX: forward() 시그니처 Transformers 호환
|
| 690 |
+
✅ v1.4.3 HOTFIX: dtype 불일치 수정 (bfloat16 지원)
|
| 691 |
✅ PhoenixPreTrainedModel 베이스 클래스 포함
|
| 692 |
✅ 모든 Retention 클래스 완전 구현
|
| 693 |
"""
|
|
|
|
| 749 |
b, s, _ = hidden_states.shape
|
| 750 |
device, dtype = hidden_states.device, hidden_states.dtype
|
| 751 |
|
| 752 |
+
# ✅ FIX: dtype과 device 모두 일치시킴
|
| 753 |
+
if self.q_proj.weight.device != device or self.q_proj.weight.dtype != dtype:
|
| 754 |
self.to(device=device, dtype=dtype)
|
| 755 |
|
| 756 |
q = self.q_proj(hidden_states).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
| 803 |
):
|
| 804 |
b, s, h = hidden_states.shape
|
| 805 |
device, dtype = hidden_states.device, hidden_states.dtype
|
| 806 |
+
|
| 807 |
+
# ✅ FIX: dtype과 device 모두 일치시킴
|
| 808 |
+
if next(self.short_proj.parameters()).device != device or next(self.short_proj.parameters()).dtype != dtype:
|
| 809 |
self.to(device=device, dtype=dtype)
|
| 810 |
|
| 811 |
ret_out = self.base_retention(hidden_states)[0]
|
|
|
|
| 828 |
layers = getattr(layers, 'layers', getattr(layers, 'h', getattr(layers, 'layers', None)))
|
| 829 |
if layers is None: return model, 0, 0
|
| 830 |
|
| 831 |
+
# ✅ FIX: 원본 모델의 dtype 감지
|
| 832 |
+
original_dtype = None
|
| 833 |
+
for param in model.parameters():
|
| 834 |
+
original_dtype = param.dtype
|
| 835 |
+
break
|
| 836 |
+
|
| 837 |
cnt = 0
|
| 838 |
for i, layer in enumerate(layers):
|
| 839 |
if hasattr(layer, 'self_attn'):
|
| 840 |
+
# 새 Retention 생성
|
| 841 |
+
new_retention = HierarchicalRetention(model.config, i) if use_hierarchical else MultiScaleRetention(model.config, i)
|
| 842 |
+
|
| 843 |
+
# ✅ FIX: 원본 dtype으로 변환
|
| 844 |
+
if original_dtype is not None:
|
| 845 |
+
new_retention = new_retention.to(dtype=original_dtype)
|
| 846 |
+
|
| 847 |
+
layer.self_attn = new_retention
|
| 848 |
cnt += 1
|
| 849 |
return model, cnt, len(layers)
|
| 850 |
|
|
|
|
| 1888 |
**Complete Integrated Version with All Fixes**
|
| 1889 |
|
| 1890 |
✅ **NEW v1.4.3!** forward() 시그니처 Transformers 호환 - 완벽 수정!
|
| 1891 |
+
✅ **NEW v1.4.3!** dtype 불일치 수정 - bfloat16 완벽 지원!
|
| 1892 |
✅ Embedding Tying 저장 시점 처리
|
| 1893 |
✅ State Dict 직접 로드로 Retention 보존
|
| 1894 |
✅ Model Structure Pre-Analysis
|
|
|
|
| 2021 |
|
| 2022 |
### What's New in v1.4.3 (Complete Integrated Version)
|
| 2023 |
- ✅ **CRITICAL FIX: forward() Signature** - Transformers 호환성 완벽 수정
|
| 2024 |
+
- ✅ **HOTFIX: dtype 불일치** - bfloat16 완벽 지원
|
| 2025 |
- ✅ **Embedding Tying** - 저장 시점에 자동 처리
|
| 2026 |
- ✅ **Qwen3-0.6B Generation Fixed** - 정상적인 텍스트 생성
|
| 2027 |
- ✅ **완전 통합** - 모든 수정사항 포함
|