seawolf2357 commited on
Commit
cce66a2
·
verified ·
1 Parent(s): 7916437

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -30
app.py CHANGED
@@ -682,16 +682,17 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
682
 
683
  def generate_modeling_phoenix_code():
684
  """
685
- PHOENIX Custom Modeling Code 생성 v1.4.2
686
- ✅ FIX: Embedding Tying 개선
687
  """
688
 
689
  modeling_code = '''"""
690
- PHOENIX Retention Model - Custom Implementation v1.4.2
691
  Auto-loaded by HuggingFace transformers with trust_remote_code=True
692
 
693
- ✅ FIX v1.4.2: Embedding Tying 개선 - 저장 시점 처리
694
- ✅ FIX v1.4.1: State Dict 직접 로드로 Retention 가중치 보존
 
695
 
696
  VIDraft AI Research Lab
697
  """
@@ -712,7 +713,7 @@ class PhoenixConfig(PretrainedConfig):
712
  def __init__(
713
  self,
714
  use_phoenix_retention=True,
715
- phoenix_version="1.4.2",
716
  original_architecture=None,
717
  original_model=None,
718
  **kwargs
@@ -724,32 +725,43 @@ class PhoenixConfig(PretrainedConfig):
724
  self.original_model = original_model
725
 
726
 
727
- # [MultiScaleRetention and HierarchicalRetention classes would be here - same as in main code]
 
728
 
729
 
730
- class PhoenixPreTrainedModel(PreTrainedModel):
731
- """Base PHOENIX PreTrainedModel"""
732
- config_class = PhoenixConfig
733
- base_model_prefix = "phoenix"
734
- supports_gradient_checkpointing = True
735
- _no_split_modules = ["MultiScaleRetention", "HierarchicalRetention"]
 
 
 
 
 
 
 
736
 
737
- def _init_weights(self, module):
738
- if isinstance(module, nn.Linear):
739
- module.weight.data.normal_(mean=0.0, std=0.02)
740
- if module.bias is not None:
741
- module.bias.data.zero_()
742
- elif isinstance(module, nn.Embedding):
743
- module.weight.data.normal_(mean=0.0, std=0.02)
744
- elif isinstance(module, nn.LayerNorm):
745
- module.bias.data.zero_()
746
- module.weight.data.fill_(1.0)
 
 
 
747
 
748
 
749
  class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
750
  """
751
- PHOENIX Model for Causal Language Modeling v1.4.2
752
- ✅ FIX: Embedding Tying 개선
753
  """
754
 
755
  def __init__(self, config):
@@ -760,7 +772,7 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
760
 
761
  @classmethod
762
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
763
- """🔥 PHOENIX 자동 로딩! v1.4.2"""
764
  print(f"🔥 Loading PHOENIX model from {pretrained_model_name_or_path}")
765
 
766
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
@@ -780,9 +792,12 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
780
 
781
  print(f" ✅ Created base structure")
782
 
783
- # Retention 변환 (실제 코드에서는 import 필요)
784
- # base_model, converted, total = replace_attention_with_retention(base_model, use_hierarchical)
 
 
785
 
 
786
  state_dict = None
787
 
788
  if os.path.exists(pretrained_model_name_or_path):
@@ -830,7 +845,7 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
830
  print(f" Missing keys: {len(missing)}")
831
  print(f" Unexpected keys: {len(unexpected)}")
832
 
833
- # ✅ FIX v1.4.2: Embedding Tying 처리
834
  if 'lm_head.weight' in missing:
835
  print(f" ⚠️ lm_head.weight missing - checking tie_word_embeddings...")
836
 
@@ -842,7 +857,6 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
842
  print(f" 🔗 Tying lm_head.weight to embed_tokens.weight...")
843
  base_model.lm_head.weight = base_model.model.embed_tokens.weight
844
  print(f" ✅ Embedding tying applied!")
845
- print(f" Verification: {base_model.lm_head.weight is base_model.model.embed_tokens.weight}")
846
 
847
  retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()]
848
  if retention_keys:
 
682
 
683
  def generate_modeling_phoenix_code():
684
  """
685
+ PHOENIX Custom Modeling Code 생성 v1.4.3
686
+ ✅ FIX: Retention 변환 포함
687
  """
688
 
689
  modeling_code = '''"""
690
+ PHOENIX Retention Model - Custom Implementation v1.4.3
691
  Auto-loaded by HuggingFace transformers with trust_remote_code=True
692
 
693
+ ✅ FIX v1.4.3: Retention 변환 자동 실행
694
+ ✅ FIX v1.4.2: Embedding Tying 개선
695
+ ✅ FIX v1.4.1: State Dict 직접 로드
696
 
697
  VIDraft AI Research Lab
698
  """
 
