seawolf2357 commited on
Commit
f42a5e2
·
verified ·
1 Parent(s): ca4042c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -12
app.py CHANGED
@@ -25,7 +25,7 @@ import pandas as pd
25
  from typing import Dict, List, Any, Tuple, Optional
26
  import chromadb
27
  from chromadb.config import Settings
28
- from transformers import AutoModel, AutoTokenizer, AutoConfig
29
  import copy
30
 
31
  # =====================================================
@@ -693,21 +693,35 @@ def generate_text_phoenix(
693
  if not convert_attention or not model_url.strip():
694
  return "⚠️ Enable 'Attention Replace' and provide model URL", ""
695
 
696
- # 1. 모델 변환
697
- model_info, msg = convert_model_to_phoenix(model_url, use_hierarchical, "L40S")
 
698
 
699
- if model_info is None:
700
- return msg, ""
 
 
 
 
701
 
702
- model = model_info['model']
 
 
 
 
 
 
 
703
 
704
- # 2. Tokenizer 로드
705
  try:
706
  tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
 
 
707
  except Exception as e:
708
  return f"❌ Tokenizer load failed: {e}", ""
709
 
710
- # 3. 입력 토크나이즈
711
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
712
  input_ids = inputs["input_ids"]
713
 
@@ -716,15 +730,17 @@ def generate_text_phoenix(
716
  print(f" Input tokens: {input_ids.shape[1]}")
717
  print(f" Max new tokens: {max_new_tokens}")
718
 
719
- # 4. 생성
720
  start_time = time.time()
721
  generated_ids = []
722
 
723
  with torch.no_grad():
724
- for _ in range(max_new_tokens):
725
- # Forward pass
726
  outputs = model(input_ids=input_ids)
727
- logits = outputs.logits[:, -1, :]
 
 
728
 
729
  # Temperature sampling
730
  if temperature > 0:
@@ -739,7 +755,12 @@ def generate_text_phoenix(
739
 
740
  # Stop at EOS
741
  if next_token.item() == tokenizer.eos_token_id:
 
742
  break
 
 
 
 
743
 
744
  elapsed = time.time() - start_time
745
 
 
25
  from typing import Dict, List, Any, Tuple, Optional
26
  import chromadb
27
  from chromadb.config import Settings
28
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM
29
  import copy
30
 
31
  # =====================================================
 
693
  if not convert_attention or not model_url.strip():
694
  return "⚠️ Enable 'Attention Replace' and provide model URL", ""
695
 
696
+ # 1. ✅ CausalLM 모델 로드 (lm_head 포함)
697
+ print(f"📥 Loading CausalLM model: {model_url}")
698
+ config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
699
 
700
+ # Load full causal LM model
701
+ model = AutoModelForCausalLM.from_pretrained(
702
+ model_url,
703
+ trust_remote_code=True,
704
+ torch_dtype=torch.float16
705
+ ).to(DEVICE)
706
 
707
+ # 2. Attention → Retention 변환
708
+ print(f"🔄 Converting attention to retention...")
709
+ model.model, converted, total = replace_attention_with_retention(
710
+ model.model, # Convert the base model, keep lm_head
711
+ use_hierarchical=use_hierarchical
712
+ )
713
+
714
+ print(f"✅ Converted {converted}/{total} layers")
715
 
716
+ # 3. Tokenizer 로드
717
  try:
718
  tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
719
+ if tokenizer.pad_token is None:
720
+ tokenizer.pad_token = tokenizer.eos_token
721
  except Exception as e:
722
  return f"❌ Tokenizer load failed: {e}", ""
723
 
724
+ # 4. 입력 토크나이즈
725
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
726
  input_ids = inputs["input_ids"]
727
 
 
730
  print(f" Input tokens: {input_ids.shape[1]}")
731
  print(f" Max new tokens: {max_new_tokens}")
732
 
733
+ # 5. 생성
734
  start_time = time.time()
735
  generated_ids = []
736
 
737
  with torch.no_grad():
738
+ for step in range(max_new_tokens):
739
+ # Forward pass (now with lm_head)
740
  outputs = model(input_ids=input_ids)
741
+
742
+ # Get logits from lm_head
743
+ logits = outputs.logits[:, -1, :] # [B, vocab_size]
744
 
745
  # Temperature sampling
746
  if temperature > 0:
 
755
 
756
  # Stop at EOS
757
  if next_token.item() == tokenizer.eos_token_id:
758
+ print(f" Stopped at EOS token")
759
  break
760
+
761
+ # Progress
762
+ if (step + 1) % 10 == 0:
763
+ print(f" Generated {step + 1}/{max_new_tokens} tokens...")
764
 
765
  elapsed = time.time() - start_time
766