seawolf2357 commited on
Commit
7e1dc71
·
verified ·
1 Parent(s): cce66a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -161
app.py CHANGED
@@ -681,25 +681,17 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
681
  # =====================================================
682
 
683
  def generate_modeling_phoenix_code():
684
- """
685
- PHOENIX Custom Modeling Code 생성 v1.4.3
686
- ✅ FIX: Retention 변환 포함
687
- """
688
 
689
- modeling_code = '''"""
690
- PHOENIX Retention Model - Custom Implementation v1.4.3
691
- Auto-loaded by HuggingFace transformers with trust_remote_code=True
692
-
693
- ✅ FIX v1.4.3: Retention 변환 자동 실행
694
- ✅ FIX v1.4.2: Embedding Tying 개선
695
- ✅ FIX v1.4.1: State Dict 직접 로드
696
-
697
- VIDraft AI Research Lab
698
  """
699
 
700
  import torch
701
  import torch.nn as nn
702
- from typing import Optional, Tuple, Union
703
  from transformers.modeling_utils import PreTrainedModel
704
  from transformers.configuration_utils import PretrainedConfig
705
  from transformers import AutoConfig, AutoModelForCausalLM
@@ -707,181 +699,203 @@ import os
707
 
708
 
709
  class PhoenixConfig(PretrainedConfig):
710
- """PHOENIX Model Configuration"""
711
  model_type = "phoenix"
712
-
713
- def __init__(
714
- self,
715
- use_phoenix_retention=True,
716
- phoenix_version="1.4.3",
717
- original_architecture=None,
718
- original_model=None,
719
- **kwargs
720
- ):
721
  super().__init__(**kwargs)
722
  self.use_phoenix_retention = use_phoenix_retention
723
  self.phoenix_version = phoenix_version
724
- self.original_architecture = original_architecture
725
  self.original_model = original_model
 
726
 
727
 
728
- # ✅ CRITICAL: Retention 클래스들을 포함해야 함!
729
- # (여기에 MultiScaleRetention, HierarchicalRetention 전체 코드 삽입)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
 
731
 
732
- def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
733
- """
734
- Hub 로드 시 자동으로 Attention → Retention 변환
735
- """
736
- print("🔄 Converting Attention → Retention for loaded model...")
737
-
738
- layers = None
739
- if hasattr(model, 'model') and hasattr(model.model, 'layers'):
740
- layers = model.model.layers
 
 
 
741
 
742
- if layers is None:
743
- print("❌ Cannot find layers")
744
- return model, 0, 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
 
746
- replaced_count = 0
747
- for layer_idx, layer in enumerate(layers):
748
  if hasattr(layer, 'self_attn'):
749
- if use_hierarchical:
750
- new_retention = HierarchicalRetention(model.config, layer_idx)
751
- else:
752
- new_retention = MultiScaleRetention(model.config, layer_idx)
753
-
754
- layer.self_attn = new_retention
755
- replaced_count += 1
 
 
 
 
756
 
757
- print(f"✅ Converted {replaced_count}/{len(layers)} layers")
758
- return model, replaced_count, len(layers)
 
 
 
 
 
 
759
 
760
 
761
  class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
762
- """
763
- PHOENIX Model for Causal Language Modeling v1.4.3
764
- ✅ FIX: Retention 자동 변환 포함
765
- """
766
-
767
  def __init__(self, config):
768
  super().__init__(config)
769
- self.config = config
770
- self._original_model = None
771
- self._initialized = False
772
-
773
  @classmethod
774
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
775
- """🔥 PHOENIX 자동 로딩! v1.4.3"""
776
- print(f"🔥 Loading PHOENIX model from {pretrained_model_name_or_path}")
777
-
778
- config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
779
-
780
- original_model = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B')
781
- use_hierarchical = getattr(config, 'use_hierarchical', True)
782
-
783
- print(f" 📋 Original model: {original_model}")
784
- print(f" 🔄 Hierarchical: {use_hierarchical}")
785
 
786
  try:
787
- base_config = AutoConfig.from_pretrained(original_model, trust_remote_code=True)
788
  except:
789
- base_config = config
790
-
791
- base_model = AutoModelForCausalLM.from_config(base_config)
792
-
793
- print(f" ✅ Created base structure")
794
-
795
- # CRITICAL FIX: Retention 변환 실행!
796
- base_model, converted, total = replace_attention_with_retention_for_loading(
797
- base_model, use_hierarchical
798
- )
799
-
800
- # state_dict 로드
801
- state_dict = None
802
-
803
- if os.path.exists(pretrained_model_name_or_path):
804
- safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
805
- pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
806
-
807
- if os.path.exists(safetensors_path):
808
- try:
809
- from safetensors.torch import load_file
810
- state_dict = load_file(safetensors_path)
811
- print(f" ✅ Loaded from safetensors")
812
- except:
813
- pass
814
-
815
- if state_dict is None and os.path.exists(pytorch_path):
816
- state_dict = torch.load(pytorch_path, map_location='cpu')
817
- print(f" ✅ Loaded from pytorch_model.bin")
818
  else:
819
- try:
820
- from huggingface_hub import hf_hub_download
821
-
822
  try:
823
- safetensors_path = hf_hub_download(
824
- repo_id=pretrained_model_name_or_path,
825
- filename="model.safetensors"
826
- )
827
- from safetensors.torch import load_file
828
- state_dict = load_file(safetensors_path)
829
- print(f" ✅ Loaded from Hub (safetensors)")
830
- except:
831
- pytorch_path = hf_hub_download(
832
- repo_id=pretrained_model_name_or_path,
833
- filename="pytorch_model.bin"
834
- )
835
- state_dict = torch.load(pytorch_path, map_location='cpu')
836
- print(f" ✅ Loaded from Hub (pytorch_model.bin)")
837
- except Exception as e:
838
- print(f" ❌ Failed to load weights: {e}")
839
-
840
- if state_dict is not None:
841
- try:
842
- missing, unexpected = base_model.load_state_dict(state_dict, strict=False)
843
-
844
- print(f" ✅ Weights loaded")
845
- print(f" Missing keys: {len(missing)}")
846
- print(f" Unexpected keys: {len(unexpected)}")
847
-
848
- # ✅ Embedding Tying
849
- if 'lm_head.weight' in missing:
850
- print(f" ⚠️ lm_head.weight missing - checking tie_word_embeddings...")
851
-
852
- tie_embeddings = getattr(config, 'tie_word_embeddings', False)
853
- print(f" tie_word_embeddings: {tie_embeddings}")
854
-
855
- if tie_embeddings and hasattr(base_model, 'lm_head') and hasattr(base_model, 'model'):
856
- if hasattr(base_model.model, 'embed_tokens'):
857
- print(f" 🔗 Tying lm_head.weight to embed_tokens.weight...")
858
- base_model.lm_head.weight = base_model.model.embed_tokens.weight
859
- print(f" ✅ Embedding tying applied!")
860
-
861
- retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()]
862
- if retention_keys:
863
- print(f" ✅ Found {len(retention_keys)} Retention weight keys")
864
-
865
- except Exception as e:
866
- print(f" ⚠️ Weight loading warning: {e}")
867
-
868
- phoenix_instance = cls(config)
869
- phoenix_instance._original_model = base_model
870
- phoenix_instance._initialized = True
871
 
872
- print(f"✅ PHOENIX model ready!")
 
 
 
 
 
 
 
873
 
874
- return phoenix_instance
 
 
 
 
875
 
876
- def forward(self, *args, **kwargs):
877
- if not self._initialized or self._original_model is None:
878
- raise ValueError("Model not properly initialized. Use from_pretrained().")
879
- return self._original_model(*args, **kwargs)
880
 
881
- def generate(self, *args, **kwargs):
882
- if not self._initialized or self._original_model is None:
883
- raise ValueError("Model not properly initialized. Use from_pretrained().")
884
- return self._original_model.generate(*args, **kwargs)
885
 
886
 
887
  AutoConfig.register("phoenix", PhoenixConfig)
 
681
  # =====================================================
682
 
683
  def generate_modeling_phoenix_code():
684
+ """PHOENIX Custom Modeling Code v1.4.3 - COMPLETE"""
 
 
 
685
 
686
+ return '''"""
687
+ PHOENIX Retention Model v1.4.3
688
+ PhoenixPreTrainedModel 베이스 클래스 포함
689
+ ✅ 모든 Retention 클래스 완전 구현
 
 
 
 
 
690
  """
