manueldeprada HF Staff commited on
Commit
eff00e5
·
verified ·
1 Parent(s): b60e744

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +1 -1
custom_generate/generate.py CHANGED
@@ -208,7 +208,7 @@ def _group_beam_search(
208
 
209
  # define beam scorer
210
  beam_scorer = BeamSearchScorer(
211
- batch_size=input_ids.shape[0],
212
  num_beams=generation_config.num_beams,
213
  device=input_ids.device,
214
  length_penalty=generation_config.length_penalty,
 
208
 
209
  # define beam scorer
210
  beam_scorer = BeamSearchScorer(
211
+ batch_size=input_ids.shape[0] // generation_config.num_beams,
212
  num_beams=generation_config.num_beams,
213
  device=input_ids.device,
214
  length_penalty=generation_config.length_penalty,