seawolf2357 commited on
Commit
8863ba4
·
verified ·
1 Parent(s): cb3c4bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -56
app.py CHANGED
@@ -534,38 +534,57 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
534
  replaced_count = 0
535
  total_layers = 0
536
 
537
- # structure_info 활용
 
 
 
 
538
  if structure_info and structure_info.get('layer_path'):
539
  layer_path = structure_info['layer_path']
540
  print(f" Using structure info: {layer_path}")
541
 
542
  if layer_path == 'model.layers':
543
- layers = model.model.layers if hasattr(model, 'model') and hasattr(model.model, 'layers') else None
 
544
  elif layer_path == 'transformer.h':
545
- layers = model.transformer.h if hasattr(model, 'transformer') and hasattr(model.transformer, 'h') else None
 
546
  elif layer_path == 'layers':
547
- layers = model.layers if hasattr(model, 'layers') else None
 
548
  elif layer_path == 'model.decoder.layers':
549
- layers = model.model.decoder.layers if hasattr(model, 'model') and hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers') else None
550
- else:
551
- layers = None
552
- else:
553
- # 기존 방식대로 탐색
554
- if hasattr(model, 'transformer'):
555
- layers = model.transformer.h
556
- elif hasattr(model, 'model') and hasattr(model.model, 'layers'):
557
- layers = model.model.layers
558
- elif hasattr(model, 'layers'):
559
- layers = model.layers
560
- else:
561
- layers = None
 
 
 
 
 
 
 
 
562
 
563
  if layers is None:
564
- print("⚠️ Unknown model structure - cannot find layers")
 
 
 
 
565
  return model, 0, 0
566
 
567
  total_layers = len(layers)
568
- print(f" Found {total_layers} layers")
569
 
570
  # GQA 감지 (structure_info 우선)
571
  if structure_info and structure_info.get('gqa_detected'):
@@ -944,43 +963,54 @@ class HierarchicalRetention(nn.Module):
944
 
945
 
946
  def replace_attention_with_retention(model, use_hierarchical=True):
947
- """Attention → Retention 변환"""
948
  converted_count = 0
949
  total_layers = 0
950
 
 
 
 
951
  if hasattr(model, 'model') and hasattr(model.model, 'layers'):
952
  layers = model.model.layers
953
- total_layers = len(layers)
954
-
955
- config = model.config
956
-
957
- for layer_idx, layer in enumerate(layers):
958
- if hasattr(layer, 'self_attn'):
959
- old_attn = layer.self_attn
960
-
961
- if use_hierarchical:
962
- new_retention = HierarchicalRetention(config, layer_idx)
963
- else:
964
- new_retention = MultiScaleRetention(config, layer_idx)
965
-
966
- if hasattr(old_attn, 'q_proj'):
967
- try:
968
- target = new_retention.base_retention if use_hierarchical else new_retention
969
-
970
- if old_attn.q_proj.weight.shape == target.q_proj.weight.shape:
971
- target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
972
- if old_attn.k_proj.weight.shape == target.k_proj.weight.shape:
973
- target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
974
- if old_attn.v_proj.weight.shape == target.v_proj.weight.shape:
975
- target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
976
- if old_attn.o_proj.weight.shape == target.o_proj.weight.shape:
977
- target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
978
- except:
979
- pass
980
-
981
- layer.self_attn = new_retention
982
- converted_count += 1
983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
984
  return model, converted_count, total_layers
985
 
986
 
@@ -1042,6 +1072,10 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
1042
 
1043
  print(f"✅ Converted {converted}/{total} layers to Retention")
1044
 
 
 
 
 
1045
  phoenix_instance = cls(config)
1046
  phoenix_instance._original_model = base_model
1047
  phoenix_instance._initialized = True