691
 
692
  import torch
693
  import torch.nn as nn
694
+ from typing import Optional, Tuple
695
  from transformers.modeling_utils import PreTrainedModel
696
  from transformers.configuration_utils import PretrainedConfig
697
  from transformers import AutoConfig, AutoModelForCausalLM
 
699
 
700
 
701
  class PhoenixConfig(PretrainedConfig):
 
702
  model_type = "phoenix"
703
+ def __init__(self, use_phoenix_retention=True, phoenix_version="1.4.3",
704
+ original_model=None, use_hierarchical=True, **kwargs):
 
 
 
 
 
 
 
705
  super().__init__(**kwargs)
706
  self.use_phoenix_retention = use_phoenix_retention
707
  self.phoenix_version = phoenix_version
 
708
  self.original_model = original_model
709
+ self.use_hierarchical = use_hierarchical
710
 
711
 
712
+ class MultiScaleRetention(nn.Module):
713
+ def __init__(self, config, layer_idx=0):
714
+ super().__init__()
715
+ self.hidden_size = config.hidden_size
716
+ self.num_heads = config.num_attention_heads
717
+ self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
718
+ self.num_key_value_heads = getattr(config, 'num_key_value_heads', self.num_heads)
719
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
720
+ self.q_dim = self.num_heads * self.head_dim
721
+ self.kv_dim = self.num_key_value_heads * self.head_dim
722
+
723
+ self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False)
724
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
725
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
726
+ self.o_proj = nn.Linear(self.q_dim, self.hidden_size, bias=False)
727
+ self.decay = nn.Parameter(torch.linspace(0.95, 0.99, self.num_heads))
728
+ self.group_norm = nn.GroupNorm(self.num_heads, self.q_dim)
729
+
730
+ def _repeat_kv(self, x, n):
731
+ b, h, s, d = x.shape
732
+ if n == 1: return x
733
+ return x[:, :, None, :, :].expand(b, h, n, s, d).reshape(b, h*n, s, d)
734
+
735
+ def forward(self, hidden_states, **kwargs):
736
+ b, s, _ = hidden_states.shape
737
+ device, dtype = hidden_states.device, hidden_states.dtype
738
+
739
+ if self.q_proj.weight.device != device:
740
+ self.to(device=device, dtype=dtype)
741
+
742
+ q = self.q_proj(hidden_states).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
743
+ k = self.k_proj(hidden_states).view(b, s, self.num_key_value_heads, self.head_dim).transpose(1, 2)
744
+ v = self.v_proj(hidden_states).view(b, s, self.num_key_value_heads, self.head_dim).transpose(1, 2)
745
+
746
+ k = self._repeat_kv(k, self.num_key_value_groups)
747
+ v = self._repeat_kv(v, self.num_key_value_groups)
748
+
749
+ out = self._retention(q, k, v)
750
+ out = out.transpose(1, 2).reshape(b, s, self.q_dim)
751
+ out = self.group_norm(out.transpose(1, 2)).transpose(1, 2)
752
+ return (self.o_proj(torch.clamp(out, -10, 10)), None)
753
+
754
+ def _retention(self, q, k, v):
755
+ b, h, s, d = q.shape
756
+ state = torch.zeros(b, h, d, d, dtype=q.dtype, device=q.device) + 1e-6
757
+ decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to(q)
758
+ outs = []
759
+ for t in range(s):
760
+ state = decay * state + torch.clamp(torch.einsum('bhd,bhe->bhde', k[:,:,t], v[:,:,t]), -5, 5)
761
+ state = torch.clamp(state, -10, 10)
762
+ outs.append(torch.einsum('bhd,bhde->bhe', q[:,:,t], state))
763
+ return torch.stack(outs, dim=2)
764
 
