seawolf2357 commited on
Commit
0d2bdda
·
verified ·
1 Parent(s): a01c0f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -21
app.py CHANGED
@@ -635,12 +635,20 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
635
  v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
636
  o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
637
 
 
 
 
 
 
 
 
638
  if q_match and k_match and v_match and o_match:
639
  target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
640
  target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
641
  target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
642
  target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
643
- print(f" ✅ Layer {layer_idx}: Perfect match")
 
644
 
645
  elif q_match and o_match:
646
  target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
@@ -652,14 +660,17 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
652
  target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
653
  target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
654
 
655
- print(f" ✅ Layer {layer_idx}: Partial (GQA)")
 
656
 
657
  else:
658
  nn.init.xavier_uniform_(target.q_proj.weight)
659
  nn.init.xavier_uniform_(target.k_proj.weight)
660
  nn.init.xavier_uniform_(target.v_proj.weight)
661
  nn.init.xavier_uniform_(target.o_proj.weight)
662
- print(f" ⚠️ Layer {layer_idx}: Xavier init")
 
 
663
 
664
  except Exception as e:
665
  print(f" ⚠️ Layer {layer_idx}: Weight copy failed - {e}")
@@ -989,6 +1000,8 @@ def replace_attention_with_retention(model, use_hierarchical=True):
989
  total_layers = len(layers)
990
  config = model.config
991
 
 
 
992
  for layer_idx, layer in enumerate(layers):
993
  if hasattr(layer, 'self_attn'):
994
  old_attn = layer.self_attn
@@ -1002,16 +1015,43 @@ def replace_attention_with_retention(model, use_hierarchical=True):
1002
  try:
1003
  target = new_retention.base_retention if use_hierarchical else new_retention
1004
 
1005
- if old_attn.q_proj.weight.shape == target.q_proj.weight.shape:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1006
  target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
1007
- if old_attn.k_proj.weight.shape == target.k_proj.weight.shape:
1008
  target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
1009
- if old_attn.v_proj.weight.shape == target.v_proj.weight.shape:
1010
  target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
1011
- if old_attn.o_proj.weight.shape == target.o_proj.weight.shape:
1012
  target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1013
  except Exception as e:
1014
- print(f"Weight copy warning for layer {layer_idx}: {e}")
 
1015
 
1016
  layer.self_attn = new_retention
1017
  converted_count += 1
@@ -1150,10 +1190,17 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
1150
  print(f" Missing keys: {len(missing)}")
1151
  print(f" Unexpected keys: {len(unexpected)}")
1152
 
 
 
 
 
 
 
1153
  # Retention 가중치 확인
1154
  retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()]
1155
  if retention_keys:
1156
  print(f" ✅ Found {len(retention_keys)} Retention weight keys")
 
1157
  else:
1158
  print(f" ⚠️ No Retention keys found in state dict")
1159
 
