allow Output subclasses in contrastive search

#4
Files changed (1) hide show
  1. custom_generate/generate.py +6 -3
custom_generate/generate.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import logging
2
  from typing import TYPE_CHECKING, Optional, Union
3
 
@@ -14,7 +15,6 @@ from transformers.generation.utils import (
14
  GenerateNonBeamOutput,
15
  GenerationMixin,
16
  )
17
- from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
18
  from transformers.utils import ModelOutput
19
 
20
 
@@ -414,7 +414,8 @@ def _contrastive_search(
414
  for layer in outputs.decoder_attentions:
415
  layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
416
  next_step_decoder_attentions += (layer,)
417
- outputs = Seq2SeqLMOutput(
 
418
  past_key_values=next_past_key_values,
419
  decoder_hidden_states=next_decoder_hidden_states,
420
  decoder_attentions=next_step_decoder_attentions or None,
@@ -426,11 +427,13 @@ def _contrastive_search(
426
  for layer in outputs.attentions:
427
  layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
428
  next_step_attentions += (layer,)
429
- outputs = CausalLMOutputWithPast(
 
430
  past_key_values=next_past_key_values,
431
  hidden_states=next_decoder_hidden_states,
432
  attentions=next_step_attentions or None,
433
  )
 
434
  # contrastive_search main logic end
435
 
436
  # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
 
1
+ from dataclasses import replace
2
  import logging
3
  from typing import TYPE_CHECKING, Optional, Union
4
 
 
15
  GenerateNonBeamOutput,
16
  GenerationMixin,
17
  )
 
18
  from transformers.utils import ModelOutput
19
 
20
 
 
414
  for layer in outputs.decoder_attentions:
415
  layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
416
  next_step_decoder_attentions += (layer,)
417
+ outputs = replace(
418
+ outputs,
419
  past_key_values=next_past_key_values,
420
  decoder_hidden_states=next_decoder_hidden_states,
421
  decoder_attentions=next_step_decoder_attentions or None,
 
427
  for layer in outputs.attentions:
428
  layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
429
  next_step_attentions += (layer,)
430
+ outputs = replace(
431
+ outputs,
432
  past_key_values=next_past_key_values,
433
  hidden_states=next_decoder_hidden_states,
434
  attentions=next_step_attentions or None,
435
  )
436
+
437
  # contrastive_search main logic end
438
 
439
  # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping