allow Output subclasses in contrastive search
#4
by
jood-canva
- opened
custom_generate/generate.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import logging
|
| 2 |
from typing import TYPE_CHECKING, Optional, Union
|
| 3 |
|
|
@@ -14,7 +15,6 @@ from transformers.generation.utils import (
|
|
| 14 |
GenerateNonBeamOutput,
|
| 15 |
GenerationMixin,
|
| 16 |
)
|
| 17 |
-
from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
| 18 |
from transformers.utils import ModelOutput
|
| 19 |
|
| 20 |
|
|
@@ -414,7 +414,8 @@ def _contrastive_search(
|
|
| 414 |
for layer in outputs.decoder_attentions:
|
| 415 |
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
| 416 |
next_step_decoder_attentions += (layer,)
|
| 417 |
-
outputs =
|
|
|
|
| 418 |
past_key_values=next_past_key_values,
|
| 419 |
decoder_hidden_states=next_decoder_hidden_states,
|
| 420 |
decoder_attentions=next_step_decoder_attentions or None,
|
|
@@ -426,11 +427,13 @@ def _contrastive_search(
|
|
| 426 |
for layer in outputs.attentions:
|
| 427 |
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
| 428 |
next_step_attentions += (layer,)
|
| 429 |
-
outputs =
|
|
|
|
| 430 |
past_key_values=next_past_key_values,
|
| 431 |
hidden_states=next_decoder_hidden_states,
|
| 432 |
attentions=next_step_attentions or None,
|
| 433 |
)
|
|
|
|
| 434 |
# contrastive_search main logic end
|
| 435 |
|
| 436 |
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
|
|
|
| 1 |
+
from dataclasses import replace
|
| 2 |
import logging
|
| 3 |
from typing import TYPE_CHECKING, Optional, Union
|
| 4 |
|
|
|
|
| 15 |
GenerateNonBeamOutput,
|
| 16 |
GenerationMixin,
|
| 17 |
)
|
|
|
|
| 18 |
from transformers.utils import ModelOutput
|
| 19 |
|
| 20 |
|
|
|
|
| 414 |
for layer in outputs.decoder_attentions:
|
| 415 |
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
| 416 |
next_step_decoder_attentions += (layer,)
|
| 417 |
+
outputs = replace(
|
| 418 |
+
outputs,
|
| 419 |
past_key_values=next_past_key_values,
|
| 420 |
decoder_hidden_states=next_decoder_hidden_states,
|
| 421 |
decoder_attentions=next_step_decoder_attentions or None,
|
|
|
|
| 427 |
for layer in outputs.attentions:
|
| 428 |
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
| 429 |
next_step_attentions += (layer,)
|
| 430 |
+
outputs = replace(
|
| 431 |
+
outputs,
|
| 432 |
past_key_values=next_past_key_values,
|
| 433 |
hidden_states=next_decoder_hidden_states,
|
| 434 |
attentions=next_step_attentions or None,
|
| 435 |
)
|
| 436 |
+
|
| 437 |
# contrastive_search main logic end
|
| 438 |
|
| 439 |
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|