seawolf2357 commited on
Commit
ec1f612
·
verified ·
1 Parent(s): f42a5e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -38
app.py CHANGED
@@ -734,63 +734,110 @@ def generate_text_phoenix(
734
  start_time = time.time()
735
  generated_ids = []
736
 
 
 
737
  with torch.no_grad():
738
  for step in range(max_new_tokens):
739
- # Forward pass (now with lm_head)
740
- outputs = model(input_ids=input_ids)
741
-
742
- # Get logits from lm_head
743
- logits = outputs.logits[:, -1, :] # [B, vocab_size]
744
-
745
- # Temperature sampling
746
- if temperature > 0:
747
- probs = F.softmax(logits / temperature, dim=-1)
748
- next_token = torch.multinomial(probs, num_samples=1)
749
- else:
750
- next_token = logits.argmax(dim=-1, keepdim=True)
751
-
752
- # Append
753
- generated_ids.append(next_token.item())
754
- input_ids = torch.cat([input_ids, next_token], dim=1)
755
-
756
- # Stop at EOS
757
- if next_token.item() == tokenizer.eos_token_id:
758
- print(f" Stopped at EOS token")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
  break
760
-
761
- # Progress
762
- if (step + 1) % 10 == 0:
763
- print(f" Generated {step + 1}/{max_new_tokens} tokens...")
764
 
765
  elapsed = time.time() - start_time
766
 
767
- # 5. 디코드
768
- generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
769
- full_text = prompt + generated_text
770
-
771
- # 6. 결과
 
 
 
 
 
 
 
 
772
  output_md = f"""
773
  ## 📝 Generated Text
774
 
775
- **Prompt**: {prompt}
776
-
777
- ---
778
-
779
- **Generated**:
780
 
 
 
781
  {generated_text}
782
-
783
- ---
784
 
785
  **Full Text**:
786
-
787
  {full_text}
 
788
  """
789
 
 
790
  stats_md = f"""
791
  ## 📊 Generation Statistics
792
 
793
- - **Input tokens**: {input_ids.shape[1] - len(generated_ids)}
794
  - **Generated tokens**: {len(generated_ids)}
795
  - **Total tokens**: {input_ids.shape[1]}
796
  - **Time**: {elapsed:.2f}s
 
734
  start_time = time.time()
735
  generated_ids = []
736
 
737
+ model.eval() # ✅ Set to eval mode
738
+
739
  with torch.no_grad():
740
  for step in range(max_new_tokens):
741
+ try:
742
+ # Forward pass (now with lm_head)
743
+ outputs = model(input_ids=input_ids)
744
+
745
+ # Get logits from lm_head
746
+ logits = outputs.logits[:, -1, :] # [B, vocab_size]
747
+
748
+ # Clamp logits to prevent numerical issues
749
+ logits = torch.clamp(logits, min=-100, max=100)
750
+
751
+ # Temperature sampling
752
+ if temperature > 0.01:
753
+ logits = logits / temperature
754
+ probs = F.softmax(logits, dim=-1)
755
+
756
+ # Check for NaN/Inf
757
+ if torch.isnan(probs).any() or torch.isinf(probs).any():
758
+ print(f" ⚠️ NaN/Inf detected at step {step}, using greedy")
759
+ next_token = logits.argmax(dim=-1, keepdim=True)
760
+ else:
761
+ # ✅ Add small epsilon to avoid zero probabilities
762
+ probs = probs + 1e-10
763
+ probs = probs / probs.sum(dim=-1, keepdim=True)
764
+ next_token = torch.multinomial(probs, num_samples=1)
765
+ else:
766
+ next_token = logits.argmax(dim=-1, keepdim=True)
767
+
768
+ next_token_id = next_token.item()
769
+
770
+ # ✅ Validate token range
771
+ if next_token_id < 0 or next_token_id >= model.config.vocab_size:
772
+ print(f" ⚠️ Invalid token {next_token_id}, stopping")
773
+ break
774
+
775
+ # Append
776
+ generated_ids.append(next_token_id)
777
+ input_ids = torch.cat([input_ids, next_token], dim=1)
778
+
779
+ # ✅ Limit max sequence length
780
+ if input_ids.shape[1] > 2048:
781
+ print(f" ⚠️ Max sequence length reached, stopping")
782
+ break
783
+
784
+ # Stop at EOS
785
+ if next_token_id == tokenizer.eos_token_id:
786
+ print(f" ✅ Stopped at EOS token")
787
+ break
788
+
789
+ # Progress
790
+ if (step + 1) % 10 == 0:
791
+ print(f" Generated {step + 1}/{max_new_tokens} tokens...")
792
+
793
+ except RuntimeError as e:
794
+ print(f" ❌ Runtime error at step {step}: {e}")
795
+ if "CUDA" in str(e):
796
+ print(f" Stopping generation due to CUDA error")
797
+ break
798
+ except Exception as e:
799
+ print(f" ❌ Error at step {step}: {e}")
800
  break
 
 
 
 
801
 
802
  elapsed = time.time() - start_time
803
 
804
+ # 6. 디코드
805
+ if len(generated_ids) == 0:
806
+ generated_text = "[No tokens generated]"
807
+ full_text = prompt
808
+ else:
809
+ try:
810
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
811
+ full_text = prompt + " " + generated_text
812
+ except Exception as e:
813
+ generated_text = f"[Decode error: {e}]"
814
+ full_text = prompt
815
+
816
+ # 7. 결과
817
  output_md = f"""
818
  ## 📝 Generated Text
819
 
820
+ **Prompt**:
821
+ ```
822
+ {prompt}
823
+ ```
 
824
 
825
+ **Generated** ({len(generated_ids)} tokens):
826
+ ```
827
  {generated_text}
828
+ ```
 
829
 
830
  **Full Text**:
831
+ ```
832
  {full_text}
833
+ ```
834
  """
835
 
836
+ initial_tokens = input_ids.shape[1] - len(generated_ids)
837
  stats_md = f"""
838
  ## 📊 Generation Statistics
839
 
840
+ - **Input tokens**: {initial_tokens}
841
  - **Generated tokens**: {len(generated_ids)}
842
  - **Total tokens**: {input_ids.shape[1]}
843
  - **Time**: {elapsed:.2f}s