seawolf2357 commited on
Commit
ca4042c
·
verified ·
1 Parent(s): 62dbb7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py CHANGED
@@ -684,6 +684,107 @@ def convert_model_to_phoenix(model_url, use_hierarchical=True, gpu_type="L40S"):
684
  return None, f"❌ Conversion failed: {str(e)}"
685
 
686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687
  def run_phoenix_experiment(model_url, use_hierarchical, convert_attention, sequence_length, gpu_type):
688
  """Run PHOENIX experiment"""
689
  try:
@@ -871,6 +972,42 @@ with gr.Blocks(
871
  [convert_url, convert_hierarchical, convert_gpu],
872
  [gr.State(), convert_output])
873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
  with gr.Tab("🧪 Experiment"):
875
  with gr.Row():
876
  with gr.Column(scale=1):
 
684
  return None, f"❌ Conversion failed: {str(e)}"
685
 
686
 
687
+ def generate_text_phoenix(
688
+ model_url, use_hierarchical, convert_attention,
689
+ prompt, max_new_tokens, temperature
690
+ ):
691
+ """PHOENIX로 텍스트 생성"""
692
+ try:
693
+ if not convert_attention or not model_url.strip():
694
+ return "⚠️ Enable 'Attention Replace' and provide model URL", ""
695
+
696
+ # 1. 모델 변환
697
+ model_info, msg = convert_model_to_phoenix(model_url, use_hierarchical, "L40S")
698
+
699
+ if model_info is None:
700
+ return msg, ""
701
+
702
+ model = model_info['model']
703
+
704
+ # 2. Tokenizer 로드
705
+ try:
706
+ tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
707
+ except Exception as e:
708
+ return f"❌ Tokenizer load failed: {e}", ""
709
+
710
+ # 3. 입력 토크나이즈
711
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
712
+ input_ids = inputs["input_ids"]
713
+
714
+ print(f"\n📝 Generating text...")
715
+ print(f" Prompt: {prompt}")
716
+ print(f" Input tokens: {input_ids.shape[1]}")
717
+ print(f" Max new tokens: {max_new_tokens}")
718
+
719
+ # 4. 생성
720
+ start_time = time.time()
721
+ generated_ids = []
722
+
723
+ with torch.no_grad():
724
+ for _ in range(max_new_tokens):
725
+ # Forward pass
726
+ outputs = model(input_ids=input_ids)
727
+ logits = outputs.logits[:, -1, :]
728
+
729
+ # Temperature sampling
730
+ if temperature > 0:
731
+ probs = F.softmax(logits / temperature, dim=-1)
732
+ next_token = torch.multinomial(probs, num_samples=1)
733
+ else:
734
+ next_token = logits.argmax(dim=-1, keepdim=True)
735
+
736
+ # Append
737
+ generated_ids.append(next_token.item())
738
+ input_ids = torch.cat([input_ids, next_token], dim=1)
739
+
740
+ # Stop at EOS
741
+ if next_token.item() == tokenizer.eos_token_id:
742
+ break
743
+
744
+ elapsed = time.time() - start_time
745
+
746
+ # 5. 디코드
747
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
748
+ full_text = prompt + generated_text
749
+
750
+ # 6. 결과
751
+ output_md = f"""
752
+ ## 📝 Generated Text
753
+
754
+ **Prompt**: {prompt}
755
+
756
+ ---
757
+
758
+ **Generated**:
759
+
760
+ {generated_text}
761
+
762
+ ---
763
+
764
+ **Full Text**:
765
+
766
+ {full_text}
767
+ """
768
+
769
+ stats_md = f"""
770
+ ## 📊 Generation Statistics
771
+
772
+ - **Input tokens**: {input_ids.shape[1] - len(generated_ids)}
773
+ - **Generated tokens**: {len(generated_ids)}
774
+ - **Total tokens**: {input_ids.shape[1]}
775
+ - **Time**: {elapsed:.2f}s
776
+ - **Speed**: {len(generated_ids) / elapsed:.1f} tokens/s
777
+ - **Temperature**: {temperature}
778
+ - **Model**: PHOENIX Retention (O(n))
779
+ """
780
+
781
+ return output_md, stats_md
782
+
783
+ except Exception as e:
784
+ import traceback
785
+ return f"❌ Generation failed:\n```\n{traceback.format_exc()}\n```", ""
786
+
787
+
788
  def run_phoenix_experiment(model_url, use_hierarchical, convert_attention, sequence_length, gpu_type):
789
  """Run PHOENIX experiment"""
790
  try:
 
972
  [convert_url, convert_hierarchical, convert_gpu],
973
  [gr.State(), convert_output])
974
 
975
+ with gr.Tab("💬 Text Generation (NEW!)"):
976
+ gr.Markdown("""
977
+ ### PHOENIX 텍스트 생성
978
+
979
+ 변환된 모델로 실제 텍스트를 생성합니다.
980
+ """)
981
+
982
+ with gr.Row():
983
+ with gr.Column(scale=1):
984
+ gen_model_url = gr.Textbox(label="🔗 Model URL", value=DEFAULT_MODEL)
985
+ gen_hierarchical = gr.Checkbox(value=True, label="Hierarchical")
986
+ gen_convert = gr.Checkbox(value=True, label="Enable Conversion")
987
+
988
+ gen_prompt = gr.Textbox(
989
+ label="📝 Input Prompt",
990
+ placeholder="Enter your prompt here...",
991
+ lines=3,
992
+ value="The future of AI is"
993
+ )
994
+
995
+ gen_max_tokens = gr.Slider(16, 256, 64, step=16, label="Max New Tokens")
996
+ gen_temperature = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature")
997
+
998
+ gen_btn = gr.Button("🚀 Generate Text", variant="primary")
999
+
1000
+ with gr.Column(scale=2):
1001
+ gen_output = gr.Markdown(label="Generated Text")
1002
+ gen_stats = gr.Markdown(label="Statistics")
1003
+
1004
+ gen_btn.click(
1005
+ fn=generate_text_phoenix,
1006
+ inputs=[gen_model_url, gen_hierarchical, gen_convert, gen_prompt,
1007
+ gen_max_tokens, gen_temperature],
1008
+ outputs=[gen_output, gen_stats]
1009
+ )
1010
+
1011
  with gr.Tab("🧪 Experiment"):
1012
  with gr.Row():
1013
  with gr.Column(scale=1):