joaogante HF Staff ischemist commited on
Commit
a52f0cd
·
verified ·
1 Parent(s): 98da73e

downcast the scores after upcasting to prevent runtime errors (#1)

Browse files

- downcast the scores after upcasting to prevent runtime errors (595f9946de2ee21bf6f0c7e489b9b816aad7ca14)


Co-authored-by: Anton Morgunov <[email protected]>

Files changed (1) hide show
  1. custom_generate/generate.py +2 -1
custom_generate/generate.py CHANGED
@@ -379,7 +379,8 @@ def _group_beam_search(
379
  next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
380
 
381
  if output_scores:
382
- processed_score[batch_group_indices] = next_token_scores_processed
 
383
 
384
  # reshape for beam search
385
  next_token_scores = next_token_scores.view(
 
379
  next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
380
 
381
  if output_scores:
382
+ processed_score[batch_group_indices] = next_token_scores_processed.to(processed_score.dtype)
383
+
384
 
385
  # reshape for beam search
386
  next_token_scores = next_token_scores.view(