765
 
766
+ class HierarchicalRetention(nn.Module):
767
+ def __init__(self, config, layer_idx=0):
768
+ super().__init__()
769
+ self.base_retention = MultiScaleRetention(config, layer_idx)
770
+ h = config.hidden_size
771
+ self.d_state = h // 2
772
+ self.short_proj = nn.Linear(h, self.d_state)
773
+ self.medium_proj = nn.Linear(self.d_state, self.d_state)
774
+ self.long_proj = nn.Linear(self.d_state, self.d_state*2)
775
+ self.fusion = nn.Linear(self.d_state*4, h)
776
+ self.norm = nn.LayerNorm(h)
777
+ self.decays = [0.5, 0.8, 0.95]
778
 
779
+ def forward(self, x, **kwargs):
780
+ b, s, h = x.shape
781
+ device, dtype = x.device, x.dtype
782
+ if next(self.short_proj.parameters()).device != device:
783
+ self.to(device=device, dtype=dtype)
784
+
785
+ ret_out = self.base_retention(x)[0]
786
+ short = torch.zeros(b, self.d_state, dtype=dtype, device=device)
787
+ med = torch.zeros(b, self.d_state, dtype=dtype, device=device)
788
+ long = torch.zeros(b, self.d_state*2, dtype=dtype, device=device)
789
+ outs = []
790
+
791
+ for t in range(s):
792
+ short = self.decays[0]*short + self.short_proj(ret_out[:,t])
793
+ if t % 8 == 0: med = self.decays[1]*med + self.medium_proj(short)
794
+ if t % 64 == 0: long = self.decays[2]*long + self.long_proj(med)
795
+ outs.append(self.fusion(torch.cat([short, med, long], -1)))
796
+
797
+ return (self.norm(torch.stack(outs, 1)), None)
798
+
799
+
800
+ def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
801
+ layers = getattr(model, 'model', model)
802
+ layers = getattr(layers, 'layers', getattr(layers, 'h', getattr(layers, 'layers', None)))
803
+ if layers is None: return model, 0, 0
804
 