@@ -1605,8 +1639,9 @@ def burn_model_zero_shot(
1605
  print(f"\n🔄 STEP 3: Converting Attention → Retention...")
1606
  convert_start = time.time()
1607
 
1608
- model.model, converted, total = replace_attention_with_retention(
1609
- model.model,
 
1610
  use_hierarchical=use_hierarchical,
1611
  structure_info=structure_info
1612
  )
@@ -1618,8 +1653,30 @@ def burn_model_zero_shot(
1618
 
1619
  if converted == 0:
1620
  print(f"\n⚠️ WARNING: No layers were converted!")
1621
- print(f" This model may not work correctly.")
1622
- print(f" Structure info: {structure_info}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1623
 
1624
  # 4. 평가
1625
  print(f"\n📊 STEP 4: Evaluating model quality...")
@@ -1725,8 +1782,8 @@ def burn_model_with_finetuning(
1725
  tokenizer.pad_token = tokenizer.eos_token
1726
 
1727
  print(f"\n🔄 STEP 3: Converting...")
1728
- model.model, converted, total = replace_attention_with_retention(
1729
- model.model,
1730
  use_hierarchical=use_hierarchical,
1731
  structure_info=structure_info
1732
  )
 
534
  replaced_count = 0
535
  total_layers = 0
536
 
537
+ # 레이어 탐색 (여러 경로 시도)
538
+ layers = None
539
+ layer_path = None
540
+
541
+ # 1. structure_info 활용
542
  if structure_info and structure_info.get('layer_path'):
543
  layer_path = structure_info['layer_path']
544
  print(f" Using structure info: {layer_path}")
545
 
546
  if layer_path == 'model.layers':
547
+ if hasattr(model, 'model') and hasattr(model.model, 'layers'):
548
+ layers = model.model.layers
549
  elif layer_path == 'transformer.h':
550
+ if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
551
+ layers = model.transformer.h
552
  elif layer_path == 'layers':
553
+ if hasattr(model, 'layers'):
554
+ layers = model.layers
555
  elif layer_path == 'model.decoder.layers':
556
+ if hasattr(model, 'model') and hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers'):
557
+ layers = model.model.decoder.layers
558
+
559
+ # 2. 자동 탐색 (structure_info 없거나 실패 시)
560
+ if layers is None:
561
+ print(f" Auto-detecting layer structure...")
562
+
563
+ possible_paths = [
564
+ ('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
565
+ ('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
566
+ ('layers', lambda m: m.layers if hasattr(m, 'layers') else None),
567
+ ('model.decoder.layers', lambda m: m.model.decoder.layers if hasattr(m, 'model') and hasattr(m.model, 'decoder') and hasattr(m.model.decoder, 'layers') else None),
568
+ ]
569
+
570
+ for path_name, path_fn in possible_paths:
571
+ result = path_fn(model)
572
+ if result is not None:
573
+ layers = result
574
+ layer_path = path_name
575
+ print(f" ✅ Found layers at: {path_name}")
576
+ break
577
 
578
  if layers is None:
579
+ print(" Cannot find layers - model structure not supported")
580
+ print(f" Model type: {type(model)}")
581
+ print(f" Has 'model' attr: {hasattr(model, 'model')}")
582
+ print(f" Has 'transformer' attr: {hasattr(model, 'transformer')}")
583
+ print(f" Has 'layers' attr: {hasattr(model, 'layers')}")
584
  return model, 0, 0
585
 
586
  total_layers = len(layers)
587
+ print(f" Found {total_layers} layers at '{layer_path}'")
588
 
589
  # GQA 감지 (structure_info 우선)
590
  if structure_info and structure_info.get('gqa_detected'):
 
963
 
964
 
965
  def replace_attention_with_retention(model, use_hierarchical=True):
966
+ """Attention → Retention 변환 (개선됨)"""
967
  converted_count = 0
968
  total_layers = 0
969
 
970
+ # 레이어 찾기 (여러 경로 시도)
971
+ layers = None
972
+
973
  if hasattr(model, 'model') and hasattr(model.model, 'layers'):
974
  layers = model.model.layers
975
+ elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
976
+ layers = model.transformer.h
977
+ elif hasattr(model, 'layers'):
978
+ layers = model.layers
979
+ else:
980
+ print("Cannot find layers in model")
981
+ return model, 0, 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982
 
983
+ total_layers = len(layers)
984
+ config = model.config
985
+
986
+ for layer_idx, layer in enumerate(layers):
987
+ if hasattr(layer, 'self_attn'):
988
+ old_attn = layer.self_attn
989
+
990
+ if use_hierarchical:
991
+ new_retention = HierarchicalRetention(config, layer_idx)
992
+ else:
993
+ new_retention = MultiScaleRetention(config, layer_idx)
994
+
995
+ if hasattr(old_attn, 'q_proj'):
996
+ try:
997
+ target = new_retention.base_retention if use_hierarchical else new_retention
998
+
999
+ if old_attn.q_proj.weight.shape == target.q_proj.weight.shape:
1000
+ target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
1001
+ if old_attn.k_proj.weight.shape == target.k_proj.weight.shape:
1002
+ target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
1003
+ if old_attn.v_proj.weight.shape == target.v_proj.weight.shape:
1004
+ target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
1005
+ if old_attn.o_proj.weight.shape == target.o_proj.weight.shape:
1006
+ target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
1007
+ except Exception as e:
1008
+ print(f"Weight copy warning for layer {layer_idx}: {e}")
1009
+
1010
+ layer.self_attn = new_retention
1011
+ converted_count += 1
1012
+
1013
+ print(f"Converted {converted_count}/{total_layers} layers to Retention")
1014
  return model, converted_count, total_layers
1015
 
1016
 
 
1072
 
1073
  print(f"✅ Converted {converted}/{total} layers to Retention")
1074
 
1075
+ if converted == 0:
1076
+ print(f"⚠️ WARNING: No layers were converted!")
1077
+ print(f" Model may not have Retention active.")
1078
+
1079
  phoenix_instance = cls(config)
1080
  phoenix_instance._original_model = base_model
1081
  phoenix_instance._initialized = True
 
1639
  print(f"\n🔄 STEP 3: Converting Attention → Retention...")
1640
  convert_start = time.time()
1641
 
1642
+ # FIX: 전체 모델을 전달하여 내부에서 레이어 찾기
1643
+ model, converted, total = replace_attention_with_retention(
1644
+ model,
1645
  use_hierarchical=use_hierarchical,
1646
  structure_info=structure_info
1647
  )
 
1653
 
1654
  if converted == 0:
1655
  print(f"\n⚠️ WARNING: No layers were converted!")
1656
+ print(f" This indicates a structural mismatch.")
1657
+ print(f" Model type: {type(model)}")
1658
+ if structure_info:
1659
+ print(f" Structure info: {structure_info.get('layer_path', 'unknown')}")
1660
+ print(f" Please check the model architecture.")
1661
+ else:
1662
+ # 변환 검증
1663
+ print(f"\n🔍 Verifying conversion...")
1664
+ verified_retention = 0
1665
+
1666
+ if hasattr(model, 'model') and hasattr(model.model, 'layers'):
1667
+ check_layers = model.model.layers
1668
+ else:
1669
+ check_layers = []
1670
+
1671
+ for layer in check_layers:
1672
+ if hasattr(layer, 'self_attn'):
1673
+ if 'Retention' in layer.self_attn.__class__.__name__:
1674
+ verified_retention += 1
1675
+
1676
+ print(f" ✅ Verified: {verified_retention}/{len(check_layers)} layers have Retention")
1677
+
1678
+ if verified_retention == 0 and converted > 0:
1679
+ print(f" ⚠️ WARNING: Conversion reported success but verification failed!")
1680
 
1681
  # 4. 평가
1682
  print(f"\n📊 STEP 4: Evaluating model quality...")
 
1782
  tokenizer.pad_token = tokenizer.eos_token
1783
 
1784
  print(f"\n🔄 STEP 3: Converting...")
1785
+ model, converted, total = replace_attention_with_retention(
1786
+ model,
1787
  use_hierarchical=use_hierarchical,
1788
  structure_info=structure_info
1789
  )