Upload folder using huggingface_hub
Browse files
custom_generate/generate.py
CHANGED
|
@@ -208,7 +208,7 @@ def _group_beam_search(
|
|
| 208 |
|
| 209 |
# define beam scorer
|
| 210 |
beam_scorer = BeamSearchScorer(
|
| 211 |
-
batch_size=input_ids.shape[0],
|
| 212 |
num_beams=generation_config.num_beams,
|
| 213 |
device=input_ids.device,
|
| 214 |
length_penalty=generation_config.length_penalty,
|
|
|
|
| 208 |
|
| 209 |
# define beam scorer
|
| 210 |
beam_scorer = BeamSearchScorer(
|
| 211 |
+
batch_size=input_ids.shape[0] // generation_config.num_beams,
|
| 212 |
num_beams=generation_config.num_beams,
|
| 213 |
device=input_ids.device,
|
| 214 |
length_penalty=generation_config.length_penalty,
|