Update app.py
Browse files
app.py
CHANGED
|
@@ -1901,8 +1901,630 @@ def burn_model_zero_shot(
|
|
| 1901 |
}
|
| 1902 |
|
| 1903 |
|
| 1904 |
-
|
| 1905 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1906 |
|
| 1907 |
# 전역 초기화
|
| 1908 |
db = ExperimentDatabase(DB_PATH)
|
|
@@ -1912,12 +2534,12 @@ db = ExperimentDatabase(DB_PATH)
|
|
| 1912 |
# =====================================================
|
| 1913 |
|
| 1914 |
with gr.Blocks(
|
| 1915 |
-
title="🔮 PHOENIX v1.4 - State Dict Direct Loading",
|
| 1916 |
theme=gr.themes.Soft(),
|
| 1917 |
) as demo:
|
| 1918 |
|
| 1919 |
gr.Markdown("""
|
| 1920 |
-
# 🔮 PHOENIX Retention Platform v1.4
|
| 1921 |
|
| 1922 |
**State Dict Direct Loading + Structure-Aware Burning**
|
| 1923 |
|
|
@@ -1936,7 +2558,7 @@ with gr.Blocks(
|
|
| 1936 |
with gr.Tabs():
|
| 1937 |
with gr.Tab("🔥 Model Burning"):
|
| 1938 |
gr.Markdown("""
|
| 1939 |
-
### 🔥 PHOENIX Model Burning v1.4
|
| 1940 |
|
| 1941 |
**모델 구조를 먼저 분석한 후 변환합니다!**
|
| 1942 |
**Hub 로드 시 State Dict 직접 로드로 Retention 보존!**
|
|
@@ -2049,18 +2671,18 @@ with gr.Blocks(
|
|
| 2049 |
gr.Markdown(f"""
|
| 2050 |
---
|
| 2051 |
|
| 2052 |
-
## 🔥 PHOENIX Model Burning Platform v1.4
|
| 2053 |
|
| 2054 |
-
### What's New in v1.4
|
|
|
|
| 2055 |
- ✅ **State Dict Direct Loading** - Hub 로드 시 Retention 가중치 보존
|
| 2056 |
-
- ✅ **Fixed Hub Loading** - Custom Code에서 올바른 가중치 로드
|
| 2057 |
- ✅ **Model Structure Pre-Analysis** - 변환 전 구조 파악
|
| 2058 |
- ✅ **Qwen3 Support** - Qwen3 모델 완벽 지원
|
| 2059 |
|
| 2060 |
**HuggingFace Token**: {'✅ Connected' if HF_TOKEN else '❌ Not Found'}
|
| 2061 |
**Default Model**: {DEFAULT_MODEL}
|
| 2062 |
|
| 2063 |
-
**VIDraft AI Research Lab** | PHOENIX v1.4
|
| 2064 |
""")
|
| 2065 |
|
| 2066 |
if __name__ == "__main__":
|
|
|
|
| 1901 |
}
|
| 1902 |
|
| 1903 |
|
| 1904 |
+
def burn_model_with_finetuning(
|
| 1905 |
+
model_url: str,
|
| 1906 |
+
output_dir: str,
|
| 1907 |
+
dataset_path: str,
|
| 1908 |
+
use_hierarchical: bool = True,
|
| 1909 |
+
num_epochs: int = 1,
|
| 1910 |
+
batch_size: int = 4,
|
| 1911 |
+
learning_rate: float = 5e-5,
|
| 1912 |
+
max_steps: int = 100,
|
| 1913 |
+
):
|
| 1914 |
+
"""Fine-tuning Model Burning with Structure Analysis"""
|
| 1915 |
+
print("="*80)
|
| 1916 |
+
print("🔥 PHOENIX Fine-tuning Model Burning v1.4.1")
|
| 1917 |
+
print("="*80)
|
| 1918 |
+
|
| 1919 |
+
output_path = Path(output_dir)
|
| 1920 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 1921 |
+
|
| 1922 |
+
try:
|
| 1923 |
+
# 1. 구조 분석
|
| 1924 |
+
print(f"\n🔍 STEP 1: Model Structure Analysis...")
|
| 1925 |
+
structure_info = analyze_model_structure(model_url)
|
| 1926 |
+
|
| 1927 |
+
# 2. 로드 & 변환
|
| 1928 |
+
print(f"\n📥 STEP 2: Loading model...")
|
| 1929 |
+
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
|
| 1930 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 1931 |
+
model_url,
|
| 1932 |
+
trust_remote_code=True,
|
| 1933 |
+
torch_dtype=torch.float16,
|
| 1934 |
+
).to(DEVICE)
|
| 1935 |
+
|
| 1936 |
+
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
|
| 1937 |
+
if tokenizer.pad_token is None:
|
| 1938 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 1939 |
+
|
| 1940 |
+
print(f"\n🔄 STEP 3: Converting...")
|
| 1941 |
+
model, converted, total = replace_attention_with_retention(
|
| 1942 |
+
model,
|
| 1943 |
+
use_hierarchical=use_hierarchical,
|
| 1944 |
+
structure_info=structure_info
|
| 1945 |
+
)
|
| 1946 |
+
|
| 1947 |
+
conversion_rate = converted / total if total > 0 else 0
|
| 1948 |
+
print(f"✅ Converted {converted}/{total} layers")
|
| 1949 |
+
|
| 1950 |
+
# 3. 데이터셋 로드
|
| 1951 |
+
print(f"\n📊 STEP 4: Loading dataset: {dataset_path}")
|
| 1952 |
+
|
| 1953 |
+
if dataset_path.endswith('.txt'):
|
| 1954 |
+
with open(dataset_path, 'r', encoding='utf-8') as f:
|
| 1955 |
+
texts = [line.strip() for line in f if line.strip()]
|
| 1956 |
+
|
| 1957 |
+
def tokenize_fn(text):
|
| 1958 |
+
return tokenizer(
|
| 1959 |
+
text,
|
| 1960 |
+
truncation=True,
|
| 1961 |
+
max_length=512,
|
| 1962 |
+
padding='max_length',
|
| 1963 |
+
return_tensors='pt'
|
| 1964 |
+
)
|
| 1965 |
+
|
| 1966 |
+
tokenized_data = [tokenize_fn(text) for text in texts[:1000]]
|
| 1967 |
+
else:
|
| 1968 |
+
dataset = load_dataset('text', data_files=dataset_path)
|
| 1969 |
+
|
| 1970 |
+
def tokenize_function(examples):
|
| 1971 |
+
return tokenizer(
|
| 1972 |
+
examples['text'],
|
| 1973 |
+
truncation=True,
|
| 1974 |
+
max_length=512,
|
| 1975 |
+
padding='max_length',
|
| 1976 |
+
)
|
| 1977 |
+
|
| 1978 |
+
dataset = dataset.map(tokenize_function, batched=True)
|
| 1979 |
+
tokenized_data = dataset['train']
|
| 1980 |
+
|
| 1981 |
+
print(f"✅ Loaded {len(tokenized_data)} samples")
|
| 1982 |
+
|
| 1983 |
+
# 4. Fine-tuning
|
| 1984 |
+
print(f"\n🚀 STEP 5: Starting fine-tuning...")
|
| 1985 |
+
model.train()
|
| 1986 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 1987 |
+
|
| 1988 |
+
step = 0
|
| 1989 |
+
total_loss = 0.0
|
| 1990 |
+
|
| 1991 |
+
for epoch in range(num_epochs):
|
| 1992 |
+
for i in range(0, len(tokenized_data), batch_size):
|
| 1993 |
+
if step >= max_steps:
|
| 1994 |
+
break
|
| 1995 |
+
|
| 1996 |
+
batch = tokenized_data[i:i+batch_size]
|
| 1997 |
+
|
| 1998 |
+
if isinstance(batch, list):
|
| 1999 |
+
input_ids = torch.stack([item['input_ids'].squeeze() for item in batch]).to(DEVICE)
|
| 2000 |
+
attention_mask = torch.stack([item['attention_mask'].squeeze() for item in batch]).to(DEVICE)
|
| 2001 |
+
else:
|
| 2002 |
+
input_ids = torch.tensor(batch['input_ids']).to(DEVICE)
|
| 2003 |
+
attention_mask = torch.tensor(batch['attention_mask']).to(DEVICE)
|
| 2004 |
+
|
| 2005 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
| 2006 |
+
loss = outputs.loss
|
| 2007 |
+
|
| 2008 |
+
loss.backward()
|
| 2009 |
+
optimizer.step()
|
| 2010 |
+
optimizer.zero_grad()
|
| 2011 |
+
|
| 2012 |
+
total_loss += loss.item()
|
| 2013 |
+
step += 1
|
| 2014 |
+
|
| 2015 |
+
if step % 10 == 0:
|
| 2016 |
+
print(f" Step {step}/{max_steps} - Loss: {total_loss/step:.4f}")
|
| 2017 |
+
|
| 2018 |
+
final_loss = total_loss / step if step > 0 else 0.0
|
| 2019 |
+
print(f"✅ Training complete - Final Loss: {final_loss:.4f}")
|
| 2020 |
+
|
| 2021 |
+
# 5. 평가 & 저장
|
| 2022 |
+
model.eval()
|
| 2023 |
+
quality_score = evaluate_model_quality(model, tokenizer)
|
| 2024 |
+
|
| 2025 |
+
metadata = {
|
| 2026 |
+
'phoenix_version': '1.4.1',
|
| 2027 |
+
'original_model': model_url,
|
| 2028 |
+
'use_hierarchical': use_hierarchical,
|
| 2029 |
+
'conversion_rate': conversion_rate,
|
| 2030 |
+
'quality_score': quality_score,
|
| 2031 |
+
'burning_type': 'fine_tuning',
|
| 2032 |
+
'training_steps': step,
|
| 2033 |
+
'final_loss': final_loss,
|
| 2034 |
+
'dataset': dataset_path,
|
| 2035 |
+
'structure_info': structure_info,
|
| 2036 |
+
'timestamp': datetime.now().isoformat(),
|
| 2037 |
+
}
|
| 2038 |
+
|
| 2039 |
+
save_phoenix_model_with_code(model, tokenizer, output_path, model_url, metadata)
|
| 2040 |
+
|
| 2041 |
+
result = {
|
| 2042 |
+
'status': 'success',
|
| 2043 |
+
'model_path': str(output_path),
|
| 2044 |
+
'conversion_rate': conversion_rate,
|
| 2045 |
+
'quality_score': quality_score,
|
| 2046 |
+
'training_steps': step,
|
| 2047 |
+
'final_loss': final_loss,
|
| 2048 |
+
'structure_info': structure_info,
|
| 2049 |
+
}
|
| 2050 |
+
|
| 2051 |
+
return result
|
| 2052 |
+
|
| 2053 |
+
except Exception as e:
|
| 2054 |
+
import traceback
|
| 2055 |
+
error_msg = traceback.format_exc()
|
| 2056 |
+
print(f"\n❌ Fine-tuning burning failed:\n{error_msg}")
|
| 2057 |
+
return {
|
| 2058 |
+
'status': 'failed',
|
| 2059 |
+
'error': str(e),
|
| 2060 |
+
'traceback': error_msg
|
| 2061 |
+
}
|
| 2062 |
+
|
| 2063 |
+
|
| 2064 |
+
# =====================================================
|
| 2065 |
+
# Gradio UI Functions
|
| 2066 |
+
# =====================================================
|
| 2067 |
+
|
| 2068 |
+
def burn_phoenix_model_ui(
|
| 2069 |
+
model_url,
|
| 2070 |
+
use_hierarchical,
|
| 2071 |
+
dataset_path,
|
| 2072 |
+
output_name,
|
| 2073 |
+
use_finetuning,
|
| 2074 |
+
num_epochs,
|
| 2075 |
+
batch_size,
|
| 2076 |
+
learning_rate,
|
| 2077 |
+
max_steps,
|
| 2078 |
+
upload_to_hub,
|
| 2079 |
+
hub_repo_name,
|
| 2080 |
+
hub_private,
|
| 2081 |
+
):
|
| 2082 |
+
"""Gradio UI용 모델 버닝 함수"""
|
| 2083 |
+
|
| 2084 |
+
print("\n" + "="*80)
|
| 2085 |
+
print("🔥 PHOENIX MODEL BURNING START v1.4.1")
|
| 2086 |
+
print("="*80)
|
| 2087 |
+
|
| 2088 |
+
try:
|
| 2089 |
+
if not model_url.strip():
|
| 2090 |
+
return "⚠️ Model URL is required", None
|
| 2091 |
+
|
| 2092 |
+
if not output_name.strip():
|
| 2093 |
+
output_name = f"phoenix_{model_url.split('/')[-1]}_{int(time.time())}"
|
| 2094 |
+
|
| 2095 |
+
output_dir = f"{MODELS_PATH}/{output_name}"
|
| 2096 |
+
|
| 2097 |
+
print(f"📋 Configuration:")
|
| 2098 |
+
print(f" Model URL: {model_url}")
|
| 2099 |
+
print(f" Output Name: {output_name}")
|
| 2100 |
+
print(f" Hierarchical: {use_hierarchical}")
|
| 2101 |
+
print(f" Upload to Hub: {upload_to_hub}")
|
| 2102 |
+
|
| 2103 |
+
has_dataset = dataset_path and dataset_path.strip() and Path(dataset_path).exists()
|
| 2104 |
+
|
| 2105 |
+
if use_finetuning and not has_dataset:
|
| 2106 |
+
return "⚠️ Fine-tuning requires a valid dataset path", None
|
| 2107 |
+
|
| 2108 |
+
if upload_to_hub and not HF_TOKEN:
|
| 2109 |
+
warning_msg = "⚠️ HuggingFace Token Not Found! Continuing with local burning only..."
|
| 2110 |
+
print(f"\n{warning_msg}")
|
| 2111 |
+
|
| 2112 |
+
# Burning 실행
|
| 2113 |
+
print(f"\n{'='*80}")
|
| 2114 |
+
if use_finetuning and has_dataset:
|
| 2115 |
+
print("🚀 Starting Fine-tuning Burning...")
|
| 2116 |
+
result = burn_model_with_finetuning(
|
| 2117 |
+
model_url=model_url,
|
| 2118 |
+
output_dir=output_dir,
|
| 2119 |
+
dataset_path=dataset_path,
|
| 2120 |
+
use_hierarchical=use_hierarchical,
|
| 2121 |
+
num_epochs=num_epochs,
|
| 2122 |
+
batch_size=batch_size,
|
| 2123 |
+
learning_rate=learning_rate,
|
| 2124 |
+
max_steps=max_steps,
|
| 2125 |
+
)
|
| 2126 |
+
else:
|
| 2127 |
+
print("🚀 Starting Zero-shot Burning...")
|
| 2128 |
+
result = burn_model_zero_shot(
|
| 2129 |
+
model_url=model_url,
|
| 2130 |
+
output_dir=output_dir,
|
| 2131 |
+
use_hierarchical=use_hierarchical,
|
| 2132 |
+
)
|
| 2133 |
+
|
| 2134 |
+
if result['status'] != 'success':
|
| 2135 |
+
error_msg = f"❌ Burning Failed\n```\n{result.get('error', 'Unknown error')}\n```"
|
| 2136 |
+
return error_msg, None
|
| 2137 |
+
|
| 2138 |
+
print(f"\n✅ Burning completed successfully!")
|
| 2139 |
+
|
| 2140 |
+
# HuggingFace Hub 업로드
|
| 2141 |
+
hub_url = None
|
| 2142 |
+
verification_passed = False
|
| 2143 |
+
upload_status = "Not attempted"
|
| 2144 |
+
|
| 2145 |
+
if upload_to_hub:
|
| 2146 |
+
if not HF_TOKEN:
|
| 2147 |
+
upload_status = "❌ Failed - No HF_TOKEN"
|
| 2148 |
+
else:
|
| 2149 |
+
success, hub_url, upload_msg = upload_to_huggingface_hub(
|
| 2150 |
+
model_path=result['model_path'],
|
| 2151 |
+
original_model_url=model_url,
|
| 2152 |
+
repo_name=hub_repo_name if hub_repo_name.strip() else None,
|
| 2153 |
+
private=hub_private,
|
| 2154 |
+
skip_verification=False
|
| 2155 |
+
)
|
| 2156 |
+
|
| 2157 |
+
verification_passed = success
|
| 2158 |
+
upload_status = f"✅ Uploaded to {hub_url}" if success else f"❌ Upload failed"
|
| 2159 |
+
else:
|
| 2160 |
+
upload_status = "⏭️ Skipped"
|
| 2161 |
+
|
| 2162 |
+
# 데이터베이스 저장
|
| 2163 |
+
burning_info = {
|
| 2164 |
+
'model_url': model_url,
|
| 2165 |
+
'output_path': result['model_path'],
|
| 2166 |
+
'hub_url': hub_url,
|
| 2167 |
+
'use_hierarchical': use_hierarchical,
|
| 2168 |
+
'dataset_used': has_dataset,
|
| 2169 |
+
'conversion_rate': result.get('conversion_rate', 0.0),
|
| 2170 |
+
'training_steps': result.get('training_steps', 0),
|
| 2171 |
+
'final_loss': result.get('final_loss'),
|
| 2172 |
+
'evaluation_score': result.get('quality_score', 0.0),
|
| 2173 |
+
'verification_passed': verification_passed,
|
| 2174 |
+
}
|
| 2175 |
+
|
| 2176 |
+
db.save_burning(burning_info)
|
| 2177 |
+
|
| 2178 |
+
# 결과 포맷팅
|
| 2179 |
+
structure_info = result.get('structure_info', {})
|
| 2180 |
+
|
| 2181 |
+
output_md = f"""
|
| 2182 |
+
# 🔥 Model Burning Complete! (v1.4.1)
|
| 2183 |
+
|
| 2184 |
+
## 🔍 Structure Analysis
|
| 2185 |
+
- **Model Type**: {structure_info.get('model_type', 'unknown')}
|
| 2186 |
+
- **Architecture**: {structure_info.get('architectures', 'unknown')}
|
| 2187 |
+
- **Total Layers**: {structure_info.get('total_layers', 0)}
|
| 2188 |
+
- **Layer Path**: {structure_info.get('layer_path', 'unknown')}
|
| 2189 |
+
- **Has self_attn**: {structure_info.get('has_self_attn', False)}
|
| 2190 |
+
- **GQA Detected**: {structure_info.get('gqa_detected', False)}
|
| 2191 |
+
|
| 2192 |
+
## 📦 Model Information
|
| 2193 |
+
- **Original Model**: {model_url}
|
| 2194 |
+
- **Output Path**: `{result['model_path']}`
|
| 2195 |
+
- **Burning Type**: {'Fine-tuning' if has_dataset else 'Zero-shot'}
|
| 2196 |
+
- **Hierarchical**: {use_hierarchical}
|
| 2197 |
+
|
| 2198 |
+
## 📊 Metrics
|
| 2199 |
+
- **Conversion Rate**: {result.get('conversion_rate', 0)*100:.1f}%
|
| 2200 |
+
- **Quality Score**: {result.get('quality_score', 0):.2f}/1.00
|
| 2201 |
+
"""
|
| 2202 |
+
|
| 2203 |
+
if 'training_steps' in result:
|
| 2204 |
+
output_md += f"""
|
| 2205 |
+
## 🚀 Training
|
| 2206 |
+
- **Steps**: {result['training_steps']}
|
| 2207 |
+
- **Final Loss**: {result.get('final_loss', 0.0):.4f}
|
| 2208 |
+
"""
|
| 2209 |
+
|
| 2210 |
+
output_md += f"""
|
| 2211 |
+
## ⏱️ Time Breakdown
|
| 2212 |
+
- **Total**: {result.get('total_time', 0):.1f}s
|
| 2213 |
+
"""
|
| 2214 |
+
|
| 2215 |
+
if 'load_time' in result:
|
| 2216 |
+
output_md += f"- **Load**: {result['load_time']:.1f}s\n"
|
| 2217 |
+
output_md += f"- **Convert**: {result['convert_time']:.1f}s\n"
|
| 2218 |
+
output_md += f"- **Evaluate**: {result['eval_time']:.1f}s\n"
|
| 2219 |
+
output_md += f"- **Save**: {result['save_time']:.1f}s\n"
|
| 2220 |
+
|
| 2221 |
+
output_md += f"""
|
| 2222 |
+
---
|
| 2223 |
+
|
| 2224 |
+
## 🌐 HuggingFace Hub Upload
|
| 2225 |
+
|
| 2226 |
+
**Status**: {upload_status}
|
| 2227 |
+
"""
|
| 2228 |
+
|
| 2229 |
+
if hub_url:
|
| 2230 |
+
output_md += f"""
|
| 2231 |
+
**Model URL**: [{hub_url}]({hub_url})
|
| 2232 |
+
|
| 2233 |
+
### 🚀 Load from Hub
|
| 2234 |
+
```python
|
| 2235 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 2236 |
+
|
| 2237 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 2238 |
+
"{hub_url.replace('https://huggingface.co/', '')}",
|
| 2239 |
+
trust_remote_code=True,
|
| 2240 |
+
torch_dtype="auto",
|
| 2241 |
+
device_map="auto"
|
| 2242 |
+
)
|
| 2243 |
+
```
|
| 2244 |
+
"""
|
| 2245 |
+
|
| 2246 |
+
output_md += f"""
|
| 2247 |
+
---
|
| 2248 |
+
|
| 2249 |
+
✅ **PHOENIX Model Ready! (v1.4.1)**
|
| 2250 |
+
"""
|
| 2251 |
+
|
| 2252 |
+
# 플롯
|
| 2253 |
+
fig = go.Figure()
|
| 2254 |
+
|
| 2255 |
+
metrics_names = ['Conversion', 'Quality']
|
| 2256 |
+
metrics_values = [result.get('conversion_rate', 0), result.get('quality_score', 0)]
|
| 2257 |
+
|
| 2258 |
+
if verification_passed:
|
| 2259 |
+
metrics_names.append('Upload')
|
| 2260 |
+
metrics_values.append(1.0)
|
| 2261 |
+
|
| 2262 |
+
fig.add_trace(go.Bar(
|
| 2263 |
+
x=metrics_names,
|
| 2264 |
+
y=metrics_values,
|
| 2265 |
+
marker_color=['#3b82f6', '#10b981', '#8b5cf6'][:len(metrics_names)]
|
| 2266 |
+
))
|
| 2267 |
+
|
| 2268 |
+
fig.update_layout(
|
| 2269 |
+
title="🔥 Burning Metrics",
|
| 2270 |
+
yaxis_range=[0, 1],
|
| 2271 |
+
template='plotly_white',
|
| 2272 |
+
height=400
|
| 2273 |
+
)
|
| 2274 |
+
|
| 2275 |
+
return output_md, fig
|
| 2276 |
+
|
| 2277 |
+
except Exception as e:
|
| 2278 |
+
import traceback
|
| 2279 |
+
error_msg = traceback.format_exc()
|
| 2280 |
+
|
| 2281 |
+
return f"""
|
| 2282 |
+
❌ **Burning Failed**
|
| 2283 |
+
|
| 2284 |
+
**Error:** {str(e)}
|
| 2285 |
+
|
| 2286 |
+
**Traceback:**
|
| 2287 |
+
```
|
| 2288 |
+
{error_msg}
|
| 2289 |
+
```
|
| 2290 |
+
""", None
|
| 2291 |
+
|
| 2292 |
+
|
| 2293 |
+
def view_burning_history():
|
| 2294 |
+
"""View burning history"""
|
| 2295 |
+
try:
|
| 2296 |
+
history = db.get_burning_history(limit=20)
|
| 2297 |
+
|
| 2298 |
+
if not history:
|
| 2299 |
+
return "📭 No burning history yet", None
|
| 2300 |
+
|
| 2301 |
+
df = pd.DataFrame(history)
|
| 2302 |
+
|
| 2303 |
+
fig = px.scatter(
|
| 2304 |
+
df,
|
| 2305 |
+
x='timestamp',
|
| 2306 |
+
y='evaluation_score',
|
| 2307 |
+
size='conversion_rate',
|
| 2308 |
+
color='verification_passed',
|
| 2309 |
+
hover_data=['model_url', 'output_path', 'hub_url'],
|
| 2310 |
+
title='Burning History'
|
| 2311 |
+
)
|
| 2312 |
+
|
| 2313 |
+
cols = ['id', 'model_url', 'hub_url', 'conversion_rate',
|
| 2314 |
+
'evaluation_score', 'verification_passed', 'timestamp']
|
| 2315 |
+
available = [c for c in cols if c in df.columns]
|
| 2316 |
+
|
| 2317 |
+
return f"## 📊 Burning History\n\n{df[available].to_markdown(index=False)}", fig
|
| 2318 |
+
|
| 2319 |
+
except Exception as e:
|
| 2320 |
+
return f"❌ Error: {e}", None
|
| 2321 |
+
|
| 2322 |
+
|
| 2323 |
+
def validate_phoenix_model(
|
| 2324 |
+
model_source,
|
| 2325 |
+
model_path_or_url,
|
| 2326 |
+
test_prompts,
|
| 2327 |
+
max_tokens,
|
| 2328 |
+
temperature,
|
| 2329 |
+
verify_retention
|
| 2330 |
+
):
|
| 2331 |
+
"""PHOENIX 모델 검증"""
|
| 2332 |
+
try:
|
| 2333 |
+
print("="*80)
|
| 2334 |
+
print("🧪 PHOENIX Model Validation v1.4.1")
|
| 2335 |
+
print("="*80)
|
| 2336 |
+
|
| 2337 |
+
# 1. 모델 로드
|
| 2338 |
+
print(f"\n📥 Loading model from {model_source}...")
|
| 2339 |
+
start_time = time.time()
|
| 2340 |
+
|
| 2341 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 2342 |
+
model_path_or_url,
|
| 2343 |
+
trust_remote_code=True,
|
| 2344 |
+
torch_dtype=torch.float16,
|
| 2345 |
+
).to(DEVICE)
|
| 2346 |
+
|
| 2347 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 2348 |
+
model_path_or_url,
|
| 2349 |
+
trust_remote_code=True
|
| 2350 |
+
)
|
| 2351 |
+
|
| 2352 |
+
if tokenizer.pad_token is None:
|
| 2353 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 2354 |
+
|
| 2355 |
+
load_time = time.time() - start_time
|
| 2356 |
+
print(f"✅ Model loaded in {load_time:.2f}s")
|
| 2357 |
+
|
| 2358 |
+
# 2. 메타데이터
|
| 2359 |
+
metadata = {}
|
| 2360 |
+
metadata_path = None
|
| 2361 |
+
|
| 2362 |
+
if model_source == "local":
|
| 2363 |
+
metadata_path = Path(model_path_or_url) / "phoenix_metadata.json"
|
| 2364 |
+
else:
|
| 2365 |
+
try:
|
| 2366 |
+
from huggingface_hub import hf_hub_download
|
| 2367 |
+
metadata_path = hf_hub_download(
|
| 2368 |
+
repo_id=model_path_or_url,
|
| 2369 |
+
filename="phoenix_metadata.json"
|
| 2370 |
+
)
|
| 2371 |
+
except:
|
| 2372 |
+
pass
|
| 2373 |
+
|
| 2374 |
+
if metadata_path and Path(metadata_path).exists():
|
| 2375 |
+
with open(metadata_path, 'r') as f:
|
| 2376 |
+
metadata = json.load(f)
|
| 2377 |
+
|
| 2378 |
+
# 3. Retention 검증
|
| 2379 |
+
retention_info = ""
|
| 2380 |
+
if verify_retention:
|
| 2381 |
+
print(f"\n🔍 Verifying Retention mechanism...")
|
| 2382 |
+
|
| 2383 |
+
retention_count = 0
|
| 2384 |
+
attention_count = 0
|
| 2385 |
+
|
| 2386 |
+
# PhoenixModelForCausalLM인 경우 _original_model 확인
|
| 2387 |
+
check_model = model
|
| 2388 |
+
if hasattr(model, '_original_model') and model._original_model is not None:
|
| 2389 |
+
print(f" 📋 Detected PhoenixModelForCausalLM wrapper")
|
| 2390 |
+
check_model = model._original_model
|
| 2391 |
+
|
| 2392 |
+
layers = []
|
| 2393 |
+
if hasattr(check_model, 'model') and hasattr(check_model.model, 'layers'):
|
| 2394 |
+
layers = check_model.model.layers
|
| 2395 |
+
elif hasattr(check_model, 'layers'):
|
| 2396 |
+
layers = check_model.layers
|
| 2397 |
+
|
| 2398 |
+
print(f" 🔍 Checking {len(layers)} layers...")
|
| 2399 |
+
|
| 2400 |
+
for i, layer in enumerate(layers):
|
| 2401 |
+
if hasattr(layer, 'self_attn'):
|
| 2402 |
+
attn = layer.self_attn
|
| 2403 |
+
class_name = attn.__class__.__name__
|
| 2404 |
+
|
| 2405 |
+
if 'Retention' in class_name:
|
| 2406 |
+
retention_count += 1
|
| 2407 |
+
if i < 3: # 처음 3개만 출력
|
| 2408 |
+
print(f" ✅ Layer {i}: {class_name}")
|
| 2409 |
+
else:
|
| 2410 |
+
attention_count += 1
|
| 2411 |
+
if i < 3:
|
| 2412 |
+
print(f" ⚠️ Layer {i}: {class_name}")
|
| 2413 |
+
|
| 2414 |
+
total = retention_count + attention_count
|
| 2415 |
+
retention_info = f"""
|
| 2416 |
+
### 🔍 Retention Verification
|
| 2417 |
+
- **Retention Layers**: {retention_count}/{total}
|
| 2418 |
+
- **Attention Layers**: {attention_count}/{total}
|
| 2419 |
+
- **Status**: {'✅ PHOENIX Active' if retention_count > 0 else '⚠️ No Retention Found'}
|
| 2420 |
+
"""
|
| 2421 |
+
print(f" 📊 Result: {retention_count}/{total} layers have Retention")
|
| 2422 |
+
|
| 2423 |
+
# 4. 생성 테스트
|
| 2424 |
+
print(f"\n🚀 Running generation tests...")
|
| 2425 |
+
|
| 2426 |
+
prompts = [p.strip() for p in test_prompts.split('\n') if p.strip()]
|
| 2427 |
+
if not prompts:
|
| 2428 |
+
prompts = ["The future of AI is", "Once upon a time"]
|
| 2429 |
+
|
| 2430 |
+
results = []
|
| 2431 |
+
total_gen_time = 0
|
| 2432 |
+
|
| 2433 |
+
for i, prompt in enumerate(prompts, 1):
|
| 2434 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
| 2435 |
+
|
| 2436 |
+
gen_start = time.time()
|
| 2437 |
+
|
| 2438 |
+
with torch.no_grad():
|
| 2439 |
+
outputs = model.generate(
|
| 2440 |
+
**inputs,
|
| 2441 |
+
max_new_tokens=max_tokens,
|
| 2442 |
+
temperature=temperature,
|
| 2443 |
+
do_sample=temperature > 0.01,
|
| 2444 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 2445 |
+
)
|
| 2446 |
+
|
| 2447 |
+
gen_time = time.time() - gen_start
|
| 2448 |
+
total_gen_time += gen_time
|
| 2449 |
+
|
| 2450 |
+
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 2451 |
+
|
| 2452 |
+
tokens_generated = len(outputs[0]) - len(inputs['input_ids'][0])
|
| 2453 |
+
tokens_per_sec = tokens_generated / gen_time if gen_time > 0 else 0
|
| 2454 |
+
|
| 2455 |
+
results.append({
|
| 2456 |
+
'prompt': prompt,
|
| 2457 |
+
'generated': generated,
|
| 2458 |
+
'time': gen_time,
|
| 2459 |
+
'tokens': tokens_generated,
|
| 2460 |
+
'tokens_per_sec': tokens_per_sec,
|
| 2461 |
+
})
|
| 2462 |
+
|
| 2463 |
+
# 5. 결과
|
| 2464 |
+
output_md = f"""
|
| 2465 |
+
# ✅ PHOENIX Model Validation Complete! (v1.4.1)
|
| 2466 |
+
|
| 2467 |
+
## 📦 Model Information
|
| 2468 |
+
- **Source**: {model_source.upper()}
|
| 2469 |
+
- **Path/URL**: `{model_path_or_url}`
|
| 2470 |
+
- **Load Time**: {load_time:.2f}s
|
| 2471 |
+
|
| 2472 |
+
## 📋 Metadata
|
| 2473 |
+
"""
|
| 2474 |
+
|
| 2475 |
+
if metadata:
|
| 2476 |
+
output_md += f"""
|
| 2477 |
+
- **PHOENIX Version**: {metadata.get('phoenix_version', 'Unknown')}
|
| 2478 |
+
- **Original Model**: {metadata.get('original_model', 'Unknown')}
|
| 2479 |
+
- **Conversion Rate**: {metadata.get('conversion_rate', 0)*100:.1f}%
|
| 2480 |
+
"""
|
| 2481 |
+
|
| 2482 |
+
if retention_info:
|
| 2483 |
+
output_md += retention_info
|
| 2484 |
+
|
| 2485 |
+
output_md += f"""
|
| 2486 |
+
## 🚀 Generation Tests
|
| 2487 |
+
|
| 2488 |
+
**Total Tests**: {len(results)}
|
| 2489 |
+
**Average Speed**: {sum(r['tokens_per_sec'] for r in results)/len(results):.1f} tokens/s
|
| 2490 |
+
|
| 2491 |
+
---
|
| 2492 |
+
"""
|
| 2493 |
+
|
| 2494 |
+
for i, result in enumerate(results, 1):
|
| 2495 |
+
output_md += f"""
|
| 2496 |
+
### Test {i}
|
| 2497 |
+
|
| 2498 |
+
**Generated:**
|
| 2499 |
+
```
|
| 2500 |
+
{result['generated']}
|
| 2501 |
+
```
|
| 2502 |
+
|
| 2503 |
+
**Stats**: {result['time']:.2f}s | {result['tokens_per_sec']:.1f} tokens/s
|
| 2504 |
+
|
| 2505 |
+
---
|
| 2506 |
+
"""
|
| 2507 |
+
|
| 2508 |
+
# 6. 그래프
|
| 2509 |
+
fig = go.Figure()
|
| 2510 |
+
|
| 2511 |
+
fig.add_trace(go.Bar(
|
| 2512 |
+
x=[f"Test {i+1}" for i in range(len(results))],
|
| 2513 |
+
y=[r['tokens_per_sec'] for r in results],
|
| 2514 |
+
marker_color='#10b981'
|
| 2515 |
+
))
|
| 2516 |
+
|
| 2517 |
+
fig.update_layout(
|
| 2518 |
+
title="Generation Speed (tokens/s)",
|
| 2519 |
+
template='plotly_white'
|
| 2520 |
+
)
|
| 2521 |
+
|
| 2522 |
+
return output_md, fig
|
| 2523 |
+
|
| 2524 |
+
except Exception as e:
|
| 2525 |
+
import traceback
|
| 2526 |
+
return f"❌ Validation failed:\n```\n{traceback.format_exc()}\n```", None
|
| 2527 |
+
|
| 2528 |
|
| 2529 |
# 전역 초기화
|
| 2530 |
db = ExperimentDatabase(DB_PATH)
|
|
|
|
| 2534 |
# =====================================================
|
| 2535 |
|
| 2536 |
with gr.Blocks(
|
| 2537 |
+
title="🔮 PHOENIX v1.4.1 - State Dict Direct Loading",
|
| 2538 |
theme=gr.themes.Soft(),
|
| 2539 |
) as demo:
|
| 2540 |
|
| 2541 |
gr.Markdown("""
|
| 2542 |
+
# 🔮 PHOENIX Retention Platform v1.4.1
|
| 2543 |
|
| 2544 |
**State Dict Direct Loading + Structure-Aware Burning**
|
| 2545 |
|
|
|
|
| 2558 |
with gr.Tabs():
|
| 2559 |
with gr.Tab("🔥 Model Burning"):
|
| 2560 |
gr.Markdown("""
|
| 2561 |
+
### 🔥 PHOENIX Model Burning v1.4.1
|
| 2562 |
|
| 2563 |
**모델 구조를 먼저 분석한 후 변환합니다!**
|
| 2564 |
**Hub 로드 시 State Dict 직접 로드로 Retention 보존!**
|
|
|
|
| 2671 |
gr.Markdown(f"""
|
| 2672 |
---
|
| 2673 |
|
| 2674 |
+
## 🔥 PHOENIX Model Burning Platform v1.4.1
|
| 2675 |
|
| 2676 |
+
### What's New in v1.4.1
|
| 2677 |
+
- ✅ **FIX: head_dim calculation** - Config 우선 사용
|
| 2678 |
- ✅ **State Dict Direct Loading** - Hub 로드 시 Retention 가중치 보존
|
|
|
|
| 2679 |
- ✅ **Model Structure Pre-Analysis** - 변환 전 구조 파악
|
| 2680 |
- ✅ **Qwen3 Support** - Qwen3 모델 완벽 지원
|
| 2681 |
|
| 2682 |
**HuggingFace Token**: {'✅ Connected' if HF_TOKEN else '❌ Not Found'}
|
| 2683 |
**Default Model**: {DEFAULT_MODEL}
|
| 2684 |
|
| 2685 |
+
**VIDraft AI Research Lab** | PHOENIX v1.4.1
|
| 2686 |
""")
|
| 2687 |
|
| 2688 |
if __name__ == "__main__":
|