713
  def __init__(
714
  self,
715
  use_phoenix_retention=True,
716
+ phoenix_version="1.4.3",
717
  original_architecture=None,
718
  original_model=None,
719
  **kwargs
 
725
  self.original_model = original_model
726
 
727
 
728
+ # CRITICAL: Retention 클래스들을 포함해야 함!
729
+ # (여기에 MultiScaleRetention, HierarchicalRetention 전체 코드 삽입)
730
 
731
 
732
+ def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
733
+ """
734
+ Hub 로드 시 자동으로 Attention → Retention 변환
735
+ """
736
+ print("🔄 Converting Attention → Retention for loaded model...")
737
+
738
+ layers = None
739
+ if hasattr(model, 'model') and hasattr(model.model, 'layers'):
740
+ layers = model.model.layers
741
+
742
+ if layers is None:
743
+ print("❌ Cannot find layers")
744
+ return model, 0, 0
745
 
746
+ replaced_count = 0
747
+ for layer_idx, layer in enumerate(layers):
748
+ if hasattr(layer, 'self_attn'):
749
+ if use_hierarchical:
750
+ new_retention = HierarchicalRetention(model.config, layer_idx)
751
+ else:
752
+ new_retention = MultiScaleRetention(model.config, layer_idx)
753
+
754
+ layer.self_attn = new_retention
755
+ replaced_count += 1
756
+
757
+ print(f"✅ Converted {replaced_count}/{len(layers)} layers")
758
+ return model, replaced_count, len(layers)
759
 
760
 
761
  class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
762
  """
763
+ PHOENIX Model for Causal Language Modeling v1.4.3
764
+ ✅ FIX: Retention 자동 변환 포함
765
  """
766
 
767
  def __init__(self, config):
 
772
 
773
  @classmethod
774
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
775
+ """🔥 PHOENIX 자동 로딩! v1.4.3"""
776
  print(f"🔥 Loading PHOENIX model from {pretrained_model_name_or_path}")
777
 
778
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
 
792
 
793
  print(f" ✅ Created base structure")
794
 
795
+ # CRITICAL FIX: Retention 변환 실행!
796
+ base_model, converted, total = replace_attention_with_retention_for_loading(
797
+ base_model, use_hierarchical
798
+ )
799
 
800
+ # state_dict 로드
801
  state_dict = None
802
 
803
  if os.path.exists(pretrained_model_name_or_path):
 
845
  print(f" Missing keys: {len(missing)}")
846
  print(f" Unexpected keys: {len(unexpected)}")
847
 
848
+ # ✅ Embedding Tying
849
  if 'lm_head.weight' in missing:
850
  print(f" ⚠️ lm_head.weight missing - checking tie_word_embeddings...")
851
 
 
857
  print(f" 🔗 Tying lm_head.weight to embed_tokens.weight...")
858
  base_model.lm_head.weight = base_model.model.embed_tokens.weight
859
  print(f" ✅ Embedding tying applied!")
 
860
 
861
  retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()]
862
  if retention_keys: