downcast the scores after upcasting to prevent runtime errors (#1)
Browse files- downcast the scores after upcasting to prevent runtime errors (595f9946de2ee21bf6f0c7e489b9b816aad7ca14)
Co-authored-by: Anton Morgunov <[email protected]>
custom_generate/generate.py
CHANGED
|
@@ -379,7 +379,8 @@ def _group_beam_search(
|
|
| 379 |
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
| 380 |
|
| 381 |
if output_scores:
|
| 382 |
-
processed_score[batch_group_indices] = next_token_scores_processed
|
|
|
|
| 383 |
|
| 384 |
# reshape for beam search
|
| 385 |
next_token_scores = next_token_scores.view(
|
|
|
|
| 379 |
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
| 380 |
|
| 381 |
if output_scores:
|
| 382 |
+
processed_score[batch_group_indices] = next_token_scores_processed.to(processed_score.dtype)
|
| 383 |
+
|
| 384 |
|
| 385 |
# reshape for beam search
|
| 386 |
next_token_scores = next_token_scores.view(
|