Update app.py
Browse files
app.py
CHANGED
|
@@ -681,25 +681,17 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
|
|
| 681 |
# =====================================================
|
| 682 |
|
| 683 |
def generate_modeling_phoenix_code():
|
| 684 |
-
"""
|
| 685 |
-
PHOENIX Custom Modeling Code 생성 v1.4.3
|
| 686 |
-
✅ FIX: Retention 변환 포함
|
| 687 |
-
"""
|
| 688 |
|
| 689 |
-
|
| 690 |
-
PHOENIX Retention Model
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
✅ FIX v1.4.3: Retention 변환 자동 실행
|
| 694 |
-
✅ FIX v1.4.2: Embedding Tying 개선
|
| 695 |
-
✅ FIX v1.4.1: State Dict 직접 로드
|
| 696 |
-
|
| 697 |
-
VIDraft AI Research Lab
|
| 698 |
"""
|
| 699 |
|
| 700 |
import torch
|
| 701 |
import torch.nn as nn
|
| 702 |
-
from typing import Optional, Tuple
|
| 703 |
from transformers.modeling_utils import PreTrainedModel
|
| 704 |
from transformers.configuration_utils import PretrainedConfig
|
| 705 |
from transformers import AutoConfig, AutoModelForCausalLM
|
|
@@ -707,181 +699,203 @@ import os
|
|
| 707 |
|
| 708 |
|
| 709 |
class PhoenixConfig(PretrainedConfig):
|
| 710 |
-
"""PHOENIX Model Configuration"""
|
| 711 |
model_type = "phoenix"
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
self,
|
| 715 |
-
use_phoenix_retention=True,
|
| 716 |
-
phoenix_version="1.4.3",
|
| 717 |
-
original_architecture=None,
|
| 718 |
-
original_model=None,
|
| 719 |
-
**kwargs
|
| 720 |
-
):
|
| 721 |
super().__init__(**kwargs)
|
| 722 |
self.use_phoenix_retention = use_phoenix_retention
|
| 723 |
self.phoenix_version = phoenix_version
|
| 724 |
-
self.original_architecture = original_architecture
|
| 725 |
self.original_model = original_model
|
|
|
|
| 726 |
|
| 727 |
|
| 728 |
-
|
| 729 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
|
| 731 |
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
|
|
|
|
|
|
|
|
|
| 741 |
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 745 |
|
| 746 |
-
|
| 747 |
-
for
|
| 748 |
if hasattr(layer, 'self_attn'):
|
| 749 |
-
if use_hierarchical
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 756 |
|
| 757 |
-
|
| 758 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 759 |
|
| 760 |
|
| 761 |
class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
|
| 762 |
-
"""
|
| 763 |
-
PHOENIX Model for Causal Language Modeling v1.4.3
|
| 764 |
-
✅ FIX: Retention 자동 변환 포함
|
| 765 |
-
"""
|
| 766 |
-
|
| 767 |
def __init__(self, config):
|
| 768 |
super().__init__(config)
|
| 769 |
-
self.
|
| 770 |
-
self.
|
| 771 |
-
|
| 772 |
-
|
| 773 |
@classmethod
|
| 774 |
-
def from_pretrained(cls,
|
| 775 |
-
"
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
original_model = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B')
|
| 781 |
-
use_hierarchical = getattr(config, 'use_hierarchical', True)
|
| 782 |
-
|
| 783 |
-
print(f" 📋 Original model: {original_model}")
|
| 784 |
-
print(f" 🔄 Hierarchical: {use_hierarchical}")
|
| 785 |
|
| 786 |
try:
|
| 787 |
-
|
| 788 |
except:
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
print(f" ✅
|
| 794 |
-
|
| 795 |
-
#
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
if os.path.exists(safetensors_path):
|
| 808 |
-
try:
|
| 809 |
-
from safetensors.torch import load_file
|
| 810 |
-
state_dict = load_file(safetensors_path)
|
| 811 |
-
print(f" ✅ Loaded from safetensors")
|
| 812 |
-
except:
|
| 813 |
-
pass
|
| 814 |
-
|
| 815 |
-
if state_dict is None and os.path.exists(pytorch_path):
|
| 816 |
-
state_dict = torch.load(pytorch_path, map_location='cpu')
|
| 817 |
-
print(f" ✅ Loaded from pytorch_model.bin")
|
| 818 |
else:
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
try:
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
except:
|
| 831 |
-
pytorch_path = hf_hub_download(
|
| 832 |
-
repo_id=pretrained_model_name_or_path,
|
| 833 |
-
filename="pytorch_model.bin"
|
| 834 |
-
)
|
| 835 |
-
state_dict = torch.load(pytorch_path, map_location='cpu')
|
| 836 |
-
print(f" ✅ Loaded from Hub (pytorch_model.bin)")
|
| 837 |
-
except Exception as e:
|
| 838 |
-
print(f" ❌ Failed to load weights: {e}")
|
| 839 |
-
|
| 840 |
-
if state_dict is not None:
|
| 841 |
-
try:
|
| 842 |
-
missing, unexpected = base_model.load_state_dict(state_dict, strict=False)
|
| 843 |
-
|
| 844 |
-
print(f" ✅ Weights loaded")
|
| 845 |
-
print(f" Missing keys: {len(missing)}")
|
| 846 |
-
print(f" Unexpected keys: {len(unexpected)}")
|
| 847 |
-
|
| 848 |
-
# ✅ Embedding Tying
|
| 849 |
-
if 'lm_head.weight' in missing:
|
| 850 |
-
print(f" ⚠️ lm_head.weight missing - checking tie_word_embeddings...")
|
| 851 |
-
|
| 852 |
-
tie_embeddings = getattr(config, 'tie_word_embeddings', False)
|
| 853 |
-
print(f" tie_word_embeddings: {tie_embeddings}")
|
| 854 |
-
|
| 855 |
-
if tie_embeddings and hasattr(base_model, 'lm_head') and hasattr(base_model, 'model'):
|
| 856 |
-
if hasattr(base_model.model, 'embed_tokens'):
|
| 857 |
-
print(f" 🔗 Tying lm_head.weight to embed_tokens.weight...")
|
| 858 |
-
base_model.lm_head.weight = base_model.model.embed_tokens.weight
|
| 859 |
-
print(f" ✅ Embedding tying applied!")
|
| 860 |
-
|
| 861 |
-
retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()]
|
| 862 |
-
if retention_keys:
|
| 863 |
-
print(f" ✅ Found {len(retention_keys)} Retention weight keys")
|
| 864 |
-
|
| 865 |
-
except Exception as e:
|
| 866 |
-
print(f" ⚠️ Weight loading warning: {e}")
|
| 867 |
-
|
| 868 |
-
phoenix_instance = cls(config)
|
| 869 |
-
phoenix_instance._original_model = base_model
|
| 870 |
-
phoenix_instance._initialized = True
|
| 871 |
|
| 872 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
|
| 874 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 875 |
|
| 876 |
-
def forward(self, *
|
| 877 |
-
if not self.
|
| 878 |
-
|
| 879 |
-
return self._original_model(*args, **kwargs)
|
| 880 |
|
| 881 |
-
def generate(self, *
|
| 882 |
-
if not self.
|
| 883 |
-
|
| 884 |
-
return self._original_model.generate(*args, **kwargs)
|
| 885 |
|
| 886 |
|
| 887 |
AutoConfig.register("phoenix", PhoenixConfig)
|
|
|
|
| 681 |
# =====================================================
|
| 682 |
|
| 683 |
def generate_modeling_phoenix_code():
|
| 684 |
+
"""PHOENIX Custom Modeling Code v1.4.3 - COMPLETE"""
|
|
|
|
|
|
|
|
|
|
| 685 |
|
| 686 |
+
return '''"""
|
| 687 |
+
PHOENIX Retention Model v1.4.3
|
| 688 |
+
✅ PhoenixPreTrainedModel 베이스 클래스 포함
|
| 689 |
+
✅ 모든 Retention 클래스 완전 구현
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
"""
|
| 691 |
|
| 692 |
import torch
|
| 693 |
import torch.nn as nn
|
| 694 |
+
from typing import Optional, Tuple
|
| 695 |
from transformers.modeling_utils import PreTrainedModel
|
| 696 |
from transformers.configuration_utils import PretrainedConfig
|
| 697 |
from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
|
| 699 |
|
| 700 |
|
| 701 |
class PhoenixConfig(PretrainedConfig):
|
|
|
|
| 702 |
model_type = "phoenix"
|
| 703 |
+
def __init__(self, use_phoenix_retention=True, phoenix_version="1.4.3",
|
| 704 |
+
original_model=None, use_hierarchical=True, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
super().__init__(**kwargs)
|
| 706 |
self.use_phoenix_retention = use_phoenix_retention
|
| 707 |
self.phoenix_version = phoenix_version
|
|
|
|
| 708 |
self.original_model = original_model
|
| 709 |
+
self.use_hierarchical = use_hierarchical
|
| 710 |
|
| 711 |
|
| 712 |
+
class MultiScaleRetention(nn.Module):
|
| 713 |
+
def __init__(self, config, layer_idx=0):
|
| 714 |
+
super().__init__()
|
| 715 |
+
self.hidden_size = config.hidden_size
|
| 716 |
+
self.num_heads = config.num_attention_heads
|
| 717 |
+
self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
|
| 718 |
+
self.num_key_value_heads = getattr(config, 'num_key_value_heads', self.num_heads)
|
| 719 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 720 |
+
self.q_dim = self.num_heads * self.head_dim
|
| 721 |
+
self.kv_dim = self.num_key_value_heads * self.head_dim
|
| 722 |
+
|
| 723 |
+
self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False)
|
| 724 |
+
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
|
| 725 |
+
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
|
| 726 |
+
self.o_proj = nn.Linear(self.q_dim, self.hidden_size, bias=False)
|
| 727 |
+
self.decay = nn.Parameter(torch.linspace(0.95, 0.99, self.num_heads))
|
| 728 |
+
self.group_norm = nn.GroupNorm(self.num_heads, self.q_dim)
|
| 729 |
+
|
| 730 |
+
def _repeat_kv(self, x, n):
|
| 731 |
+
b, h, s, d = x.shape
|
| 732 |
+
if n == 1: return x
|
| 733 |
+
return x[:, :, None, :, :].expand(b, h, n, s, d).reshape(b, h*n, s, d)
|
| 734 |
+
|
| 735 |
+
def forward(self, hidden_states, **kwargs):
|
| 736 |
+
b, s, _ = hidden_states.shape
|
| 737 |
+
device, dtype = hidden_states.device, hidden_states.dtype
|
| 738 |
+
|
| 739 |
+
if self.q_proj.weight.device != device:
|
| 740 |
+
self.to(device=device, dtype=dtype)
|
| 741 |
+
|
| 742 |
+
q = self.q_proj(hidden_states).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
|
| 743 |
+
k = self.k_proj(hidden_states).view(b, s, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 744 |
+
v = self.v_proj(hidden_states).view(b, s, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 745 |
+
|
| 746 |
+
k = self._repeat_kv(k, self.num_key_value_groups)
|
| 747 |
+
v = self._repeat_kv(v, self.num_key_value_groups)
|
| 748 |
+
|
| 749 |
+
out = self._retention(q, k, v)
|
| 750 |
+
out = out.transpose(1, 2).reshape(b, s, self.q_dim)
|
| 751 |
+
out = self.group_norm(out.transpose(1, 2)).transpose(1, 2)
|
| 752 |
+
return (self.o_proj(torch.clamp(out, -10, 10)), None)
|
| 753 |
+
|
| 754 |
+
def _retention(self, q, k, v):
|
| 755 |
+
b, h, s, d = q.shape
|
| 756 |
+
state = torch.zeros(b, h, d, d, dtype=q.dtype, device=q.device) + 1e-6
|
| 757 |
+
decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to(q)
|
| 758 |
+
outs = []
|
| 759 |
+
for t in range(s):
|
| 760 |
+
state = decay * state + torch.clamp(torch.einsum('bhd,bhe->bhde', k[:,:,t], v[:,:,t]), -5, 5)
|
| 761 |
+
state = torch.clamp(state, -10, 10)
|
| 762 |
+
outs.append(torch.einsum('bhd,bhde->bhe', q[:,:,t], state))
|
| 763 |
+
return torch.stack(outs, dim=2)
|
| 764 |
|
| 765 |
|
| 766 |
+
class HierarchicalRetention(nn.Module):
|
| 767 |
+
def __init__(self, config, layer_idx=0):
|
| 768 |
+
super().__init__()
|
| 769 |
+
self.base_retention = MultiScaleRetention(config, layer_idx)
|
| 770 |
+
h = config.hidden_size
|
| 771 |
+
self.d_state = h // 2
|
| 772 |
+
self.short_proj = nn.Linear(h, self.d_state)
|
| 773 |
+
self.medium_proj = nn.Linear(self.d_state, self.d_state)
|
| 774 |
+
self.long_proj = nn.Linear(self.d_state, self.d_state*2)
|
| 775 |
+
self.fusion = nn.Linear(self.d_state*4, h)
|
| 776 |
+
self.norm = nn.LayerNorm(h)
|
| 777 |
+
self.decays = [0.5, 0.8, 0.95]
|
| 778 |
|
| 779 |
+
def forward(self, x, **kwargs):
|
| 780 |
+
b, s, h = x.shape
|
| 781 |
+
device, dtype = x.device, x.dtype
|
| 782 |
+
if next(self.short_proj.parameters()).device != device:
|
| 783 |
+
self.to(device=device, dtype=dtype)
|
| 784 |
+
|
| 785 |
+
ret_out = self.base_retention(x)[0]
|
| 786 |
+
short = torch.zeros(b, self.d_state, dtype=dtype, device=device)
|
| 787 |
+
med = torch.zeros(b, self.d_state, dtype=dtype, device=device)
|
| 788 |
+
long = torch.zeros(b, self.d_state*2, dtype=dtype, device=device)
|
| 789 |
+
outs = []
|
| 790 |
+
|
| 791 |
+
for t in range(s):
|
| 792 |
+
short = self.decays[0]*short + self.short_proj(ret_out[:,t])
|
| 793 |
+
if t % 8 == 0: med = self.decays[1]*med + self.medium_proj(short)
|
| 794 |
+
if t % 64 == 0: long = self.decays[2]*long + self.long_proj(med)
|
| 795 |
+
outs.append(self.fusion(torch.cat([short, med, long], -1)))
|
| 796 |
+
|
| 797 |
+
return (self.norm(torch.stack(outs, 1)), None)
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
|
| 801 |
+
layers = getattr(model, 'model', model)
|
| 802 |
+
layers = getattr(layers, 'layers', getattr(layers, 'h', getattr(layers, 'layers', None)))
|
| 803 |
+
if layers is None: return model, 0, 0
|
| 804 |
|
| 805 |
+
cnt = 0
|
| 806 |
+
for i, layer in enumerate(layers):
|
| 807 |
if hasattr(layer, 'self_attn'):
|
| 808 |
+
layer.self_attn = HierarchicalRetention(model.config, i) if use_hierarchical else MultiScaleRetention(model.config, i)
|
| 809 |
+
cnt += 1
|
| 810 |
+
return model, cnt, len(layers)
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
# ✅ CRITICAL: PhoenixPreTrainedModel 베이스 클래스
|
| 814 |
+
class PhoenixPreTrainedModel(PreTrainedModel):
|
| 815 |
+
config_class = PhoenixConfig
|
| 816 |
+
base_model_prefix = "phoenix"
|
| 817 |
+
supports_gradient_checkpointing = True
|
| 818 |
+
_no_split_modules = ["MultiScaleRetention", "HierarchicalRetention"]
|
| 819 |
|
| 820 |
+
def _init_weights(self, m):
|
| 821 |
+
std = getattr(self.config, 'initializer_range', 0.02)
|
| 822 |
+
if isinstance(m, nn.Linear):
|
| 823 |
+
m.weight.data.normal_(0, std)
|
| 824 |
+
if m.bias is not None: m.bias.data.zero_()
|
| 825 |
+
elif isinstance(m, nn.Embedding):
|
| 826 |
+
m.weight.data.normal_(0, std)
|
| 827 |
+
if m.padding_idx: m.weight.data[m.padding_idx].zero_()
|
| 828 |
|
| 829 |
|
| 830 |
class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 831 |
def __init__(self, config):
|
| 832 |
super().__init__(config)
|
| 833 |
+
self._model = None
|
| 834 |
+
self._ready = False
|
| 835 |
+
|
|
|
|
| 836 |
@classmethod
|
| 837 |
+
def from_pretrained(cls, path, *args, **kwargs):
|
| 838 |
+
print(f"🔥 PHOENIX v1.4.3 loading from {path}")
|
| 839 |
+
config = AutoConfig.from_pretrained(path, trust_remote_code=True)
|
| 840 |
+
orig = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B')
|
| 841 |
+
hier = getattr(config, 'use_hierarchical', True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 842 |
|
| 843 |
try:
|
| 844 |
+
base_cfg = AutoConfig.from_pretrained(orig, trust_remote_code=True)
|
| 845 |
except:
|
| 846 |
+
base_cfg = config
|
| 847 |
+
|
| 848 |
+
model = AutoModelForCausalLM.from_config(base_cfg)
|
| 849 |
+
model, conv, tot = replace_attention_with_retention_for_loading(model, hier)
|
| 850 |
+
print(f" ✅ Converted {conv}/{tot} layers")
|
| 851 |
+
|
| 852 |
+
# 가중치 로드
|
| 853 |
+
sd = None
|
| 854 |
+
if os.path.exists(path):
|
| 855 |
+
for fname in ["model.safetensors", "pytorch_model.bin"]:
|
| 856 |
+
fpath = os.path.join(path, fname)
|
| 857 |
+
if os.path.exists(fpath):
|
| 858 |
+
if fname.endswith('.safetensors'):
|
| 859 |
+
from safetensors.torch import load_file
|
| 860 |
+
sd = load_file(fpath)
|
| 861 |
+
else:
|
| 862 |
+
sd = torch.load(fpath, map_location='cpu')
|
| 863 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
else:
|
| 865 |
+
from huggingface_hub import hf_hub_download
|
| 866 |
+
for fname in ["model.safetensors", "pytorch_model.bin"]:
|
|
|
|
| 867 |
try:
|
| 868 |
+
fpath = hf_hub_download(path, fname)
|
| 869 |
+
if fname.endswith('.safetensors'):
|
| 870 |
+
from safetensors.torch import load_file
|
| 871 |
+
sd = load_file(fpath)
|
| 872 |
+
else:
|
| 873 |
+
sd = torch.load(fpath, map_location='cpu')
|
| 874 |
+
break
|
| 875 |
+
except: pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 876 |
|
| 877 |
+
if sd:
|
| 878 |
+
miss, unex = model.load_state_dict(sd, strict=False)
|
| 879 |
+
print(f" 📦 Weights: {len(miss)} missing, {len(unex)} unexpected")
|
| 880 |
+
|
| 881 |
+
if 'lm_head.weight' in miss and getattr(config, 'tie_word_embeddings', False):
|
| 882 |
+
if hasattr(model, 'lm_head') and hasattr(model.model, 'embed_tokens'):
|
| 883 |
+
model.lm_head.weight = model.model.embed_tokens.weight
|
| 884 |
+
print(f" 🔗 Tied embeddings")
|
| 885 |
|
| 886 |
+
inst = cls(config)
|
| 887 |
+
inst._model = model
|
| 888 |
+
inst._ready = True
|
| 889 |
+
print(f"✅ PHOENIX v1.4.3 ready!")
|
| 890 |
+
return inst
|
| 891 |
|
| 892 |
+
def forward(self, *a, **k):
|
| 893 |
+
if not self._ready: raise ValueError("Not initialized")
|
| 894 |
+
return self._model(*a, **k)
|
|
|
|
| 895 |
|
| 896 |
+
def generate(self, *a, **k):
|
| 897 |
+
if not self._ready: raise ValueError("Not initialized")
|
| 898 |
+
return self._model.generate(*a, **k)
|
|
|
|
| 899 |
|
| 900 |
|
| 901 |
AutoConfig.register("phoenix", PhoenixConfig)
|