Update app.py
Browse files
app.py
CHANGED
|
@@ -635,12 +635,20 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
|
|
| 635 |
v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
|
| 636 |
o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
|
| 637 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
if q_match and k_match and v_match and o_match:
|
| 639 |
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
|
| 640 |
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
|
| 641 |
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
|
| 642 |
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
|
| 643 |
-
|
|
|
|
| 644 |
|
| 645 |
elif q_match and o_match:
|
| 646 |
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
|
|
@@ -652,14 +660,17 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
|
|
| 652 |
target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
|
| 653 |
target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
|
| 654 |
|
| 655 |
-
|
|
|
|
| 656 |
|
| 657 |
else:
|
| 658 |
nn.init.xavier_uniform_(target.q_proj.weight)
|
| 659 |
nn.init.xavier_uniform_(target.k_proj.weight)
|
| 660 |
nn.init.xavier_uniform_(target.v_proj.weight)
|
| 661 |
nn.init.xavier_uniform_(target.o_proj.weight)
|
| 662 |
-
|
|
|
|
|
|
|
| 663 |
|
| 664 |
except Exception as e:
|
| 665 |
print(f" ⚠️ Layer {layer_idx}: Weight copy failed - {e}")
|
|
@@ -989,6 +1000,8 @@ def replace_attention_with_retention(model, use_hierarchical=True):
|
|
| 989 |
total_layers = len(layers)
|
| 990 |
config = model.config
|
| 991 |
|
|
|
|
|
|
|
| 992 |
for layer_idx, layer in enumerate(layers):
|
| 993 |
if hasattr(layer, 'self_attn'):
|
| 994 |
old_attn = layer.self_attn
|
|
@@ -1002,16 +1015,43 @@ def replace_attention_with_retention(model, use_hierarchical=True):
|
|
| 1002 |
try:
|
| 1003 |
target = new_retention.base_retention if use_hierarchical else new_retention
|
| 1004 |
|
| 1005 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1006 |
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
|
| 1007 |
-
if old_attn.k_proj.weight.shape == target.k_proj.weight.shape:
|
| 1008 |
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
|
| 1009 |
-
if old_attn.v_proj.weight.shape == target.v_proj.weight.shape:
|
| 1010 |
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
|
| 1011 |
-
if old_attn.o_proj.weight.shape == target.o_proj.weight.shape:
|
| 1012 |
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1013 |
except Exception as e:
|
| 1014 |
-
|
|
|
|
| 1015 |
|
| 1016 |
layer.self_attn = new_retention
|
| 1017 |
converted_count += 1
|
|
@@ -1150,10 +1190,17 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
|
|
| 1150 |
print(f" Missing keys: {len(missing)}")
|
| 1151 |
print(f" Unexpected keys: {len(unexpected)}")
|
| 1152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1153 |
# Retention 가중치 확인
|
| 1154 |
retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()]
|
| 1155 |
if retention_keys:
|
| 1156 |
print(f" ✅ Found {len(retention_keys)} Retention weight keys")
|
|
|
|
| 1157 |
else:
|
| 1158 |
print(f" ⚠️ No Retention keys found in state dict")
|
| 1159 |
|
|
@@ -2314,18 +2361,33 @@ def validate_phoenix_model(
|
|
| 2314 |
retention_count = 0
|
| 2315 |
attention_count = 0
|
| 2316 |
|
| 2317 |
-
|
| 2318 |
-
|
| 2319 |
-
|
| 2320 |
-
|
| 2321 |
-
|
| 2322 |
-
|
| 2323 |
-
|
| 2324 |
-
|
| 2325 |
-
|
| 2326 |
-
|
| 2327 |
-
|
| 2328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2329 |
|
| 2330 |
total = retention_count + attention_count
|
| 2331 |
retention_info = f"""
|
|
@@ -2334,7 +2396,7 @@ def validate_phoenix_model(
|
|
| 2334 |
- **Attention Layers**: {attention_count}/{total}
|
| 2335 |
- **Status**: {'✅ PHOENIX Active' if retention_count > 0 else '⚠️ No Retention Found'}
|
| 2336 |
"""
|
| 2337 |
-
print(f"
|
| 2338 |
|
| 2339 |
# 4. 생성 테스트
|
| 2340 |
print(f"\n🚀 Running generation tests...")
|
|
|
|
| 635 |
v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
|
| 636 |
o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
|
| 637 |
|
| 638 |
+
if layer_idx == 0: # 첫 레이어만 상세 출력
|
| 639 |
+
print(f" 🔍 Layer 0 shape analysis:")
|
| 640 |
+
print(f" Old Q: {old_attn.q_proj.weight.shape} vs New Q: {target.q_proj.weight.shape} → {'✅' if q_match else '❌'}")
|
| 641 |
+
print(f" Old K: {old_attn.k_proj.weight.shape} vs New K: {target.k_proj.weight.shape} → {'✅' if k_match else '❌'}")
|
| 642 |
+
print(f" Old V: {old_attn.v_proj.weight.shape} vs New V: {target.v_proj.weight.shape} → {'✅' if v_match else '❌'}")
|
| 643 |
+
print(f" Old O: {old_attn.o_proj.weight.shape} vs New O: {target.o_proj.weight.shape} → {'✅' if o_match else '❌'}")
|
| 644 |
+
|
| 645 |
if q_match and k_match and v_match and o_match:
|
| 646 |
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
|
| 647 |
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
|
| 648 |
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
|
| 649 |
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
|
| 650 |
+
if layer_idx == 0:
|
| 651 |
+
print(f" ✅ Layer {layer_idx}: Perfect match - weights copied")
|
| 652 |
|
| 653 |
elif q_match and o_match:
|
| 654 |
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
|
|
|
|
| 660 |
target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
|
| 661 |
target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
|
| 662 |
|
| 663 |
+
if layer_idx == 0:
|
| 664 |
+
print(f" ✅ Layer {layer_idx}: Partial match (GQA) - partial weights copied")
|
| 665 |
|
| 666 |
else:
|
| 667 |
nn.init.xavier_uniform_(target.q_proj.weight)
|
| 668 |
nn.init.xavier_uniform_(target.k_proj.weight)
|
| 669 |
nn.init.xavier_uniform_(target.v_proj.weight)
|
| 670 |
nn.init.xavier_uniform_(target.o_proj.weight)
|
| 671 |
+
if layer_idx == 0:
|
| 672 |
+
print(f" ⚠️ Layer {layer_idx}: Shape mismatch - Xavier init used")
|
| 673 |
+
print(f" This will result in random weights!")
|
| 674 |
|
| 675 |
except Exception as e:
|
| 676 |
print(f" ⚠️ Layer {layer_idx}: Weight copy failed - {e}")
|
|
|
|
| 1000 |
total_layers = len(layers)
|
| 1001 |
config = model.config
|
| 1002 |
|
| 1003 |
+
print(f"Converting {total_layers} layers...")
|
| 1004 |
+
|
| 1005 |
for layer_idx, layer in enumerate(layers):
|
| 1006 |
if hasattr(layer, 'self_attn'):
|
| 1007 |
old_attn = layer.self_attn
|
|
|
|
| 1015 |
try:
|
| 1016 |
target = new_retention.base_retention if use_hierarchical else new_retention
|
| 1017 |
|
| 1018 |
+
# Shape 확인
|
| 1019 |
+
q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape
|
| 1020 |
+
k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape
|
| 1021 |
+
v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
|
| 1022 |
+
o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
|
| 1023 |
+
|
| 1024 |
+
if layer_idx == 0:
|
| 1025 |
+
print(f"Layer 0 analysis:")
|
| 1026 |
+
print(f" Q: {old_attn.q_proj.weight.shape} vs {target.q_proj.weight.shape} → {'✅' if q_match else '❌'}")
|
| 1027 |
+
print(f" K: {old_attn.k_proj.weight.shape} vs {target.k_proj.weight.shape} → {'✅' if k_match else '❌'}")
|
| 1028 |
+
print(f" V: {old_attn.v_proj.weight.shape} vs {target.v_proj.weight.shape} → {'✅' if v_match else '❌'}")
|
| 1029 |
+
print(f" O: {old_attn.o_proj.weight.shape} vs {target.o_proj.weight.shape} → {'✅' if o_match else '❌'}")
|
| 1030 |
+
|
| 1031 |
+
# 가중치 복사
|
| 1032 |
+
if q_match and k_match and v_match and o_match:
|
| 1033 |
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
|
|
|
|
| 1034 |
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
|
|
|
|
| 1035 |
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
|
|
|
|
| 1036 |
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
|
| 1037 |
+
if layer_idx == 0:
|
| 1038 |
+
print(f" ✅ Perfect match - weights copied")
|
| 1039 |
+
elif q_match and o_match:
|
| 1040 |
+
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
|
| 1041 |
+
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
|
| 1042 |
+
k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0])
|
| 1043 |
+
v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0])
|
| 1044 |
+
target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
|
| 1045 |
+
target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
|
| 1046 |
+
if layer_idx == 0:
|
| 1047 |
+
print(f" ✅ Partial match (GQA) - partial copy")
|
| 1048 |
+
else:
|
| 1049 |
+
if layer_idx == 0:
|
| 1050 |
+
print(f" ⚠️ Shape mismatch - keeping random init")
|
| 1051 |
+
|
| 1052 |
except Exception as e:
|
| 1053 |
+
if layer_idx == 0:
|
| 1054 |
+
print(f"Weight copy error: {e}")
|
| 1055 |
|
| 1056 |
layer.self_attn = new_retention
|
| 1057 |
converted_count += 1
|
|
|
|
| 1190 |
print(f" Missing keys: {len(missing)}")
|
| 1191 |
print(f" Unexpected keys: {len(unexpected)}")
|
| 1192 |
|
| 1193 |
+
# 상세 정보 출력 (처음 5개만)
|
| 1194 |
+
if missing:
|
| 1195 |
+
print(f" Missing (first 5): {missing[:5]}")
|
| 1196 |
+
if unexpected:
|
| 1197 |
+
print(f" Unexpected (first 5): {unexpected[:5]}")
|
| 1198 |
+
|
| 1199 |
# Retention 가중치 확인
|
| 1200 |
retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()]
|
| 1201 |
if retention_keys:
|
| 1202 |
print(f" ✅ Found {len(retention_keys)} Retention weight keys")
|
| 1203 |
+
print(f" Sample keys: {retention_keys[:3]}")
|
| 1204 |
else:
|
| 1205 |
print(f" ⚠️ No Retention keys found in state dict")
|
| 1206 |
|
|
|
|
| 2361 |
retention_count = 0
|
| 2362 |
attention_count = 0
|
| 2363 |
|
| 2364 |
+
# PhoenixModelForCausalLM인 경우 _original_model 확인
|
| 2365 |
+
check_model = model
|
| 2366 |
+
if hasattr(model, '_original_model') and model._original_model is not None:
|
| 2367 |
+
print(f" 📋 Detected PhoenixModelForCausalLM wrapper")
|
| 2368 |
+
check_model = model._original_model
|
| 2369 |
+
|
| 2370 |
+
layers = []
|
| 2371 |
+
if hasattr(check_model, 'model') and hasattr(check_model.model, 'layers'):
|
| 2372 |
+
layers = check_model.model.layers
|
| 2373 |
+
elif hasattr(check_model, 'layers'):
|
| 2374 |
+
layers = check_model.layers
|
| 2375 |
+
|
| 2376 |
+
print(f" 🔍 Checking {len(layers)} layers...")
|
| 2377 |
+
|
| 2378 |
+
for i, layer in enumerate(layers):
|
| 2379 |
+
if hasattr(layer, 'self_attn'):
|
| 2380 |
+
attn = layer.self_attn
|
| 2381 |
+
class_name = attn.__class__.__name__
|
| 2382 |
+
|
| 2383 |
+
if 'Retention' in class_name:
|
| 2384 |
+
retention_count += 1
|
| 2385 |
+
if i < 3: # 처음 3개만 출력
|
| 2386 |
+
print(f" ✅ Layer {i}: {class_name}")
|
| 2387 |
+
else:
|
| 2388 |
+
attention_count += 1
|
| 2389 |
+
if i < 3:
|
| 2390 |
+
print(f" ⚠️ Layer {i}: {class_name}")
|
| 2391 |
|
| 2392 |
total = retention_count + attention_count
|
| 2393 |
retention_info = f"""
|
|
|
|
| 2396 |
- **Attention Layers**: {attention_count}/{total}
|
| 2397 |
- **Status**: {'✅ PHOENIX Active' if retention_count > 0 else '⚠️ No Retention Found'}
|
| 2398 |
"""
|
| 2399 |
+
print(f" 📊 Result: {retention_count}/{total} layers have Retention")
|
| 2400 |
|
| 2401 |
# 4. 생성 테스트
|
| 2402 |
print(f"\n🚀 Running generation tests...")
|