diff --git "a/indextts/gpt/transformers_generation_utils.py" "b/indextts/gpt/transformers_generation_utils.py" deleted file mode 100755--- "a/indextts/gpt/transformers_generation_utils.py" +++ /dev/null @@ -1,4747 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import copy -import inspect -import warnings -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.distributed as dist -from torch import nn -from torch.nn import functional as F - -from transformers.cache_utils import ( - Cache, - DynamicCache, - EncoderDecoderCache, - OffloadedCache, - QuantizedCacheConfig, - StaticCache, -) -from transformers.configuration_utils import PretrainedConfig -from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled -from transformers.integrations.fsdp import is_fsdp_managed_module -from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput -from transformers.pytorch_utils import isin_mps_friendly -from transformers.tokenization_utils import ExtensionsTrie -from transformers.utils import ( - ModelOutput, - is_accelerate_available, - is_hqq_available, - is_optimum_quanto_available, - # is_quanto_available, - is_torchdynamo_compiling, - logging, -) -from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint -from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer -from transformers.generation.candidate_generator import ( - AssistedCandidateGenerator, - AssistedCandidateGeneratorDifferentTokenizers, - CandidateGenerator, - PromptLookupCandidateGenerator, - _crop_past_key_values, - _prepare_attention_mask, - _prepare_token_type_ids, -) -from transformers.generation.configuration_utils import ( - NEED_SETUP_CACHE_CLASSES_MAPPING, - QUANT_BACKEND_CLASSES_MAPPING, - GenerationConfig, - GenerationMode, -) -from transformers.generation.logits_process import ( - EncoderNoRepeatNGramLogitsProcessor, - EncoderRepetitionPenaltyLogitsProcessor, - EpsilonLogitsWarper, - EtaLogitsWarper, - ExponentialDecayLengthPenalty, - ForcedBOSTokenLogitsProcessor, - ForcedEOSTokenLogitsProcessor, - HammingDiversityLogitsProcessor, - InfNanRemoveLogitsProcessor, - LogitNormalization, - LogitsProcessorList, - MinLengthLogitsProcessor, - MinNewTokensLengthLogitsProcessor, - MinPLogitsWarper, - NoBadWordsLogitsProcessor, - NoRepeatNGramLogitsProcessor, - PrefixConstrainedLogitsProcessor, - RepetitionPenaltyLogitsProcessor, - SequenceBiasLogitsProcessor, - SuppressTokensAtBeginLogitsProcessor, - SuppressTokensLogitsProcessor, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, - TypicalLogitsWarper, - UnbatchedClassifierFreeGuidanceLogitsProcessor, -) -from transformers.generation.stopping_criteria import ( - ConfidenceCriteria, - EosTokenCriteria, - MaxLengthCriteria, - MaxTimeCriteria, - StoppingCriteria, - StoppingCriteriaList, - StopStringCriteria, -) - - -if TYPE_CHECKING: - from transformers.modeling_utils import PreTrainedModel - from transformers.tokenization_utils_base import PreTrainedTokenizerBase - from transformers.generation.streamers import BaseStreamer - -logger = logging.get_logger(__name__) - -if is_accelerate_available(): - from accelerate.hooks import AlignDevicesHook, add_hook_to_module - - -@dataclass -class GenerateDecoderOnlyOutput(ModelOutput): - """ - Outputs of decoder-only generation models, when using non-beam methods. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for - each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): - Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for - each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): - Returns the model cache, used to speed up decoding. Different models have a different cache format, check - the model's documentation. Usually, a [`~cache_utils.Cache`] instance. - """ - - sequences: torch.LongTensor = None - scores: Optional[Tuple[torch.FloatTensor]] = None - logits: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None - - -@dataclass -class GenerateEncoderDecoderOutput(ModelOutput): - """ - Outputs of encoder-decoder generation models, when using non-beam methods. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for - each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): - Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for - each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, - sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Returns the model cache, used to speed up decoding. Different models have a different cache format, check - the model's documentation. Usually, a [`~cache_utils.Cache`] instance. - """ - - sequences: torch.LongTensor = None - scores: Optional[Tuple[torch.FloatTensor]] = None - logits: Optional[Tuple[torch.FloatTensor]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None - - -@dataclass -class GenerateBeamDecoderOnlyOutput(ModelOutput): - """ - Outputs of decoder-only generation models, when using beam methods. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): - Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), - with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): - Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for - each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): - Beam indices of generated token id at each generation step. `torch.LongTensor` of shape - `(batch_size*num_return_sequences, sequence_length)`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): - Returns the model cache, used to speed up decoding. Different models have a different cache format, check - the model's documentation. Usually, a [`~cache_utils.Cache`] instance. - """ - - sequences: torch.LongTensor = None - sequences_scores: Optional[torch.FloatTensor] = None - scores: Optional[Tuple[torch.FloatTensor]] = None - logits: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[torch.LongTensor] = None - attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None - - -@dataclass -class GenerateBeamEncoderDecoderOutput(ModelOutput): - """ - Outputs of encoder-decoder generation models, when using beam methods. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): - Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), - with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): - Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for - each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): - Beam indices of generated token id at each generation step. `torch.LongTensor` of shape - `(batch_size*num_return_sequences, sequence_length)`. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, - sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, - sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): - Returns the model cache, used to speed up decoding. Different models have a different cache format, check - the model's documentation. Usually, a [`~cache_utils.Cache`] instance. - """ - - sequences: torch.LongTensor = None - sequences_scores: Optional[torch.FloatTensor] = None - scores: Optional[Tuple[torch.FloatTensor]] = None - logits: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[torch.LongTensor] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None - - -# TODO (joao): remove the equivalent classes and typing shortcuts below in v5 -# Equivalent classes (kept for retrocompatibility purposes) -GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput -ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput -SampleDecoderOnlyOutput = GenerateDecoderOnlyOutput - -ContrastiveSearchEncoderDecoderOutput = GenerateEncoderDecoderOutput -GreedySearchEncoderDecoderOutput = GenerateEncoderDecoderOutput -SampleEncoderDecoderOutput = GenerateEncoderDecoderOutput - -BeamSearchDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput -BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput - -BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput -BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput - -GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] -SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] -BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] -BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] -ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput] - -# Typing shortcuts -GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] -GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput] -GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] - - -class GenerationMixin: - """ - A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. - - The class exposes [`~generation.GenerationMixin.generate`], which can be used for: - - *greedy decoding* if `num_beams=1` and `do_sample=False` - - *contrastive search* if `penalty_alpha>0` and `top_k>1` - - *multinomial sampling* if `num_beams=1` and `do_sample=True` - - *beam-search decoding* if `num_beams>1` and `do_sample=False` - - *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True` - - *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1` - - *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None` - - *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()` - - To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). - """ - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[Cache] = None, - attention_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ): - """ - Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or - slicing inputs given the existing cache. - - See the forward pass in the model documentation for expected arguments (different models might have different - requirements for e.g. `past_key_values`). This function should work as is for most LLMs. - """ - - # 1. Handle BC: - model_inputs = {} - # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`) - if self._supports_cache_class: - model_inputs["cache_position"] = cache_position - # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this - # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly - # (this alternative is not as robust as calling `generate` and letting it create `cache_position`) - elif cache_position is None: - past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) - - # 2. Generic cache-dependent input preparation - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case - if past_key_values is not None: - model_inputs["past_key_values"] = past_key_values - if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - # 3. Prepare base model inputs - input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if not self.config.is_encoder_decoder: - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs[input_ids_key] = None - model_inputs["inputs_embeds"] = inputs_embeds - else: - # `clone` calls in this function ensure a consistent stride. See #32227 - model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) - model_inputs["inputs_embeds"] = None - else: - model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) - - # 4. Create missing `position_ids` on the fly - if ( - attention_mask is not None - and kwargs.get("position_ids") is None - and "position_ids" in set(inspect.signature(self.forward).parameters.keys()) - ): - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below) - - # 5. Slice model inputs if it's an input that should have the same length as `input_ids` - for model_input_name in ["position_ids", "token_type_ids"]: - model_input = kwargs.get(model_input_name) - if model_input is not None: - if past_key_values: - model_input = model_input[:, -input_ids.shape[1] :] - model_input = model_input.clone(memory_format=torch.contiguous_format) - model_inputs[model_input_name] = model_input - - # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass) - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs[input_ids_key].shape - device = model_inputs[input_ids_key].device - - # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create - # the 4D causal mask exists, it should be present in the base model (XXXModel class). - base_model = getattr(self, self.base_model_prefix, None) - if base_model is None: - causal_mask_creation_function = getattr( - self, "_prepare_4d_causal_attention_mask_with_cache_position", None - ) - else: - causal_mask_creation_function = getattr( - base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None - ) - if causal_mask_creation_function is None: - logger.warning_once( - f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method " - "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're " - "writing code, see Llama for an example implementation. If you're a user, please report this " - "issue on GitHub." - ) - else: - attention_mask = causal_mask_creation_function( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - config=self.config, - past_key_values=past_key_values, - ) - if attention_mask is not None: - model_inputs["attention_mask"] = attention_mask - - # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). - for key, value in kwargs.items(): - if key not in model_inputs: - model_inputs[key] = value - - # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) - model_inputs.pop("labels", None) - return model_inputs - - def _prepare_model_inputs( - self, - inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[torch.Tensor] = None, - model_kwargs: Optional[Dict[str, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: - """ - This function extracts the model-specific `inputs` for generation. - """ - # 1. retrieve all kwargs that are non-None or non-model input related. - # some encoder-decoder models have different names for model and encoder - if ( - self.config.is_encoder_decoder - and hasattr(self, "encoder") - and self.encoder.main_input_name != self.main_input_name - ): - input_name = self.encoder.main_input_name - else: - input_name = self.main_input_name - - model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} - - # 2. check whether model_input_name is passed as kwarg - # if yes and `inputs` is None use kwarg inputs - inputs_kwarg = model_kwargs.pop(input_name, None) - if inputs_kwarg is not None and inputs is not None: - raise ValueError( - f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " - f"Make sure to either pass {inputs} or {input_name}=..." - ) - elif inputs_kwarg is not None: - inputs = inputs_kwarg - - # 3. In the presence of `inputs_embeds` for text models: - # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model - # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with - # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`) - # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and - # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states. - if input_name == "input_ids" and "inputs_embeds" in model_kwargs: - if not self.config.is_encoder_decoder: - has_inputs_embeds_forwarding = "inputs_embeds" in set( - inspect.signature(self.prepare_inputs_for_generation).parameters.keys() - ) - if not has_inputs_embeds_forwarding: - raise ValueError( - f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} " - "doesn't have its forwarding implemented. See the GPT2 implementation for an example " - "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" - ) - # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of - # the attention mask) can rely on the actual model input. - model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( - inputs, bos_token_id, model_kwargs=model_kwargs - ) - else: - if inputs is not None: - raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") - inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" - - # 4. if `inputs` is still None, try to create `input_ids` from BOS token - inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) - return inputs, input_name, model_kwargs - - def _maybe_initialize_input_ids_for_generation( - self, - inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[torch.Tensor] = None, - model_kwargs: Optional[Dict[str, torch.Tensor]] = None, - ) -> torch.LongTensor: - """Initializes input ids for generation, if necessary.""" - if inputs is not None: - return inputs - - encoder_outputs = model_kwargs.get("encoder_outputs") - if self.config.is_encoder_decoder and encoder_outputs is not None: - # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding - shape = encoder_outputs.last_hidden_state.size()[:-1] - return torch.ones(shape, dtype=torch.long, device=self.device) * -100 - - # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with - # soft-prompting or in multimodal implementations built on top of decoder-only language models. - batch_size = 1 - for value in model_kwargs.values(): - if isinstance(value, torch.Tensor): - batch_size = value.shape[0] - break - - if "inputs_embeds" in model_kwargs: - return torch.ones((batch_size, 0), dtype=torch.long, device=self.device) - - if bos_token_id is None: - raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") - - return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id - - def _prepare_attention_mask_for_generation( - self, - inputs: torch.Tensor, - pad_token_id: Optional[torch.Tensor], - eos_token_id: Optional[torch.Tensor], - ) -> torch.LongTensor: - # No information for attention mask inference -> return default attention mask - default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) - if pad_token_id is None: - return default_attention_mask - - is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] - if not is_input_ids: - return default_attention_mask - - is_pad_token_in_inputs = (pad_token_id is not None) and ( - isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any() - ) - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( - isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any() - ) - can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id - attention_mask_from_padding = inputs.ne(pad_token_id).long() - - attention_mask = ( - attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask - ) - return attention_mask - - def _prepare_encoder_decoder_kwargs_for_generation( - self, - inputs_tensor: torch.Tensor, - model_kwargs, - model_input_name: Optional[str], - generation_config: GenerationConfig, - ) -> Dict[str, Any]: - # 1. get encoder - encoder = self.get_encoder() - # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device - # as the inputs. - if hasattr(self, "hf_device_map"): - if hasattr(encoder, "_hf_hook"): - encoder._hf_hook.io_same_device = True - else: - add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True)) - - # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config. - irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] - encoder_kwargs = { - argument: value - for argument, value in model_kwargs.items() - if not any(argument.startswith(p) for p in irrelevant_prefix) - } - encoder_signature = set(inspect.signature(encoder.forward).parameters) - encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature - if not encoder_accepts_wildcard: - encoder_kwargs = { - argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature - } - encoder_kwargs["output_attentions"] = generation_config.output_attentions - encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states - - # 3. make sure that encoder returns `ModelOutput` - model_input_name = model_input_name if model_input_name is not None else self.main_input_name - encoder_kwargs["return_dict"] = True - encoder_kwargs[model_input_name] = inputs_tensor - model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore - - return model_kwargs - - def _prepare_decoder_input_ids_for_generation( - self, - batch_size: int, - model_input_name: str, - model_kwargs: Dict[str, torch.Tensor], - decoder_start_token_id: torch.Tensor, - device: torch.device = None, - ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: - """Prepares `decoder_input_ids` for generation with encoder-decoder models""" - # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, - # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. - if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - decoder_input_ids = model_kwargs.pop("decoder_input_ids") - elif "input_ids" in model_kwargs and model_input_name != "input_ids": - decoder_input_ids = model_kwargs.pop("input_ids") - else: - decoder_input_ids = None - - # 2. `decoder_start_token_id` must have shape (batch_size, 1) - if device is None: - device = self.device - if decoder_start_token_id.ndim == 1: - if decoder_start_token_id.shape[0] != batch_size: - raise ValueError( - f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" - ) - decoder_start_token_id = decoder_start_token_id.view(-1, 1) - else: - decoder_start_token_id = ( - torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id - ) - - # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. - # no user input -> use decoder_start_token_id as decoder_input_ids - if decoder_input_ids is None: - decoder_input_ids = decoder_start_token_id - # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the - # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic. - # See: https://github.com/huggingface/transformers/pull/31470 - elif "donut" in self.__class__.__name__.lower() or ( - self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower() - ): - pass - elif self.config.model_type in ["whisper"]: - pass - # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust - # decoder_attention_mask if provided) - elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item(): - decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1) - if "decoder_attention_mask" in model_kwargs: - decoder_attention_mask = model_kwargs["decoder_attention_mask"] - decoder_attention_mask = torch.cat( - (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), - dim=-1, - ) - model_kwargs["decoder_attention_mask"] = decoder_attention_mask - - return decoder_input_ids, model_kwargs - - @staticmethod - def _expand_inputs_for_generation( - expand_size: int = 1, - is_encoder_decoder: bool = False, - input_ids: Optional[torch.LongTensor] = None, - **model_kwargs, - ) -> Tuple[torch.LongTensor, Dict[str, Any]]: - """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" - # Do not call torch.repeat_interleave if expand_size is 1 because it clones - # the input tensor and thus requires more memory although no change is applied - if expand_size == 1: - return input_ids, model_kwargs - - def _expand_dict_for_generation(dict_to_expand): - for key in dict_to_expand: - if ( - key != "cache_position" - and dict_to_expand[key] is not None - and isinstance(dict_to_expand[key], torch.Tensor) - ): - dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) - return dict_to_expand - - if input_ids is not None: - input_ids = input_ids.repeat_interleave(expand_size, dim=0) - - model_kwargs = _expand_dict_for_generation(model_kwargs) - - if is_encoder_decoder: - if model_kwargs.get("encoder_outputs") is None: - raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") - model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) - - return input_ids, model_kwargs - - def _extract_past_from_model_output(self, outputs: ModelOutput): - past_key_values = None - cache_name = "past_key_values" - if "past_key_values" in outputs: - past_key_values = outputs.past_key_values - elif "mems" in outputs: - past_key_values = outputs.mems - elif "past_buckets_states" in outputs: - past_key_values = outputs.past_buckets_states - elif "cache_params" in outputs: - past_key_values = outputs.cache_params - cache_name = "cache_params" - - return cache_name, past_key_values - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - num_new_tokens: int = 1, - ) -> Dict[str, Any]: - # update past_key_values keeping its naming used in model code - cache_name, cache = self._extract_past_from_model_output(outputs) - model_kwargs[cache_name] = cache - if getattr(outputs, "state", None) is not None: - model_kwargs["state"] = outputs.state - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) - - if not is_encoder_decoder: - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - else: - # update decoder attention mask - if "decoder_attention_mask" in model_kwargs: - decoder_attention_mask = model_kwargs["decoder_attention_mask"] - model_kwargs["decoder_attention_mask"] = torch.cat( - [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], - dim=-1, - ) - - if model_kwargs.get("use_cache", True): - model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens - else: - past_positions = model_kwargs.pop("cache_position") - new_positions = torch.arange( - past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype - ).to(past_positions.device) - model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) - return model_kwargs - - def _reorder_cache(self, past_key_values, beam_idx): - raise NotImplementedError( - f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" - f" enable beam search for {self.__class__}" - ) - - def _get_candidate_generator( - self, - generation_config: GenerationConfig, - input_ids: torch.LongTensor, - inputs_tensor: torch.Tensor, - assistant_model: "PreTrainedModel", - logits_processor: LogitsProcessorList, - target_tokenizer: "PreTrainedTokenizerBase", - assistant_tokenizer: "PreTrainedTokenizerBase", - model_kwargs: Dict, - ) -> CandidateGenerator: - """ - Returns the candidate generator to be used in `assisted_generation` - """ - different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer)) - - if generation_config.prompt_lookup_num_tokens is not None: - candidate_generator = PromptLookupCandidateGenerator( - eos_token_id=generation_config._eos_token_tensor, - num_output_tokens=generation_config.prompt_lookup_num_tokens, - max_matching_ngram_size=generation_config.max_matching_ngram_size, - max_length=generation_config.max_length, - ) - elif different_tokenizers: - candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( - input_ids=input_ids, - assistant_model=assistant_model, - generation_config=generation_config, - model_kwargs=model_kwargs, - inputs_tensor=inputs_tensor, - logits_processor=logits_processor, - target_tokenizer=target_tokenizer, - assistant_tokenizer=assistant_tokenizer, - ) - else: - candidate_generator = AssistedCandidateGenerator( - input_ids=input_ids, - assistant_model=assistant_model, - generation_config=generation_config, - model_kwargs=model_kwargs, - inputs_tensor=inputs_tensor, - logits_processor=logits_processor, - ) - return candidate_generator - - def _get_logits_processor( - self, - generation_config: GenerationConfig, - input_ids_seq_length: int, - encoder_input_ids: torch.LongTensor, - prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], - logits_processor: Optional[LogitsProcessorList], - device: str = None, - model_kwargs: Optional[Dict[str, Any]] = None, - negative_prompt_ids: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - ) -> LogitsProcessorList: - """ - This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] - instances used to modify the scores of the language model head. - """ - # instantiate processors list - processors = LogitsProcessorList() - - if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1: - processors.append( - UnbatchedClassifierFreeGuidanceLogitsProcessor( - generation_config.guidance_scale, - self, - unconditional_ids=negative_prompt_ids, - unconditional_attention_mask=negative_prompt_attention_mask, - use_cache=generation_config.use_cache, - ) - ) - if generation_config.sequence_bias is not None: - processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias)) - - if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0: - processors.append( - HammingDiversityLogitsProcessor( - diversity_penalty=generation_config.diversity_penalty, - num_beams=generation_config.num_beams, - num_beam_groups=generation_config.num_beam_groups, - ) - ) - if ( - generation_config.encoder_repetition_penalty is not None - and generation_config.encoder_repetition_penalty != 1.0 - ): - if len(encoder_input_ids.shape) == 2: - processors.append( - EncoderRepetitionPenaltyLogitsProcessor( - penalty=generation_config.encoder_repetition_penalty, - encoder_input_ids=encoder_input_ids, - ) - ) - else: - warnings.warn( - "Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to " - "`generate`, ignoring the argument.", - UserWarning, - ) - if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: - processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) - if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: - processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) - if ( - generation_config.encoder_no_repeat_ngram_size is not None - and generation_config.encoder_no_repeat_ngram_size > 0 - ): - if len(encoder_input_ids.shape) == 2: - processors.append( - EncoderNoRepeatNGramLogitsProcessor( - generation_config.encoder_no_repeat_ngram_size, - encoder_input_ids, - ) - ) - else: - warnings.warn( - "Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to " - "`generate`, ignoring the argument.", - UserWarning, - ) - if generation_config.bad_words_ids is not None: - processors.append( - NoBadWordsLogitsProcessor( - generation_config.bad_words_ids, - generation_config._eos_token_tensor, - ) - ) - if ( - generation_config.min_length is not None - and generation_config._eos_token_tensor is not None - and generation_config.min_length > 0 - ): - processors.append( - MinLengthLogitsProcessor( - generation_config.min_length, - generation_config._eos_token_tensor, - device=device, - ) - ) - if ( - generation_config.min_new_tokens is not None - and generation_config._eos_token_tensor is not None - and generation_config.min_new_tokens > 0 - ): - processors.append( - MinNewTokensLengthLogitsProcessor( - input_ids_seq_length, - generation_config.min_new_tokens, - generation_config._eos_token_tensor, - device=device, - ) - ) - if prefix_allowed_tokens_fn is not None: - processors.append( - PrefixConstrainedLogitsProcessor( - prefix_allowed_tokens_fn, - generation_config.num_beams // generation_config.num_beam_groups, - ) - ) - if generation_config.forced_bos_token_id is not None: - processors.append( - ForcedBOSTokenLogitsProcessor( - generation_config.forced_bos_token_id, - ) - ) - if generation_config.forced_eos_token_id is not None: - processors.append( - ForcedEOSTokenLogitsProcessor( - generation_config.max_length, - generation_config.forced_eos_token_id, - device=device, - ) - ) - if generation_config.remove_invalid_values is True: - processors.append(InfNanRemoveLogitsProcessor()) - if generation_config.exponential_decay_length_penalty is not None: - processors.append( - ExponentialDecayLengthPenalty( - generation_config.exponential_decay_length_penalty, - generation_config._eos_token_tensor, - input_ids_seq_length, - ) - ) - if generation_config.suppress_tokens is not None: - processors.append( - SuppressTokensLogitsProcessor( - generation_config.suppress_tokens, - device=device, - ) - ) - if generation_config.begin_suppress_tokens is not None: - begin_index = input_ids_seq_length - begin_index = ( - begin_index - if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) - else begin_index + 1 - ) - processors.append( - SuppressTokensAtBeginLogitsProcessor( - generation_config.begin_suppress_tokens, - begin_index, - device=device, - ) - ) - if generation_config.forced_decoder_ids is not None: - # TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT - raise ValueError( - "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument " - "in favour of `input_ids` or `decoder_input_ids` respectively.", - ) - if generation_config.watermarking_config is not None: - processors.append( - generation_config.watermarking_config.construct_processor(self.config.vocab_size, device) - ) - - # TODO (joao): find a strategy to specify the order of the processors - processors = self._merge_criteria_processor_list(processors, logits_processor) - - # Processors previously known as `LogitsWarpers`, only applied with sampling strategies - if generation_config.do_sample: - # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a - # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) - if generation_config.num_beams > 1: - if isinstance(generation_config._eos_token_tensor, list): - min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 - elif isinstance(generation_config._eos_token_tensor, torch.Tensor): - min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 - else: - min_tokens_to_keep = 2 - else: - min_tokens_to_keep = 1 - - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` - if generation_config.temperature is not None and generation_config.temperature != 1.0: - processors.append(TemperatureLogitsWarper(generation_config.temperature)) - if generation_config.top_k is not None and generation_config.top_k != 0: - processors.append( - TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep) - ) - if generation_config.top_p is not None and generation_config.top_p < 1.0: - processors.append( - TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep) - ) - if generation_config.min_p is not None: - # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) - processors.append( - MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep) - ) - if generation_config.typical_p is not None and generation_config.typical_p < 1.0: - processors.append( - TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) - ) - if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: - processors.append( - EpsilonLogitsWarper( - epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep - ) - ) - if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: - processors.append( - EtaLogitsWarper( - epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device - ) - ) - - # `LogitNormalization` should always be the last logit processor, when present - if generation_config.renormalize_logits is True: - processors.append(LogitNormalization()) - return processors - - def _get_stopping_criteria( - self, - generation_config: GenerationConfig, - stopping_criteria: Optional[StoppingCriteriaList], - tokenizer: Optional["PreTrainedTokenizerBase"] = None, - **kwargs, - ) -> StoppingCriteriaList: - criteria = StoppingCriteriaList() - if generation_config.max_length is not None: - max_position_embeddings = getattr(self.config, "max_position_embeddings", None) - criteria.append( - MaxLengthCriteria( - max_length=generation_config.max_length, - max_position_embeddings=max_position_embeddings, - ) - ) - if generation_config.max_time is not None: - criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) - if generation_config.stop_strings is not None: - if tokenizer is None: - raise ValueError( - "There are one or more stop strings, either in the arguments to `generate` or in the " - "model's generation config, but we could not locate a tokenizer. When generating with " - "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`." - ) - criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) - if generation_config._eos_token_tensor is not None: - criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor)) - if ( - generation_config.is_assistant - and generation_config.assistant_confidence_threshold is not None - and generation_config.assistant_confidence_threshold > 0 - ): - criteria.append( - ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold) - ) - criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) - return criteria - - def _merge_criteria_processor_list( - self, - default_list: Union[LogitsProcessorList, StoppingCriteriaList], - custom_list: Union[LogitsProcessorList, StoppingCriteriaList], - ) -> Union[LogitsProcessorList, StoppingCriteriaList]: - if len(custom_list) == 0: - return default_list - for default in default_list: - for custom in custom_list: - if type(custom) is type(default): - object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" - raise ValueError( - f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" - f" `.generate()`, but it has already been created with the values {default}. {default} has been" - " created by passing the corresponding arguments to generate or by the model's config default" - f" values. If you just want to change the default values of {object_type} consider passing" - f" them as arguments to `.generate()` instead of using a custom {object_type}." - ) - default_list.extend(custom_list) - return default_list - - def compute_transition_scores( - self, - sequences: torch.Tensor, - scores: Tuple[torch.Tensor], - beam_indices: Optional[torch.Tensor] = None, - normalize_logits: bool = False, - ) -> torch.Tensor: - """ - Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was - used). This is a convenient method to quicky obtain the scores of the selected tokens at generation time. - - Parameters: - sequences (`torch.LongTensor`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or - shorter if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)`): - Transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), - with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - beam_indices (`torch.LongTensor`, *optional*): - Beam indices of generated token id at each generation step. `torch.LongTensor` of shape - `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at - generate-time. - normalize_logits (`bool`, *optional*, defaults to `False`): - Whether to normalize the logits (which, for legacy reasons, may be unnormalized). - - Return: - `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing - the transition scores (logits) - - Examples: - - ```python - >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM - >>> import numpy as np - - >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer.pad_token_id = tokenizer.eos_token_id - >>> inputs = tokenizer(["Today is"], return_tensors="pt") - - >>> # Example 1: Print the scores for each token generated with Greedy Search - >>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True) - >>> transition_scores = model.compute_transition_scores( - ... outputs.sequences, outputs.scores, normalize_logits=True - ... ) - >>> # input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for - >>> # encoder-decoder models, like BART or T5. - >>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] - >>> generated_tokens = outputs.sequences[:, input_length:] - >>> for tok, score in zip(generated_tokens[0], transition_scores[0]): - ... # | token | token string | log probability | probability - ... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}") - | 262 | the | -1.414 | 24.33% - | 1110 | day | -2.609 | 7.36% - | 618 | when | -2.010 | 13.40% - | 356 | we | -1.859 | 15.58% - | 460 | can | -2.508 | 8.14% - - >>> # Example 2: Reconstruct the sequence scores from Beam Search - >>> outputs = model.generate( - ... **inputs, - ... max_new_tokens=5, - ... num_beams=4, - ... num_return_sequences=4, - ... return_dict_in_generate=True, - ... output_scores=True, - ... ) - >>> transition_scores = model.compute_transition_scores( - ... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False - ... ) - >>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores. - >>> # Tip 1: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the - >>> # use case, you might want to recompute it with `normalize_logits=True`. - >>> # Tip 2: the output length does NOT include the input length - >>> output_length = np.sum(transition_scores.numpy() < 0, axis=1) - >>> length_penalty = model.generation_config.length_penalty - >>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty) - >>> print(np.allclose(outputs.sequences_scores, reconstructed_scores)) - True - ```""" - # 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent - # to a beam search approach were the first (and only) beam is always selected - if beam_indices is None: - beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device) - beam_indices = beam_indices.expand(-1, len(scores)) - - # 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being - # seq_len - input_length - scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) - - # 3. Optionally normalize the logits (across the vocab dimension) - if normalize_logits: - scores = scores.reshape(-1, self.config.vocab_size, scores.shape[-1]) - scores = torch.nn.functional.log_softmax(scores, dim=1) - scores = scores.reshape(-1, scores.shape[-1]) - - # 4. cut beam_indices to longest beam length - beam_indices_mask = beam_indices < 0 - max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() - beam_indices = beam_indices.clone()[:, :max_beam_length] - beam_indices_mask = beam_indices_mask[:, :max_beam_length] - - # 5. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards - beam_indices[beam_indices_mask] = 0 - - # 6. multiply beam_indices with vocab size to gather correctly from scores - beam_sequence_indices = beam_indices * self.config.vocab_size - - # 7. Define which indices contributed to scores - cut_idx = sequences.shape[-1] - max_beam_length - indices = sequences[:, cut_idx:] + beam_sequence_indices - - # 8. Compute scores - transition_scores = scores.gather(0, indices) - - # 9. Mask out transition_scores of beams that stopped early - transition_scores[beam_indices_mask] = 0 - - return transition_scores - - def _validate_model_class(self): - """ - Confirms that the model class is compatible with generation. If not, raises an exception that points to the - right class to use. - """ - # TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from - # `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can - # safely call `GenerationMixin.generate` - if not is_torchdynamo_compiling() and not self.can_generate(): - terminations_with_generation_support = [ - "ForCausalLM", - "ForConditionalGeneration", - "ForSpeechSeq2Seq", - "ForVision2Seq", - ] - raise TypeError( - f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " - "it doesn't have a language model head. Classes that support generation often end in one of these " - f"names: {terminations_with_generation_support}." - ) - - def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): - if assistant_model is None: - return - - if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: - attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] - attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] - are_equal = all( - getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check - ) - if not are_equal: - raise ValueError( - "The main model and the assistant don't have compatible encoder-dependent input shapes. " - "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." - ) - - doc_reference = ( - "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)" - ) - if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: - if assistant_tokenizer is not None: - raise ValueError( - f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}." - ) - else: - if tokenizer is None or assistant_tokenizer is None: - raise ValueError( - f"The main and assistant moedels have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}." - ) - - def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): - """Validates model kwargs for generation. Generate argument typos will also be caught here.""" - # If a `Cache` instance is passed, checks whether the model is compatible with it - if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: - raise ValueError( - f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " - "check the model documentation for supported cache formats." - ) - - # Excludes arguments that are handled before calling any model function - if self.config.is_encoder_decoder: - for key in ["decoder_input_ids"]: - model_kwargs.pop(key, None) - - unused_model_args = [] - model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) - # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If - # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) - if "kwargs" in model_args or "model_kwargs" in model_args: - model_args |= set(inspect.signature(self.forward).parameters) - - # Encoder-Decoder models may also need Encoder arguments from `model_kwargs` - if self.config.is_encoder_decoder: - base_model = getattr(self, self.base_model_prefix, None) - - # allow encoder kwargs - encoder = getattr(self, "encoder", None) - # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`. - # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder` - # TODO: A better way to handle this. - if encoder is None and base_model is not None: - encoder = getattr(base_model, "encoder", None) - - if encoder is not None: - encoder_model_args = set(inspect.signature(encoder.forward).parameters) - model_args |= encoder_model_args - - # allow decoder kwargs - decoder = getattr(self, "decoder", None) - if decoder is None and base_model is not None: - decoder = getattr(base_model, "decoder", None) - - if decoder is not None: - decoder_model_args = set(inspect.signature(decoder.forward).parameters) - model_args |= {f"decoder_{x}" for x in decoder_model_args} - - # allow assistant_encoder_outputs to be passed if we're doing assisted generating - if "assistant_encoder_outputs" in model_kwargs: - model_args |= {"assistant_encoder_outputs"} - - for key, value in model_kwargs.items(): - if value is not None and key not in model_args: - unused_model_args.append(key) - - if unused_model_args: - raise ValueError( - f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" - " generate arguments will also show up in this list)" - ) - - def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): - """Performs validation related to the resulting generated length""" - - # Can't throw warnings/exceptions during compilation - if is_torchdynamo_compiling(): - return - - # 1. Max length warnings related to poor parameterization - if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: - # 20 is the default max_length of the generation config - warnings.warn( - f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " - "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " - "generation.", - UserWarning, - ) - if input_ids_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - raise ValueError( - f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_length` or, better yet, setting `max_new_tokens`." - ) - - # 2. Min length warnings due to unfeasible parameter combinations - min_length_error_suffix = ( - " Generation will stop at the defined maximum length. You should decrease the minimum length and/or " - "increase the maximum length." - ) - if has_default_max_length: - min_length_error_suffix += ( - f" Note that `max_length` is set to {generation_config.max_length}, its default value." - ) - if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: - warnings.warn( - f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than" - f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, - UserWarning, - ) - if generation_config.min_new_tokens is not None: - min_length = generation_config.min_new_tokens + input_ids_length - if min_length > generation_config.max_length: - warnings.warn( - f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when " - f"added to the prompt length ({input_ids_length}), is larger than" - f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, - UserWarning, - ) - - def _prepare_generated_length( - self, - generation_config, - has_default_max_length, - has_default_min_length, - model_input_name, - input_ids_length, - inputs_tensor, - ): - """Prepared max and min length in generation configs to avoid clashes between similar attributes""" - - if generation_config.max_new_tokens is not None: - if not has_default_max_length and generation_config.max_length is not None: - logger.warning( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) - generation_config.max_length = generation_config.max_new_tokens + input_ids_length - - # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length - # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`` - elif ( - model_input_name == "inputs_embeds" - and input_ids_length != inputs_tensor.shape[1] - and not self.config.is_encoder_decoder - ): - generation_config.max_length -= inputs_tensor.shape[1] - - # same for min length - if generation_config.min_new_tokens is not None: - if not has_default_min_length: - logger.warning( - f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(=" - f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) - generation_config.min_length = generation_config.min_new_tokens + input_ids_length - - elif ( - model_input_name == "inputs_embeds" - and input_ids_length != inputs_tensor.shape[1] - and not self.config.is_encoder_decoder - ): - generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0) - - return generation_config - - def _prepare_generation_config( - self, generation_config: Optional[GenerationConfig], **kwargs: Dict - ) -> Tuple[GenerationConfig, Dict]: - """ - Prepares the base generation config, then applies any generation configuration options from kwargs. This - function handles retrocompatibility with respect to configuration files. - """ - # TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400) - # replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with - # the parameterization in `fullgraph=False` so as to enable `fullgraph=True`. - - # priority: `generation_config` argument > `model.generation_config` (the default generation config) - using_model_generation_config = False - if generation_config is None: - # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, - # the following conditions must be met - # 1) the generation config must have been created from the model config (`_from_model_config` field); - # 2) the generation config must have seen no modification since its creation (the hash is the same); - # 3) there are non-default generation parameters in the model config. - # 4) the user must have set new generation parameters in the model config. - # NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation. - if ( - not is_torchdynamo_compiling() - and self.generation_config._from_model_config # 1) - and self.generation_config._original_object_hash == hash(self.generation_config) # 2) - and len(self.config._get_non_default_generation_parameters()) > 0 # 3) - ): - new_generation_config = GenerationConfig.from_model_config(self.config) - if new_generation_config != self.generation_config: # 4) - warnings.warn( - "You have modified the pretrained model configuration to control generation. This is a" - " deprecated strategy to control generation and will be removed in v5." - " Please use and modify the model generation configuration (see" - " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )", - UserWarning, - ) - self.generation_config = new_generation_config - - generation_config = self.generation_config - using_model_generation_config = True - - # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` - # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an - # exception will be raised in `_validate_model_kwargs` - if not is_torchdynamo_compiling(): - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model - if not using_model_generation_config: - if generation_config.bos_token_id is None: - generation_config.bos_token_id = self.generation_config.bos_token_id - if generation_config.eos_token_id is None: - generation_config.eos_token_id = self.generation_config.eos_token_id - if generation_config.pad_token_id is None: - generation_config.pad_token_id = self.generation_config.pad_token_id - if generation_config.decoder_start_token_id is None: - generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id - else: - model_kwargs = kwargs - - return generation_config, model_kwargs - - def _get_initial_cache_position(self, input_ids, model_kwargs): - """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" - # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` - if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: - cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 - elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder: - cache_position = ( - torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 - ) - else: - cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 - - past_length = 0 - if model_kwargs.get("past_key_values") is not None: - cache = model_kwargs["past_key_values"] - past_length = 0 - if not isinstance(cache, Cache): - past_length = cache[0][0].shape[2] - elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: - past_length = cache.get_seq_length() - - # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, - # end-to-end compilation will yield bad results because `cache_position` will be incorrect. - if not is_torchdynamo_compiling(): - cache_position = cache_position[past_length:] - - model_kwargs["cache_position"] = cache_position - return model_kwargs - - def _get_cache( - self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs - ) -> Cache: - """ - Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a - new `generate` call requires a larger cache or uses a different batch size. - - Returns the resulting cache object. - """ - cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] - requires_cross_attention_cache = ( - self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None - ) - - if hasattr(self, "_cache"): - cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache - - if cache_implementation == "sliding_window": - max_cache_len = min(self.config.sliding_window, max_cache_len) - - need_new_cache = ( - not hasattr(self, "_cache") - or (not isinstance(cache_to_check, cache_cls)) - or cache_to_check.batch_size != batch_size - ) - if cache_implementation != "mamba": - need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len - - if requires_cross_attention_cache and hasattr(self, "_cache"): - need_new_cache = ( - need_new_cache - or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1] - ) - - if need_new_cache: - if hasattr(self.config, "_pre_quantization_dtype"): - cache_dtype = self.config._pre_quantization_dtype - else: - if not is_torchdynamo_compiling(): - cache_dtype = self.dtype - else: - # NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`. - # Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative - # models. May cause trobles with non-text modalities. - cache_dtype = self.get_output_embeddings().weight.dtype - - def get_layer_device_map(execution_device_map: Optional[dict] = None): - if execution_device_map is None: - return None - elif len(execution_device_map) == 1 and "" in execution_device_map: - return {idx: execution_device_map[""] for idx in range(self.config.num_hidden_layers)} - layer_device_map = {} - for layer in execution_device_map: - for idx in range(self.config.num_hidden_layers): - if f".{idx}." in f"{layer}.": - layer_device_map[idx] = execution_device_map[layer] - break - for idx in range(self.config.num_hidden_layers): - if idx not in layer_device_map: - raise RuntimeError(f"layer {idx} has not been mapped to a device.") - return layer_device_map - - execution_device_map = None - # Taken from dispatch_model from accelerate. - # This is needed here if we don't want to make changes in accelerate in order to save execution_device - # For offloaded case, we need to get the execution device, not just the device where it is offloaded - if hasattr(self, "hf_device_map"): - main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0] - execution_device_map = { - name: main_device if device in ["cpu", "disk"] else device - for name, device in self.hf_device_map.items() - } - layer_device_map = get_layer_device_map(execution_device_map) - - cache_kwargs = { - "config": self.config.get_text_config(), - "batch_size": batch_size, - "max_cache_len": max_cache_len, - "device": device, - "dtype": cache_dtype, - "layer_device_map": layer_device_map, - } - self._cache = cache_cls(**cache_kwargs) - if requires_cross_attention_cache: - encoder_kwargs = cache_kwargs.copy() - encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1] - self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs)) - else: - self._cache.reset() - return self._cache - - def _supports_default_dynamic_cache(self) -> bool: - """ - Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. - This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which - uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in - order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed - for `HybridMambaAttentionDynamicCache`). - """ - return ( - self._supports_cache_class - and "jamba" not in self.__class__.__name__.lower() - and "zamba" not in self.__class__.__name__.lower() - ) - - def _prepare_cache_for_generation( - self, - generation_config: GenerationConfig, - model_kwargs: Dict, - assistant_model: "PreTrainedModel", - batch_size: int, - max_cache_length: int, - device: torch.device, - ) -> bool: - """ - Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is - instantiated, writes it to `model_kwargs`, under the name expected by the model. - """ - - cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" - requires_cross_attention_cache = ( - self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None - ) - - # Quick escape route 1: if the user specifies a cache, we only need to: - # a) check for conflicting `generate` arguments - # b) convert to the new cache format (if the user passes a legacy cache and model supports it) - user_defined_cache = model_kwargs.get(cache_name) - if user_defined_cache is not None: - if generation_config.cache_implementation is not None: - raise ValueError( - f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " - "Cache object) is unsupported. Please use only one of the two." - ) - if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache(): - model_kwargs[cache_name] = ( - DynamicCache.from_legacy_cache(user_defined_cache) - if not requires_cross_attention_cache - else EncoderDecoderCache.from_legacy_cache(user_defined_cache) - ) - return - - # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in - # `generation_config.validate()`) - if generation_config.use_cache is False: - return - - # Quick escape route 3: model that only supports legacy caches = nothing to prepare - if not self._supports_default_dynamic_cache(): - if generation_config.cache_implementation is not None: - warnings.warn( - "This model does not support `Cache` instances, it only supports the legacy cache format (tuple " - f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be " - "ignored.", - UserWarning, - ) - return - - # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation` - - # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, - # which is only supported in dynamic caches atm - if assistant_model is not None and generation_config.cache_implementation is not None: - logger.warning_once( - "An assistant model is provided, using a dynamic cache instead of a cache of type=" - f"'{generation_config.cache_implementation}'." - ) - generation_config.cache_implementation = None - - if generation_config.cache_implementation is not None: - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if generation_config.cache_implementation == "static" and not self._supports_static_cache: - raise ValueError( - "This model does not support `cache_implementation='static'`. Please check the following " - "issue: https://github.com/huggingface/transformers/issues/28981" - ) - model_kwargs[cache_name] = self._get_cache( - cache_implementation=generation_config.cache_implementation, - batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, - max_cache_len=max_cache_length, - device=device, - model_kwargs=model_kwargs, - ) - elif generation_config.cache_implementation == "quantized": - if not self._supports_quantized_cache: - raise ValueError( - "This model does not support the quantized cache. If you want your model to support quantized " - "cache, please open an issue and tag @zucchini-nlp." - ) - - cache_config = ( - generation_config.cache_config - if generation_config.cache_config is not None - else QuantizedCacheConfig() - ) - cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] - - # if cache_config.backend == "quanto" and not (is_optimum_quanto_available() or is_quanto_available()): - if cache_config.backend == "quanto" and not is_optimum_quanto_available(): - raise ImportError( - "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. " - "Please install it via with `pip install optimum-quanto`" - ) - elif cache_config.backend == "HQQ" and not is_hqq_available(): - raise ImportError( - "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " - "Please install it via with `pip install hqq`" - ) - - model_kwargs[cache_name] = cache_class(cache_config) - elif generation_config.cache_implementation == "offloaded": - model_kwargs[cache_name] = OffloadedCache() - - # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that - # keeps copying the cache thus using much more memory - else: - model_kwargs[cache_name] = ( - DynamicCache() - if not requires_cross_attention_cache - else EncoderDecoderCache(DynamicCache(), DynamicCache()) - ) - - def _supports_num_logits_to_keep(self) -> bool: - """ - Return True if the current model supports the keyword argument `num_logits_to_keep` in forward() - to save memory. Checking it in this way allows to avoid using a new model attribute. - """ - return "num_logits_to_keep" in set(inspect.signature(self.forward).parameters.keys()) - - def _prepare_special_tokens( - self, - generation_config: GenerationConfig, - kwargs_has_attention_mask: Optional[bool] = None, - device: Optional[Union[torch.device, str]] = None, - ): - """ - Prepares the special tokens for generation, overwriting the generation config with their processed versions - converted to tensor. - - Note that `generation_config` is changed in place and stops being serializable after this method is called. - That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the - function). However, if called outside `generate`, consider creating a copy of `generation_config` first. - """ - - # Convert special tokens to tensors - def _tensor_or_none(token, device=None): - if token is None: - return token - - device = device if device is not None else self.device - if isinstance(token, torch.Tensor): - return token.to(device) - return torch.tensor(token, device=device, dtype=torch.long) - - bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) - eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) - pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) - decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device) - - # for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892) - if self.config.is_encoder_decoder: - decoder_start_token_tensor = ( - decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor - ) - - # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). - if eos_token_tensor is not None and eos_token_tensor.ndim == 0: - eos_token_tensor = eos_token_tensor.unsqueeze(0) - - # Set pad token if unset (and there are conditions to do so) - if pad_token_tensor is None and eos_token_tensor is not None: - if not is_torchdynamo_compiling(): - if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") - pad_token_tensor = eos_token_tensor[0] - - # Sanity checks/warnings - if self.config.is_encoder_decoder and decoder_start_token_tensor is None: - raise ValueError( - "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." - ) - if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow - if ( - eos_token_tensor is not None - and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any() - ): - if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: - logger.warning_once( - "The attention mask is not set and cannot be inferred from input because pad token is same as " - "eos token. As a consequence, you may observe unexpected behavior. Please pass your input's " - "`attention_mask` to obtain reliable results." - ) - if eos_token_tensor is not None and ( - torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any() - ): - logger.warning( - f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation " - "will not stop until the maximum length is reached. Depending on other flags, it may even crash." - ) - - # Update generation config with the updated special tokens tensors - # NOTE: this must be written into a different attribute name than the one holding the original special tokens - # (in their non-tensor form), in order to enable end-to-end compilation. See - # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations - generation_config._bos_token_tensor = bos_token_tensor - generation_config._eos_token_tensor = eos_token_tensor - generation_config._pad_token_tensor = pad_token_tensor - generation_config._decoder_start_token_tensor = decoder_start_token_tensor - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, - negative_prompt_ids: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - r""" - - Generates sequences of token ids for models with a language modeling head. - - - - Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the - model's default generation configuration. You can override any `generation_config` by passing the corresponding - parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - - For an overview of generation strategies and code examples, check out the [following - guide](../generation_strategies). - - - - Parameters: - inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): - The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the - method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` - should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of - `input_ids`, `input_values`, `input_features`, or `pixel_values`. - generation_config ([`~generation.GenerationConfig`], *optional*): - The generation configuration to be used as base parametrization for the generation call. `**kwargs` - passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which has the following loading - priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model - configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s - default values, whose documentation should be checked to parameterize generation. - logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and - generation config. If a logit processor is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complements the default stopping criteria built from arguments and a - generation config. If a stopping criteria is passed that is already created with the arguments or a - generation config an error is thrown. If your stopping criteria depends on the `scores` input, make - sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is - intended for advanced users. - prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): - If provided, this function constraints the beam search to allowed tokens only at each step. If not - provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and - `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned - on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful - for constrained generation conditioned on the prefix, as described in [Autoregressive Entity - Retrieval](https://arxiv.org/abs/2010.00904). - synced_gpus (`bool`, *optional*): - Whether to continue running the while loop until max_length. Unless overridden, this flag will be set - to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid - deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`. - assistant_model (`PreTrainedModel`, *optional*): - An assistant model that can be used to accelerate generation. The assistant model must have the exact - same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model - is much faster than running generation with the model you're calling generate from. As such, the - assistant model should be much smaller. - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - The negative prompt needed for some processors such as CFG. The batch size must match the input batch - size. This is an experimental feature, subject to breaking API changes in future versions. - negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Attention_mask for `negative_prompt_ids`. - kwargs (`Dict[str, Any]`, *optional*): - Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be - forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder - specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. - - Return: - [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` - or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`. - - If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.GenerateDecoderOnlyOutput`], - - [`~generation.GenerateBeamDecoderOnlyOutput`] - - If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.GenerateEncoderDecoderOutput`], - - [`~generation.GenerateBeamEncoderDecoderOutput`] - """ - - # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - self._validate_model_class() - tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria - assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation - - generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) - self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) - - # 2. Set generation parameters if not already defined - if synced_gpus is None: - synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 - - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) - requires_attention_mask = "encoder_outputs" not in model_kwargs - kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None - - # 3. Define model inputs - inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( - inputs, generation_config.bos_token_id, model_kwargs - ) - batch_size = inputs_tensor.shape[0] - - device = inputs_tensor.device - self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) - - # decoder-only models must use left-padding for batched generation. - if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): - # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` - # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. - if ( - generation_config._pad_token_tensor is not None - and batch_size > 1 - and len(inputs_tensor.shape) == 2 - and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 - ): - logger.warning( - "A decoder-only architecture is being used, but right-padding was detected! For correct " - "generation results, please set `padding_side='left'` when initializing the tokenizer." - ) - - # 4. Define other model kwargs - # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are - # generating the first new token or not, and we only want to use the embeddings for the first new token) - if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": - generation_config.use_cache = True - - if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor - ) - elif kwargs_has_attention_mask: - # TODO (joao): generalize this check with other types of inputs - if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2: - raise ValueError("`attention_mask` passed to `generate` must be 2D.") - - if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: - # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` - model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, model_kwargs, model_input_name, generation_config - ) - - # 5. Prepare `input_ids` which will be used for auto-regressive generation - if self.config.is_encoder_decoder: - input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( - batch_size=batch_size, - model_input_name=model_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=generation_config._decoder_start_token_tensor, - device=inputs_tensor.device, - ) - else: - input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") - - if generation_config.token_healing: - input_ids = self.heal_tokens(input_ids, tokenizer) - - if streamer is not None: - streamer.put(input_ids.cpu()) - - # 6. Prepare `max_length` depending on other stopping criteria. - input_ids_length = input_ids.shape[-1] - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None - generation_config = self._prepare_generated_length( - generation_config=generation_config, - has_default_max_length=has_default_max_length, - has_default_min_length=has_default_min_length, - model_input_name=model_input_name, - inputs_tensor=inputs_tensor, - input_ids_length=input_ids_length, - ) - - # If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole - # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding - # dynamically overrides this value as it can need more than the last token logits - if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: - model_kwargs["num_logits_to_keep"] = 1 - - self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) - - # 7. Prepare the cache. - # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. - # - different models have a different cache name expected by the model (default = "past_key_values") - # - `max_length`, prepared above, is used to determine the maximum cache length - # TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format) - cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" - user_defined_cache = model_kwargs.get(cache_name) - max_cache_length = generation_config.max_length - if ( - inputs_tensor.shape[1] != input_ids_length - and model_input_name == "inputs_embeds" - and not self.config.is_encoder_decoder - ): - max_cache_length += inputs_tensor.shape[1] - self._prepare_cache_for_generation( - generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device - ) - - # 8. determine generation mode - generation_mode = generation_config.get_generation_mode(assistant_model) - - if streamer is not None and (generation_config.num_beams > 1): - raise ValueError( - "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." - ) - - if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: - warnings.warn( - "You are calling .generate() with the `input_ids` being on a device type different" - f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" - f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." - " Please make sure that you have put `input_ids` to the" - f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" - " running `.generate()`.", - UserWarning, - ) - - # 9. prepare logits processors and stopping criteria - prepared_logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_length, - encoder_input_ids=inputs_tensor, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - device=inputs_tensor.device, - model_kwargs=model_kwargs, - negative_prompt_ids=negative_prompt_ids, - negative_prompt_attention_mask=negative_prompt_attention_mask, - ) - prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs - ) - - # Set model_kwargs `use_cache` so we can use it later in forward runs - model_kwargs["use_cache"] = generation_config.use_cache - - # 10. go into different generation modes - if generation_mode == GenerationMode.ASSISTED_GENERATION: - if generation_config.num_return_sequences > 1: - raise ValueError( - "num_return_sequences has to be 1 when doing assisted generate, " - f"but is {generation_config.num_return_sequences}." - ) - if batch_size > 1: - raise ValueError("assisted generate is only supported for batch_size = 1") - if not model_kwargs["use_cache"]: - raise ValueError("assisted generate requires `use_cache=True`") - if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]: - raise ValueError("assisted generate is not supported with Static cache classes`") - if self._is_stateful: - # In assisted generation we need the ability to confirm whether the model would pick certain tokens, - # which is not possible with stateful models (they can't reset to a previous subset of generated text) - raise ValueError( - f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" - ) - - # 11. Get the candidate generator, given the parameterization - candidate_generator = self._get_candidate_generator( - generation_config=generation_config, - input_ids=input_ids, - inputs_tensor=inputs_tensor, - assistant_model=assistant_model, - logits_processor=logits_processor, - target_tokenizer=tokenizer, - assistant_tokenizer=assistant_tokenizer, - model_kwargs=model_kwargs, - ) - - # 12. run assisted generate - result = self._assisted_decoding( - input_ids, - candidate_generator=candidate_generator, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, - **model_kwargs, - ) - elif generation_mode == GenerationMode.DOLA_GENERATION: - if self._is_stateful: - # DoLa decoding was not designed for stateful models, and would require some changes - raise ValueError( - f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}" - ) - result = self._dola_decoding( - input_ids, - dola_layers=generation_config.dola_layers, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, - **model_kwargs, - ) - - elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: - if not model_kwargs["use_cache"]: - raise ValueError("Contrastive search requires `use_cache=True`") - if self._is_stateful: - # Just like assisted generation, we need to be able to rollback to a previous state (see comment above) - raise ValueError( - f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" - ) - - result = self._contrastive_search( - input_ids, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, - **model_kwargs, - ) - - elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. expand input_ids with `num_return_sequences` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_return_sequences, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - - # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) - result = self._sample( - input_ids, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, - **model_kwargs, - ) - - elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): - # 11. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=generation_config.num_beams, - device=inputs_tensor.device, - length_penalty=generation_config.length_penalty, - do_early_stopping=generation_config.early_stopping, - num_beam_hyps_to_keep=generation_config.num_return_sequences, - max_length=generation_config.max_length, - ) - - # 12. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - - # 13. run beam sample - result = self._beam_search( - input_ids, - beam_scorer, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH: - # 11. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=generation_config.num_beams, - device=inputs_tensor.device, - length_penalty=generation_config.length_penalty, - do_early_stopping=generation_config.early_stopping, - num_beam_hyps_to_keep=generation_config.num_return_sequences, - num_beam_groups=generation_config.num_beam_groups, - max_length=generation_config.max_length, - ) - # 12. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - # 13. run beam search - result = self._group_beam_search( - input_ids, - beam_scorer, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH: - final_constraints = [] - if generation_config.constraints is not None: - final_constraints = generation_config.constraints - - if generation_config.force_words_ids is not None: - - def typeerror(): - raise ValueError( - "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]` " - f"of positive integers, but is {generation_config.force_words_ids}." - ) - - if ( - not isinstance(generation_config.force_words_ids, list) - or len(generation_config.force_words_ids) == 0 - ): - typeerror() - - for word_ids in generation_config.force_words_ids: - if isinstance(word_ids[0], list): - if not isinstance(word_ids, list) or len(word_ids) == 0: - typeerror() - if any(not isinstance(token_ids, list) for token_ids in word_ids): - typeerror() - if any( - any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) - for token_ids in word_ids - ): - typeerror() - - constraint = DisjunctiveConstraint(word_ids) - else: - if not isinstance(word_ids, list) or len(word_ids) == 0: - typeerror() - if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): - typeerror() - - constraint = PhrasalConstraint(word_ids) - final_constraints.append(constraint) - - # 11. prepare beam search scorer - constrained_beam_scorer = ConstrainedBeamSearchScorer( - constraints=final_constraints, - batch_size=batch_size, - num_beams=generation_config.num_beams, - device=inputs_tensor.device, - length_penalty=generation_config.length_penalty, - do_early_stopping=generation_config.early_stopping, - num_beam_hyps_to_keep=generation_config.num_return_sequences, - max_length=generation_config.max_length, - ) - # 12. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - # 13. run beam search - result = self._constrained_beam_search( - input_ids, - constrained_beam_scorer=constrained_beam_scorer, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - # Convert to legacy cache format if requested - if ( - generation_config.return_legacy_cache is not False # Should check for `True` after v4.47 - and not is_torchdynamo_compiling() - and hasattr(result, "past_key_values") - and hasattr(result.past_key_values, "to_legacy_cache") - and result.past_key_values.to_legacy_cache is not None - ): - # handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type) - should_convert_cache = generation_config.return_legacy_cache - is_user_defined_cache = user_defined_cache is not None - is_default_cache_type = ( - type(result.past_key_values) == DynamicCache # noqa E721 - or ( - isinstance(result.past_key_values, EncoderDecoderCache) - and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 - and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721 - ) - ) - if not is_user_defined_cache and is_default_cache_type: - logger.warning_once( - "From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` " - "instance instead by default (as opposed to the legacy tuple of tuples format). If you want to " - "keep returning the legacy format, please set `return_legacy_cache=True`." - ) - should_convert_cache = True - if should_convert_cache: - result.past_key_values = result.past_key_values.to_legacy_cache() - return result - - def _has_unfinished_sequences( - self, - this_peer_finished: bool, - synced_gpus: bool, - device: torch.device, - cur_len: Optional[int] = None, - max_length: Optional[int] = None, - ) -> bool: - """ - Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is - fed through `this_peer_finished`. ZeRO stage 3-friendly. - """ - # torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile, - # although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria) - # TODO (joao): remove this when torch's support for control flow is not experimental (https://pytorch.org/docs/stable/generated/torch.cond.html) - if is_torchdynamo_compiling(): - return cur_len < max_length - else: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - return False - elif this_peer_finished: - return False - return True - - def heal_tokens( - self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None - ) -> torch.LongTensor: - r""" - Generates sequences of token ids for models with a language modeling head. - Parameters: - input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation. - tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids. - Return: - `torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension. - """ - if tokenizer is None: - raise ValueError( - " When generating with token healing, you must pass the model's tokenizer to the `tokenizer` " - "argument of `generate`." - ) - - bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id - vocab_trie = ExtensionsTrie(tokenizer.get_vocab()) - generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id) - - # assumption: leading/trailing whitespace is not meaningful, so the prompts are - # stripped before re-tokenizing to desensitize generation to whitespace artefacts - prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)] - input_ids = tokenizer( - prompts, - return_tensors="pt", - padding=True, - ).input_ids.to(input_ids.device) - - # replace bos with pad to not condition healing on it - input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids) - - """ - the latter code assumes the input_ids is not empty, - input_id has to be checked if contains elements - """ - if input_ids.numel() == 0: - return input_ids - - tail_ids = input_ids[:, -1].tolist() - - space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0] - # tail tokens are used for a prefix search, thus, whitespaces are replaced with - # their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace - tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids) - - for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)): - batch_ids = input_ids[batch_idx] - if torch.all(batch_ids == pad_token_id).item(): - continue # skip empty sequences (all pad ids) - - # apply bias for alternatives (extensions) to the tail token - """ - seq_bias key has to be tuple with int so have to use - tokenizer function to convert str to int - """ - seq_bias = { - (tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok) - } - - if len(seq_bias) == 1: - continue # skip if there are no token alternatives to heal with - - # slightly favor original token to limit aggressive healing e.g. 'http' -> 'https' - seq_bias[(tail_id,)] += 1.0 - generation_config.update(sequence_bias=seq_bias) - - trimmed_ids = batch_ids[:-1] - - """ - the latter code assumes trimmed_ids is not empty - so have to check the its element count - """ - if trimmed_ids.numel() == 0: - continue - - # if the prompt is a single (non-pad) token, regenerate from bos - if len(batch_ids[batch_ids != pad_token_id]) == 1: - trimmed_ids[-1] = bos_token_id - - input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config) - - return input_ids - - def _dola_decoding( - self, - input_ids: torch.LongTensor, - dola_layers: Union[str, List[int]], - logits_processor: LogitsProcessorList, - stopping_criteria: StoppingCriteriaList, - generation_config: GenerationConfig, - synced_gpus: bool, - streamer: "BaseStreamer", - **model_kwargs, - ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **dola decoding** and can be - used for decoder-only text models. - The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language - Models" (https://arxiv.org/abs/2309.03883) in ICLR 2024. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - dola_layers (`Union[str, List[int]]`): - The candidate layers used in contrasting layers of DoLa. It can be either 1) 'low' or 'high', which - means the lower part or higher part of the model layers, respectively, or 2) a list of layer indices - to be used for candidate layers. The 0-th layer is the word embedding layer of the model. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - generation_config ([`~generation.GenerationConfig`]): - The generation configuration to be used as parametrization of the decoding method. - synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed to avoid deadlocking with - `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - model_kwargs: - Additional model specific keyword arguments will be forwarded to the `forward` function of the model. - If model is an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] - or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - """ - - if self.config.is_encoder_decoder: - raise ValueError("DoLa decoding is only available for decoder-only models.") - # init values - - pad_token_id = generation_config._pad_token_tensor - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) - do_sample = generation_config.do_sample - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # keep track of which sequences are already finished - batch_size = input_ids.shape[0] - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - - this_peer_finished = False - - # prepare layers for DoLa decoding - final_layer = self.config.get_text_config().num_hidden_layers - # if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, - # as the early exit from word embeddings will become identity function - # if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th - # layer otherwise. Notice that DoLa does not help shallow models much. - if not self.config.tie_word_embeddings: - start_layer = 0 - elif final_layer > 2: - start_layer = 2 - elif final_layer == 2: - start_layer = 1 - else: - start_layer = 0 - - # For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` - # are used for `'low'` and `'high'` layers, respectively. - # For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for - # `'low'` and `'high'` layers, respectively. - if isinstance(dola_layers, str) and dola_layers == "low": - if start_layer == final_layer // 2: - candidate_premature_layers = [start_layer] - else: - candidate_premature_layers = ( - list(range(start_layer, final_layer // 2, 2)) - if final_layer <= 40 - else list(range(start_layer, 20, 2)) - ) - elif isinstance(dola_layers, str) and dola_layers == "high": - candidate_premature_layers = ( - list(range(final_layer // 2, final_layer, 2)) - if final_layer <= 40 - else list(range(final_layer - 20, final_layer, 2)) - ) - # Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. - elif isinstance(dola_layers, list): - candidate_premature_layers = [i for i in dola_layers if i < final_layer] - else: - raise ValueError("dola_layers must be either 'low', 'high' or a list of integers.") - - lm_head = self.get_output_embeddings() - if lm_head is None: - raise ValueError("DoLa is not supported for models that don't have output embeddings.") - - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=True, - ) - - # .float() is needed to retain precision for later logits manipulations - final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone().float() - final_logits = outputs.logits[:, -1, :].float() - candidate_premature_logits = {} - for candidate_premature_layer in candidate_premature_layers: - candidate_premature_logits[candidate_premature_layer] = lm_head( - outputs.hidden_states[candidate_premature_layer][:, -1, :] - ).to(final_logits.device) - - # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if synced_gpus and this_peer_finished: - continue - - next_token_logits = _dola_select_contrast( - candidate_premature_layers, candidate_premature_logits, final_logits - ) - next_token_logits = next_token_logits.to(input_ids.device) - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_scores,) - if output_logits: - raw_logits += (final_layer_next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - if do_sample: # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: # argmax - next_tokens = torch.argmax(next_token_scores, dim=-1) - - # finished sentences should have their next token be a padding token - if has_eos_stopping_criteria: - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if streamer is not None: - streamer.put(next_tokens.cpu()) - - # stop when each sentence is finished - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - this_peer_finished = unfinished_sequences.max() == 0 - - if streamer is not None: - streamer.end() - - if return_dict_in_generate: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids - - @torch.no_grad() - def _contrastive_search( - self, - input_ids: torch.LongTensor, - logits_processor: LogitsProcessorList, - stopping_criteria: StoppingCriteriaList, - generation_config: GenerationConfig, - synced_gpus: bool, - streamer: Optional["BaseStreamer"], - **model_kwargs, - ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **contrastive search** and can - be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - generation_config ([`~generation.GenerationConfig`]): - The generation configuration to be used as parametrization of the decoding method. - synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed to avoid deadlocking with - `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - model_kwargs: - Additional model specific keyword arguments will be forwarded to the `forward` function of the model. - If model is an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] - or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - """ - # init values - has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) - top_k = generation_config.top_k - penalty_alpha = generation_config.penalty_alpha - pad_token_id = generation_config._pad_token_tensor - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - sequential = generation_config.low_memory - - # init attention / hidden states / scores tuples - raw_logits = () if (return_dict_in_generate and output_logits) else None - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - # keep track of which sequences are already finished - batch_size = input_ids.shape[0] - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - - # Create cosine_matrix_mask based on the attention_mask - cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long) - if self.config.is_encoder_decoder: - if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None: - cosine_matrix_mask = model_kwargs["decoder_attention_mask"] - else: - cosine_matrix_mask = model_kwargs["attention_mask"] - cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0) - - this_peer_finished = False - - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; - # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step - if model_kwargs.get("past_key_values") is None or ( - isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache)) - and model_kwargs["past_key_values"].get_seq_length() == 0 - ): - # prepare inputs - model_kwargs["use_cache"] = True - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save - # the `encoder_outputs` - outputs = self( - **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions - ) - - # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with - # previous tokens) - if self.config.is_encoder_decoder: - last_hidden_states = outputs.decoder_hidden_states[-1] - else: - last_hidden_states = outputs.hidden_states[-1] - - # next logit for contrastive search to select top-k candidate tokens - # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration - # (the clone itself is always small) - # .float() is needed to retain precision for later logits manipulations - logit_for_next_step = outputs.logits[:, -1, :].clone().float() - logit_for_next_step = logit_for_next_step.to(input_ids.device) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - - if not sequential: - # Expands model inputs top_k times, for batched forward passes (akin to beam search). - _, model_kwargs = self._expand_inputs_for_generation( - expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs - ) - - past_key_values = model_kwargs.get("past_key_values") - if past_key_values is None: - raise ValueError( - f"{self.__class__.__name__} does not support caching and therefore **can't** be used " - "for contrastive search." - ) - elif ( - not isinstance(past_key_values[0], (tuple, torch.Tensor)) - or past_key_values[0][0].shape[0] != batch_size - ): - raise ValueError( - f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " - "used for contrastive search without further modifications." - ) - - # contrastive_search main logic start: - # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by - # degeneration penalty - processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step) - next_probs = nn.functional.softmax(processed_logit_for_next_step, dim=-1) - - top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_logits: - raw_logits += (logit_for_next_step,) - if output_scores: - scores += (processed_logit_for_next_step,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # This is needed to properly delete outputs.logits which may be very large for this first iteration - # Otherwise a reference to outputs.logits is kept all along until after the next call to self.forward() - del outputs - - if not sequential: - # Replicates the new past_key_values to match the `top_k` candidates - past = model_kwargs["past_key_values"] - # If it is a static cache, modify it in-place layer after layer to save memory - if isinstance(past, DynamicCache) or ( - isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache) - ): - past.batch_repeat_interleave(top_k) - else: - new_key_values = [] - for layer in past: - items = [] - # item is either the key or the value matrix - for item in layer: - items.append(item.repeat_interleave(top_k, dim=0)) - new_key_values.append(tuple(items)) - - past = tuple(new_key_values) - - model_kwargs["past_key_values"] = past - - if sequential: - all_outputs = [] - for i in range(top_k): - # compute the candidate tokens by the language model and collect their hidden_states - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs) - - outputs = self( - **next_model_inputs, - return_dict=True, - output_hidden_states=True, - output_attentions=output_attentions, - ) - if isinstance(outputs["past_key_values"], DynamicCache) or ( - isinstance(outputs["past_key_values"], EncoderDecoderCache) - and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache) - ): - # Remove past K-V from output since we don't need to stack later - outputs["past_key_values"] = None - # Remove last token from past K-V since we don't want to append it at this point - model_kwargs["past_key_values"].crop(-1) - - all_outputs.append(outputs) - outputs = stack_model_outputs(all_outputs, self.config.get_text_config()) - - else: - # compute the candidate tokens by the language model and collect their hidden_states - # assembles top_k_ids into batch of size k - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) - - outputs = self( - **next_model_inputs, - return_dict=True, - output_hidden_states=True, - output_attentions=output_attentions, - ) - - # This is essential to avoid having a last reference to the big past K-V and double the necessary memory - # in the next loop - del next_model_inputs - - # name is different for encoder-decoder and decoder-only models - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states - else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states - - # .float() is needed to retain precision for later logits manipulations - logits = outputs.logits[:, -1, :].float() - context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) - - # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the - # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't - # introduce (noticeable) slowdowns on single-device runs. - selected_idx = _ranking_fast( - context_hidden, next_hidden, top_k_probs, cosine_matrix_mask, penalty_alpha, top_k - ) - cosine_matrix_mask = torch.cat( - [cosine_matrix_mask, cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1))], dim=-1 - ) - selected_idx = selected_idx.to("cpu") - - # This will be used instead of the previous inneficient torch.stack(torch.split()) - augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)]) - - # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing - # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores - # (model confidence minus degeneration penalty); (6) decoder hidden_states - next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] - next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) - next_hidden = next_hidden[range(batch_size), selected_idx, :] - last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) - - next_decoder_hidden_states = () - for layer in full_hidden_states: - layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] - next_decoder_hidden_states += (layer,) - - # generate past_key_values cache of only the selected token - if sequential: - next_model_input = self.prepare_inputs_for_generation( - top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs - ) - - selected_outputs = self( - **next_model_input, - return_dict=True, - output_hidden_states=False, - output_attentions=False, - ) - next_past_key_values = selected_outputs["past_key_values"] - - else: - _, next_past_key_values = self._extract_past_from_model_output(outputs) - # Do it in-place layer per layer to save memory - if isinstance(next_past_key_values, DynamicCache) or ( - isinstance(next_past_key_values, EncoderDecoderCache) - and isinstance(next_past_key_values.self_attention_cache, DynamicCache) - ): - next_past_key_values.batch_select_indices(augmented_idx) - else: - new_key_values = [] - for layer in next_past_key_values: - items = [] - # item is either the key or the value matrix - for item in layer: - items.append(item[augmented_idx, ...]) - new_key_values.append(tuple(items)) - - next_past_key_values = tuple(new_key_values) - - logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] - logit_for_next_step = logit_for_next_step.to(input_ids.device) - - # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration - if self.config.is_encoder_decoder: - next_step_cross_attentions = () - next_step_decoder_attentions = () - if output_attentions: - for layer in outputs.cross_attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] - next_step_cross_attentions += (layer,) - for layer in outputs.decoder_attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] - next_step_decoder_attentions += (layer,) - outputs = Seq2SeqLMOutput( - past_key_values=next_past_key_values, - decoder_hidden_states=next_decoder_hidden_states, - decoder_attentions=next_step_decoder_attentions or None, - cross_attentions=next_step_cross_attentions or None, - ) - else: - next_step_attentions = () - if output_attentions: - for layer in outputs.attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] - next_step_attentions += (layer,) - outputs = CausalLMOutputWithPast( - past_key_values=next_past_key_values, - hidden_states=next_decoder_hidden_states, - attentions=next_step_attentions or None, - ) - # contrastive_search main logic end - - # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if synced_gpus and this_peer_finished: - continue - - # finished sentences should have their next token be a padding token - if has_eos_stopping_criteria: - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if streamer is not None: - streamer.put(next_tokens.cpu()) - - # stop when each sentence is finished - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - this_peer_finished = unfinished_sequences.max() == 0 - - if streamer is not None: - streamer.end() - - if return_dict_in_generate: - # Contrastive search works by forward looking at the next token, so we need to exclude it from - # `past_key_values` to be consistent with the other decoding methods - if model_kwargs.get("past_key_values") is not None: - if isinstance(model_kwargs["past_key_values"], DynamicCache) or ( - isinstance(model_kwargs["past_key_values"], EncoderDecoderCache) - and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache) - ): - model_kwargs["past_key_values"].crop(-1) - else: - past_key_values = [] - for layer in model_kwargs["past_key_values"]: - layer_past_key_values = [] - for item in layer: - layer_past_key_values.append(item[..., :-1, :]) - past_key_values.append(tuple(layer_past_key_values)) - model_kwargs["past_key_values"] = tuple(past_key_values) - - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids - - def _sample( - self, - input_ids: torch.LongTensor, - logits_processor: LogitsProcessorList, - stopping_criteria: StoppingCriteriaList, - generation_config: GenerationConfig, - synced_gpus: bool, - streamer: Optional["BaseStreamer"], - **model_kwargs, - ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and - can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - generation_config ([`~generation.GenerationConfig`]): - The generation configuration to be used as parametrization of the decoding method. - synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed to avoid deadlocking with - `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: - A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - """ - # init values - pad_token_id = generation_config._pad_token_tensor - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - max_length = generation_config.max_length - has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) - do_sample = generation_config.do_sample - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape - this_peer_finished = False - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - - while self._has_unfinished_sequences( - this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length - ): - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - - # forward pass to get next token - outputs = self(**model_inputs, return_dict=True) - - # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if synced_gpus and this_peer_finished: - continue - - # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration - # (the clone itself is always small) - next_token_logits = outputs.logits.clone()[:, -1, :].float() - next_token_logits = next_token_logits.to(input_ids.device) - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_scores,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # token selection - if do_sample: - probs = nn.functional.softmax(next_token_scores, dim=-1) - # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(next_token_scores, dim=-1) - - # finished sentences should have their next token be a padding token - if has_eos_stopping_criteria: - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if streamer is not None: - streamer.put(next_tokens.cpu()) - - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - this_peer_finished = unfinished_sequences.max() == 0 - cur_len += 1 - - # This is needed to properly delete outputs.logits which may be very large for first iteration - # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration - del outputs - - if streamer is not None: - streamer.end() - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids - - def _temporary_reorder_cache(self, past_key_values, beam_idx): - """ - Temporary function to handle the different types of cache reordering processes while we roll out `Cache`. - - TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need - for this function, with `Cache.reorder_cache` being the sole remaining code path - """ - model_class = self.__class__.__name__.lower() - # Exception 1: code path for models using the legacy cache format - if isinstance(past_key_values, (tuple, list)): - past_key_values = self._reorder_cache(past_key_values, beam_idx) - # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their - # cache format is standardized, to avoid adding complexity to the codebase. - elif "gptbigcode" in model_class: - if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)): - raise ValueError( - f"Using an unsupported cache format with {model_class}. Currently, it only supports the " - "legacy tuple format or `DynamicCache`" - ) - past_key_values = self._reorder_cache(past_key_values, beam_idx) - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - # Standard code path: use the `Cache.reorder_cache` - else: - past_key_values.reorder_cache(beam_idx) - return past_key_values - - def _beam_search( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: LogitsProcessorList, - stopping_criteria: StoppingCriteriaList, - generation_config: GenerationConfig, - synced_gpus: bool, - **model_kwargs, - ) -> Union[GenerateBeamOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **beam search decoding** and - can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`: - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - generation_config ([`~generation.GenerationConfig`]): - The generation configuration to be used as parametrization of the decoding method. - synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed to avoid deadlocking with - `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - """ - # init values - pad_token_id = generation_config._pad_token_tensor - eos_token_id = generation_config._eos_token_tensor - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - sequential = generation_config.low_memory - do_sample = generation_config.do_sample - - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - beam_indices = ( - tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None - ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens - # of the first beam are considered to avoid sampling the exact same tokens across all beams. - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False - - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - - # print("model_kwargs: ", model_kwargs) - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - - # if sequential is True, split the input to batches of batch_size and run sequentially - if sequential: - if any( - model_name in self.__class__.__name__.lower() - for model_name in [ - "fsmt", - "reformer", - "ctrl", - "gpt_bigcode", - "transo_xl", - "xlnet", - "cpm", - "jamba", - ] - ): - raise RuntimeError( - f"Currently generation for {self.__class__.__name__} is not supported " - f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature." - ) - - inputs_per_sub_batches = _split_model_inputs( - model_inputs, - split_size=batch_size, - full_batch_size=batch_beam_size, - config=self.config.get_text_config(), - ) - outputs_per_sub_batch = [ - self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches - ] - - outputs = stack_model_outputs(outputs_per_sub_batch, self.config.get_text_config()) - - else: # Unchanged original behavior - outputs = self(**model_inputs, return_dict=True) - - # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue - - # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration - # (the clone itself is always small) - # .float() is needed to retain precision for later logits manipulations - next_token_logits = outputs.logits[:, -1, :].clone().float() - next_token_logits = next_token_logits.to(input_ids.device) - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( - next_token_scores_processed - ) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_scores_processed,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - - # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1 - # non eos token per beam. - n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 - n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams - if do_sample: - # import time - # start = time.time() - probs = nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep) - next_token_scores = torch.gather(next_token_scores, -1, next_tokens) - next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) - next_tokens = torch.gather(next_tokens, -1, _indices) - # print("*"*20, probs.shape, n_tokens_to_keep, next_token_scores.shape, next_tokens.shape) - # print("*"*20, time.time() - start) - else: - next_token_scores, next_tokens = torch.topk( - next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True - ) - - next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - - # This is needed to properly delete outputs.logits which may be very large for first iteration - # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration - # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory - # (that way the memory peak does not include outputs.logits) - del outputs - - if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], beam_idx - ) - - if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - this_peer_finished = True - - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - - if self.config.is_encoder_decoder: - return GenerateBeamEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateBeamDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return sequence_outputs["sequences"] - - def _group_beam_search( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: LogitsProcessorList, - stopping_criteria: StoppingCriteriaList, - generation_config: GenerationConfig, - synced_gpus: bool, - **model_kwargs, - ): - r""" - Generates sequences of token ids for models with a language modeling head using **diverse beam search - decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - generation_config ([`~generation.GenerationConfig`]): - The generation configuration to be used as parametrization of the decoding method. - synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed to avoid deadlocking with - `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). - model_kwargs: - Additional model specific kwargs that will be forwarded to the `forward` function of the model. If - model is an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - """ - # init values - pad_token_id = generation_config._pad_token_tensor - eos_token_id = generation_config._eos_token_tensor - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - - num_beams = beam_scorer.num_beams - num_beam_groups = beam_scorer.num_beam_groups - num_sub_beams = num_beams // num_beam_groups - batch_size = len(beam_scorer._beam_hyps) // num_beam_groups - device = input_ids.device - - batch_beam_size, cur_len = input_ids.shape - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - - if return_dict_in_generate and output_scores: - beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] - else: - beam_indices = None - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in - # the same group don't produce same tokens every time. - beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) - beam_scores[:, ::num_sub_beams] = 0 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False - - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - # predicted tokens in cur_len step - current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) - - # indices which will form the beams in the next time step - reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) - - # do one decoder step on all beams of all sentences in batch - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - - outputs = self(**model_inputs, return_dict=True) - - # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue - - if output_scores: - processed_score = torch.zeros_like(outputs.logits[:, -1, :]) - if output_logits: - # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration - # (the clone itself is always small) - raw_logit_score = outputs.logits[:, -1, :].clone() - raw_logit_score = raw_logit_score.to(input_ids.device) - - for beam_group_idx in range(num_beam_groups): - group_start_idx = beam_group_idx * num_sub_beams - group_end_idx = min(group_start_idx + num_sub_beams, num_beams) - group_size = group_end_idx - group_start_idx - - # indices of beams of current group among all sentences in batch - batch_group_indices = [] - - for batch_idx in range(batch_size): - batch_group_indices.extend( - [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] - ) - group_input_ids = input_ids[batch_group_indices] - - # select outputs of beams of current group only - # No need to clone() the logits here as they will not retain outputs.logits at the end of the loop - # .float() is needed to retain precision for later logits manipulations - next_token_logits = outputs.logits[batch_group_indices, -1, :].float() - next_token_logits = next_token_logits.to(input_ids.device) - - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * group_size, vocab_size) - vocab_size = next_token_scores.shape[-1] - - next_token_scores_processed = logits_processor( - group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx - ) - next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) - next_token_scores = next_token_scores.expand_as(next_token_scores_processed) - - if output_scores: - processed_score[batch_group_indices] = next_token_scores_processed - - # reshape for beam search - next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) - - # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. - n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 - next_token_scores, next_tokens = torch.topk( - next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True - ) - - next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") - next_tokens = next_tokens % vocab_size - - # stateless - process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None - beam_outputs = beam_scorer.process( - group_input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=process_beam_indices, - group_index=beam_group_idx, - decoder_prompt_len=decoder_prompt_len, - ) - beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - if return_dict_in_generate and output_scores: - beam_indices[beam_group_idx] = tuple( - beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) - ) - - input_ids[batch_group_indices] = group_input_ids[beam_idx] - group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - current_tokens[batch_group_indices] = group_input_ids[:, -1] - - # (beam_idx // group_size) -> batch_idx - # (beam_idx % group_size) -> offset of idx inside the group - reordering_indices[batch_group_indices] = ( - num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") - + group_start_idx - + (beam_idx % group_size) - ) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (processed_score,) - if output_logits: - raw_logits += (raw_logit_score,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) - - # This is needed to properly delete outputs.logits which may be very large for first iteration - # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration - # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory - # (that way the memory peak does not include outputs.logits) - del outputs - - if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], reordering_indices - ) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - this_peer_finished = True - - final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=final_beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - - if self.config.is_encoder_decoder: - return GenerateBeamEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateBeamDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return sequence_outputs["sequences"] - - def _constrained_beam_search( - self, - input_ids: torch.LongTensor, - constrained_beam_scorer: ConstrainedBeamSearchScorer, - logits_processor: LogitsProcessorList, - stopping_criteria: StoppingCriteriaList, - generation_config: GenerationConfig, - synced_gpus: bool, - **model_kwargs, - ) -> Union[GenerateBeamOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **constrained beam search - decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - constrained_beam_scorer (`ConstrainedBeamSearchScorer`): - A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation, while satisfying a list of positive constraints. For more information, the - documentation of [`ConstrainedBeamSearchScorer`] should be read. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - generation_config ([`~generation.GenerationConfig`]): - The generation configuration to be used as parametrization of the decoding method. - synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed to avoid deadlocking with - `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - """ - # init values - pad_token_id = generation_config._pad_token_tensor - eos_token_id = generation_config._eos_token_tensor - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - - batch_size = len(constrained_beam_scorer._beam_hyps) - num_beams = constrained_beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - beam_indices = ( - tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None - ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens - # of the first beam are considered to avoid sampling the exact same tokens across all beams. - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False - - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - - outputs = self(**model_inputs, return_dict=True) - - # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue - - # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration - # (the clone itself is always small) - # .float() is needed to retain precision for later logits manipulations - next_token_logits = outputs.logits[:, -1, :].clone().float() - next_token_logits = next_token_logits.to(input_ids.device) - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor(input_ids, next_token_scores) - - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( - next_token_scores_processed - ) - - scores_for_all_vocab = next_token_scores.clone() - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_scores,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - - # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. - n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 - next_token_scores, next_tokens = torch.topk( - next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True - ) - - next_indices = (next_tokens / vocab_size).long() - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = constrained_beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - scores_for_all_vocab, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - - # This is needed to properly delete outputs.logits which may be very large for first iteration - # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration - # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory - # (that way the memory peak does not include outputs.logits) - del outputs - - if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], beam_idx - ) - - if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - - # increase cur_len - cur_len = cur_len + 1 - - if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - this_peer_finished = True - - sequence_outputs = constrained_beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - if self.config.is_encoder_decoder: - return GenerateBeamEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateBeamDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return sequence_outputs["sequences"] - - def _assisted_decoding( - self, - input_ids: torch.LongTensor, - candidate_generator: CandidateGenerator, - logits_processor: LogitsProcessorList, - stopping_criteria: StoppingCriteriaList, - generation_config: GenerationConfig, - synced_gpus: bool, - streamer: Optional["BaseStreamer"], - **model_kwargs, - ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **greedy decoding** or - **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a - candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text - models. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - candidate_generator (`CandidateGenerator`): - A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For - more information, the documentation of [`CandidateGenerator`] should be read. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - generation_config ([`~generation.GenerationConfig`]): - The generation configuration to be used as parametrization of the decoding method. - synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed to avoid deadlocking with - `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - model_kwargs: - Additional model specific keyword arguments will be forwarded to the `forward` function of the model. - If model is an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - """ - # init values - do_sample = generation_config.do_sample - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - # keep track of which sequences are already finished - batch_size = input_ids.shape[0] - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - - this_peer_finished = False - is_first_iteration = True # to preserve the same API in the output as other generation methods - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - cur_len = input_ids.shape[-1] - - # 1. Fetch candidate sequences from a `CandidateGenerator` - candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) - - if candidate_logits is not None: - candidate_logits = candidate_logits.to(self.device) - - candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] - is_done_candidate = stopping_criteria(candidate_input_ids, None) - - # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain - # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, - # we use this forward pass to also pick the subsequent logits in the original model. - - # 2.1. Prepare the model inputs - candidate_kwargs = copy.copy(model_kwargs) - candidate_kwargs = _prepare_attention_mask( - candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder - ) - candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) - if "cache_position" in candidate_kwargs: - candidate_kwargs["cache_position"] = torch.cat( - ( - candidate_kwargs["cache_position"], - torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long), - ), - dim=0, - ) - - model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) - if "num_logits_to_keep" in model_inputs: - model_inputs["num_logits_to_keep"] = candidate_length + 1 - - # 2.2. Run a forward pass on the candidate sequence - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - - outputs = self(**model_inputs) - - # 2.3. Process the new logits - # .float() is needed to retain precision for later logits manipulations - new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present - new_logits = new_logits.to(input_ids.device) - next_token_logits = new_logits.clone() - if len(logits_processor) > 0: - for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) - - # 3. Select the accepted tokens. There are two possible cases: - # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) - # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). - if do_sample and candidate_logits is not None: - valid_tokens, n_matches = _speculative_sampling( - candidate_input_ids, - candidate_logits, - candidate_length, - new_logits, - is_done_candidate, - ) - - # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the - # original model logits with the candidate tokens. We can keep the candidate tokens until the first - # mismatch, or until the max length is reached. - else: - if do_sample: - probs = new_logits.softmax(dim=-1) - selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] - else: - selected_tokens = new_logits.argmax(dim=-1) - - candidate_new_tokens = candidate_input_ids[:, cur_len:] - n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() - - # Ensure we don't generate beyond max_len or an EOS token - if is_done_candidate and n_matches == candidate_length: - n_matches -= 1 - valid_tokens = selected_tokens[:, : n_matches + 1] - - # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated - # by the model after the last candidate match is also valid, as it is generated from a correct sequence. - # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there - # is no match. - - # 4.1. Get the valid continuation, after the matching tokens - input_ids = torch.cat((input_ids, valid_tokens), dim=-1) - if streamer is not None: - streamer.put(valid_tokens.cpu()) - new_cur_len = input_ids.shape[-1] - - # 4.2. Discard past key values relative to unused assistant tokens - new_cache_size = new_cur_len - 1 - outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) - - # 5. Update the candidate generation strategy if needed - candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) - - # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - num_new_tokens=n_matches + 1, - ) - if synced_gpus and this_peer_finished: - continue - - # Store scores, attentions and hidden_states when required - # Assistant: modified to append one tuple element per token, as in the other generation methods. - if return_dict_in_generate: - newly_added_length = n_matches + 1 - if output_scores: - scores += tuple(new_logits[:, i, :] for i in range(newly_added_length)) - if output_logits: - raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length)) - - newly_added_length = new_cur_len if is_first_iteration else newly_added_length - if output_attentions: - if self.config.is_encoder_decoder: - cross_attentions = _split_model_outputs( - cross_attentions, outputs.cross_attentions, cur_len, newly_added_length - ) - decoder_attentions = _split_model_outputs( - decoder_attentions, - outputs.decoder_attentions, - cur_len, - newly_added_length, - is_decoder_attention=True, - ) - # some (V)LLMs have hard requirement on SDPA and thus never return attn - elif outputs.attentions[0] is not None: - decoder_attentions = _split_model_outputs( - decoder_attentions, - outputs.attentions, - cur_len, - newly_added_length, - is_decoder_attention=True, - ) - if output_hidden_states: - if self.config.is_encoder_decoder: - decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length - ) - else: - decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length - ) - - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - this_peer_finished = unfinished_sequences.max() == 0 - is_first_iteration = False - - if streamer is not None: - streamer.end() - - if ( - hasattr(candidate_generator, "assistant_model") - and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic" - ): - candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( - candidate_generator.num_assistant_tokens - ) - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids - - -def _speculative_sampling( - candidate_input_ids, - candidate_logits, - candidate_length, - new_logits, - is_done_candidate, -): - """ - Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns - the selected tokens, as well as the number of candidate matches. - - NOTE: Unless otherwise stated, the variable names match those in the paper. - """ - new_candidate_input_ids = candidate_input_ids[:, -candidate_length:] - # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens - # selected by the assistant, respectively. - q = candidate_logits.softmax(dim=-1) - q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) - p = new_logits.softmax(dim=-1) - p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) - probability_ratio = p_i / q_i - - # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller - # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio - # (= keep with p = probability_ratio). Keep all the tokens until the first rejection - r_i = torch.rand_like(probability_ratio) - is_accepted = r_i <= probability_ratio - n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 - - # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) - if is_done_candidate and n_matches == candidate_length: - # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model - # due to acceptance on EOS we fix `n_matches` - n_matches -= 1 - valid_tokens = new_candidate_input_ids[:, : n_matches + 1] - else: - # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. - gamma = candidate_logits.shape[1] - p_n_plus_1 = p[:, n_matches, :] - if n_matches < gamma: - q_n_plus_1 = q[:, n_matches, :] - p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0) - p_prime.div_(p_prime.sum()) - else: - p_prime = p_n_plus_1 - t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] - - # The selected tokens include the matches (if any) plus the next sampled tokens - if n_matches > 0: - valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) - else: - valid_tokens = t - - return valid_tokens, n_matches - - -def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): - """ - Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple - where each member corresponds to a single generated token. - """ - # Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the - # prompt. - if len(outputs) == 0: - new_tuple = () - for layer in new_outputs: - last_dim_size = cur_len if is_decoder_attention else layer.shape[-1] - new_tuple += (layer[..., :cur_len, :last_dim_size],) - outputs += (new_tuple,) - # The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly - cur_len += 1 - added_len -= cur_len - - for i in range(added_len): - new_tuple = () - for layer in new_outputs: - last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1] - new_tuple += (layer[..., i : i + 1, :last_dim_size],) - outputs += (new_tuple,) - return outputs - - -def _ranking_fast( - context_hidden: torch.FloatTensor, - next_hidden: torch.FloatTensor, - next_top_k_probs: torch.FloatTensor, - cosine_matrix_mask: torch.LongTensor, - alpha: float, - beam_width: int, -) -> torch.FloatTensor: - """ - Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described - in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each - row in the batch. - """ - norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) - norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) - cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] - - # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions) - # Using a large negative value for masked positions - cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype) - cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min - cosine_matrix = cosine_matrix + cosine_matrix_mask - - degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K] - next_top_k_probs = next_top_k_probs.view(-1) # [B*K] - contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty - contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] - _, selected_idx = contrastive_score.max(dim=-1) # [B] - return selected_idx - - -def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int = None): - """ - Takes care of three cases: - 1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim - 2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and - return a list of tuples - 3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and - return a list of tuples of tuples - (see documentation of ModelOutput) - """ - if data is None: - return [None] * (full_batch_size // split_size) - if isinstance(data, torch.Tensor): - return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] - # New cache format - elif isinstance(data, DynamicCache) or ( - isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache) - ): - return data.batch_split(full_batch_size, split_size, num_hidden_layers) - elif isinstance(data, tuple): - # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) - if isinstance(data[0], tuple): - return [ - tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data) - for i in range(0, full_batch_size, split_size) - ] - - else: - return [ - tuple(sub_tensor[i : i + split_size] for sub_tensor in data) - for i in range(0, full_batch_size, split_size) - ] - else: - raise TypeError(f"Unexpected attribute type: {type(data)}") - - -def _split_model_inputs( - model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int, config: PretrainedConfig -) -> List[Union[ModelOutput, Dict]]: - """ - Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split - size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from - previous forward pass. - """ - # Edge case: if model_input is None, return a list of Nones - # this happens with Whisper where encoder_outputs is None - if model_input is None: - return [model_input] * (full_batch_size // split_size) - # Infer the class from the object - model_output_cls = type(model_input) - if (full_batch_size % split_size) != 0: - raise ValueError("`full_batch_size` must be divisible by `split_size`") - - if split_size > full_batch_size: - raise ValueError("`split_size` must be smaller or equal to `full_batch_size`") - - # Helper function to split tensors or tuples of tensors - - # Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them - keys = ( - model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys() - ) - # We only keep keys that are in the model_input - keys = [k for k in keys if k in model_input] - # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a - # ModelOutput object. - # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] - keys_to_ignore = ["cache_position", "encoder_outputs", "num_logits_to_keep"] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] - - num_hidden_layers = config.get_text_config().num_hidden_layers - - # we split the tensors and tuples of tensors - data_split_list = [ - {k: _split(model_input[k], full_batch_size, num_hidden_layers, split_size)[i] for k in non_bool_keys} - for i in range(full_batch_size // split_size) - ] - # bool values are the same and replicated for each split - bool_data = {k: model_input[k] for k in bool_keys} - # encoder_outputs is a ModelOutput object and should be split by its own - if "encoder_outputs" in model_input: - encoder_outputs_split = _split_model_inputs( - model_input["encoder_outputs"], split_size, full_batch_size, config.get_text_config() - ) - data_split_list = [ - {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) - ] - # num_logits_to_keep should be replicated for each split, similar to bool values - if "num_logits_to_keep" in model_input: - data_split_list = [ - {**data_split, "num_logits_to_keep": model_input["num_logits_to_keep"]} for data_split in data_split_list - ] - - # Convert each dictionary in the list to an object of the inferred class - split_model_inputs: List[Union[ModelOutput, Dict]] = [ - model_output_cls(**data_split, **bool_data) for data_split in data_split_list - ] - - return split_model_inputs - - -def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConfig) -> ModelOutput: - """ - Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the - specific ModelOutput subclass from the list provided. - """ - if not model_outputs: - raise ValueError("Input list is empty.") - - # Infer the class from the first object in the list - model_output_cls = type(model_outputs[0]) - num_hidden_layers = config.get_text_config().num_hidden_layers - - # Ensure all objects are of the same type - if not all(isinstance(obj, model_output_cls) for obj in model_outputs): - raise ValueError("All elements in the list should be of the same type.") - - # Helper function to concat tensors or tuples of tensors - def _concat(data): - """ - Reverse of `_split` function above. - """ - if any(data is None for data in data): - return None - if isinstance(data[0], torch.Tensor): - return torch.cat(data, dim=0) - # New cache format - elif isinstance(data[0], DynamicCache): - return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) - elif isinstance(data[0], EncoderDecoderCache): - return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) - elif isinstance(data[0], tuple): - # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) - if isinstance(data[0][0], tuple): - return tuple( - tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0]))) - for i in range(len(data[0])) - ) - else: - return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0]))) - elif isinstance(data[0], (int, float)): - # If the elements are integers or floats, return a tensor - return torch.tensor(data) - else: - raise TypeError(f"Unexpected attribute type: {type(data[0])}") - - # Use a dictionary comprehension to gather attributes from all objects and concatenate them - concatenated_data = { - k: _concat([getattr(model_output, k) for model_output in model_outputs]) - for k in model_output_cls.__dataclass_fields__.keys() - } - - # Return a new object of the inferred class with the concatenated attributes - return model_output_cls(**concatenated_data) - - -def _relative_top_filter( - scores: torch.FloatTensor, - baseline_scores: torch.FloatTensor, - relative_top: float = 0.1, - filter_value: float = -float("Inf"), - base_filter_value=-1e-3, - min_tokens_to_keep: int = 1, -) -> torch.FloatTensor: - """ - Reference: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235 - Apply filtering to only keep tokens with a probability above a certain threshold. The threshold is defined as `relative_top` * max probability in the distribution. - """ - scores_normalized = scores.log_softmax(dim=-1) - baseline_scores_normalized = baseline_scores.log_softmax(dim=-1) - sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True) - min_thresh = sorted_logits[..., min_tokens_to_keep - 1] - probs_max = torch.max(scores_normalized, dim=-1).values - probs_thresh = probs_max + np.log(relative_top) - probs_thresh = torch.min(min_thresh, probs_thresh) - probs_thresh = probs_thresh.unsqueeze(-1) - baseline_scores_normalized[scores_normalized < probs_thresh] = base_filter_value - scores_normalized[scores_normalized < probs_thresh] = filter_value - return scores_normalized, baseline_scores_normalized - - -def _dola_select_contrast( - candidate_premature_layers: List[int], - candidate_premature_logits: Dict[int, torch.FloatTensor], - final_logits: torch.FloatTensor, -) -> torch.FloatTensor: - if len(candidate_premature_layers) == 1: - base_logits = candidate_premature_logits[candidate_premature_layers[0]] - final_logits, base_logits = _relative_top_filter(final_logits, base_logits) - logits = final_logits - base_logits - return logits - - # 1. Stacking all premature_layers into a new dimension - stacked_premature_layers = torch.stack([candidate_premature_logits[i] for i in candidate_premature_layers], dim=0) - - # 2. Calculate the softmax values for mature_layer and all premature_layers - # shape: (batch_size, vocab_size) - softmax_mature_layer = F.softmax(final_logits, dim=-1) - # shape: (num_premature_layers, batch_size, vocab_size) - softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1) - - # 3. Calculate the average distribution - # shape: (num_premature_layers, batch_size, vocab_size) - avg_dist = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers) - - # 4. Calculate log-softmax for the KL divergence - # shape: (batch_size, vocab_size) - log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1) - # shape: (num_premature_layers, batch_size, vocab_size) - log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1) - - # 5. Calculate the KL divergences and then the JS divergences - # shape: (num_premature_layers, batch_size) - kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], avg_dist, reduction="none").mean(-1) - # shape: (num_premature_layers, batch_size) - kl2 = F.kl_div(log_softmax_premature_layers, avg_dist, reduction="none").mean(-1) - js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size) - - # 6. Reduce the batchmean - js_divs = js_divs.mean(-1) # shape: (num_premature_layers,) - premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())] - - base_logits = candidate_premature_logits[premature_layer] - final_logits, base_logits = _relative_top_filter(final_logits, base_logits) - logits = final_logits - base_logits - return logits