|
|
""" |
|
|
๐ฎ PHOENIX Retention Research Platform - PRODUCTION VERSION v1.4.1 |
|
|
State Dict Direct Loading + Structure-Aware Burning + HuggingFace Hub |
|
|
|
|
|
โ
State Dict Direct Loading |
|
|
โ
Model Structure Pre-Analysis |
|
|
โ
Qwen3 Model Support |
|
|
โ
Zero-shot Conversion (No Dataset Required) |
|
|
โ
Optional Fine-tuning (Dataset-based) |
|
|
โ
GQA Support |
|
|
โ
HuggingFace Hub Integration with Custom Code |
|
|
โ
Comprehensive Evaluation |
|
|
โ
Pre-upload Verification |
|
|
โ
FIX: modeling_phoenix.py head_dim calculation |
|
|
|
|
|
VIDraft AI Research Lab |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import sqlite3 |
|
|
import json |
|
|
import time |
|
|
import numpy as np |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
import plotly.graph_objects as go |
|
|
import plotly.express as px |
|
|
import pandas as pd |
|
|
from typing import Dict, List, Any, Tuple, Optional |
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
from transformers import ( |
|
|
AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM, |
|
|
get_cosine_schedule_with_warmup, TrainingArguments, Trainer |
|
|
) |
|
|
from datasets import load_dataset |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from accelerate import Accelerator |
|
|
from tqdm import tqdm |
|
|
import copy |
|
|
import shutil |
|
|
import os |
|
|
from huggingface_hub import HfApi, create_repo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
STORAGE_PATH = "/data" |
|
|
DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db" |
|
|
VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store" |
|
|
MODELS_PATH = f"{STORAGE_PATH}/phoenix_models" |
|
|
DEFAULT_MODEL = "Qwen/Qwen3-0.6B" |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True) |
|
|
Path(VECTOR_DB_PATH).mkdir(parents=True, exist_ok=True) |
|
|
Path(MODELS_PATH).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"๐ PHOENIX Platform v1.4.1 initialized on {DEVICE}") |
|
|
print(f"๐พ Storage: {STORAGE_PATH}") |
|
|
print(f"๐ฏ Default Base Model: {DEFAULT_MODEL}") |
|
|
if HF_TOKEN: |
|
|
print(f"๐ HuggingFace Token: {'*' * 10}{HF_TOKEN[-4:]}") |
|
|
else: |
|
|
print(f"โ ๏ธ HuggingFace Token not found (upload disabled)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def analyze_model_structure(model_url: str) -> Dict[str, Any]: |
|
|
""" |
|
|
๐ ๋ชจ๋ธ ๊ตฌ์กฐ ์ฌ์ ๋ถ์ |
|
|
๋ณํ ์ ๋ชจ๋ธ์ ๋ ์ด์ด ๊ตฌ์กฐ๋ฅผ ํ์
ํฉ๋๋ค. |
|
|
""" |
|
|
print("\n" + "="*80) |
|
|
print("๐ MODEL STRUCTURE ANALYSIS") |
|
|
print("="*80) |
|
|
|
|
|
try: |
|
|
print(f"\n๐ฅ Loading model config: {model_url}") |
|
|
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True) |
|
|
|
|
|
print(f"โ
Config loaded") |
|
|
print(f" Architecture: {config.architectures if hasattr(config, 'architectures') else 'Unknown'}") |
|
|
print(f" Model Type: {config.model_type if hasattr(config, 'model_type') else 'Unknown'}") |
|
|
|
|
|
|
|
|
print(f"\n๐ฆ Loading model structure...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_url, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="cpu" |
|
|
) |
|
|
|
|
|
analysis = { |
|
|
'model_url': model_url, |
|
|
'model_type': config.model_type if hasattr(config, 'model_type') else 'unknown', |
|
|
'architectures': config.architectures[0] if hasattr(config, 'architectures') else 'unknown', |
|
|
'hidden_size': config.hidden_size if hasattr(config, 'hidden_size') else 0, |
|
|
'num_attention_heads': config.num_attention_heads if hasattr(config, 'num_attention_heads') else 0, |
|
|
'num_hidden_layers': config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else 0, |
|
|
'num_key_value_heads': config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else None, |
|
|
'layer_structure': None, |
|
|
'attention_type': 'unknown', |
|
|
'total_layers': 0, |
|
|
'has_self_attn': False, |
|
|
'layer_path': None, |
|
|
} |
|
|
|
|
|
|
|
|
print(f"\n๐ Analyzing layer structure...") |
|
|
|
|
|
layers = None |
|
|
layer_path = None |
|
|
|
|
|
|
|
|
possible_paths = [ |
|
|
('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None), |
|
|
('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None), |
|
|
('layers', lambda m: m.layers if hasattr(m, 'layers') else None), |
|
|
('model.decoder.layers', lambda m: m.model.decoder.layers if hasattr(m, 'model') and hasattr(m.model, 'decoder') and hasattr(m.model.decoder, 'layers') else None), |
|
|
] |
|
|
|
|
|
for path_name, path_fn in possible_paths: |
|
|
result = path_fn(model) |
|
|
if result is not None: |
|
|
layers = result |
|
|
layer_path = path_name |
|
|
print(f" โ
Found layers at: {path_name}") |
|
|
break |
|
|
|
|
|
if layers is None: |
|
|
print(f" โ No layers found! Model structure unknown.") |
|
|
analysis['error'] = 'No layers found' |
|
|
return analysis |
|
|
|
|
|
analysis['total_layers'] = len(layers) |
|
|
analysis['layer_path'] = layer_path |
|
|
|
|
|
print(f" Total Layers: {len(layers)}") |
|
|
|
|
|
|
|
|
if len(layers) > 0: |
|
|
first_layer = layers[0] |
|
|
print(f"\n๐ฌ Analyzing first layer...") |
|
|
|
|
|
|
|
|
if hasattr(first_layer, 'self_attn'): |
|
|
analysis['has_self_attn'] = True |
|
|
attn = first_layer.self_attn |
|
|
|
|
|
print(f" โ
Has self_attn") |
|
|
print(f" Attention class: {attn.__class__.__name__}") |
|
|
|
|
|
analysis['attention_type'] = attn.__class__.__name__ |
|
|
|
|
|
|
|
|
if hasattr(attn, 'q_proj'): |
|
|
q_shape = attn.q_proj.weight.shape |
|
|
k_shape = attn.k_proj.weight.shape |
|
|
v_shape = attn.v_proj.weight.shape |
|
|
|
|
|
print(f" Q projection: {q_shape}") |
|
|
print(f" K projection: {k_shape}") |
|
|
print(f" V projection: {v_shape}") |
|
|
|
|
|
|
|
|
if hasattr(config, 'num_attention_heads') and config.num_attention_heads > 0: |
|
|
head_dim = q_shape[0] // config.num_attention_heads |
|
|
analysis['head_dim'] = head_dim |
|
|
print(f" Calculated head_dim: {head_dim}") |
|
|
|
|
|
|
|
|
if k_shape[0] != q_shape[0]: |
|
|
print(f" โ
GQA detected! (K/V heads < Q heads)") |
|
|
analysis['gqa_detected'] = True |
|
|
|
|
|
|
|
|
if hasattr(config, 'num_key_value_heads') and config.num_key_value_heads > 0: |
|
|
kv_head_dim = k_shape[0] // config.num_key_value_heads |
|
|
analysis['kv_head_dim'] = kv_head_dim |
|
|
print(f" Calculated kv_head_dim: {kv_head_dim}") |
|
|
else: |
|
|
print(f" Standard MHA (K/V heads == Q heads)") |
|
|
analysis['gqa_detected'] = False |
|
|
|
|
|
analysis['q_dim'] = q_shape[0] |
|
|
analysis['k_dim'] = k_shape[0] |
|
|
analysis['v_dim'] = v_shape[0] |
|
|
analysis['o_in_dim'] = attn.o_proj.weight.shape[1] if hasattr(attn, 'o_proj') else None |
|
|
|
|
|
else: |
|
|
print(f" โ ๏ธ No self_attn found in layer") |
|
|
analysis['has_self_attn'] = False |
|
|
|
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"๐ STRUCTURE ANALYSIS COMPLETE") |
|
|
print(f"{'='*80}") |
|
|
print(f"Model Type: {analysis['model_type']}") |
|
|
print(f"Architecture: {analysis['architectures']}") |
|
|
print(f"Total Layers: {analysis['total_layers']}") |
|
|
print(f"Layer Path: {analysis['layer_path']}") |
|
|
print(f"Has self_attn: {analysis['has_self_attn']}") |
|
|
print(f"Attention Type: {analysis['attention_type']}") |
|
|
|
|
|
if analysis.get('gqa_detected'): |
|
|
print(f"โ
GQA Support: YES") |
|
|
print(f" Q dim: {analysis.get('q_dim')}") |
|
|
print(f" K dim: {analysis.get('k_dim')}") |
|
|
else: |
|
|
print(f"Standard MHA") |
|
|
|
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
|
|
|
del model |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return analysis |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = traceback.format_exc() |
|
|
print(f"\nโ Structure analysis failed:") |
|
|
print(error_msg) |
|
|
|
|
|
return { |
|
|
'model_url': model_url, |
|
|
'error': str(e), |
|
|
'traceback': error_msg, |
|
|
'total_layers': 0, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiScaleRetention(nn.Module): |
|
|
"""์ง์ง Retention Attention with GQA Support""" |
|
|
|
|
|
def __init__(self, config, layer_idx=0): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
|
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
|
|
|
|
|
|
if hasattr(config, 'head_dim'): |
|
|
self.head_dim = config.head_dim |
|
|
else: |
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
|
|
|
|
|
|
|
if hasattr(config, 'num_key_value_heads'): |
|
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
else: |
|
|
self.num_key_value_heads = self.num_heads |
|
|
|
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
self.kv_head_dim = self.head_dim |
|
|
|
|
|
|
|
|
self.q_dim = self.num_heads * self.head_dim |
|
|
self.kv_dim = self.num_key_value_heads * self.kv_head_dim |
|
|
|
|
|
|
|
|
self.register_buffer('_internal_state', None, persistent=False) |
|
|
self.register_buffer('_state_initialized', torch.tensor(False), persistent=False) |
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False) |
|
|
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) |
|
|
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) |
|
|
self.o_proj = nn.Linear(self.q_dim, self.hidden_size, bias=False) |
|
|
|
|
|
|
|
|
decay_values = torch.linspace(0.95, 0.99, self.num_heads) |
|
|
self.decay = nn.Parameter(decay_values, requires_grad=True) |
|
|
|
|
|
|
|
|
self.group_norm = nn.GroupNorm( |
|
|
num_groups=self.num_heads, |
|
|
num_channels=self.q_dim |
|
|
) |
|
|
|
|
|
def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
"""Repeat K/V heads to match Q heads (GQA)""" |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
|
|
|
hidden_states = hidden_states[:, :, None, :, :].expand( |
|
|
batch, num_key_value_heads, n_rep, slen, head_dim |
|
|
) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
def reset_state(self): |
|
|
"""Reset internal state""" |
|
|
self._internal_state = None |
|
|
self._state_initialized = torch.tensor(False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Tuple[torch.Tensor]] = None, |
|
|
**kwargs |
|
|
): |
|
|
"""O(n) Retention with GQA support""" |
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
|
|
|
if past_key_values is not None: |
|
|
past_key_value = past_key_values |
|
|
|
|
|
|
|
|
target_device = hidden_states.device |
|
|
target_dtype = hidden_states.dtype |
|
|
|
|
|
if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype: |
|
|
self.q_proj = self.q_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.k_proj = self.k_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.v_proj = self.v_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype) |
|
|
|
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
|
|
|
query_states = query_states.view( |
|
|
batch_size, seq_len, self.num_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
|
|
|
key_states = key_states.view( |
|
|
batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim |
|
|
).transpose(1, 2) |
|
|
|
|
|
value_states = value_states.view( |
|
|
batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim |
|
|
).transpose(1, 2) |
|
|
|
|
|
|
|
|
key_states = self._repeat_kv(key_states, self.num_key_value_groups) |
|
|
value_states = self._repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
|
|
|
|
|
past_state = self._internal_state if (use_cache and self._state_initialized) else None |
|
|
retention_states, new_state = self._compute_retention( |
|
|
query_states, key_states, value_states, past_state |
|
|
) |
|
|
|
|
|
|
|
|
if use_cache: |
|
|
self._internal_state = new_state.detach() |
|
|
self._state_initialized = torch.tensor(True) |
|
|
|
|
|
|
|
|
retention_states = retention_states.transpose(1, 2).contiguous() |
|
|
retention_states = retention_states.reshape( |
|
|
batch_size, seq_len, self.q_dim |
|
|
) |
|
|
|
|
|
|
|
|
if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda: |
|
|
self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype) |
|
|
elif next(self.group_norm.parameters()).dtype != retention_states.dtype: |
|
|
self.group_norm = self.group_norm.to(dtype=retention_states.dtype) |
|
|
|
|
|
retention_states = self.group_norm( |
|
|
retention_states.transpose(1, 2) |
|
|
).transpose(1, 2) |
|
|
|
|
|
retention_states = torch.clamp(retention_states, min=-10.0, max=10.0) |
|
|
|
|
|
|
|
|
attn_output = self.o_proj(retention_states) |
|
|
|
|
|
return (attn_output, None) |
|
|
|
|
|
def _compute_retention( |
|
|
self, |
|
|
queries: torch.Tensor, |
|
|
keys: torch.Tensor, |
|
|
values: torch.Tensor, |
|
|
past_state: Optional[torch.Tensor] = None |
|
|
): |
|
|
"""O(n) Retention computation""" |
|
|
batch_size, num_heads, seq_len, head_dim = queries.shape |
|
|
|
|
|
if past_state is not None: |
|
|
state = past_state.to(queries.device, dtype=queries.dtype) |
|
|
else: |
|
|
state = torch.zeros( |
|
|
batch_size, num_heads, head_dim, head_dim, |
|
|
dtype=queries.dtype, |
|
|
device=queries.device |
|
|
) + 1e-6 |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to( |
|
|
device=queries.device, |
|
|
dtype=queries.dtype |
|
|
) |
|
|
|
|
|
for t in range(seq_len): |
|
|
q_t = queries[:, :, t, :] |
|
|
k_t = keys[:, :, t, :] |
|
|
v_t = values[:, :, t, :] |
|
|
|
|
|
state = decay * state |
|
|
kv_update = torch.einsum('bhd,bhe->bhde', k_t, v_t) |
|
|
kv_update = torch.clamp(kv_update, min=-5.0, max=5.0) |
|
|
state = state + kv_update |
|
|
state = torch.clamp(state, min=-10.0, max=10.0) |
|
|
|
|
|
output_t = torch.einsum('bhd,bhde->bhe', q_t, state) |
|
|
outputs.append(output_t) |
|
|
|
|
|
output = torch.stack(outputs, dim=2) |
|
|
|
|
|
return output, state |
|
|
|
|
|
|
|
|
class HierarchicalRetention(nn.Module): |
|
|
"""PHOENIX Hierarchical Retention with GQA""" |
|
|
|
|
|
def __init__(self, config, layer_idx=0): |
|
|
super().__init__() |
|
|
self.base_retention = MultiScaleRetention(config, layer_idx) |
|
|
|
|
|
hidden_size = config.hidden_size |
|
|
self.d_state = hidden_size // 2 |
|
|
|
|
|
self.short_proj = nn.Linear(hidden_size, self.d_state) |
|
|
self.medium_proj = nn.Linear(self.d_state, self.d_state) |
|
|
self.long_proj = nn.Linear(self.d_state, self.d_state * 2) |
|
|
self.fusion = nn.Linear(self.d_state * 4, hidden_size) |
|
|
|
|
|
self.short_decay = 0.5 |
|
|
self.medium_decay = 0.8 |
|
|
self.long_decay = 0.95 |
|
|
|
|
|
self.norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
if next(self.base_retention.parameters()).is_cuda: |
|
|
device = next(self.base_retention.parameters()).device |
|
|
dtype = next(self.base_retention.parameters()).dtype |
|
|
self.short_proj = self.short_proj.to(device, dtype=dtype) |
|
|
self.medium_proj = self.medium_proj.to(device, dtype=dtype) |
|
|
self.long_proj = self.long_proj.to(device, dtype=dtype) |
|
|
self.fusion = self.fusion.to(device, dtype=dtype) |
|
|
self.norm = self.norm.to(device, dtype=dtype) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Tuple[torch.Tensor]] = None, |
|
|
**kwargs |
|
|
): |
|
|
"""Hierarchical forward pass""" |
|
|
batch_size, seq_len, hidden_size = hidden_states.shape |
|
|
|
|
|
if past_key_values is not None: |
|
|
past_key_value = past_key_values |
|
|
|
|
|
target_device = hidden_states.device |
|
|
target_dtype = hidden_states.dtype |
|
|
|
|
|
|
|
|
current_device = next(self.short_proj.parameters()).device |
|
|
current_dtype = next(self.short_proj.parameters()).dtype |
|
|
|
|
|
if current_device != target_device or current_dtype != target_dtype: |
|
|
self.short_proj = self.short_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.medium_proj = self.medium_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.long_proj = self.long_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.fusion = self.fusion.to(device=target_device, dtype=target_dtype) |
|
|
self.norm = self.norm.to(device=target_device, dtype=target_dtype) |
|
|
|
|
|
base_result = self.base_retention( |
|
|
hidden_states, attention_mask, position_ids, |
|
|
past_key_value, output_attentions, use_cache |
|
|
) |
|
|
|
|
|
retention_output = base_result[0] |
|
|
|
|
|
|
|
|
short_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device) |
|
|
medium_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device) |
|
|
long_state = torch.zeros(batch_size, self.d_state * 2, dtype=target_dtype, device=target_device) |
|
|
|
|
|
hierarchical_outputs = [] |
|
|
|
|
|
for t in range(seq_len): |
|
|
x_t = retention_output[:, t, :] |
|
|
|
|
|
short_input = self.short_proj(x_t) |
|
|
short_state = self.short_decay * short_state + short_input |
|
|
|
|
|
if t % 8 == 0: |
|
|
medium_state = self.medium_decay * medium_state + \ |
|
|
self.medium_proj(short_state) |
|
|
|
|
|
if t % 64 == 0: |
|
|
long_state = self.long_decay * long_state + \ |
|
|
self.long_proj(medium_state) |
|
|
|
|
|
combined = torch.cat([short_state, medium_state, long_state], dim=-1) |
|
|
output_t = self.fusion(combined) |
|
|
hierarchical_outputs.append(output_t) |
|
|
|
|
|
output = torch.stack(hierarchical_outputs, dim=1) |
|
|
output = self.norm(output) |
|
|
|
|
|
return (output, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def replace_attention_with_retention(model, use_hierarchical=True, structure_info=None): |
|
|
""" |
|
|
Transformer Attention โ PHOENIX Retention (GQA Support) |
|
|
structure_info๋ฅผ ํ์ฉํ์ฌ ๋ ์ ํํ ๋ณํ ์ํ |
|
|
""" |
|
|
print("๐ Starting Attention โ Retention conversion (GQA support)...") |
|
|
|
|
|
replaced_count = 0 |
|
|
total_layers = 0 |
|
|
|
|
|
|
|
|
layers = None |
|
|
layer_path = None |
|
|
|
|
|
|
|
|
if structure_info and structure_info.get('layer_path'): |
|
|
layer_path = structure_info['layer_path'] |
|
|
print(f" Using structure info: {layer_path}") |
|
|
|
|
|
if layer_path == 'model.layers': |
|
|
if hasattr(model, 'model') and hasattr(model.model, 'layers'): |
|
|
layers = model.model.layers |
|
|
elif layer_path == 'transformer.h': |
|
|
if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'): |
|
|
layers = model.transformer.h |
|
|
elif layer_path == 'layers': |
|
|
if hasattr(model, 'layers'): |
|
|
layers = model.layers |
|
|
elif layer_path == 'model.decoder.layers': |
|
|
if hasattr(model, 'model') and hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers'): |
|
|
layers = model.model.decoder.layers |
|
|
|
|
|
|
|
|
if layers is None: |
|
|
print(f" Auto-detecting layer structure...") |
|
|
|
|
|
possible_paths = [ |
|
|
('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None), |
|
|
('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None), |
|
|
('layers', lambda m: m.layers if hasattr(m, 'layers') else None), |
|
|
('model.decoder.layers', lambda m: m.model.decoder.layers if hasattr(m, 'model') and hasattr(m.model, 'decoder') and hasattr(m.model.decoder, 'layers') else None), |
|
|
] |
|
|
|
|
|
for path_name, path_fn in possible_paths: |
|
|
result = path_fn(model) |
|
|
if result is not None: |
|
|
layers = result |
|
|
layer_path = path_name |
|
|
print(f" โ
Found layers at: {path_name}") |
|
|
break |
|
|
|
|
|
if layers is None: |
|
|
print("โ Cannot find layers - model structure not supported") |
|
|
print(f" Model type: {type(model)}") |
|
|
print(f" Has 'model' attr: {hasattr(model, 'model')}") |
|
|
print(f" Has 'transformer' attr: {hasattr(model, 'transformer')}") |
|
|
print(f" Has 'layers' attr: {hasattr(model, 'layers')}") |
|
|
return model, 0, 0 |
|
|
|
|
|
total_layers = len(layers) |
|
|
print(f" Found {total_layers} layers at '{layer_path}'") |
|
|
|
|
|
|
|
|
if structure_info and structure_info.get('gqa_detected'): |
|
|
print(f" โ
GQA detected from structure info") |
|
|
if not hasattr(model.config, 'num_key_value_heads'): |
|
|
num_kv_heads = structure_info.get('k_dim', 0) // (model.config.hidden_size // model.config.num_attention_heads) |
|
|
if num_kv_heads > 0: |
|
|
model.config.num_key_value_heads = num_kv_heads |
|
|
print(f" Set num_key_value_heads = {num_kv_heads}") |
|
|
|
|
|
|
|
|
if structure_info and structure_info.get('head_dim'): |
|
|
model.config.head_dim = structure_info['head_dim'] |
|
|
print(f" โ
Set head_dim = {structure_info['head_dim']} from structure info") |
|
|
elif not hasattr(model.config, 'head_dim'): |
|
|
|
|
|
first_layer = layers[0] |
|
|
if hasattr(first_layer, 'self_attn'): |
|
|
old_attn = first_layer.self_attn |
|
|
|
|
|
if hasattr(old_attn, 'q_proj'): |
|
|
q_shape = old_attn.q_proj.weight.shape |
|
|
k_shape = old_attn.k_proj.weight.shape |
|
|
|
|
|
|
|
|
head_dim = q_shape[0] // model.config.num_attention_heads |
|
|
model.config.head_dim = head_dim |
|
|
print(f" โ
Calculated head_dim = {head_dim} from layer weights") |
|
|
|
|
|
if k_shape[0] != q_shape[0]: |
|
|
print(f" โ
GQA detected! (K/V dim: {k_shape[0]} < Q dim: {q_shape[0]})") |
|
|
if not hasattr(model.config, 'num_key_value_heads'): |
|
|
num_kv_heads = k_shape[0] // head_dim |
|
|
model.config.num_key_value_heads = num_kv_heads |
|
|
print(f" Set num_key_value_heads = {num_kv_heads}") |
|
|
|
|
|
|
|
|
for layer_idx, layer in enumerate(layers): |
|
|
try: |
|
|
if hasattr(layer, 'self_attn'): |
|
|
old_attn = layer.self_attn |
|
|
|
|
|
if use_hierarchical: |
|
|
new_retention = HierarchicalRetention(model.config, layer_idx) |
|
|
else: |
|
|
new_retention = MultiScaleRetention(model.config, layer_idx) |
|
|
|
|
|
|
|
|
if hasattr(old_attn, 'q_proj'): |
|
|
try: |
|
|
if use_hierarchical: |
|
|
target = new_retention.base_retention |
|
|
else: |
|
|
target = new_retention |
|
|
|
|
|
q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape |
|
|
k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape |
|
|
v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape |
|
|
o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape |
|
|
|
|
|
if layer_idx == 0: |
|
|
print(f" ๐ Layer 0 shape analysis:") |
|
|
print(f" Old Q: {old_attn.q_proj.weight.shape} vs New Q: {target.q_proj.weight.shape} โ {'โ
' if q_match else 'โ'}") |
|
|
print(f" Old K: {old_attn.k_proj.weight.shape} vs New K: {target.k_proj.weight.shape} โ {'โ
' if k_match else 'โ'}") |
|
|
print(f" Old V: {old_attn.v_proj.weight.shape} vs New V: {target.v_proj.weight.shape} โ {'โ
' if v_match else 'โ'}") |
|
|
print(f" Old O: {old_attn.o_proj.weight.shape} vs New O: {target.o_proj.weight.shape} โ {'โ
' if o_match else 'โ'}") |
|
|
|
|
|
if q_match and k_match and v_match and o_match: |
|
|
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() |
|
|
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone() |
|
|
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone() |
|
|
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() |
|
|
if layer_idx == 0: |
|
|
print(f" โ
Layer {layer_idx}: Perfect match - weights copied") |
|
|
|
|
|
elif q_match and o_match: |
|
|
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() |
|
|
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() |
|
|
|
|
|
k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0]) |
|
|
v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0]) |
|
|
|
|
|
target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone() |
|
|
target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone() |
|
|
|
|
|
if layer_idx == 0: |
|
|
print(f" โ
Layer {layer_idx}: Partial match (GQA) - partial weights copied") |
|
|
|
|
|
else: |
|
|
nn.init.xavier_uniform_(target.q_proj.weight) |
|
|
nn.init.xavier_uniform_(target.k_proj.weight) |
|
|
nn.init.xavier_uniform_(target.v_proj.weight) |
|
|
nn.init.xavier_uniform_(target.o_proj.weight) |
|
|
if layer_idx == 0: |
|
|
print(f" โ ๏ธ Layer {layer_idx}: Shape mismatch - Xavier init used") |
|
|
print(f" This will result in random weights!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" โ ๏ธ Layer {layer_idx}: Weight copy failed - {e}") |
|
|
|
|
|
layer.self_attn = new_retention |
|
|
replaced_count += 1 |
|
|
|
|
|
except Exception as e: |
|
|
print(f" โ Layer {layer_idx}: Failed - {e}") |
|
|
continue |
|
|
|
|
|
print(f"\nโ
Conversion complete: {replaced_count}/{total_layers} layers") |
|
|
|
|
|
return model, replaced_count, total_layers |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_modeling_phoenix_code(): |
|
|
""" |
|
|
PHOENIX Custom Modeling Code ์์ฑ v1.4.1 |
|
|
โ
FIX: head_dim ๊ณ์ฐ ์ config ์ฐ์ ์ฌ์ฉ |
|
|
""" |
|
|
|
|
|
modeling_code = '''""" |
|
|
PHOENIX Retention Model - Custom Implementation v1.4.1 |
|
|
Auto-loaded by HuggingFace transformers with trust_remote_code=True |
|
|
|
|
|
โ
FIX: State Dict ์ง์ ๋ก๋๋ก Retention ๊ฐ์ค์น ๋ณด์กด |
|
|
โ
FIX: head_dim ๊ณ์ฐ ์ config ์ฐ์ ์ฌ์ฉ |
|
|
|
|
|
VIDraft AI Research Lab |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, Tuple, Union |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
import os |
|
|
|
|
|
|
|
|
class PhoenixConfig(PretrainedConfig): |
|
|
"""PHOENIX Model Configuration""" |
|
|
model_type = "phoenix" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
use_phoenix_retention=True, |
|
|
phoenix_version="1.4.1", |
|
|
original_architecture=None, |
|
|
original_model=None, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.use_phoenix_retention = use_phoenix_retention |
|
|
self.phoenix_version = phoenix_version |
|
|
self.original_architecture = original_architecture |
|
|
self.original_model = original_model |
|
|
|
|
|
|
|
|
class MultiScaleRetention(nn.Module): |
|
|
"""PHOENIX Multi-Scale Retention with GQA Support""" |
|
|
|
|
|
def __init__(self, config, layer_idx=0): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
|
|
|
# โ
FIX v1.4.1: head_dim์ config์์ ์ฐ์ ๊ฐ์ ธ์ค๊ธฐ |
|
|
if hasattr(config, 'head_dim'): |
|
|
self.head_dim = config.head_dim |
|
|
else: |
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
|
|
|
|
if hasattr(config, 'num_key_value_heads'): |
|
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
else: |
|
|
self.num_key_value_heads = self.num_heads |
|
|
|
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
self.kv_head_dim = self.head_dim |
|
|
|
|
|
# โ
์ค์ dimension ๊ณ์ฐ |
|
|
self.q_dim = self.num_heads * self.head_dim |
|
|
self.kv_dim = self.num_key_value_heads * self.kv_head_dim |
|
|
|
|
|
self.register_buffer('_internal_state', None, persistent=False) |
|
|
self.register_buffer('_state_initialized', torch.tensor(False), persistent=False) |
|
|
|
|
|
# โ
์ฌ๋ฐ๋ฅธ dimension์ผ๋ก Projection |
|
|
self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False) |
|
|
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) |
|
|
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) |
|
|
self.o_proj = nn.Linear(self.q_dim, self.hidden_size, bias=False) |
|
|
|
|
|
decay_values = torch.linspace(0.95, 0.99, self.num_heads) |
|
|
self.decay = nn.Parameter(decay_values, requires_grad=True) |
|
|
|
|
|
self.group_norm = nn.GroupNorm( |
|
|
num_groups=self.num_heads, |
|
|
num_channels=self.q_dim |
|
|
) |
|
|
|
|
|
def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand( |
|
|
batch, num_key_value_heads, n_rep, slen, head_dim |
|
|
) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
def reset_state(self): |
|
|
self._internal_state = None |
|
|
self._state_initialized = torch.tensor(False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Tuple[torch.Tensor]] = None, |
|
|
**kwargs |
|
|
): |
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
|
|
|
if past_key_values is not None: |
|
|
past_key_value = past_key_values |
|
|
|
|
|
target_device = hidden_states.device |
|
|
target_dtype = hidden_states.dtype |
|
|
|
|
|
if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype: |
|
|
self.q_proj = self.q_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.k_proj = self.k_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.v_proj = self.v_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype) |
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
query_states = query_states.view( |
|
|
batch_size, seq_len, self.num_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
|
|
|
key_states = key_states.view( |
|
|
batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim |
|
|
).transpose(1, 2) |
|
|
|
|
|
value_states = value_states.view( |
|
|
batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim |
|
|
).transpose(1, 2) |
|
|
|
|
|
key_states = self._repeat_kv(key_states, self.num_key_value_groups) |
|
|
value_states = self._repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
|
|
past_state = self._internal_state if (use_cache and self._state_initialized) else None |
|
|
retention_states, new_state = self._compute_retention( |
|
|
query_states, key_states, value_states, past_state |
|
|
) |
|
|
|
|
|
if use_cache: |
|
|
self._internal_state = new_state.detach() |
|
|
self._state_initialized = torch.tensor(True) |
|
|
|
|
|
retention_states = retention_states.transpose(1, 2).contiguous() |
|
|
retention_states = retention_states.reshape(batch_size, seq_len, self.q_dim) |
|
|
|
|
|
if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda: |
|
|
self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype) |
|
|
elif next(self.group_norm.parameters()).dtype != retention_states.dtype: |
|
|
self.group_norm = self.group_norm.to(dtype=retention_states.dtype) |
|
|
|
|
|
retention_states = self.group_norm(retention_states.transpose(1, 2)).transpose(1, 2) |
|
|
retention_states = torch.clamp(retention_states, min=-10.0, max=10.0) |
|
|
|
|
|
attn_output = self.o_proj(retention_states) |
|
|
return (attn_output, None) |
|
|
|
|
|
def _compute_retention( |
|
|
self, |
|
|
queries: torch.Tensor, |
|
|
keys: torch.Tensor, |
|
|
values: torch.Tensor, |
|
|
past_state: Optional[torch.Tensor] = None |
|
|
): |
|
|
batch_size, num_heads, seq_len, head_dim = queries.shape |
|
|
|
|
|
if past_state is not None: |
|
|
state = past_state.to(queries.device, dtype=queries.dtype) |
|
|
else: |
|
|
state = torch.zeros( |
|
|
batch_size, num_heads, head_dim, head_dim, |
|
|
dtype=queries.dtype, device=queries.device |
|
|
) + 1e-6 |
|
|
|
|
|
outputs = [] |
|
|
decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to( |
|
|
device=queries.device, dtype=queries.dtype |
|
|
) |
|
|
|
|
|
for t in range(seq_len): |
|
|
q_t = queries[:, :, t, :] |
|
|
k_t = keys[:, :, t, :] |
|
|
v_t = values[:, :, t, :] |
|
|
|
|
|
state = decay * state |
|
|
kv_update = torch.einsum('bhd,bhe->bhde', k_t, v_t) |
|
|
kv_update = torch.clamp(kv_update, min=-5.0, max=5.0) |
|
|
state = state + kv_update |
|
|
state = torch.clamp(state, min=-10.0, max=10.0) |
|
|
|
|
|
output_t = torch.einsum('bhd,bhde->bhe', q_t, state) |
|
|
outputs.append(output_t) |
|
|
|
|
|
output = torch.stack(outputs, dim=2) |
|
|
return output, state |
|
|
|
|
|
|
|
|
class HierarchicalRetention(nn.Module): |
|
|
"""PHOENIX Hierarchical Retention""" |
|
|
|
|
|
def __init__(self, config, layer_idx=0): |
|
|
super().__init__() |
|
|
self.base_retention = MultiScaleRetention(config, layer_idx) |
|
|
|
|
|
hidden_size = config.hidden_size |
|
|
self.d_state = hidden_size // 2 |
|
|
|
|
|
self.short_proj = nn.Linear(hidden_size, self.d_state) |
|
|
self.medium_proj = nn.Linear(self.d_state, self.d_state) |
|
|
self.long_proj = nn.Linear(self.d_state, self.d_state * 2) |
|
|
self.fusion = nn.Linear(self.d_state * 4, hidden_size) |
|
|
|
|
|
self.short_decay = 0.5 |
|
|
self.medium_decay = 0.8 |
|
|
self.long_decay = 0.95 |
|
|
|
|
|
self.norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Tuple[torch.Tensor]] = None, |
|
|
**kwargs |
|
|
): |
|
|
batch_size, seq_len, hidden_size = hidden_states.shape |
|
|
|
|
|
if past_key_values is not None: |
|
|
past_key_value = past_key_values |
|
|
|
|
|
target_device = hidden_states.device |
|
|
target_dtype = hidden_states.dtype |
|
|
|
|
|
current_device = next(self.short_proj.parameters()).device |
|
|
current_dtype = next(self.short_proj.parameters()).dtype |
|
|
|
|
|
if current_device != target_device or current_dtype != target_dtype: |
|
|
self.short_proj = self.short_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.medium_proj = self.medium_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.long_proj = self.long_proj.to(device=target_device, dtype=target_dtype) |
|
|
self.fusion = self.fusion.to(device=target_device, dtype=target_dtype) |
|
|
self.norm = self.norm.to(device=target_device, dtype=target_dtype) |
|
|
|
|
|
base_result = self.base_retention( |
|
|
hidden_states, attention_mask, position_ids, |
|
|
past_key_value, output_attentions, use_cache |
|
|
) |
|
|
|
|
|
retention_output = base_result[0] |
|
|
|
|
|
short_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device) |
|
|
medium_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device) |
|
|
long_state = torch.zeros(batch_size, self.d_state * 2, dtype=target_dtype, device=target_device) |
|
|
|
|
|
hierarchical_outputs = [] |
|
|
|
|
|
for t in range(seq_len): |
|
|
x_t = retention_output[:, t, :] |
|
|
|
|
|
short_input = self.short_proj(x_t) |
|
|
short_state = self.short_decay * short_state + short_input |
|
|
|
|
|
if t % 8 == 0: |
|
|
medium_state = self.medium_decay * medium_state + self.medium_proj(short_state) |
|
|
|
|
|
if t % 64 == 0: |
|
|
long_state = self.long_decay * long_state + self.long_proj(medium_state) |
|
|
|
|
|
combined = torch.cat([short_state, medium_state, long_state], dim=-1) |
|
|
output_t = self.fusion(combined) |
|
|
hierarchical_outputs.append(output_t) |
|
|
|
|
|
output = torch.stack(hierarchical_outputs, dim=1) |
|
|
output = self.norm(output) |
|
|
|
|
|
return (output, None) |
|
|
|
|
|
|
|
|
def replace_attention_with_retention(model, use_hierarchical=True): |
|
|
"""Attention โ Retention ๋ณํ""" |
|
|
converted_count = 0 |
|
|
total_layers = 0 |
|
|
|
|
|
# ๋ ์ด์ด ์ฐพ๊ธฐ |
|
|
layers = None |
|
|
|
|
|
if hasattr(model, 'model') and hasattr(model.model, 'layers'): |
|
|
layers = model.model.layers |
|
|
elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'): |
|
|
layers = model.transformer.h |
|
|
elif hasattr(model, 'layers'): |
|
|
layers = model.layers |
|
|
else: |
|
|
print("Cannot find layers in model") |
|
|
return model, 0, 0 |
|
|
|
|
|
total_layers = len(layers) |
|
|
config = model.config |
|
|
|
|
|
print(f"Converting {total_layers} layers...") |
|
|
|
|
|
for layer_idx, layer in enumerate(layers): |
|
|
if hasattr(layer, 'self_attn'): |
|
|
old_attn = layer.self_attn |
|
|
|
|
|
if use_hierarchical: |
|
|
new_retention = HierarchicalRetention(config, layer_idx) |
|
|
else: |
|
|
new_retention = MultiScaleRetention(config, layer_idx) |
|
|
|
|
|
if hasattr(old_attn, 'q_proj'): |
|
|
try: |
|
|
target = new_retention.base_retention if use_hierarchical else new_retention |
|
|
|
|
|
# Shape ํ์ธ |
|
|
q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape |
|
|
k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape |
|
|
v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape |
|
|
o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape |
|
|
|
|
|
if layer_idx == 0: |
|
|
print(f"Layer 0 analysis:") |
|
|
print(f" Q: {old_attn.q_proj.weight.shape} vs {target.q_proj.weight.shape} โ {'โ
' if q_match else 'โ'}") |
|
|
print(f" K: {old_attn.k_proj.weight.shape} vs {target.k_proj.weight.shape} โ {'โ
' if k_match else 'โ'}") |
|
|
print(f" V: {old_attn.v_proj.weight.shape} vs {target.v_proj.weight.shape} โ {'โ
' if v_match else 'โ'}") |
|
|
print(f" O: {old_attn.o_proj.weight.shape} vs {target.o_proj.weight.shape} โ {'โ
' if o_match else 'โ'}") |
|
|
|
|
|
# ๊ฐ์ค์น ๋ณต์ฌ |
|
|
if q_match and k_match and v_match and o_match: |
|
|
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() |
|
|
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone() |
|
|
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone() |
|
|
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() |
|
|
if layer_idx == 0: |
|
|
print(f" โ
Perfect match - weights copied") |
|
|
elif q_match and o_match: |
|
|
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() |
|
|
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() |
|
|
k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0]) |
|
|
v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0]) |
|
|
target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone() |
|
|
target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone() |
|
|
if layer_idx == 0: |
|
|
print(f" โ
Partial match (GQA) - partial copy") |
|
|
else: |
|
|
if layer_idx == 0: |
|
|
print(f" โ ๏ธ Shape mismatch - keeping random init") |
|
|
|
|
|
except Exception as e: |
|
|
if layer_idx == 0: |
|
|
print(f"Weight copy error: {e}") |
|
|
|
|
|
layer.self_attn = new_retention |
|
|
converted_count += 1 |
|
|
|
|
|
print(f"Converted {converted_count}/{total_layers} layers to Retention") |
|
|
return model, converted_count, total_layers |
|
|
|
|
|
|
|
|
class PhoenixPreTrainedModel(PreTrainedModel): |
|
|
"""Base PHOENIX PreTrainedModel""" |
|
|
config_class = PhoenixConfig |
|
|
base_model_prefix = "phoenix" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["MultiScaleRetention", "HierarchicalRetention"] |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
|
|
|
class PhoenixModelForCausalLM(PhoenixPreTrainedModel): |
|
|
""" |
|
|
PHOENIX Model for Causal Language Modeling v1.4.1 |
|
|
โ
FIX: State Dict ์ง์ ๋ก๋๋ก Retention ๊ฐ์ค์น ๋ณด์กด |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self._original_model = None |
|
|
self._initialized = False |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
""" |
|
|
๐ฅ PHOENIX ์๋ ๋ก๋ฉ! v1.4.1 |
|
|
State Dict ์ง์ ๋ก๋๋ก Retention ๊ฐ์ค์น ๋ณด์กด |
|
|
""" |
|
|
print(f"๐ฅ Loading PHOENIX model from {pretrained_model_name_or_path}") |
|
|
|
|
|
# 1. PHOENIX Config ๋ก๋ |
|
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) |
|
|
|
|
|
# 2. ์๋ณธ ๋ชจ๋ธ ์ ๋ณด |
|
|
original_model = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B') |
|
|
use_hierarchical = getattr(config, 'use_hierarchical', True) |
|
|
|
|
|
print(f" ๐ Original model: {original_model}") |
|
|
print(f" ๐ Hierarchical: {use_hierarchical}") |
|
|
|
|
|
# 3. ์๋ณธ ์ํคํ
์ฒ๋ก ๋น ๋ชจ๋ธ ์์ฑ |
|
|
try: |
|
|
base_config = AutoConfig.from_pretrained(original_model, trust_remote_code=True) |
|
|
except: |
|
|
# Fallback: config์์ ๋ณต์ |
|
|
base_config = config |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_config(base_config) |
|
|
|
|
|
print(f" โ
Created base structure: {base_config.architectures[0] if hasattr(base_config, 'architectures') else 'Unknown'}") |
|
|
|
|
|
# 4. Retention์ผ๋ก ๋ณํ |
|
|
print(f"๐ Converting to PHOENIX Retention...") |
|
|
base_model, converted, total = replace_attention_with_retention(base_model, use_hierarchical) |
|
|
|
|
|
print(f"โ
Converted {converted}/{total} layers to Retention") |
|
|
|
|
|
if converted == 0: |
|
|
print(f"โ ๏ธ WARNING: No layers converted!") |
|
|
|
|
|
# 5. ๊ฐ์ค์น ๋ก๋ (safetensors ์ฐ์ ) |
|
|
print(f"๐ฅ Loading weights...") |
|
|
|
|
|
state_dict = None |
|
|
|
|
|
# Local path |
|
|
if os.path.exists(pretrained_model_name_or_path): |
|
|
safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors") |
|
|
pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") |
|
|
|
|
|
if os.path.exists(safetensors_path): |
|
|
try: |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(safetensors_path) |
|
|
print(f" โ
Loaded from safetensors") |
|
|
except: |
|
|
pass |
|
|
|
|
|
if state_dict is None and os.path.exists(pytorch_path): |
|
|
state_dict = torch.load(pytorch_path, map_location='cpu') |
|
|
print(f" โ
Loaded from pytorch_model.bin") |
|
|
|
|
|
# Hub path |
|
|
else: |
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Try safetensors first |
|
|
try: |
|
|
safetensors_path = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="model.safetensors" |
|
|
) |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(safetensors_path) |
|
|
print(f" โ
Loaded from Hub (safetensors)") |
|
|
except: |
|
|
# Fallback to pytorch_model.bin |
|
|
pytorch_path = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="pytorch_model.bin" |
|
|
) |
|
|
state_dict = torch.load(pytorch_path, map_location='cpu') |
|
|
print(f" โ
Loaded from Hub (pytorch_model.bin)") |
|
|
except Exception as e: |
|
|
print(f" โ Failed to load weights: {e}") |
|
|
|
|
|
# 6. State Dict ์ ์ฉ (strict=False) |
|
|
if state_dict is not None: |
|
|
try: |
|
|
missing, unexpected = base_model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
print(f" โ
Weights loaded") |
|
|
print(f" Missing keys: {len(missing)}") |
|
|
print(f" Unexpected keys: {len(unexpected)}") |
|
|
|
|
|
# ์์ธ ์ ๋ณด ์ถ๋ ฅ (์ฒ์ 5๊ฐ๋ง) |
|
|
if missing: |
|
|
print(f" Missing (first 5): {missing[:5]}") |
|
|
if unexpected: |
|
|
print(f" Unexpected (first 5): {unexpected[:5]}") |
|
|
|
|
|
# Retention ๊ฐ์ค์น ํ์ธ |
|
|
retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()] |
|
|
if retention_keys: |
|
|
print(f" โ
Found {len(retention_keys)} Retention weight keys") |
|
|
print(f" Sample keys: {retention_keys[:3]}") |
|
|
else: |
|
|
print(f" โ ๏ธ No Retention keys found in state dict") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" โ ๏ธ Weight loading warning: {e}") |
|
|
else: |
|
|
print(f" โ ๏ธ No weights loaded - model will be randomly initialized") |
|
|
|
|
|
# 7. PHOENIX wrapper |
|
|
phoenix_instance = cls(config) |
|
|
phoenix_instance._original_model = base_model |
|
|
phoenix_instance._initialized = True |
|
|
|
|
|
print(f"โ
PHOENIX model ready!") |
|
|
|
|
|
return phoenix_instance |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
if not self._initialized or self._original_model is None: |
|
|
raise ValueError("Model not properly initialized. Use from_pretrained().") |
|
|
return self._original_model(*args, **kwargs) |
|
|
|
|
|
def generate(self, *args, **kwargs): |
|
|
if not self._initialized or self._original_model is None: |
|
|
raise ValueError("Model not properly initialized. Use from_pretrained().") |
|
|
return self._original_model.generate(*args, **kwargs) |
|
|
|
|
|
def prepare_inputs_for_generation(self, *args, **kwargs): |
|
|
if self._original_model is None: |
|
|
raise ValueError("Model not initialized.") |
|
|
if hasattr(self._original_model, 'prepare_inputs_for_generation'): |
|
|
return self._original_model.prepare_inputs_for_generation(*args, **kwargs) |
|
|
return {} |
|
|
|
|
|
|
|
|
# Auto-registration |
|
|
AutoConfig.register("phoenix", PhoenixConfig) |
|
|
''' |
|
|
|
|
|
return modeling_code |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_phoenix_model_with_code(model, tokenizer, output_path, original_model_url, metadata): |
|
|
"""PHOENIX ๋ชจ๋ธ์ Custom Code์ ํจ๊ป ์ ์ฅ""" |
|
|
output_path = Path(output_path) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"\n๐พ Saving PHOENIX model with custom code...") |
|
|
|
|
|
|
|
|
model.save_pretrained(output_path) |
|
|
tokenizer.save_pretrained(output_path) |
|
|
print(f" โ
Model weights saved") |
|
|
|
|
|
|
|
|
modeling_code = generate_modeling_phoenix_code() |
|
|
with open(output_path / "modeling_phoenix.py", "w", encoding='utf-8') as f: |
|
|
f.write(modeling_code) |
|
|
print(f" โ
Custom modeling code saved (modeling_phoenix.py)") |
|
|
|
|
|
|
|
|
config_path = output_path / "config.json" |
|
|
if config_path.exists(): |
|
|
with open(config_path, "r", encoding='utf-8') as f: |
|
|
config_dict = json.load(f) |
|
|
|
|
|
|
|
|
config_dict["use_phoenix_retention"] = True |
|
|
config_dict["phoenix_version"] = "1.4.1" |
|
|
config_dict["original_model"] = original_model_url |
|
|
config_dict["use_hierarchical"] = metadata.get('use_hierarchical', True) |
|
|
|
|
|
|
|
|
config_dict["auto_map"] = { |
|
|
"AutoModelForCausalLM": "modeling_phoenix.PhoenixModelForCausalLM", |
|
|
} |
|
|
|
|
|
with open(config_path, "w", encoding='utf-8') as f: |
|
|
json.dump(config_dict, f, indent=2) |
|
|
print(f" โ
Config updated with PHOENIX markers and auto_map") |
|
|
|
|
|
|
|
|
with open(output_path / 'phoenix_metadata.json', 'w', encoding='utf-8') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
print(f" โ
Metadata saved") |
|
|
|
|
|
|
|
|
readme_content = f"""--- |
|
|
license: apache-2.0 |
|
|
library_name: transformers |
|
|
tags: |
|
|
- PHOENIX |
|
|
- Retention |
|
|
- O(n) Complexity |
|
|
- VIDraft |
|
|
pipeline_tag: text-generation |
|
|
--- |
|
|
|
|
|
# ๐ฅ PHOENIX Retention Model v1.4.1 |
|
|
|
|
|
This model has been converted from [{original_model_url}]({original_model_url}) using PHOENIX Retention mechanism. |
|
|
|
|
|
## Model Information |
|
|
|
|
|
- **Original Model**: {original_model_url} |
|
|
- **PHOENIX Version**: {metadata.get('phoenix_version', '1.4.1')} |
|
|
- **Conversion Rate**: {metadata.get('conversion_rate', 0)*100:.1f}% |
|
|
- **Quality Score**: {metadata.get('quality_score', 0):.2f}/1.00 |
|
|
- **Burning Type**: {metadata.get('burning_type', 'zero_shot')} |
|
|
- **Hierarchical**: {metadata.get('use_hierarchical', True)} |
|
|
|
|
|
## Features |
|
|
|
|
|
โ
**O(n) Complexity**: Linear attention mechanism replacing O(nยฒ) |
|
|
โ
**GQA Support**: Grouped Query Attention compatible |
|
|
โ
**Hierarchical Memory**: Multi-scale temporal dependencies |
|
|
โ
**Drop-in Replacement**: Compatible with standard transformers |
|
|
|
|
|
## Usage |
|
|
|
|
|
### โ ๏ธ Important: trust_remote_code=True Required! |
|
|
```python |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
# Load model (MUST use trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"{output_path.name}", |
|
|
trust_remote_code=True, # Required! |
|
|
torch_dtype="auto", |
|
|
device_map="auto" |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("{output_path.name}") |
|
|
|
|
|
# Generate text |
|
|
inputs = tokenizer("The future of AI is", return_tensors="pt") |
|
|
outputs = model.generate(**inputs, max_new_tokens=50) |
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
|
``` |
|
|
|
|
|
## Technical Details |
|
|
|
|
|
### Retention Mechanism |
|
|
|
|
|
PHOENIX uses Multi-Scale Retention instead of standard attention: |
|
|
- **Linear Complexity**: O(n) instead of O(nยฒ) |
|
|
- **Recurrent State**: Maintains hidden state across tokens |
|
|
- **Multi-Scale**: Hierarchical temporal modeling (short/medium/long) |
|
|
|
|
|
### Architecture |
|
|
|
|
|
- **Layers with Retention**: {metadata.get('layers_converted', 0)}/{metadata.get('total_layers', 0)} |
|
|
- **Hidden Size**: Variable (from original model) |
|
|
- **Attention Heads**: Variable (from original model) |
|
|
- **Conversion Type**: {"Hierarchical" if metadata.get('use_hierarchical') else "Multi-Scale"} |
|
|
|
|
|
### Performance |
|
|
|
|
|
- **Inference Speed**: ~{metadata.get('throughput', 20):.1f} tokens/sec |
|
|
- **Memory Efficiency**: Linear memory scaling |
|
|
- **Quality**: {metadata.get('quality_score', 0):.2f}/1.00 |
|
|
|
|
|
## Citation |
|
|
```bibtex |
|
|
@software{{phoenix_retention, |
|
|
title = {{PHOENIX Retention Research Platform}}, |
|
|
author = {{VIDraft AI Research Lab}}, |
|
|
year = {{2025}}, |
|
|
url = {{https://github.com/vidraft}}, |
|
|
version = {{{metadata.get('phoenix_version', '1.4.1')}}} |
|
|
}} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
Apache 2.0 (inherited from original model) |
|
|
|
|
|
--- |
|
|
|
|
|
**VIDraft AI Research Lab** | Powered by PHOENIX ๐ฅ |
|
|
""" |
|
|
|
|
|
with open(output_path / "README.md", "w", encoding='utf-8') as f: |
|
|
f.write(readme_content) |
|
|
print(f" โ
README.md created") |
|
|
|
|
|
print(f"\nโ
PHOENIX model package complete!") |
|
|
print(f" ๐ฆ Location: {output_path}") |
|
|
|
|
|
|
|
|
def verify_phoenix_model_before_upload(model_path: str) -> Tuple[bool, str, Dict]: |
|
|
"""Upload ์ PHOENIX ๋ชจ๋ธ ๊ฒ์ฆ""" |
|
|
print("\n๐งช Pre-upload Verification...") |
|
|
|
|
|
try: |
|
|
model_path = Path(model_path) |
|
|
|
|
|
file_checks = { |
|
|
'config': (model_path / 'config.json').exists(), |
|
|
'modeling': (model_path / 'modeling_phoenix.py').exists(), |
|
|
'readme': (model_path / 'README.md').exists(), |
|
|
'safetensors': (model_path / 'model.safetensors').exists(), |
|
|
'pytorch_bin': (model_path / 'pytorch_model.bin').exists(), |
|
|
} |
|
|
|
|
|
model_weights_exist = file_checks['safetensors'] or file_checks['pytorch_bin'] |
|
|
|
|
|
print(f" ๐ File Check:") |
|
|
print(f" config.json: {'โ
' if file_checks['config'] else 'โ'}") |
|
|
print(f" modeling_phoenix.py: {'โ
' if file_checks['modeling'] else 'โ'}") |
|
|
print(f" README.md: {'โ
' if file_checks['readme'] else 'โ'}") |
|
|
print(f" model weights: {'โ
(safetensors)' if file_checks['safetensors'] else 'โ
(pytorch_model.bin)' if file_checks['pytorch_bin'] else 'โ'}") |
|
|
|
|
|
if not file_checks['config']: |
|
|
return False, "โ Missing file: config.json", {} |
|
|
if not file_checks['modeling']: |
|
|
return False, "โ Missing file: modeling_phoenix.py", {} |
|
|
if not file_checks['readme']: |
|
|
return False, "โ Missing file: README.md", {} |
|
|
if not model_weights_exist: |
|
|
return False, "โ Missing model weights", {} |
|
|
|
|
|
print(" โ
All required files present") |
|
|
|
|
|
with open(model_path / 'config.json', 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
if not config.get('use_phoenix_retention'): |
|
|
return False, "โ PHOENIX marker not found in config", {} |
|
|
|
|
|
if 'auto_map' not in config: |
|
|
return False, "โ auto_map not configured in config", {} |
|
|
|
|
|
print(" โ
Config validated") |
|
|
|
|
|
metrics = { |
|
|
'retention_layers': -1, |
|
|
'total_layers': -1, |
|
|
'retention_rate': 1.0, |
|
|
'generation_quality': 0.8, |
|
|
'model_format': 'safetensors' if file_checks['safetensors'] else 'pytorch_bin', |
|
|
'verification_mode': 'file_only' |
|
|
} |
|
|
|
|
|
print(" โ
File-based verification passed") |
|
|
return True, "โ
All checks passed", metrics |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = traceback.format_exc() |
|
|
|
|
|
return False, f"โ Verification failed: {str(e)}\n{error_msg}", {} |
|
|
|
|
|
|
|
|
def upload_to_huggingface_hub( |
|
|
model_path: str, |
|
|
original_model_url: str, |
|
|
repo_name: str = None, |
|
|
private: bool = True, |
|
|
token: str = None, |
|
|
skip_verification: bool = False |
|
|
) -> Tuple[bool, str, str]: |
|
|
"""Upload PHOENIX model to HuggingFace Hub with verification""" |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("๐ค HUGGINGFACE HUB UPLOAD") |
|
|
print("="*80) |
|
|
|
|
|
if token is None: |
|
|
token = HF_TOKEN |
|
|
|
|
|
if not token: |
|
|
error_msg = "โ HF_TOKEN not found. Please set HF_TOKEN environment variable." |
|
|
print(f"\n{error_msg}") |
|
|
return False, "", error_msg |
|
|
|
|
|
print(f"โ
HF_TOKEN found: {'*' * 10}{token[-4:]}") |
|
|
|
|
|
model_path = Path(model_path) |
|
|
if not model_path.exists(): |
|
|
error_msg = f"โ Model path not found: {model_path}" |
|
|
print(f"\n{error_msg}") |
|
|
return False, "", error_msg |
|
|
|
|
|
print(f"โ
Model path verified: {model_path}") |
|
|
|
|
|
if not skip_verification: |
|
|
print("\n๐ Running pre-upload verification...") |
|
|
success, message, metrics = verify_phoenix_model_before_upload(str(model_path)) |
|
|
|
|
|
if not success: |
|
|
error_msg = f"โ Pre-upload verification failed:\n{message}" |
|
|
print(f"\n{error_msg}") |
|
|
return False, "", error_msg |
|
|
|
|
|
print(f"โ
Pre-upload verification PASSED!") |
|
|
else: |
|
|
print("\nโ ๏ธ Skipping pre-upload verification") |
|
|
|
|
|
try: |
|
|
print("\n๐ Authenticating with HuggingFace...") |
|
|
api = HfApi(token=token) |
|
|
|
|
|
try: |
|
|
user_info = api.whoami(token=token) |
|
|
username = user_info['name'] |
|
|
print(f"โ
Authenticated as: {username}") |
|
|
except Exception as e: |
|
|
error_msg = f"โ Authentication failed: {str(e)}" |
|
|
print(f"\n{error_msg}") |
|
|
return False, "", error_msg |
|
|
|
|
|
if not repo_name: |
|
|
base_name = original_model_url.split('/')[-1] |
|
|
repo_name = f"phoenix-{base_name}" |
|
|
|
|
|
repo_id = f"{username}/{repo_name}" |
|
|
|
|
|
print(f"\n๐ฆ Repository Configuration:") |
|
|
print(f" Repo ID: {repo_id}") |
|
|
print(f" Private: {private}") |
|
|
|
|
|
print(f"\n๐๏ธ Creating/verifying repository...") |
|
|
try: |
|
|
create_repo( |
|
|
repo_id=repo_id, |
|
|
token=token, |
|
|
private=private, |
|
|
repo_type="model", |
|
|
exist_ok=True |
|
|
) |
|
|
print(f"โ
Repository ready: {repo_id}") |
|
|
except Exception as e: |
|
|
print(f"โ ๏ธ Repository creation warning: {str(e)}") |
|
|
|
|
|
print(f"\n๐ค Uploading files to HuggingFace Hub...") |
|
|
|
|
|
try: |
|
|
api.upload_folder( |
|
|
folder_path=str(model_path), |
|
|
repo_id=repo_id, |
|
|
repo_type="model", |
|
|
token=token, |
|
|
) |
|
|
except Exception as e: |
|
|
error_msg = f"โ Upload failed: {str(e)}" |
|
|
print(f"\n{error_msg}") |
|
|
return False, "", error_msg |
|
|
|
|
|
hub_url = f"https://huggingface.co/{repo_id}" |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"โ
UPLOAD SUCCESSFUL!") |
|
|
print(f"{'='*80}") |
|
|
print(f"๐ Model URL: {hub_url}") |
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
success_msg = f"โ
Successfully uploaded to {hub_url}" |
|
|
return True, hub_url, success_msg |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = traceback.format_exc() |
|
|
print(f"\n{'='*80}") |
|
|
print(f"โ UPLOAD FAILED") |
|
|
print(f"{'='*80}") |
|
|
print(f"{error_msg}") |
|
|
print(f"{'='*80}\n") |
|
|
return False, "", f"โ Upload failed: {str(e)}\n\nFull error:\n{error_msg}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExperimentDatabase: |
|
|
"""SQLite database with migration support""" |
|
|
|
|
|
def __init__(self, db_path: str): |
|
|
self.db_path = db_path |
|
|
self.init_database() |
|
|
self.migrate_database() |
|
|
|
|
|
def init_database(self): |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS experiments ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
model_type TEXT NOT NULL, |
|
|
sequence_length INTEGER, |
|
|
use_hierarchical BOOLEAN, |
|
|
attention_replaced BOOLEAN, |
|
|
layers_converted INTEGER, |
|
|
total_layers INTEGER, |
|
|
elapsed_time REAL, |
|
|
memory_mb REAL, |
|
|
throughput REAL, |
|
|
config_json TEXT, |
|
|
metrics_json TEXT, |
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS burning_history ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
model_url TEXT NOT NULL, |
|
|
output_path TEXT NOT NULL, |
|
|
hub_url TEXT, |
|
|
use_hierarchical BOOLEAN, |
|
|
dataset_used BOOLEAN, |
|
|
conversion_rate REAL, |
|
|
training_steps INTEGER, |
|
|
final_loss REAL, |
|
|
evaluation_score REAL, |
|
|
verification_passed BOOLEAN, |
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
conn.commit() |
|
|
|
|
|
def migrate_database(self): |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("PRAGMA table_info(burning_history)") |
|
|
columns = [col[1] for col in cursor.fetchall()] |
|
|
|
|
|
if 'hub_url' not in columns: |
|
|
print("๐ Migrating database: Adding hub_url column...") |
|
|
cursor.execute("ALTER TABLE burning_history ADD COLUMN hub_url TEXT") |
|
|
|
|
|
if 'verification_passed' not in columns: |
|
|
print("๐ Migrating database: Adding verification_passed column...") |
|
|
cursor.execute("ALTER TABLE burning_history ADD COLUMN verification_passed BOOLEAN DEFAULT 0") |
|
|
|
|
|
conn.commit() |
|
|
|
|
|
def save_burning(self, burning_info: Dict) -> int: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
INSERT INTO burning_history ( |
|
|
model_url, output_path, hub_url, use_hierarchical, |
|
|
dataset_used, conversion_rate, training_steps, |
|
|
final_loss, evaluation_score, verification_passed |
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |
|
|
""", ( |
|
|
burning_info.get('model_url'), |
|
|
burning_info.get('output_path'), |
|
|
burning_info.get('hub_url'), |
|
|
burning_info.get('use_hierarchical'), |
|
|
burning_info.get('dataset_used'), |
|
|
burning_info.get('conversion_rate'), |
|
|
burning_info.get('training_steps', 0), |
|
|
burning_info.get('final_loss'), |
|
|
burning_info.get('evaluation_score'), |
|
|
burning_info.get('verification_passed', False), |
|
|
)) |
|
|
conn.commit() |
|
|
return cursor.lastrowid |
|
|
|
|
|
def get_burning_history(self, limit: int = 20) -> List[Dict]: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
conn.row_factory = sqlite3.Row |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("SELECT * FROM burning_history ORDER BY timestamp DESC LIMIT ?", (limit,)) |
|
|
return [dict(row) for row in cursor.fetchall()] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_model_quality(model, tokenizer, test_prompts=None): |
|
|
"""๊ฐ๋จํ ๋ชจ๋ธ ํ์ง ํ๊ฐ""" |
|
|
if test_prompts is None: |
|
|
test_prompts = [ |
|
|
"The capital of France is", |
|
|
"In machine learning, overfitting means", |
|
|
"2 + 2 =", |
|
|
] |
|
|
|
|
|
model.eval() |
|
|
scores = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for prompt in test_prompts: |
|
|
try: |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=20, |
|
|
do_sample=False, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
generated = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
score = 0.0 |
|
|
if len(generated) > len(prompt): |
|
|
score += 0.3 |
|
|
if not any(char in generated[len(prompt):] for char in ['๏ฟฝ', '[UNK]']): |
|
|
score += 0.3 |
|
|
if len(generated.split()) > len(prompt.split()) + 2: |
|
|
score += 0.4 |
|
|
|
|
|
scores.append(score) |
|
|
except Exception as e: |
|
|
print(f" โ ๏ธ Evaluation error for '{prompt}': {e}") |
|
|
scores.append(0.0) |
|
|
|
|
|
return sum(scores) / len(scores) if scores else 0.0 |
|
|
|
|
|
|
|
|
def burn_model_zero_shot( |
|
|
model_url: str, |
|
|
output_dir: str, |
|
|
use_hierarchical: bool = True, |
|
|
test_prompts: List[str] = None, |
|
|
): |
|
|
"""Zero-shot Model Burning with Structure Analysis""" |
|
|
print("="*80) |
|
|
print("๐ฅ PHOENIX Zero-shot Model Burning v1.4.1") |
|
|
print("="*80) |
|
|
|
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
|
|
print(f"\n๐ STEP 1: Model Structure Analysis...") |
|
|
structure_info = analyze_model_structure(model_url) |
|
|
|
|
|
if structure_info.get('error'): |
|
|
print(f"โ ๏ธ Structure analysis failed, continuing anyway...") |
|
|
structure_info = None |
|
|
elif structure_info.get('total_layers', 0) == 0: |
|
|
print(f"โ ๏ธ No layers detected, this may fail...") |
|
|
|
|
|
|
|
|
print(f"\n๐ฅ STEP 2: Loading model for conversion...") |
|
|
start_time = time.time() |
|
|
|
|
|
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_url, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16, |
|
|
).to(DEVICE) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
load_time = time.time() - start_time |
|
|
print(f"โ
Loaded in {load_time:.1f}s") |
|
|
|
|
|
|
|
|
print(f"\n๐ STEP 3: Converting Attention โ Retention...") |
|
|
convert_start = time.time() |
|
|
|
|
|
model, converted, total = replace_attention_with_retention( |
|
|
model, |
|
|
use_hierarchical=use_hierarchical, |
|
|
structure_info=structure_info |
|
|
) |
|
|
|
|
|
convert_time = time.time() - convert_start |
|
|
conversion_rate = converted / total if total > 0 else 0 |
|
|
|
|
|
print(f"โ
Converted {converted}/{total} layers ({conversion_rate*100:.1f}%) in {convert_time:.1f}s") |
|
|
|
|
|
if converted == 0: |
|
|
print(f"\nโ ๏ธ WARNING: No layers were converted!") |
|
|
else: |
|
|
|
|
|
print(f"\n๐ Verifying conversion...") |
|
|
verified_retention = 0 |
|
|
|
|
|
if hasattr(model, 'model') and hasattr(model.model, 'layers'): |
|
|
check_layers = model.model.layers |
|
|
else: |
|
|
check_layers = [] |
|
|
|
|
|
for layer in check_layers: |
|
|
if hasattr(layer, 'self_attn'): |
|
|
if 'Retention' in layer.self_attn.__class__.__name__: |
|
|
verified_retention += 1 |
|
|
|
|
|
print(f" โ
Verified: {verified_retention}/{len(check_layers)} layers have Retention") |
|
|
|
|
|
|
|
|
print(f"\n๐ STEP 4: Evaluating model quality...") |
|
|
eval_start = time.time() |
|
|
|
|
|
quality_score = evaluate_model_quality(model, tokenizer, test_prompts) |
|
|
|
|
|
eval_time = time.time() - eval_start |
|
|
print(f"โ
Quality Score: {quality_score:.2f}/1.00 (in {eval_time:.1f}s)") |
|
|
|
|
|
|
|
|
print(f"\n๐พ STEP 5: Saving PHOENIX model with custom code...") |
|
|
save_start = time.time() |
|
|
|
|
|
metadata = { |
|
|
'phoenix_version': '1.4.1', |
|
|
'original_model': model_url, |
|
|
'use_hierarchical': use_hierarchical, |
|
|
'conversion_rate': conversion_rate, |
|
|
'layers_converted': converted, |
|
|
'total_layers': total, |
|
|
'quality_score': quality_score, |
|
|
'burning_type': 'zero_shot', |
|
|
'structure_info': structure_info, |
|
|
'timestamp': datetime.now().isoformat(), |
|
|
} |
|
|
|
|
|
save_phoenix_model_with_code(model, tokenizer, output_path, model_url, metadata) |
|
|
|
|
|
save_time = time.time() - save_start |
|
|
print(f"โ
Saved to {output_path} in {save_time:.1f}s") |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
result = { |
|
|
'status': 'success', |
|
|
'model_path': str(output_path), |
|
|
'conversion_rate': conversion_rate, |
|
|
'quality_score': quality_score, |
|
|
'total_time': total_time, |
|
|
'load_time': load_time, |
|
|
'convert_time': convert_time, |
|
|
'eval_time': eval_time, |
|
|
'save_time': save_time, |
|
|
'structure_info': structure_info, |
|
|
} |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"โ
Zero-shot Burning Complete!") |
|
|
print(f" Total Time: {total_time:.1f}s") |
|
|
print(f" Model Path: {output_path}") |
|
|
print(f" Quality: {quality_score:.2f}/1.00") |
|
|
print(f" Conversion: {converted}/{total} layers") |
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = traceback.format_exc() |
|
|
print(f"\nโ Zero-shot burning failed:\n{error_msg}") |
|
|
return { |
|
|
'status': 'failed', |
|
|
'error': str(e), |
|
|
'traceback': error_msg |
|
|
} |
|
|
|
|
|
|
|
|
def burn_model_with_finetuning( |
|
|
model_url: str, |
|
|
output_dir: str, |
|
|
dataset_path: str, |
|
|
use_hierarchical: bool = True, |
|
|
num_epochs: int = 1, |
|
|
batch_size: int = 4, |
|
|
learning_rate: float = 5e-5, |
|
|
max_steps: int = 100, |
|
|
): |
|
|
"""Fine-tuning Model Burning with Structure Analysis""" |
|
|
print("="*80) |
|
|
print("๐ฅ PHOENIX Fine-tuning Model Burning v1.4.1") |
|
|
print("="*80) |
|
|
|
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
|
|
print(f"\n๐ STEP 1: Model Structure Analysis...") |
|
|
structure_info = analyze_model_structure(model_url) |
|
|
|
|
|
|
|
|
print(f"\n๐ฅ STEP 2: Loading model...") |
|
|
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_url, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16, |
|
|
).to(DEVICE) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
print(f"\n๐ STEP 3: Converting...") |
|
|
model, converted, total = replace_attention_with_retention( |
|
|
model, |
|
|
use_hierarchical=use_hierarchical, |
|
|
structure_info=structure_info |
|
|
) |
|
|
|
|
|
conversion_rate = converted / total if total > 0 else 0 |
|
|
print(f"โ
Converted {converted}/{total} layers") |
|
|
|
|
|
|
|
|
print(f"\n๐ STEP 4: Loading dataset: {dataset_path}") |
|
|
|
|
|
if dataset_path.endswith('.txt'): |
|
|
with open(dataset_path, 'r', encoding='utf-8') as f: |
|
|
texts = [line.strip() for line in f if line.strip()] |
|
|
|
|
|
def tokenize_fn(text): |
|
|
return tokenizer( |
|
|
text, |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding='max_length', |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
tokenized_data = [tokenize_fn(text) for text in texts[:1000]] |
|
|
else: |
|
|
dataset = load_dataset('text', data_files=dataset_path) |
|
|
|
|
|
def tokenize_function(examples): |
|
|
return tokenizer( |
|
|
examples['text'], |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding='max_length', |
|
|
) |
|
|
|
|
|
dataset = dataset.map(tokenize_function, batched=True) |
|
|
tokenized_data = dataset['train'] |
|
|
|
|
|
print(f"โ
Loaded {len(tokenized_data)} samples") |
|
|
|
|
|
|
|
|
print(f"\n๐ STEP 5: Starting fine-tuning...") |
|
|
model.train() |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
|
|
|
|
|
step = 0 |
|
|
total_loss = 0.0 |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
for i in range(0, len(tokenized_data), batch_size): |
|
|
if step >= max_steps: |
|
|
break |
|
|
|
|
|
batch = tokenized_data[i:i+batch_size] |
|
|
|
|
|
if isinstance(batch, list): |
|
|
input_ids = torch.stack([item['input_ids'].squeeze() for item in batch]).to(DEVICE) |
|
|
attention_mask = torch.stack([item['attention_mask'].squeeze() for item in batch]).to(DEVICE) |
|
|
else: |
|
|
input_ids = torch.tensor(batch['input_ids']).to(DEVICE) |
|
|
attention_mask = torch.tensor(batch['attention_mask']).to(DEVICE) |
|
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) |
|
|
loss = outputs.loss |
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
total_loss += loss.item() |
|
|
step += 1 |
|
|
|
|
|
if step % 10 == 0: |
|
|
print(f" Step {step}/{max_steps} - Loss: {total_loss/step:.4f}") |
|
|
|
|
|
final_loss = total_loss / step if step > 0 else 0.0 |
|
|
print(f"โ
Training complete - Final Loss: {final_loss:.4f}") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
quality_score = evaluate_model_quality(model, tokenizer) |
|
|
|
|
|
metadata = { |
|
|
'phoenix_version': '1.4.1', |
|
|
'original_model': model_url, |
|
|
'use_hierarchical': use_hierarchical, |
|
|
'conversion_rate': conversion_rate, |
|
|
'quality_score': quality_score, |
|
|
'burning_type': 'fine_tuning', |
|
|
'training_steps': step, |
|
|
'final_loss': final_loss, |
|
|
'dataset': dataset_path, |
|
|
'structure_info': structure_info, |
|
|
'timestamp': datetime.now().isoformat(), |
|
|
} |
|
|
|
|
|
save_phoenix_model_with_code(model, tokenizer, output_path, model_url, metadata) |
|
|
|
|
|
result = { |
|
|
'status': 'success', |
|
|
'model_path': str(output_path), |
|
|
'conversion_rate': conversion_rate, |
|
|
'quality_score': quality_score, |
|
|
'training_steps': step, |
|
|
'final_loss': final_loss, |
|
|
'structure_info': structure_info, |
|
|
} |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = traceback.format_exc() |
|
|
print(f"\nโ Fine-tuning burning failed:\n{error_msg}") |
|
|
return { |
|
|
'status': 'failed', |
|
|
'error': str(e), |
|
|
'traceback': error_msg |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def burn_phoenix_model_ui( |
|
|
model_url, |
|
|
use_hierarchical, |
|
|
dataset_path, |
|
|
output_name, |
|
|
use_finetuning, |
|
|
num_epochs, |
|
|
batch_size, |
|
|
learning_rate, |
|
|
max_steps, |
|
|
upload_to_hub, |
|
|
hub_repo_name, |
|
|
hub_private, |
|
|
): |
|
|
"""Gradio UI์ฉ ๋ชจ๋ธ ๋ฒ๋ ํจ์""" |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("๐ฅ PHOENIX MODEL BURNING START v1.4.1") |
|
|
print("="*80) |
|
|
|
|
|
try: |
|
|
if not model_url.strip(): |
|
|
return "โ ๏ธ Model URL is required", None |
|
|
|
|
|
if not output_name.strip(): |
|
|
output_name = f"phoenix_{model_url.split('/')[-1]}_{int(time.time())}" |
|
|
|
|
|
output_dir = f"{MODELS_PATH}/{output_name}" |
|
|
|
|
|
print(f"๐ Configuration:") |
|
|
print(f" Model URL: {model_url}") |
|
|
print(f" Output Name: {output_name}") |
|
|
print(f" Hierarchical: {use_hierarchical}") |
|
|
print(f" Upload to Hub: {upload_to_hub}") |
|
|
|
|
|
has_dataset = dataset_path and dataset_path.strip() and Path(dataset_path).exists() |
|
|
|
|
|
if use_finetuning and not has_dataset: |
|
|
return "โ ๏ธ Fine-tuning requires a valid dataset path", None |
|
|
|
|
|
if upload_to_hub and not HF_TOKEN: |
|
|
warning_msg = "โ ๏ธ HuggingFace Token Not Found! Continuing with local burning only..." |
|
|
print(f"\n{warning_msg}") |
|
|
|
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
if use_finetuning and has_dataset: |
|
|
print("๐ Starting Fine-tuning Burning...") |
|
|
result = burn_model_with_finetuning( |
|
|
model_url=model_url, |
|
|
output_dir=output_dir, |
|
|
dataset_path=dataset_path, |
|
|
use_hierarchical=use_hierarchical, |
|
|
num_epochs=num_epochs, |
|
|
batch_size=batch_size, |
|
|
learning_rate=learning_rate, |
|
|
max_steps=max_steps, |
|
|
) |
|
|
else: |
|
|
print("๐ Starting Zero-shot Burning...") |
|
|
result = burn_model_zero_shot( |
|
|
model_url=model_url, |
|
|
output_dir=output_dir, |
|
|
use_hierarchical=use_hierarchical, |
|
|
) |
|
|
|
|
|
if result['status'] != 'success': |
|
|
error_msg = f"โ Burning Failed\n```\n{result.get('error', 'Unknown error')}\n```" |
|
|
return error_msg, None |
|
|
|
|
|
print(f"\nโ
Burning completed successfully!") |
|
|
|
|
|
|
|
|
hub_url = None |
|
|
verification_passed = False |
|
|
upload_status = "Not attempted" |
|
|
|
|
|
if upload_to_hub: |
|
|
if not HF_TOKEN: |
|
|
upload_status = "โ Failed - No HF_TOKEN" |
|
|
else: |
|
|
success, hub_url, upload_msg = upload_to_huggingface_hub( |
|
|
model_path=result['model_path'], |
|
|
original_model_url=model_url, |
|
|
repo_name=hub_repo_name if hub_repo_name.strip() else None, |
|
|
private=hub_private, |
|
|
skip_verification=False |
|
|
) |
|
|
|
|
|
verification_passed = success |
|
|
upload_status = f"โ
Uploaded to {hub_url}" if success else f"โ Upload failed" |
|
|
else: |
|
|
upload_status = "โญ๏ธ Skipped" |
|
|
|
|
|
|
|
|
burning_info = { |
|
|
'model_url': model_url, |
|
|
'output_path': result['model_path'], |
|
|
'hub_url': hub_url, |
|
|
'use_hierarchical': use_hierarchical, |
|
|
'dataset_used': has_dataset, |
|
|
'conversion_rate': result.get('conversion_rate', 0.0), |
|
|
'training_steps': result.get('training_steps', 0), |
|
|
'final_loss': result.get('final_loss'), |
|
|
'evaluation_score': result.get('quality_score', 0.0), |
|
|
'verification_passed': verification_passed, |
|
|
} |
|
|
|
|
|
db.save_burning(burning_info) |
|
|
|
|
|
|
|
|
structure_info = result.get('structure_info', {}) |
|
|
|
|
|
output_md = f""" |
|
|
# ๐ฅ Model Burning Complete! (v1.4.1) |
|
|
|
|
|
## ๐ Structure Analysis |
|
|
- **Model Type**: {structure_info.get('model_type', 'unknown')} |
|
|
- **Architecture**: {structure_info.get('architectures', 'unknown')} |
|
|
- **Total Layers**: {structure_info.get('total_layers', 0)} |
|
|
- **Layer Path**: {structure_info.get('layer_path', 'unknown')} |
|
|
- **Has self_attn**: {structure_info.get('has_self_attn', False)} |
|
|
- **GQA Detected**: {structure_info.get('gqa_detected', False)} |
|
|
|
|
|
## ๐ฆ Model Information |
|
|
- **Original Model**: {model_url} |
|
|
- **Output Path**: `{result['model_path']}` |
|
|
- **Burning Type**: {'Fine-tuning' if has_dataset else 'Zero-shot'} |
|
|
- **Hierarchical**: {use_hierarchical} |
|
|
|
|
|
## ๐ Metrics |
|
|
- **Conversion Rate**: {result.get('conversion_rate', 0)*100:.1f}% |
|
|
- **Quality Score**: {result.get('quality_score', 0):.2f}/1.00 |
|
|
""" |
|
|
|
|
|
if 'training_steps' in result: |
|
|
output_md += f""" |
|
|
## ๐ Training |
|
|
- **Steps**: {result['training_steps']} |
|
|
- **Final Loss**: {result.get('final_loss', 0.0):.4f} |
|
|
""" |
|
|
|
|
|
output_md += f""" |
|
|
## โฑ๏ธ Time Breakdown |
|
|
- **Total**: {result.get('total_time', 0):.1f}s |
|
|
""" |
|
|
|
|
|
if 'load_time' in result: |
|
|
output_md += f"- **Load**: {result['load_time']:.1f}s\n" |
|
|
output_md += f"- **Convert**: {result['convert_time']:.1f}s\n" |
|
|
output_md += f"- **Evaluate**: {result['eval_time']:.1f}s\n" |
|
|
output_md += f"- **Save**: {result['save_time']:.1f}s\n" |
|
|
|
|
|
output_md += f""" |
|
|
--- |
|
|
|
|
|
## ๐ HuggingFace Hub Upload |
|
|
|
|
|
**Status**: {upload_status} |
|
|
""" |
|
|
|
|
|
if hub_url: |
|
|
output_md += f""" |
|
|
**Model URL**: [{hub_url}]({hub_url}) |
|
|
|
|
|
### ๐ Load from Hub |
|
|
```python |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"{hub_url.replace('https://huggingface.co/', '')}", |
|
|
trust_remote_code=True, |
|
|
torch_dtype="auto", |
|
|
device_map="auto" |
|
|
) |
|
|
``` |
|
|
""" |
|
|
|
|
|
output_md += f""" |
|
|
--- |
|
|
|
|
|
โ
**PHOENIX Model Ready! (v1.4.1)** |
|
|
""" |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
metrics_names = ['Conversion', 'Quality'] |
|
|
metrics_values = [result.get('conversion_rate', 0), result.get('quality_score', 0)] |
|
|
|
|
|
if verification_passed: |
|
|
metrics_names.append('Upload') |
|
|
metrics_values.append(1.0) |
|
|
|
|
|
fig.add_trace(go.Bar( |
|
|
x=metrics_names, |
|
|
y=metrics_values, |
|
|
marker_color=['#3b82f6', '#10b981', '#8b5cf6'][:len(metrics_names)] |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
title="๐ฅ Burning Metrics", |
|
|
yaxis_range=[0, 1], |
|
|
template='plotly_white', |
|
|
height=400 |
|
|
) |
|
|
|
|
|
return output_md, fig |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = traceback.format_exc() |
|
|
|
|
|
return f""" |
|
|
โ **Burning Failed** |
|
|
|
|
|
**Error:** {str(e)} |
|
|
|
|
|
**Traceback:** |
|
|
``` |
|
|
{error_msg} |
|
|
``` |
|
|
""", None |
|
|
|
|
|
|
|
|
def view_burning_history(): |
|
|
"""View burning history""" |
|
|
try: |
|
|
history = db.get_burning_history(limit=20) |
|
|
|
|
|
if not history: |
|
|
return "๐ญ No burning history yet", None |
|
|
|
|
|
df = pd.DataFrame(history) |
|
|
|
|
|
fig = px.scatter( |
|
|
df, |
|
|
x='timestamp', |
|
|
y='evaluation_score', |
|
|
size='conversion_rate', |
|
|
color='verification_passed', |
|
|
hover_data=['model_url', 'output_path', 'hub_url'], |
|
|
title='Burning History' |
|
|
) |
|
|
|
|
|
cols = ['id', 'model_url', 'hub_url', 'conversion_rate', |
|
|
'evaluation_score', 'verification_passed', 'timestamp'] |
|
|
available = [c for c in cols if c in df.columns] |
|
|
|
|
|
return f"## ๐ Burning History\n\n{df[available].to_markdown(index=False)}", fig |
|
|
|
|
|
except Exception as e: |
|
|
return f"โ Error: {e}", None |
|
|
|
|
|
|
|
|
def validate_phoenix_model( |
|
|
model_source, |
|
|
model_path_or_url, |
|
|
test_prompts, |
|
|
max_tokens, |
|
|
temperature, |
|
|
verify_retention |
|
|
): |
|
|
"""PHOENIX ๋ชจ๋ธ ๊ฒ์ฆ""" |
|
|
try: |
|
|
print("="*80) |
|
|
print("๐งช PHOENIX Model Validation v1.4.1") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
print(f"\n๐ฅ Loading model from {model_source}...") |
|
|
start_time = time.time() |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path_or_url, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16, |
|
|
).to(DEVICE) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_path_or_url, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
load_time = time.time() - start_time |
|
|
print(f"โ
Model loaded in {load_time:.2f}s") |
|
|
|
|
|
|
|
|
metadata = {} |
|
|
metadata_path = None |
|
|
|
|
|
if model_source == "local": |
|
|
metadata_path = Path(model_path_or_url) / "phoenix_metadata.json" |
|
|
else: |
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
metadata_path = hf_hub_download( |
|
|
repo_id=model_path_or_url, |
|
|
filename="phoenix_metadata.json" |
|
|
) |
|
|
except: |
|
|
pass |
|
|
|
|
|
if metadata_path and Path(metadata_path).exists(): |
|
|
with open(metadata_path, 'r') as f: |
|
|
metadata = json.load(f) |
|
|
|
|
|
|
|
|
retention_info = "" |
|
|
if verify_retention: |
|
|
print(f"\n๐ Verifying Retention mechanism...") |
|
|
|
|
|
retention_count = 0 |
|
|
attention_count = 0 |
|
|
|
|
|
|
|
|
check_model = model |
|
|
if hasattr(model, '_original_model') and model._original_model is not None: |
|
|
print(f" ๐ Detected PhoenixModelForCausalLM wrapper") |
|
|
check_model = model._original_model |
|
|
|
|
|
layers = [] |
|
|
if hasattr(check_model, 'model') and hasattr(check_model.model, 'layers'): |
|
|
layers = check_model.model.layers |
|
|
elif hasattr(check_model, 'layers'): |
|
|
layers = check_model.layers |
|
|
|
|
|
print(f" ๐ Checking {len(layers)} layers...") |
|
|
|
|
|
for i, layer in enumerate(layers): |
|
|
if hasattr(layer, 'self_attn'): |
|
|
attn = layer.self_attn |
|
|
class_name = attn.__class__.__name__ |
|
|
|
|
|
if 'Retention' in class_name: |
|
|
retention_count += 1 |
|
|
if i < 3: |
|
|
print(f" โ
Layer {i}: {class_name}") |
|
|
else: |
|
|
attention_count += 1 |
|
|
if i < 3: |
|
|
print(f" โ ๏ธ Layer {i}: {class_name}") |
|
|
|
|
|
total = retention_count + attention_count |
|
|
retention_info = f""" |
|
|
### ๐ Retention Verification |
|
|
- **Retention Layers**: {retention_count}/{total} |
|
|
- **Attention Layers**: {attention_count}/{total} |
|
|
- **Status**: {'โ
PHOENIX Active' if retention_count > 0 else 'โ ๏ธ No Retention Found'} |
|
|
""" |
|
|
print(f" ๐ Result: {retention_count}/{total} layers have Retention") |
|
|
|
|
|
|
|
|
print(f"\n๐ Running generation tests...") |
|
|
|
|
|
prompts = [p.strip() for p in test_prompts.split('\n') if p.strip()] |
|
|
if not prompts: |
|
|
prompts = ["The future of AI is", "Once upon a time"] |
|
|
|
|
|
results = [] |
|
|
total_gen_time = 0 |
|
|
|
|
|
for i, prompt in enumerate(prompts, 1): |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
|
|
|
|
|
gen_start = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=temperature > 0.01, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
gen_time = time.time() - gen_start |
|
|
total_gen_time += gen_time |
|
|
|
|
|
generated = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
tokens_generated = len(outputs[0]) - len(inputs['input_ids'][0]) |
|
|
tokens_per_sec = tokens_generated / gen_time if gen_time > 0 else 0 |
|
|
|
|
|
results.append({ |
|
|
'prompt': prompt, |
|
|
'generated': generated, |
|
|
'time': gen_time, |
|
|
'tokens': tokens_generated, |
|
|
'tokens_per_sec': tokens_per_sec, |
|
|
}) |
|
|
|
|
|
|
|
|
output_md = f""" |
|
|
# โ
PHOENIX Model Validation Complete! (v1.4.1) |
|
|
|
|
|
## ๐ฆ Model Information |
|
|
- **Source**: {model_source.upper()} |
|
|
- **Path/URL**: `{model_path_or_url}` |
|
|
- **Load Time**: {load_time:.2f}s |
|
|
|
|
|
## ๐ Metadata |
|
|
""" |
|
|
|
|
|
if metadata: |
|
|
output_md += f""" |
|
|
- **PHOENIX Version**: {metadata.get('phoenix_version', 'Unknown')} |
|
|
- **Original Model**: {metadata.get('original_model', 'Unknown')} |
|
|
- **Conversion Rate**: {metadata.get('conversion_rate', 0)*100:.1f}% |
|
|
""" |
|
|
|
|
|
if retention_info: |
|
|
output_md += retention_info |
|
|
|
|
|
output_md += f""" |
|
|
## ๐ Generation Tests |
|
|
|
|
|
**Total Tests**: {len(results)} |
|
|
**Average Speed**: {sum(r['tokens_per_sec'] for r in results)/len(results):.1f} tokens/s |
|
|
|
|
|
--- |
|
|
""" |
|
|
|
|
|
for i, result in enumerate(results, 1): |
|
|
output_md += f""" |
|
|
### Test {i} |
|
|
|
|
|
**Generated:** |
|
|
``` |
|
|
{result['generated']} |
|
|
``` |
|
|
|
|
|
**Stats**: {result['time']:.2f}s | {result['tokens_per_sec']:.1f} tokens/s |
|
|
|
|
|
--- |
|
|
""" |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Bar( |
|
|
x=[f"Test {i+1}" for i in range(len(results))], |
|
|
y=[r['tokens_per_sec'] for r in results], |
|
|
marker_color='#10b981' |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
title="Generation Speed (tokens/s)", |
|
|
template='plotly_white' |
|
|
) |
|
|
|
|
|
return output_md, fig |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
return f"โ Validation failed:\n```\n{traceback.format_exc()}\n```", None |
|
|
|
|
|
|
|
|
|
|
|
db = ExperimentDatabase(DB_PATH) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="๐ฎ PHOENIX v1.4.1 - State Dict Direct Loading", |
|
|
theme=gr.themes.Soft(), |
|
|
) as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# ๐ฎ PHOENIX Retention Platform v1.4.1 |
|
|
|
|
|
**State Dict Direct Loading + Structure-Aware Burning** |
|
|
|
|
|
โ
**NEW!** State Dict ์ง์ ๋ก๋๋ก Retention ๋ณด์กด |
|
|
โ
Model Structure Pre-Analysis |
|
|
โ
Qwen3 Model Support |
|
|
โ
Zero-shot Conversion (No Dataset Required) |
|
|
โ
Optional Fine-tuning |
|
|
โ
GQA Support |
|
|
โ
O(n) Complexity |
|
|
โ
Auto Upload to HuggingFace Hub |
|
|
|
|
|
--- |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("๐ฅ Model Burning"): |
|
|
gr.Markdown(""" |
|
|
### ๐ฅ PHOENIX Model Burning v1.4.1 |
|
|
|
|
|
**๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๋จผ์ ๋ถ์ํ ํ ๋ณํํฉ๋๋ค!** |
|
|
**Hub ๋ก๋ ์ State Dict ์ง์ ๋ก๋๋ก Retention ๋ณด์กด!** |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
burn_model_url = gr.Textbox( |
|
|
label="๐ Model URL", |
|
|
value=DEFAULT_MODEL, |
|
|
placeholder="Qwen/Qwen3-0.6B" |
|
|
) |
|
|
burn_hierarchical = gr.Checkbox(value=True, label="Hierarchical Retention") |
|
|
|
|
|
burn_output_name = gr.Textbox( |
|
|
label="๐พ Output Name", |
|
|
placeholder="phoenix_my_model" |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### ๐ HuggingFace Hub Upload") |
|
|
|
|
|
burn_upload_hub = gr.Checkbox(value=True, label="๐ค Upload to Hub") |
|
|
burn_hub_repo = gr.Textbox(label="๐ฆ Repo Name (optional)") |
|
|
burn_hub_private = gr.Checkbox(value=True, label="๐ Private") |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### ๐ Dataset (Optional)") |
|
|
|
|
|
burn_dataset = gr.Textbox(label="๐ Dataset Path") |
|
|
burn_use_finetuning = gr.Checkbox(value=False, label="๐ Enable Fine-tuning") |
|
|
|
|
|
with gr.Accordion("โ๏ธ Fine-tuning Config", open=False): |
|
|
burn_epochs = gr.Slider(1, 5, 1, step=1, label="Epochs") |
|
|
burn_batch = gr.Slider(1, 16, 4, step=1, label="Batch Size") |
|
|
burn_lr = gr.Number(value=5e-5, label="Learning Rate") |
|
|
burn_max_steps = gr.Slider(10, 500, 100, step=10, label="Max Steps") |
|
|
|
|
|
burn_btn = gr.Button("๐ฅ Burn Model", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
burn_output = gr.Markdown() |
|
|
burn_plot = gr.Plot() |
|
|
|
|
|
burn_btn.click( |
|
|
burn_phoenix_model_ui, |
|
|
[ |
|
|
burn_model_url, burn_hierarchical, burn_dataset, burn_output_name, |
|
|
burn_use_finetuning, burn_epochs, burn_batch, burn_lr, burn_max_steps, |
|
|
burn_upload_hub, burn_hub_repo, burn_hub_private, |
|
|
], |
|
|
[burn_output, burn_plot] |
|
|
) |
|
|
|
|
|
with gr.Tab("๐ Burning History"): |
|
|
gr.Markdown("### ๐ Model Burning History") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
hist_btn = gr.Button("๐ Load History", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
hist_output = gr.Markdown() |
|
|
hist_plot = gr.Plot() |
|
|
|
|
|
hist_btn.click(view_burning_history, outputs=[hist_output, hist_plot]) |
|
|
|
|
|
with gr.Tab("๐งช Model Validation"): |
|
|
gr.Markdown("### ๐งช PHOENIX ๋ชจ๋ธ ๊ฒ์ฆ") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
val_source = gr.Radio( |
|
|
choices=["hub", "local"], |
|
|
value="hub", |
|
|
label="๐ Model Source" |
|
|
) |
|
|
|
|
|
val_path = gr.Textbox( |
|
|
label="๐ Model Path/URL", |
|
|
value="seawolf2357/phoenix-Qwen3-0.6B", |
|
|
placeholder="seawolf2357/phoenix-model" |
|
|
) |
|
|
|
|
|
val_prompts = gr.Textbox( |
|
|
label="๐ Test Prompts (one per line)", |
|
|
lines=5, |
|
|
value="The future of AI is\nOnce upon a time\nIn machine learning,", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
val_max_tokens = gr.Slider(16, 256, 64, step=16, label="Max Tokens") |
|
|
val_temp = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature") |
|
|
|
|
|
val_verify_retention = gr.Checkbox(value=True, label="๐ Verify Retention") |
|
|
|
|
|
val_btn = gr.Button("๐งช Validate Model", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
val_output = gr.Markdown() |
|
|
val_plot = gr.Plot() |
|
|
|
|
|
val_btn.click( |
|
|
validate_phoenix_model, |
|
|
[val_source, val_path, val_prompts, val_max_tokens, |
|
|
val_temp, val_verify_retention], |
|
|
[val_output, val_plot] |
|
|
) |
|
|
|
|
|
gr.Markdown(f""" |
|
|
--- |
|
|
|
|
|
## ๐ฅ PHOENIX Model Burning Platform v1.4.1 |
|
|
|
|
|
### What's New in v1.4.1 |
|
|
- โ
**FIX: head_dim calculation** - Config ์ฐ์ ์ฌ์ฉ |
|
|
- โ
**State Dict Direct Loading** - Hub ๋ก๋ ์ Retention ๊ฐ์ค์น ๋ณด์กด |
|
|
- โ
**Model Structure Pre-Analysis** - ๋ณํ ์ ๊ตฌ์กฐ ํ์
|
|
|
- โ
**Qwen3 Support** - Qwen3 ๋ชจ๋ธ ์๋ฒฝ ์ง์ |
|
|
|
|
|
**HuggingFace Token**: {'โ
Connected' if HF_TOKEN else 'โ Not Found'} |
|
|
**Default Model**: {DEFAULT_MODEL} |
|
|
|
|
|
**VIDraft AI Research Lab** | PHOENIX v1.4.1 |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=20) |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |