Update app.py
Browse files
app.py
CHANGED
|
@@ -25,7 +25,7 @@ import pandas as pd
|
|
| 25 |
from typing import Dict, List, Any, Tuple, Optional
|
| 26 |
import chromadb
|
| 27 |
from chromadb.config import Settings
|
| 28 |
-
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
| 29 |
import copy
|
| 30 |
|
| 31 |
# =====================================================
|
|
@@ -693,21 +693,35 @@ def generate_text_phoenix(
|
|
| 693 |
if not convert_attention or not model_url.strip():
|
| 694 |
return "⚠️ Enable 'Attention Replace' and provide model URL", ""
|
| 695 |
|
| 696 |
-
# 1. 모델
|
| 697 |
-
|
|
|
|
| 698 |
|
| 699 |
-
|
| 700 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 701 |
|
| 702 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
-
#
|
| 705 |
try:
|
| 706 |
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
|
|
|
|
|
|
|
| 707 |
except Exception as e:
|
| 708 |
return f"❌ Tokenizer load failed: {e}", ""
|
| 709 |
|
| 710 |
-
#
|
| 711 |
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
| 712 |
input_ids = inputs["input_ids"]
|
| 713 |
|
|
@@ -716,15 +730,17 @@ def generate_text_phoenix(
|
|
| 716 |
print(f" Input tokens: {input_ids.shape[1]}")
|
| 717 |
print(f" Max new tokens: {max_new_tokens}")
|
| 718 |
|
| 719 |
-
#
|
| 720 |
start_time = time.time()
|
| 721 |
generated_ids = []
|
| 722 |
|
| 723 |
with torch.no_grad():
|
| 724 |
-
for
|
| 725 |
-
# Forward pass
|
| 726 |
outputs = model(input_ids=input_ids)
|
| 727 |
-
|
|
|
|
|
|
|
| 728 |
|
| 729 |
# Temperature sampling
|
| 730 |
if temperature > 0:
|
|
@@ -739,7 +755,12 @@ def generate_text_phoenix(
|
|
| 739 |
|
| 740 |
# Stop at EOS
|
| 741 |
if next_token.item() == tokenizer.eos_token_id:
|
|
|
|
| 742 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
|
| 744 |
elapsed = time.time() - start_time
|
| 745 |
|
|
|
|
| 25 |
from typing import Dict, List, Any, Tuple, Optional
|
| 26 |
import chromadb
|
| 27 |
from chromadb.config import Settings
|
| 28 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM
|
| 29 |
import copy
|
| 30 |
|
| 31 |
# =====================================================
|
|
|
|
| 693 |
if not convert_attention or not model_url.strip():
|
| 694 |
return "⚠️ Enable 'Attention Replace' and provide model URL", ""
|
| 695 |
|
| 696 |
+
# 1. ✅ CausalLM 모델 로드 (lm_head 포함)
|
| 697 |
+
print(f"📥 Loading CausalLM model: {model_url}")
|
| 698 |
+
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
|
| 699 |
|
| 700 |
+
# Load full causal LM model
|
| 701 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 702 |
+
model_url,
|
| 703 |
+
trust_remote_code=True,
|
| 704 |
+
torch_dtype=torch.float16
|
| 705 |
+
).to(DEVICE)
|
| 706 |
|
| 707 |
+
# 2. Attention → Retention 변환
|
| 708 |
+
print(f"🔄 Converting attention to retention...")
|
| 709 |
+
model.model, converted, total = replace_attention_with_retention(
|
| 710 |
+
model.model, # Convert the base model, keep lm_head
|
| 711 |
+
use_hierarchical=use_hierarchical
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
print(f"✅ Converted {converted}/{total} layers")
|
| 715 |
|
| 716 |
+
# 3. Tokenizer 로드
|
| 717 |
try:
|
| 718 |
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
|
| 719 |
+
if tokenizer.pad_token is None:
|
| 720 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 721 |
except Exception as e:
|
| 722 |
return f"❌ Tokenizer load failed: {e}", ""
|
| 723 |
|
| 724 |
+
# 4. 입력 토크나이즈
|
| 725 |
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
| 726 |
input_ids = inputs["input_ids"]
|
| 727 |
|
|
|
|
| 730 |
print(f" Input tokens: {input_ids.shape[1]}")
|
| 731 |
print(f" Max new tokens: {max_new_tokens}")
|
| 732 |
|
| 733 |
+
# 5. 생성
|
| 734 |
start_time = time.time()
|
| 735 |
generated_ids = []
|
| 736 |
|
| 737 |
with torch.no_grad():
|
| 738 |
+
for step in range(max_new_tokens):
|
| 739 |
+
# Forward pass (now with lm_head)
|
| 740 |
outputs = model(input_ids=input_ids)
|
| 741 |
+
|
| 742 |
+
# Get logits from lm_head
|
| 743 |
+
logits = outputs.logits[:, -1, :] # [B, vocab_size]
|
| 744 |
|
| 745 |
# Temperature sampling
|
| 746 |
if temperature > 0:
|
|
|
|
| 755 |
|
| 756 |
# Stop at EOS
|
| 757 |
if next_token.item() == tokenizer.eos_token_id:
|
| 758 |
+
print(f" Stopped at EOS token")
|
| 759 |
break
|
| 760 |
+
|
| 761 |
+
# Progress
|
| 762 |
+
if (step + 1) % 10 == 0:
|
| 763 |
+
print(f" Generated {step + 1}/{max_new_tokens} tokens...")
|
| 764 |
|
| 765 |
elapsed = time.time() - start_time
|
| 766 |
|