import copy import inspect import sys from abc import ABC, abstractmethod from collections import OrderedDict from dataclasses import dataclass, field from typing import Any, Dict, Optional, Tuple, Union import hydra.utils import torch from hydra.errors import InstantiationException from transformers import ( AutoTokenizer, DynamicCache, GenerationConfig, LogitsProcessorList, PretrainedConfig, PreTrainedModel, StoppingCriteriaList, ) from transformers.cache_utils import Cache from transformers.generation.utils import GenerateOutput from transformers.modeling_outputs import ModelOutput # Local imports not used, but added here so that HF push_to_hub adds them to model repo # noinspection PyUnresolvedReferences from .backbone_automodel import AutoModelFromPreTrained # noqa: F401 from .backbone_encoder_decoder import ( # noqa: F401 LLMasEncoderDecoder, LLMasEncoderDecoderShareKV, ) from .noise_schedule_noise_schedules import ( # noqa: F401 CosineNoise, ExponentialNoise, LinearNoise, LogarithmicNoise, ) @dataclass class DenoiserInput(OrderedDict): """Input to the denoiser model.""" xt: torch.LongTensor # (B, L) token_ids x0: Optional[torch.LongTensor] = None # (B, L) token_ids (not used in gen.) attention_mask: Optional[torch.FloatTensor] = None past_key_values: Optional[Union[torch.FloatTensor, Cache]] = None context_mask: Optional[torch.FloatTensor] = None tokens_mask: Optional[torch.FloatTensor] = None # (B, L) t: Optional[torch.FloatTensor] = None # (B,) | # (B, L) alpha_t: Optional[torch.FloatTensor] = None # (B,) | (B, 1|L) | (B, 1|L, 1) alpha_t_prime: Optional[torch.FloatTensor] = None # (B,) | (B, 1|L) | (B, 1|L, 1) backbone_kwargs: dict[str, Any] = field(default_factory=dict) @dataclass class LossAndNllOutput(OrderedDict): """Loss output for denoiser models.""" loss: torch.FloatTensor nlls: torch.FloatTensor other_loss_terms: dict = field(default_factory=dict) @dataclass class DenoiserOutput(ModelOutput): """Output of the denoiser model.""" denoiser_output: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None tokens_mask: Optional[torch.FloatTensor] = None # Which tokens contribute to loss past_key_values: Optional[Cache] = None loss: Optional[torch.FloatTensor] = None nlls: Optional[torch.FloatTensor] = None other_loss_terms: Optional[dict[str, Any]] = None class DenoiserConfig(PretrainedConfig): """Configuration class for Denoiser models. This class is used to initialize the model and contains all the necessary parameters for the model's architecture. """ model_type = "denoiser" def __init__( self, length: Optional[int] = None, backbone_config: Optional[Dict[str, Any]] = None, noise_config: Optional[Dict[str, Any]] = None, tokenization_config: Optional[Dict[str, Any]] = None, time_conditioned_backbone: Optional[bool] = None, attn_backend: str = "sdpa", # "sdpa", "flash_attention_2", "flex_attention" train_on_context: bool = False, **kwargs, ): super().__init__(**kwargs) for v in [ "vocab_size", "mask_token_id", "pad_token_id", "bos_token_id", "eos_token_id", "pad_vocab_size_multiple", ]: if tokenization_config is not None and ( getattr(self, v, None) is None or v in tokenization_config ): setattr(self, v, tokenization_config.get(v, None)) else: setattr(self, v, None) self.backbone_config = backbone_config self.noise_config = noise_config self.tokenization_config = tokenization_config self.length = length self.time_conditioned_backbone = time_conditioned_backbone self.attn_backend = attn_backend self.train_on_context = train_on_context class Denoiser(ABC, PreTrainedModel): """Abstract base class for denoising models. This class defines the interface for AR, Diffusion, and Flow-based parametrizations. """ config_class = DenoiserConfig def __init__( self, config: DenoiserConfig, **kwargs, ): """ Initialize the Denoiser with a configuration and optional dataset type. Parameters: config (Any): Configuration object for the model. """ super().__init__(config) self.config = config self.vocab_size = config.vocab_size self.mask_token_id = config.mask_token_id self.pad_token_id = config.pad_token_id self.bos_token_id = config.bos_token_id self.eos_token_id = config.eos_token_id try: self.backbone = hydra.utils.instantiate(config.backbone_config) except InstantiationException: # When using HF and `from_pretrained`, the modules specified in `_target_` # fields in our configs are already being imported under a name with the # following format: transformers_modules... # When hydra attempts to instantiate and calls importlib under the hood, the # desired module is not found. # The snippet below aliases the desired module, enabling seamless use of # `hydra.utils.instantiate`. sys_modules = copy.deepcopy(list(sys.modules.keys())) repo_root_module = ".".join(__name__.split(".")[:-1]) for name in sys_modules: if name.startswith(repo_root_module): short = name.split(".")[-1] if short not in sys.modules: sys.modules[short] = sys.modules[name] del sys_modules self.backbone = hydra.utils.instantiate(config.backbone_config) self.tokenizer = AutoTokenizer.from_pretrained( config.tokenizer_name, trust_remote_code=True, ) self.noise_schedule = ( hydra.utils.instantiate(config.noise_config) if config.noise_config is not None else None ) self.time_conditioned_backbone = ( config.time_conditioned_backbone if config.time_conditioned_backbone is not None else "noise" in inspect.getfullargspec(self.backbone.forward).args ) # List that can contain any parameters that should not be pushed to HF, # e.g., registered buffers for static attention masks self.skip_params_for_push = [] @abstractmethod def _prepare_inputs( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.FloatTensor] = None, context_mask: Optional[torch.FloatTensor] = None, t: Optional[torch.FloatTensor] = None, past_key_values: Optional[Cache] = None, ) -> DenoiserInput: """ Prepare inputs for the model. Parameters: input_ids (LongTensor): Input tensor to the model. attention_mask (Optional[FloatTensor]): Attention mask for the model. t (Optional[FloatTensor]): Time step for the model. past_key_values (Optional[Cache]): Past key values for the model. Returns: Denoiser inputs. """ raise NotImplementedError("Denoiser subclasses must implement _prepare_inputs") def _prepare_inputs_inference( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, context: Optional[torch.LongTensor] = None, context_mask: Optional[torch.FloatTensor] = None, cache: Optional[Dict[str, Any]] = None, **backbone_kwargs: Any, ) -> Tuple[DenoiserInput, Dict[str, Any]]: raise NotImplementedError( "Denoiser subclasses must implement _prepare_inputs_inference" ) # assert input_ids is not None or context is not None, ( # "Must provide either input_ids or context." # ) # cache = cache if cache is not None else {} # past_key_values = cache.pop("past_key_values", DynamicCache()) # if context is not None: # if input_ids is not None: # if context_mask is None: # context_mask = torch.cat( # [torch.ones_like(context), torch.zeros_like(input_ids)], dim=-1 # ) # input_ids = torch.cat([context, input_ids], dim=-1) # else: # input_ids = context # context_mask = torch.ones_like(input_ids) # if attention_mask is None: # cache_length = self._get_past_key_values_seq_length(past_key_values) # full_seq_length = cache_length + input_ids.shape[-1] # attention_mask = torch.ones( # (input_ids.shape[0], 1, input_ids.shape[1], full_seq_length), # device=input_ids.device, # ) # Make attention mask 4D # attention_mask = self._preprocess_attention_mask( # attention_mask, dtype=torch.float # ) # return DenoiserInput( # xt=input_ids, # attention_mask=attention_mask, # past_key_values=past_key_values, # context_mask=context_mask, # backbone_kwargs=backbone_kwargs, # ), cache @abstractmethod def _compute_loss( self, model_output: torch.FloatTensor, denoiser_inputs: DenoiserInput, **kwargs: Any, ) -> LossAndNllOutput: """ Compute the loss for the denoising model. Parameters: model_output (FloatTensor): Output tensor from self.forward. denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model. Returns: LossAndNllOutput: loss (FloatTensor) and nlls (FloatTensor). """ raise NotImplementedError("Denoiser subclasses must implement _compute_loss") def _forward( self, backbone_output: torch.FloatTensor, denoiser_inputs: DenoiserInput, **kwargs: Any, ) -> torch.FloatTensor: """ Forward pass for the denoiser model returns probabilities over denoised sequence. Some classes may need to override this method. Parameters: backbone_output (FloatTensor): Output tensor from the backbone model. denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model. Returns: Model outputs (FloatTensor). """ return torch.log_softmax(backbone_output, dim=-1) # type: ignore def _backbone_forward( self, denoiser_inputs: DenoiserInput, **backbone_kwargs: Any, ) -> ModelOutput: """Forward pass for the backbone model (should return logits). Some classes may need to override this method. Parameters: denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model. return_updated_cache (bool): If True, return past_key_values instead of logits. Returns: Backbone output (ModelOutput instance). """ if self.time_conditioned_backbone: return self.backbone( denoiser_inputs.xt, attention_mask=denoiser_inputs.attention_mask, past_key_values=denoiser_inputs.past_key_values, noise=denoiser_inputs.alpha_t, **denoiser_inputs.backbone_kwargs, **backbone_kwargs, ) return self.backbone( denoiser_inputs.xt, attention_mask=denoiser_inputs.attention_mask, past_key_values=denoiser_inputs.past_key_values, **denoiser_inputs.backbone_kwargs, **backbone_kwargs, ) def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.FloatTensor] = None, context_mask: Optional[torch.FloatTensor] = None, t: Optional[torch.FloatTensor] = None, past_key_values: Optional[Cache] = None, compute_loss: Optional[bool] = True, **kwargs, ) -> DenoiserOutput: """ Perform a forward pass through the denoising model and (optionally) compute the loss. Parameters: input_ids (LongTensor): Input tensor to the model. attention_mask (Optional[FloatTensor]): Attention mask for the model. context_mask (Optional[FloatTensor]): Indicator for context tokens. t (Optional[FloatTensor]): Denoising time step for the model. past_key_values (Optional[Cache]): KV cache. compute_loss (Optional[bool]): Flag to compute loss. Returns: DenoiserOutput """ denoiser_inputs = self._prepare_inputs( input_ids=input_ids, attention_mask=attention_mask, context_mask=context_mask, past_key_values=past_key_values, t=t, ) backbone_output = self._backbone_forward(denoiser_inputs, **kwargs) new_past_key_values = getattr(backbone_output, "past_key_values", None) backbone_output = getattr(backbone_output, "logits", backbone_output[0]) denoiser_output = self._forward( backbone_output, denoiser_inputs, **kwargs, ) if compute_loss: loss_and_nll = self._compute_loss( model_output=denoiser_output, denoiser_inputs=denoiser_inputs, **kwargs ) loss = loss_and_nll.loss nlls = loss_and_nll.nlls other_loss_terms = loss_and_nll.other_loss_terms else: loss, nlls = None, None other_loss_terms = {} return DenoiserOutput( denoiser_output=denoiser_output, logits=backbone_output, past_key_values=new_past_key_values, tokens_mask=denoiser_inputs.tokens_mask, loss=loss, nlls=nlls, other_loss_terms=other_loss_terms, ) @staticmethod def _sample_categorical(categorical_probs, do_sample=True): """Helper function to sample from a categorical distribution.""" categorical_probs = categorical_probs.to(torch.float64) if not do_sample: return categorical_probs.argmax(dim=-1) gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()).to( categorical_probs.dtype ) return (categorical_probs / gumbel_norm).argmax(dim=-1) @staticmethod def _preprocess_attention_mask(attention_mask, dtype): min_dtype = torch.finfo(dtype).min attention_mask = torch.where( (attention_mask == 0.0).bool(), # type: ignore min_dtype, 0.0, ).to(dtype) return attention_mask @staticmethod def _get_past_key_values_seq_length(past_key_values: DynamicCache): seq_length = 0 for i in range(len(past_key_values)): if past_key_values[i][0].shape[0] > 0: # type: ignore seq_length = max( past_key_values[i][0].shape[-2], # type: ignore seq_length, ) return seq_length def update_cache( self, inputs: torch.LongTensor, cache: Optional[Dict[str, Any]] = None, **backbone_kwargs: Any, ) -> Dict[str, Any]: """ Cache the key-value pairs for the context. Args: inputs (torch.LongTensor): The context tensor. cache (Dict[str, Any | None): Cache objects, e.g., past_key_values. Returns: Dict: Updated cache objects, e.g., past_key_values. """ context_input, cache = self._prepare_inputs_inference( input_ids=inputs, cache=cache, return_updated_cache=True, **backbone_kwargs ) backbone_output = self._backbone_forward( context_input, return_updated_cache=True, # Will get absorbed in backbone_kwargs **cache, ) backbone_output = {k: v for k, v in backbone_output.items()} backbone_output.pop("logits", None) # Do not store logits in cache cache = cache | backbone_output return cache @torch.no_grad() def generate( self, inputs: Optional[torch.LongTensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, max_new_tokens: Optional[int] = None, batch_size: Optional[int] = None, device: Optional[str] = None, **kwargs: Any, ) -> Union[GenerateOutput, torch.LongTensor]: """Generates sample from denoising model. Follows signature of transformers.GenerationMixin. """ raise NotImplementedError("Denoiser subclasses must implement generate")