@@ -2314,18 +2361,33 @@ def validate_phoenix_model(
2314
  retention_count = 0
2315
  attention_count = 0
2316
 
2317
- if hasattr(model, 'model'):
2318
- layers = model.model.layers if hasattr(model.model, 'layers') else []
2319
-
2320
- for layer in layers:
2321
- if hasattr(layer, 'self_attn'):
2322
- attn = layer.self_attn
2323
- class_name = attn.__class__.__name__
2324
-
2325
- if 'Retention' in class_name:
2326
- retention_count += 1
2327
- else:
2328
- attention_count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2329
 
2330
  total = retention_count + attention_count
2331
  retention_info = f"""
@@ -2334,7 +2396,7 @@ def validate_phoenix_model(
2334
  - **Attention Layers**: {attention_count}/{total}
2335
  - **Status**: {'✅ PHOENIX Active' if retention_count > 0 else '⚠️ No Retention Found'}
2336
  """
2337
- print(f" Retention: {retention_count}/{total} layers")
2338
 
2339
  # 4. 생성 테스트
2340
  print(f"\n🚀 Running generation tests...")
 
635
  v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
636
  o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
637
 
638
+ if layer_idx == 0: # 첫 레이어만 상세 출력
639
+ print(f" 🔍 Layer 0 shape analysis:")
640
+ print(f" Old Q: {old_attn.q_proj.weight.shape} vs New Q: {target.q_proj.weight.shape} → {'✅' if q_match else '❌'}")
641
+ print(f" Old K: {old_attn.k_proj.weight.shape} vs New K: {target.k_proj.weight.shape} → {'✅' if k_match else '❌'}")
642
+ print(f" Old V: {old_attn.v_proj.weight.shape} vs New V: {target.v_proj.weight.shape} → {'✅' if v_match else '❌'}")
643
+ print(f" Old O: {old_attn.o_proj.weight.shape} vs New O: {target.o_proj.weight.shape} → {'✅' if o_match else '❌'}")
644
+
645
  if q_match and k_match and v_match and o_match:
646
  target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
647
  target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
648
  target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
649
  target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
650
+ if layer_idx == 0:
651
+ print(f" ✅ Layer {layer_idx}: Perfect match - weights copied")
652
 
653
  elif q_match and o_match:
654
  target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
 
660
  target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
661
  target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
662
 
663
+ if layer_idx == 0:
664
+ print(f" ✅ Layer {layer_idx}: Partial match (GQA) - partial weights copied")
665
 
666
  else:
667
  nn.init.xavier_uniform_(target.q_proj.weight)
668
  nn.init.xavier_uniform_(target.k_proj.weight)
669
  nn.init.xavier_uniform_(target.v_proj.weight)
670
  nn.init.xavier_uniform_(target.o_proj.weight)
671
+ if layer_idx == 0:
672
+ print(f" ⚠️ Layer {layer_idx}: Shape mismatch - Xavier init used")
673
+ print(f" This will result in random weights!")
674
 
675
  except Exception as e:
676
  print(f" ⚠️ Layer {layer_idx}: Weight copy failed - {e}")
 
1000
  total_layers = len(layers)
1001
  config = model.config
1002
 
1003
+ print(f"Converting {total_layers} layers...")
1004
+
1005
  for layer_idx, layer in enumerate(layers):
1006
  if hasattr(layer, 'self_attn'):
1007
  old_attn = layer.self_attn
 
1015
  try:
1016
  target = new_retention.base_retention if use_hierarchical else new_retention
1017
 
1018
+ # Shape 확인
1019
+ q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape
1020
+ k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape
1021
+ v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
1022
+ o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
1023
+
1024
+ if layer_idx == 0:
1025
+ print(f"Layer 0 analysis:")
1026
+ print(f" Q: {old_attn.q_proj.weight.shape} vs {target.q_proj.weight.shape} → {'✅' if q_match else '❌'}")
1027
+ print(f" K: {old_attn.k_proj.weight.shape} vs {target.k_proj.weight.shape} → {'✅' if k_match else '❌'}")
1028
+ print(f" V: {old_attn.v_proj.weight.shape} vs {target.v_proj.weight.shape} → {'✅' if v_match else '❌'}")
1029
+ print(f" O: {old_attn.o_proj.weight.shape} vs {target.o_proj.weight.shape} → {'✅' if o_match else '❌'}")
1030
+
1031
+ # 가중치 복사
1032
+ if q_match and k_match and v_match and o_match:
1033
  target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
 
1034
  target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
 
1035
  target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
 
1036
  target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
1037
+ if layer_idx == 0:
1038
+ print(f" ✅ Perfect match - weights copied")
1039
+ elif q_match and o_match:
1040
+ target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
1041
+ target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
1042
+ k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0])
1043
+ v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0])
1044
+ target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
1045
+ target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
1046
+ if layer_idx == 0:
1047
+ print(f" ✅ Partial match (GQA) - partial copy")
1048
+ else:
1049
+ if layer_idx == 0:
1050
+ print(f" ⚠️ Shape mismatch - keeping random init")
1051
+
1052
  except Exception as e:
1053
+ if layer_idx == 0:
1054
+ print(f"Weight copy error: {e}")
1055
 
1056
  layer.self_attn = new_retention
1057
  converted_count += 1
 
1190
  print(f" Missing keys: {len(missing)}")
1191
  print(f" Unexpected keys: {len(unexpected)}")
1192
 
1193
+ # 상세 정보 출력 (처음 5개만)
1194
+ if missing:
1195
+ print(f" Missing (first 5): {missing[:5]}")
1196
+ if unexpected:
1197
+ print(f" Unexpected (first 5): {unexpected[:5]}")
1198
+
1199
  # Retention 가중치 확인
1200
  retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()]
1201
  if retention_keys:
1202
  print(f" ✅ Found {len(retention_keys)} Retention weight keys")
1203
+ print(f" Sample keys: {retention_keys[:3]}")
1204
  else:
1205
  print(f" ⚠️ No Retention keys found in state dict")
1206
 
 
2361
  retention_count = 0
2362
  attention_count = 0
2363
 
2364
+ # PhoenixModelForCausalLM인 경우 _original_model 확인
2365
+ check_model = model
2366
+ if hasattr(model, '_original_model') and model._original_model is not None:
2367
+ print(f" 📋 Detected PhoenixModelForCausalLM wrapper")
2368
+ check_model = model._original_model
2369
+
2370
+ layers = []
2371
+ if hasattr(check_model, 'model') and hasattr(check_model.model, 'layers'):
2372
+ layers = check_model.model.layers
2373
+ elif hasattr(check_model, 'layers'):
2374
+ layers = check_model.layers
2375
+
2376
+ print(f" 🔍 Checking {len(layers)} layers...")
2377
+
2378
+ for i, layer in enumerate(layers):
2379
+ if hasattr(layer, 'self_attn'):
2380
+ attn = layer.self_attn
2381
+ class_name = attn.__class__.__name__
2382
+
2383
+ if 'Retention' in class_name:
2384
+ retention_count += 1
2385
+ if i < 3: # 처음 3개만 출력
2386
+ print(f" ✅ Layer {i}: {class_name}")
2387
+ else:
2388
+ attention_count += 1
2389
+ if i < 3:
2390
+ print(f" ⚠️ Layer {i}: {class_name}")
2391
 
2392
  total = retention_count + attention_count
2393
  retention_info = f"""
 
2396
  - **Attention Layers**: {attention_count}/{total}
2397
  - **Status**: {'✅ PHOENIX Active' if retention_count > 0 else '⚠️ No Retention Found'}
2398
  """
2399
+ print(f" 📊 Result: {retention_count}/{total} layers have Retention")
2400
 
2401
  # 4. 생성 테스트
2402
  print(f"\n🚀 Running generation tests...")