Update app.py
Browse files
app.py
CHANGED
|
@@ -534,38 +534,57 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
|
|
| 534 |
replaced_count = 0
|
| 535 |
total_layers = 0
|
| 536 |
|
| 537 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
if structure_info and structure_info.get('layer_path'):
|
| 539 |
layer_path = structure_info['layer_path']
|
| 540 |
print(f" Using structure info: {layer_path}")
|
| 541 |
|
| 542 |
if layer_path == 'model.layers':
|
| 543 |
-
|
|
|
|
| 544 |
elif layer_path == 'transformer.h':
|
| 545 |
-
|
|
|
|
| 546 |
elif layer_path == 'layers':
|
| 547 |
-
|
|
|
|
| 548 |
elif layer_path == 'model.decoder.layers':
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
layers
|
| 558 |
-
|
| 559 |
-
layers
|
| 560 |
-
|
| 561 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
if layers is None:
|
| 564 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
return model, 0, 0
|
| 566 |
|
| 567 |
total_layers = len(layers)
|
| 568 |
-
print(f" Found {total_layers} layers")
|
| 569 |
|
| 570 |
# GQA 감지 (structure_info 우선)
|
| 571 |
if structure_info and structure_info.get('gqa_detected'):
|
|
@@ -944,43 +963,54 @@ class HierarchicalRetention(nn.Module):
|
|
| 944 |
|
| 945 |
|
| 946 |
def replace_attention_with_retention(model, use_hierarchical=True):
|
| 947 |
-
"""Attention → Retention 변환"""
|
| 948 |
converted_count = 0
|
| 949 |
total_layers = 0
|
| 950 |
|
|
|
|
|
|
|
|
|
|
| 951 |
if hasattr(model, 'model') and hasattr(model.model, 'layers'):
|
| 952 |
layers = model.model.layers
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
if use_hierarchical:
|
| 962 |
-
new_retention = HierarchicalRetention(config, layer_idx)
|
| 963 |
-
else:
|
| 964 |
-
new_retention = MultiScaleRetention(config, layer_idx)
|
| 965 |
-
|
| 966 |
-
if hasattr(old_attn, 'q_proj'):
|
| 967 |
-
try:
|
| 968 |
-
target = new_retention.base_retention if use_hierarchical else new_retention
|
| 969 |
-
|
| 970 |
-
if old_attn.q_proj.weight.shape == target.q_proj.weight.shape:
|
| 971 |
-
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
|
| 972 |
-
if old_attn.k_proj.weight.shape == target.k_proj.weight.shape:
|
| 973 |
-
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
|
| 974 |
-
if old_attn.v_proj.weight.shape == target.v_proj.weight.shape:
|
| 975 |
-
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
|
| 976 |
-
if old_attn.o_proj.weight.shape == target.o_proj.weight.shape:
|
| 977 |
-
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
|
| 978 |
-
except:
|
| 979 |
-
pass
|
| 980 |
-
|
| 981 |
-
layer.self_attn = new_retention
|
| 982 |
-
converted_count += 1
|
| 983 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 984 |
return model, converted_count, total_layers
|
| 985 |
|
| 986 |
|
|
@@ -1042,6 +1072,10 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
|
|
| 1042 |
|
| 1043 |
print(f"✅ Converted {converted}/{total} layers to Retention")
|
| 1044 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1045 |
phoenix_instance = cls(config)
|
| 1046 |
phoenix_instance._original_model = base_model
|
| 1047 |
phoenix_instance._initialized = True
|
|
@@ -1605,8 +1639,9 @@ def burn_model_zero_shot(
|
|
| 1605 |
print(f"\n🔄 STEP 3: Converting Attention → Retention...")
|
| 1606 |
convert_start = time.time()
|
| 1607 |
|
| 1608 |
-
|
| 1609 |
-
|
|
|
|
| 1610 |
use_hierarchical=use_hierarchical,
|
| 1611 |
structure_info=structure_info
|
| 1612 |
)
|
|
@@ -1618,8 +1653,30 @@ def burn_model_zero_shot(
|
|
| 1618 |
|
| 1619 |
if converted == 0:
|
| 1620 |
print(f"\n⚠️ WARNING: No layers were converted!")
|
| 1621 |
-
print(f" This
|
| 1622 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1623 |
|
| 1624 |
# 4. 평가
|
| 1625 |
print(f"\n📊 STEP 4: Evaluating model quality...")
|
|
@@ -1725,8 +1782,8 @@ def burn_model_with_finetuning(
|
|
| 1725 |
tokenizer.pad_token = tokenizer.eos_token
|
| 1726 |
|
| 1727 |
print(f"\n🔄 STEP 3: Converting...")
|
| 1728 |
-
model
|
| 1729 |
-
model
|
| 1730 |
use_hierarchical=use_hierarchical,
|
| 1731 |
structure_info=structure_info
|
| 1732 |
)
|
|
|
|
| 534 |
replaced_count = 0
|
| 535 |
total_layers = 0
|
| 536 |
|
| 537 |
+
# 레이어 탐색 (여러 경로 시도)
|
| 538 |
+
layers = None
|
| 539 |
+
layer_path = None
|
| 540 |
+
|
| 541 |
+
# 1. structure_info 활용
|
| 542 |
if structure_info and structure_info.get('layer_path'):
|
| 543 |
layer_path = structure_info['layer_path']
|
| 544 |
print(f" Using structure info: {layer_path}")
|
| 545 |
|
| 546 |
if layer_path == 'model.layers':
|
| 547 |
+
if hasattr(model, 'model') and hasattr(model.model, 'layers'):
|
| 548 |
+
layers = model.model.layers
|
| 549 |
elif layer_path == 'transformer.h':
|
| 550 |
+
if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
|
| 551 |
+
layers = model.transformer.h
|
| 552 |
elif layer_path == 'layers':
|
| 553 |
+
if hasattr(model, 'layers'):
|
| 554 |
+
layers = model.layers
|
| 555 |
elif layer_path == 'model.decoder.layers':
|
| 556 |
+
if hasattr(model, 'model') and hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers'):
|
| 557 |
+
layers = model.model.decoder.layers
|
| 558 |
+
|
| 559 |
+
# 2. 자동 탐색 (structure_info 없거나 실패 시)
|
| 560 |
+
if layers is None:
|
| 561 |
+
print(f" Auto-detecting layer structure...")
|
| 562 |
+
|
| 563 |
+
possible_paths = [
|
| 564 |
+
('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
|
| 565 |
+
('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
|
| 566 |
+
('layers', lambda m: m.layers if hasattr(m, 'layers') else None),
|
| 567 |
+
('model.decoder.layers', lambda m: m.model.decoder.layers if hasattr(m, 'model') and hasattr(m.model, 'decoder') and hasattr(m.model.decoder, 'layers') else None),
|
| 568 |
+
]
|
| 569 |
+
|
| 570 |
+
for path_name, path_fn in possible_paths:
|
| 571 |
+
result = path_fn(model)
|
| 572 |
+
if result is not None:
|
| 573 |
+
layers = result
|
| 574 |
+
layer_path = path_name
|
| 575 |
+
print(f" ✅ Found layers at: {path_name}")
|
| 576 |
+
break
|
| 577 |
|
| 578 |
if layers is None:
|
| 579 |
+
print("❌ Cannot find layers - model structure not supported")
|
| 580 |
+
print(f" Model type: {type(model)}")
|
| 581 |
+
print(f" Has 'model' attr: {hasattr(model, 'model')}")
|
| 582 |
+
print(f" Has 'transformer' attr: {hasattr(model, 'transformer')}")
|
| 583 |
+
print(f" Has 'layers' attr: {hasattr(model, 'layers')}")
|
| 584 |
return model, 0, 0
|
| 585 |
|
| 586 |
total_layers = len(layers)
|
| 587 |
+
print(f" Found {total_layers} layers at '{layer_path}'")
|
| 588 |
|
| 589 |
# GQA 감지 (structure_info 우선)
|
| 590 |
if structure_info and structure_info.get('gqa_detected'):
|
|
|
|
| 963 |
|
| 964 |
|
| 965 |
def replace_attention_with_retention(model, use_hierarchical=True):
|
| 966 |
+
"""Attention → Retention 변환 (개선됨)"""
|
| 967 |
converted_count = 0
|
| 968 |
total_layers = 0
|
| 969 |
|
| 970 |
+
# 레이어 찾기 (여러 경로 시도)
|
| 971 |
+
layers = None
|
| 972 |
+
|
| 973 |
if hasattr(model, 'model') and hasattr(model.model, 'layers'):
|
| 974 |
layers = model.model.layers
|
| 975 |
+
elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
|
| 976 |
+
layers = model.transformer.h
|
| 977 |
+
elif hasattr(model, 'layers'):
|
| 978 |
+
layers = model.layers
|
| 979 |
+
else:
|
| 980 |
+
print("Cannot find layers in model")
|
| 981 |
+
return model, 0, 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
|
| 983 |
+
total_layers = len(layers)
|
| 984 |
+
config = model.config
|
| 985 |
+
|
| 986 |
+
for layer_idx, layer in enumerate(layers):
|
| 987 |
+
if hasattr(layer, 'self_attn'):
|
| 988 |
+
old_attn = layer.self_attn
|
| 989 |
+
|
| 990 |
+
if use_hierarchical:
|
| 991 |
+
new_retention = HierarchicalRetention(config, layer_idx)
|
| 992 |
+
else:
|
| 993 |
+
new_retention = MultiScaleRetention(config, layer_idx)
|
| 994 |
+
|
| 995 |
+
if hasattr(old_attn, 'q_proj'):
|
| 996 |
+
try:
|
| 997 |
+
target = new_retention.base_retention if use_hierarchical else new_retention
|
| 998 |
+
|
| 999 |
+
if old_attn.q_proj.weight.shape == target.q_proj.weight.shape:
|
| 1000 |
+
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
|
| 1001 |
+
if old_attn.k_proj.weight.shape == target.k_proj.weight.shape:
|
| 1002 |
+
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
|
| 1003 |
+
if old_attn.v_proj.weight.shape == target.v_proj.weight.shape:
|
| 1004 |
+
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
|
| 1005 |
+
if old_attn.o_proj.weight.shape == target.o_proj.weight.shape:
|
| 1006 |
+
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
|
| 1007 |
+
except Exception as e:
|
| 1008 |
+
print(f"Weight copy warning for layer {layer_idx}: {e}")
|
| 1009 |
+
|
| 1010 |
+
layer.self_attn = new_retention
|
| 1011 |
+
converted_count += 1
|
| 1012 |
+
|
| 1013 |
+
print(f"Converted {converted_count}/{total_layers} layers to Retention")
|
| 1014 |
return model, converted_count, total_layers
|
| 1015 |
|
| 1016 |
|
|
|
|
| 1072 |
|
| 1073 |
print(f"✅ Converted {converted}/{total} layers to Retention")
|
| 1074 |
|
| 1075 |
+
if converted == 0:
|
| 1076 |
+
print(f"⚠️ WARNING: No layers were converted!")
|
| 1077 |
+
print(f" Model may not have Retention active.")
|
| 1078 |
+
|
| 1079 |
phoenix_instance = cls(config)
|
| 1080 |
phoenix_instance._original_model = base_model
|
| 1081 |
phoenix_instance._initialized = True
|
|
|
|
| 1639 |
print(f"\n🔄 STEP 3: Converting Attention → Retention...")
|
| 1640 |
convert_start = time.time()
|
| 1641 |
|
| 1642 |
+
# ✅ FIX: 전체 모델을 전달하여 내부에서 레이어 찾기
|
| 1643 |
+
model, converted, total = replace_attention_with_retention(
|
| 1644 |
+
model,
|
| 1645 |
use_hierarchical=use_hierarchical,
|
| 1646 |
structure_info=structure_info
|
| 1647 |
)
|
|
|
|
| 1653 |
|
| 1654 |
if converted == 0:
|
| 1655 |
print(f"\n⚠️ WARNING: No layers were converted!")
|
| 1656 |
+
print(f" This indicates a structural mismatch.")
|
| 1657 |
+
print(f" Model type: {type(model)}")
|
| 1658 |
+
if structure_info:
|
| 1659 |
+
print(f" Structure info: {structure_info.get('layer_path', 'unknown')}")
|
| 1660 |
+
print(f" Please check the model architecture.")
|
| 1661 |
+
else:
|
| 1662 |
+
# 변환 검증
|
| 1663 |
+
print(f"\n🔍 Verifying conversion...")
|
| 1664 |
+
verified_retention = 0
|
| 1665 |
+
|
| 1666 |
+
if hasattr(model, 'model') and hasattr(model.model, 'layers'):
|
| 1667 |
+
check_layers = model.model.layers
|
| 1668 |
+
else:
|
| 1669 |
+
check_layers = []
|
| 1670 |
+
|
| 1671 |
+
for layer in check_layers:
|
| 1672 |
+
if hasattr(layer, 'self_attn'):
|
| 1673 |
+
if 'Retention' in layer.self_attn.__class__.__name__:
|
| 1674 |
+
verified_retention += 1
|
| 1675 |
+
|
| 1676 |
+
print(f" ✅ Verified: {verified_retention}/{len(check_layers)} layers have Retention")
|
| 1677 |
+
|
| 1678 |
+
if verified_retention == 0 and converted > 0:
|
| 1679 |
+
print(f" ⚠️ WARNING: Conversion reported success but verification failed!")
|
| 1680 |
|
| 1681 |
# 4. 평가
|
| 1682 |
print(f"\n📊 STEP 4: Evaluating model quality...")
|
|
|
|
| 1782 |
tokenizer.pad_token = tokenizer.eos_token
|
| 1783 |
|
| 1784 |
print(f"\n🔄 STEP 3: Converting...")
|
| 1785 |
+
model, converted, total = replace_attention_with_retention(
|
| 1786 |
+
model,
|
| 1787 |
use_hierarchical=use_hierarchical,
|
| 1788 |
structure_info=structure_info
|
| 1789 |
)
|