805
+ cnt = 0
806
+ for i, layer in enumerate(layers):
807
  if hasattr(layer, 'self_attn'):
808
+ layer.self_attn = HierarchicalRetention(model.config, i) if use_hierarchical else MultiScaleRetention(model.config, i)
809
+ cnt += 1
810
+ return model, cnt, len(layers)
811
+
812
+
813
+ # CRITICAL: PhoenixPreTrainedModel 베이스 클래스
814
+ class PhoenixPreTrainedModel(PreTrainedModel):
815
+ config_class = PhoenixConfig
816
+ base_model_prefix = "phoenix"
817
+ supports_gradient_checkpointing = True
818
+ _no_split_modules = ["MultiScaleRetention", "HierarchicalRetention"]
819
 
820
+ def _init_weights(self, m):
821
+ std = getattr(self.config, 'initializer_range', 0.02)
822
+ if isinstance(m, nn.Linear):
823
+ m.weight.data.normal_(0, std)
824
+ if m.bias is not None: m.bias.data.zero_()
825
+ elif isinstance(m, nn.Embedding):
826
+ m.weight.data.normal_(0, std)
827
+ if m.padding_idx: m.weight.data[m.padding_idx].zero_()
828
 
829
 
830
  class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
 
 
 
 
 
831
  def __init__(self, config):
832
  super().__init__(config)
833
+ self._model = None
834
+ self._ready = False
835
+
 
836
  @classmethod
837
+ def from_pretrained(cls, path, *args, **kwargs):
838
+ print(f"🔥 PHOENIX v1.4.3 loading from {path}")
839
+ config = AutoConfig.from_pretrained(path, trust_remote_code=True)
840
+ orig = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B')
841
+ hier = getattr(config, 'use_hierarchical', True)
 
 
 
 
 
 
842
 
843
  try:
844
+ base_cfg = AutoConfig.from_pretrained(orig, trust_remote_code=True)
845
  except:
846
+ base_cfg = config
847
+
848
+ model = AutoModelForCausalLM.from_config(base_cfg)
849
+ model, conv, tot = replace_attention_with_retention_for_loading(model, hier)
850
+ print(f" ✅ Converted {conv}/{tot} layers")
851
+
852
+ # 가중치 로드
853
+ sd = None
854
+ if os.path.exists(path):
855
+ for fname in ["model.safetensors", "pytorch_model.bin"]:
856
+ fpath = os.path.join(path, fname)
857
+ if os.path.exists(fpath):
858
+ if fname.endswith('.safetensors'):
859
+ from safetensors.torch import load_file
860
+ sd = load_file(fpath)
861
+ else:
862
+ sd = torch.load(fpath, map_location='cpu')
863
+ break
 
 
 
 
 
 
 
 
 
 
 
864
  else:
865
+ from huggingface_hub import hf_hub_download
866
+ for fname in ["model.safetensors", "pytorch_model.bin"]:
 
867
  try:
868
+ fpath = hf_hub_download(path, fname)
869
+ if fname.endswith('.safetensors'):
870
+ from safetensors.torch import load_file
871
+ sd = load_file(fpath)
872
+ else:
873
+ sd = torch.load(fpath, map_location='cpu')
874
+ break
875
+ except: pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
876
 
877
+ if sd:
878
+ miss, unex = model.load_state_dict(sd, strict=False)
879
+ print(f" 📦 Weights: {len(miss)} missing, {len(unex)} unexpected")
880
+
881
+ if 'lm_head.weight' in miss and getattr(config, 'tie_word_embeddings', False):
882
+ if hasattr(model, 'lm_head') and hasattr(model.model, 'embed_tokens'):
883
+ model.lm_head.weight = model.model.embed_tokens.weight
884
+ print(f" 🔗 Tied embeddings")
885
 
886
+ inst = cls(config)
887
+ inst._model = model
888
+ inst._ready = True
889
+ print(f"✅ PHOENIX v1.4.3 ready!")
890
+ return inst
891
 
892
+ def forward(self, *a, **k):
893
+ if not self._ready: raise ValueError("Not initialized")
894
+ return self._model(*a, **k)
 
895
 
896
+ def generate(self, *a, **k):
897
+ if not self._ready: raise ValueError("Not initialized")
898
+ return self._model.generate(*a, **k)
 
899
 
900
 
901
  AutoConfig.register("phoenix", PhoenixConfig)