|
|
from functools import partial |
|
|
from typing import Any, Dict, Literal, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from tqdm.auto import tqdm |
|
|
from transformers import ( |
|
|
GenerationConfig, |
|
|
LogitsProcessorList, |
|
|
PreTrainedTokenizer, |
|
|
StoppingCriteriaList, |
|
|
) |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
|
|
|
try: |
|
|
from torch.nn.attention.flex_attention import ( |
|
|
BlockMask, |
|
|
and_masks, |
|
|
create_block_mask, |
|
|
) |
|
|
except ImportError: |
|
|
BlockMask, and_masks, create_block_mask = None, None, None |
|
|
|
|
|
|
|
|
from .denoiser_base import ( |
|
|
Denoiser, |
|
|
DenoiserConfig, |
|
|
DenoiserInput, |
|
|
LossAndNllOutput, |
|
|
) |
|
|
|
|
|
|
|
|
def create_attn_mask(attn_mask): |
|
|
|
|
|
def padding(b, h, q_idx, kv_idx): |
|
|
return attn_mask[b, q_idx] & attn_mask[b, kv_idx] |
|
|
|
|
|
return padding |
|
|
|
|
|
|
|
|
class DiffusionGenerationConfig(GenerationConfig): |
|
|
def __init__( |
|
|
self, |
|
|
num_steps: int = 1000, |
|
|
min_t: float = 1e-5, |
|
|
block_size: Optional[int] = None, |
|
|
first_hitting: bool = False, |
|
|
sampling_strategy: Literal["posterior", "predict_then_noise"] = "posterior", |
|
|
confidence_based_noising: bool = False, |
|
|
confidence_margin_based_noising: bool = False, |
|
|
confidence_threshold: float = 1e6, |
|
|
use_model_output_cache: bool = True, |
|
|
align_inputs_to_blocks: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
"""Generation config with additional parameters relevant for diffusion model |
|
|
sampling. |
|
|
|
|
|
Args: |
|
|
num_steps (int): Number of diffusion / iterative refinement steps. |
|
|
Defaults to 1000. |
|
|
min_t (float): Minimum time to use. |
|
|
Diffusion models use t=1 for noise and t=0 for signal. |
|
|
Setting t=0 exactly can lead to certain numerical instabilities. |
|
|
Defaults to 1e-5. |
|
|
block_size (int): Block size to use for semi-autoregressive decoding. |
|
|
Defaults to None (in which case block_size is set to max_new_tokens). |
|
|
first_hitting (bool): Whether to use first hitting sampler. |
|
|
When set to true, rather than following the diffusion time and sampling |
|
|
from posterior, which can result in no tokens changing between steps, |
|
|
e.g., for masked diffusion, we explicitly determine the next time step |
|
|
at which a token will be decoded / generated. |
|
|
Note: this will negate the `num_steps` parameter, as we will decode one |
|
|
token at a time, hence, when True, num_steps = seq_length |
|
|
(or block_size, for semi-autoregressive). |
|
|
See https://arxiv.org/abs/2409.02908 for details. |
|
|
Defaults to False. |
|
|
sampling_strategy (str): Method for transitioning between latents. |
|
|
Options: |
|
|
- "posterior" - Compute and sample from the posterior |
|
|
q(x_s | x_t, x_theta). |
|
|
- "predict_then_noise" - Sample from the denoising model x_theta, |
|
|
then add back noise to produce x_s. |
|
|
Only implemented for absorbing diffusion. |
|
|
Defaults to "posterior". |
|
|
confidence_based_noising (bool): When using the "predict_then_noise" |
|
|
strategy, whether to add noise to random positions or to those that have |
|
|
the lowest probability under x_theta. |
|
|
Cannot be used in conjunction with confidence_margin_based_noising. |
|
|
Defaults to False. |
|
|
confidence_margin_based_noising (bool): When using the "predict_then_noise" |
|
|
strategy, whether to add noise to random positions or to those that have |
|
|
the lowest probability margins under x_theta, where margin is defined as |
|
|
the absolute difference between the top two probabilities at a given |
|
|
position. |
|
|
See https://arxiv.org/abs/2502.06768 for details. |
|
|
Cannot be used in conjunction with confidence_based_noising. |
|
|
Defaults to False. |
|
|
confidence_threshold (float): Confidence threshold to use for sampling. |
|
|
Any tokens that exceed threshold are decoded. |
|
|
See https://arxiv.org/abs/2505.22618 for details. |
|
|
Defaults to 1e6. |
|
|
use_model_output_cache (bool): Whether to re-use model's output, if sequence |
|
|
is unchanged, because if xt == xs, we can simply re-use the denoising |
|
|
model's outputs and save a function evaluation. |
|
|
Relevant if model.backbone is not time/noise-conditioned. |
|
|
Defaults to True. |
|
|
align_inputs_to_blocks (bool): Whether to align input tokens to block size, |
|
|
e.g., for an input of length C and block size S, context will be C // S, |
|
|
and generation will begin with a block whose first C % S tokens come |
|
|
from the input. |
|
|
kwargs: Keyword arguments passed to `GenerationConfig`. |
|
|
""" |
|
|
super().__init__(**kwargs) |
|
|
self.num_steps = num_steps |
|
|
self.min_t = min_t |
|
|
|
|
|
self.block_size = block_size if block_size is not None else self.max_new_tokens |
|
|
self.first_hitting = first_hitting |
|
|
if self.first_hitting: |
|
|
|
|
|
self.num_steps = min(num_steps, self.block_size) |
|
|
self.sampling_strategy = sampling_strategy |
|
|
assert not confidence_based_noising or not confidence_margin_based_noising, ( |
|
|
"Cannot use both `confidence_based_noising` and" |
|
|
" `confidence_margin_based_noising`." |
|
|
) |
|
|
self.confidence_based_noising = confidence_based_noising |
|
|
self.confidence_margin_based_noising = confidence_margin_based_noising |
|
|
self.confidence_threshold = confidence_threshold |
|
|
self.use_model_output_cache = use_model_output_cache |
|
|
self.align_inputs_to_blocks = align_inputs_to_blocks |
|
|
|
|
|
|
|
|
class D3PMConfig(DenoiserConfig): |
|
|
"""Configuration class for D3PM models.""" |
|
|
|
|
|
model_type = "d3pm" |
|
|
auto_map = { |
|
|
"AutoConfig": "diffusion.D3PMConfig", |
|
|
"AutoModel": "diffusion.D3PM", |
|
|
"AutoModelForMaskedLM": "diffusion.D3PM", |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
keep_clean_bos: Optional[bool] = None, |
|
|
T: int = 1000, |
|
|
diffusion_type: Literal["absorbing", "uniform"] = "absorbing", |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.keep_clean_bos = keep_clean_bos |
|
|
self.diffusion_type = diffusion_type |
|
|
self.T = T |
|
|
|
|
|
|
|
|
class D3PM(Denoiser): |
|
|
"""Denoiser class for D3PM models. |
|
|
|
|
|
This class implements the Denoiser interface for D3PM models. |
|
|
""" |
|
|
|
|
|
config_class = D3PMConfig |
|
|
|
|
|
def __init__(self, config: D3PMConfig, **kwargs): |
|
|
super().__init__(config, **kwargs) |
|
|
self.T = config.T |
|
|
self.diffusion_type = config.diffusion_type |
|
|
self._create_static_mask() |
|
|
|
|
|
def _create_static_mask(self) -> None: |
|
|
static_mask = torch.ones( |
|
|
self.config.length, self.config.length, dtype=torch.bool |
|
|
) |
|
|
self.register_buffer( |
|
|
"static_attention_mask", |
|
|
static_mask, |
|
|
) |
|
|
self.skip_params_for_push.append("static_attention_mask") |
|
|
|
|
|
def _sample_q_xt( |
|
|
self, |
|
|
x0: torch.LongTensor, |
|
|
alpha_t: torch.FloatTensor, |
|
|
context_mask: torch.FloatTensor, |
|
|
) -> torch.LongTensor: |
|
|
"""Sample from the pre-defined forward / noising process. |
|
|
|
|
|
Parameters: |
|
|
x0 (Tensor): Signal / data sample; |
|
|
can potentially include context tokens. |
|
|
alpha_t (Tensor): Amount of signal to retain. |
|
|
context_mask (Tensor): Indicator of context tokens (to remain |
|
|
unchanged). |
|
|
""" |
|
|
move_indices = torch.rand(*x0.shape, device=x0.device) < (1.0 - alpha_t) |
|
|
if self.diffusion_type == "absorbing": |
|
|
xt = torch.where( |
|
|
(move_indices * (1 - context_mask)).bool(), self.mask_token_id, x0 |
|
|
) |
|
|
if self.config.keep_clean_bos: |
|
|
xt[..., 0] = x0[..., 0] |
|
|
return xt |
|
|
if self.diffusion_type == "uniform": |
|
|
xt = torch.randint(0, self.vocab_size, x0.shape, device=x0.device) |
|
|
xt = torch.where(context_mask.bool(), x0, xt) |
|
|
if self.config.keep_clean_bos: |
|
|
xt[..., 0] = x0[..., 0] |
|
|
return xt |
|
|
raise NotImplementedError( |
|
|
f"Diffusion type '{self.diffusion_type}' not implemented." |
|
|
) |
|
|
|
|
|
def _prepare_inputs( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
context_mask: Optional[torch.FloatTensor] = None, |
|
|
t: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
): |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
if context_mask is None: |
|
|
context_mask = torch.zeros_like(attention_mask) |
|
|
|
|
|
if torch.is_floating_point(attention_mask): |
|
|
attention_mask = attention_mask.to(torch.int) |
|
|
context_mask = context_mask.to(torch.int) |
|
|
|
|
|
if t is None: |
|
|
t = torch.rand(input_ids.shape[0], device=input_ids.device) |
|
|
alpha_t, alpha_t_prime = self.noise_schedule(t) |
|
|
while alpha_t.ndim < 2: |
|
|
alpha_t = alpha_t[..., None] |
|
|
alpha_t_prime = alpha_t_prime[..., None] |
|
|
xt = self._sample_q_xt( |
|
|
x0=input_ids, |
|
|
alpha_t=alpha_t, |
|
|
context_mask=context_mask, |
|
|
) |
|
|
if ( |
|
|
context_mask is not None |
|
|
and context_mask.sum() == 0 |
|
|
and (attention_mask == 1).all() |
|
|
): |
|
|
processed_attention_mask = None |
|
|
else: |
|
|
processed_attention_mask = ( |
|
|
self.static_attention_mask[None, ...] |
|
|
& attention_mask[:, None, :] |
|
|
& attention_mask[..., None] |
|
|
)[:, None, ...] |
|
|
processed_attention_mask = self._preprocess_attention_mask( |
|
|
processed_attention_mask, dtype=torch.float |
|
|
) |
|
|
if self.training and self.config.train_on_context: |
|
|
tokens_mask = attention_mask |
|
|
else: |
|
|
tokens_mask = attention_mask * (1 - context_mask) |
|
|
return DenoiserInput( |
|
|
xt=xt, |
|
|
x0=input_ids, |
|
|
attention_mask=processed_attention_mask, |
|
|
context_mask=context_mask, |
|
|
tokens_mask=tokens_mask, |
|
|
t=t, |
|
|
alpha_t=alpha_t, |
|
|
alpha_t_prime=alpha_t_prime, |
|
|
) |
|
|
|
|
|
def _prepare_inputs_inference( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
context: Optional[torch.LongTensor] = None, |
|
|
context_mask: Optional[torch.FloatTensor] = None, |
|
|
cache: Optional[Dict[str, Any]] = None, |
|
|
**backbone_kwargs: Any, |
|
|
) -> Tuple[DenoiserInput, Dict[str, Any]]: |
|
|
assert input_ids is not None or context is not None, ( |
|
|
"Must provide either input_ids or context." |
|
|
) |
|
|
cache = cache if cache is not None else {} |
|
|
past_key_values = cache.pop("past_key_values", DynamicCache()) |
|
|
if context is not None: |
|
|
if input_ids is not None: |
|
|
if context_mask is None: |
|
|
context_mask = torch.cat( |
|
|
[torch.ones_like(context), torch.zeros_like(input_ids)], dim=-1 |
|
|
) |
|
|
input_ids = torch.cat([context, input_ids], dim=-1) |
|
|
else: |
|
|
input_ids = context |
|
|
context_mask = torch.ones_like(input_ids) |
|
|
if attention_mask is None: |
|
|
cache_length = self._get_past_key_values_seq_length(past_key_values) |
|
|
full_seq_length = cache_length + input_ids.shape[-1] |
|
|
attention_mask = torch.ones( |
|
|
(input_ids.shape[0], 1, input_ids.shape[1], full_seq_length), |
|
|
device=input_ids.device, |
|
|
) |
|
|
attention_mask = self._preprocess_attention_mask( |
|
|
attention_mask, dtype=torch.float |
|
|
) |
|
|
return DenoiserInput( |
|
|
xt=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
context_mask=context_mask, |
|
|
backbone_kwargs=backbone_kwargs | {"use_cache": False}, |
|
|
), cache |
|
|
|
|
|
def _forward( |
|
|
self, |
|
|
backbone_output: torch.FloatTensor, |
|
|
denoiser_inputs: DenoiserInput, |
|
|
**kwargs, |
|
|
) -> torch.FloatTensor: |
|
|
return torch.log_softmax(backbone_output, dim=-1) |
|
|
|
|
|
def _compute_loss( |
|
|
self, |
|
|
model_output: torch.FloatTensor, |
|
|
denoiser_inputs: DenoiserInput, |
|
|
**kwargs: Any, |
|
|
) -> LossAndNllOutput: |
|
|
raise NotImplementedError |
|
|
|
|
|
def _sample_prior(self, device, batch_size, length): |
|
|
"""Samples from prior / limiting distribution.""" |
|
|
if self.diffusion_type == "absorbing": |
|
|
return self.mask_token_id * torch.ones( |
|
|
(batch_size, length), dtype=torch.int64, device=device |
|
|
) |
|
|
if self.diffusion_type == "uniform": |
|
|
return torch.randint( |
|
|
0, |
|
|
self.vocab_size, |
|
|
(batch_size, length), |
|
|
device=device, |
|
|
dtype=torch.int64, |
|
|
) |
|
|
raise NotImplementedError( |
|
|
f"Diffusion type '{self.diffusion_type}' not implemented." |
|
|
) |
|
|
|
|
|
def _compute_posterior( |
|
|
self, |
|
|
x: Union[torch.FloatTensor, torch.LongTensor], |
|
|
xt: torch.LongTensor, |
|
|
alpha_t: torch.FloatTensor, |
|
|
alpha_s: torch.FloatTensor, |
|
|
) -> torch.FloatTensor: |
|
|
"""Computes posterior / approximate posterior q(x_s | x_t, x), |
|
|
where x represents clean sequence (as one-hots) or the output of the |
|
|
denoising model. |
|
|
|
|
|
Args: |
|
|
x (Tensor): True (one-hot) / predicted clean signal (B, L, V). |
|
|
xt (Tensor): Noised signal at time t (B, L). |
|
|
alpha_t (Tensor): Noise schedule parameter at time t (B, 1, 1). |
|
|
alpha_s (Tensor): Noise schedule parameter at time s (B, 1, 1). |
|
|
""" |
|
|
if self.diffusion_type == "absorbing": |
|
|
q_xs = x * (alpha_s - alpha_t) |
|
|
q_xs[..., self.mask_token_id] = 1 - alpha_s[..., 0] |
|
|
q_xs /= 1 - alpha_t |
|
|
return q_xs |
|
|
|
|
|
alpha_ts = alpha_t / alpha_s |
|
|
d_alpha = alpha_s - alpha_t |
|
|
xt_one_hot = torch.nn.functional.one_hot(x, self.vocab_size) |
|
|
limiting_distribution = torch.ones_like(xt_one_hot) / self.vocab_size |
|
|
if self.diffusion_type == "uniform": |
|
|
return ( |
|
|
alpha_t * self.vocab_size * x * xt_one_hot |
|
|
+ (alpha_ts - alpha_t) * xt_one_hot |
|
|
+ d_alpha * x |
|
|
+ (1 - alpha_ts) * (1 - alpha_s) * limiting_distribution |
|
|
) / ( |
|
|
alpha_t * self.vocab_size * torch.gather(x, -1, xt[..., None]) |
|
|
+ (1 - alpha_t) |
|
|
) |
|
|
raise NotImplementedError( |
|
|
f"Diffusion type {self.diffusion_type} not implemented." |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _sample_generation_timesteps( |
|
|
generation_config: DiffusionGenerationConfig, |
|
|
max_length: Optional[int] = None, |
|
|
device: Optional[str] = None, |
|
|
) -> torch.FloatTensor: |
|
|
"""Sample timesteps for diffusion generation process.""" |
|
|
if device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if max_length is None: |
|
|
max_length = generation_config.max_new_tokens |
|
|
|
|
|
if ( |
|
|
generation_config.first_hitting |
|
|
|
|
|
and generation_config.sampling_strategy == "posterior" |
|
|
): |
|
|
timesteps = torch.FloatTensor([1.0]) |
|
|
for i in range(max_length, 0, -1): |
|
|
u = torch.rand(1) |
|
|
next_t = timesteps[-1] * u ** (1 / i) |
|
|
timesteps = torch.cat((timesteps, next_t), dim=0) |
|
|
return timesteps[1:].to(device) |
|
|
return torch.linspace( |
|
|
1.0, |
|
|
generation_config.min_t, |
|
|
generation_config.num_steps + 1, |
|
|
device=device, |
|
|
)[:-1] |
|
|
|
|
|
def _generate_unconditional( |
|
|
self, |
|
|
generation_config: DiffusionGenerationConfig, |
|
|
alpha_t: torch.FloatTensor, |
|
|
alpha_s: torch.FloatTensor, |
|
|
denoiser_inputs: Optional[DenoiserInput] = None, |
|
|
model_output_cache: Optional[Dict[str, torch.FloatTensor]] = None, |
|
|
cache: Optional[Dict[str, Any]] = None, |
|
|
running_generation: Optional[torch.LongTensor] = None, |
|
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
|
**kwargs: Any, |
|
|
) -> Tuple[torch.LongTensor, Dict[str, torch.FloatTensor], Dict[str, Any]]: |
|
|
cache = cache if cache is not None else {} |
|
|
if model_output_cache is None: |
|
|
backbone_output = self._backbone_forward( |
|
|
denoiser_inputs, |
|
|
fix_cache_length=True, |
|
|
**cache, |
|
|
**kwargs, |
|
|
) |
|
|
backbone_output = {k: v for k, v in backbone_output.items()} |
|
|
logits = backbone_output.pop("logits") |
|
|
cache = cache | backbone_output |
|
|
log_x_theta = self._forward(logits, denoiser_inputs, **kwargs) |
|
|
if logits_processor is not None: |
|
|
for token_idx in range(log_x_theta.shape[1]): |
|
|
|
|
|
|
|
|
|
|
|
log_x_theta[:, token_idx] = logits_processor( |
|
|
input_ids=running_generation, |
|
|
scores=log_x_theta[:, token_idx], |
|
|
) |
|
|
log_x_theta = torch.log_softmax(log_x_theta, dim=-1) |
|
|
x_theta = log_x_theta.exp() |
|
|
else: |
|
|
x_theta = model_output_cache["x_theta"] |
|
|
model_output_cache = {"x_theta": x_theta} |
|
|
prob_check_denom = denoiser_inputs.xt.numel() |
|
|
if generation_config.sampling_strategy == "posterior": |
|
|
q_xs = self._compute_posterior( |
|
|
x_theta, denoiser_inputs.xt, alpha_t, alpha_s |
|
|
) |
|
|
|
|
|
assert abs((q_xs.sum() / prob_check_denom).item() - 1.0) < 1e-6, ( |
|
|
"Posterior probabilities not summing to 1." |
|
|
) |
|
|
assert q_xs.isnan().sum().item() == 0, "NaN found in the posterior." |
|
|
xs = self._sample_categorical(q_xs, generation_config.do_sample) |
|
|
output = torch.where( |
|
|
(denoiser_inputs.xt != self.mask_token_id).bool(), |
|
|
denoiser_inputs.xt, |
|
|
xs, |
|
|
) |
|
|
elif generation_config.sampling_strategy == "predict_and_noise": |
|
|
assert self.config.diffusion_type == "absorbing", ( |
|
|
"predict_and_noise decoding strategy only supports absorbing diffusion." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xs = self._sample_categorical(x_theta, generation_config.do_sample) |
|
|
xs_probs = x_theta.gather(-1, xs[..., None]).squeeze(dim=-1) |
|
|
output = xs.clone() |
|
|
|
|
|
|
|
|
num_noise_indices = torch.minimum( |
|
|
((1 - alpha_s) * generation_config.block_size).to(torch.int), |
|
|
(denoiser_inputs.xt == self.mask_token_id).sum() - 1, |
|
|
) |
|
|
if generation_config.confidence_based_noising: |
|
|
conf = x_theta.gather(-1, xs[..., None]).squeeze(-1) |
|
|
conf = torch.where( |
|
|
(denoiser_inputs.xt == self.mask_token_id).bool(), |
|
|
conf, |
|
|
torch.inf, |
|
|
) |
|
|
noise_indices = conf.argsort(dim=-1)[..., :num_noise_indices] |
|
|
elif generation_config.confidence_margin_based_noising: |
|
|
top2 = torch.topk(x_theta, k=2, dim=-1).values |
|
|
conf = (top2[..., 0] - top2[..., 1]).abs() |
|
|
conf = torch.where( |
|
|
(denoiser_inputs.xt == self.mask_token_id).bool(), |
|
|
conf, |
|
|
torch.inf, |
|
|
) |
|
|
noise_indices = conf.argsort(dim=-1)[..., :num_noise_indices] |
|
|
else: |
|
|
|
|
|
raise NotImplementedError |
|
|
output[..., noise_indices] = self.mask_token_id |
|
|
output = torch.where( |
|
|
xs_probs >= generation_config.confidence_threshold, xs, output |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"Sampling strategy {generation_config.sampling_strategy} not" |
|
|
" implemented." |
|
|
) |
|
|
return output, model_output_cache, cache |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
inputs: Optional[torch.LongTensor] = None, |
|
|
generation_config: Optional[DiffusionGenerationConfig] = None, |
|
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
|
max_length: Optional[int] = None, |
|
|
max_new_tokens: Optional[int] = None, |
|
|
batch_size: Optional[int] = None, |
|
|
device: Optional[str] = None, |
|
|
tokenizer: Optional[PreTrainedTokenizer] = None, |
|
|
disable_pbar: bool = False, |
|
|
**kwargs: Any, |
|
|
) -> torch.LongTensor: |
|
|
|
|
|
if generation_config is None: |
|
|
assert getattr(self, "generation_config", None) is not None, ( |
|
|
"Generation config must be provided if not present in the model." |
|
|
) |
|
|
generation_config = self.generation_config |
|
|
if inputs is None: |
|
|
inputs = torch.ones((batch_size, 1), device=device) * self.bos_token_id |
|
|
if max_length is None: |
|
|
if hasattr(generation_config, "max_length"): |
|
|
max_length = generation_config.max_length |
|
|
else: |
|
|
max_length = self.max_length |
|
|
if max_new_tokens is None: |
|
|
if hasattr(generation_config, "max_new_tokens"): |
|
|
max_new_tokens = generation_config.max_new_tokens |
|
|
else: |
|
|
max_new_tokens = max_length - inputs.shape[-1] |
|
|
batch_size = batch_size if batch_size is not None else inputs.shape[0] |
|
|
assert batch_size == 1, "Batched sampling not supported yet" |
|
|
if device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
block_size = generation_config.block_size |
|
|
max_blocks = max_new_tokens // block_size |
|
|
|
|
|
|
|
|
accumulated_samples = self._sample_prior( |
|
|
device=device, |
|
|
batch_size=batch_size, |
|
|
length=max_blocks * block_size, |
|
|
) |
|
|
accumulated_samples = torch.cat([inputs, accumulated_samples], dim=-1) |
|
|
if generation_config.use_cache and inputs.numel() > 0: |
|
|
cache = self.update_cache( |
|
|
inputs=inputs[:, : block_size * (inputs.shape[-1] // block_size)] |
|
|
if generation_config.align_inputs_to_blocks |
|
|
else inputs, |
|
|
cache={}, |
|
|
) |
|
|
else: |
|
|
cache = None |
|
|
|
|
|
if generation_config.align_inputs_to_blocks: |
|
|
inputs_offset = ( |
|
|
block_size * (inputs.shape[-1] // block_size) |
|
|
if inputs.numel() > 0 |
|
|
else 0 |
|
|
) |
|
|
else: |
|
|
inputs_offset = inputs.shape[-1] if inputs.numel() > 0 else 0 |
|
|
|
|
|
total_NFEs = 0 |
|
|
timesteps = self._sample_generation_timesteps( |
|
|
generation_config, max_length=block_size, device=device |
|
|
) |
|
|
dt = (1 - generation_config.min_t) / len(timesteps) |
|
|
block_pbar = tqdm( |
|
|
range(max_blocks), |
|
|
desc="Blocks", |
|
|
leave=True, |
|
|
disable=disable_pbar, |
|
|
) |
|
|
for block_id in block_pbar: |
|
|
block_NFEs = 0 |
|
|
xt = accumulated_samples[ |
|
|
:, |
|
|
inputs_offset + (block_id * block_size) : inputs_offset |
|
|
+ ((block_id + 1) * block_size), |
|
|
] |
|
|
if self.mask_token_id not in xt: |
|
|
continue |
|
|
step_pbar = tqdm( |
|
|
timesteps, |
|
|
desc="T", |
|
|
total=timesteps.shape[0], |
|
|
leave=False, |
|
|
disable=disable_pbar, |
|
|
) |
|
|
model_output_cache = None |
|
|
context = ( |
|
|
accumulated_samples[:, : (block_id * block_size) + inputs_offset] |
|
|
if not generation_config.use_cache |
|
|
else None |
|
|
) |
|
|
|
|
|
running_generation = accumulated_samples[ |
|
|
:, |
|
|
inputs_offset : inputs_offset + (block_id * block_size), |
|
|
] |
|
|
for t in step_pbar: |
|
|
if model_output_cache is None: |
|
|
block_NFEs += 1 |
|
|
total_NFEs += 1 |
|
|
|
|
|
alpha_t, _ = self.noise_schedule(t) |
|
|
alpha_s, _ = self.noise_schedule(t - dt) |
|
|
alpha_t = alpha_t[None, None, None] |
|
|
alpha_s = alpha_s[None, None, None] |
|
|
denoiser_inputs, cache = self._prepare_inputs_inference( |
|
|
input_ids=xt, |
|
|
context=context, |
|
|
cache=cache if generation_config.use_cache else None, |
|
|
) |
|
|
xs, model_output_cache, cache = self._generate_unconditional( |
|
|
generation_config=generation_config, |
|
|
alpha_t=alpha_t, |
|
|
alpha_s=alpha_s, |
|
|
denoiser_inputs=denoiser_inputs, |
|
|
model_output_cache=model_output_cache, |
|
|
cache=cache, |
|
|
running_generation=running_generation, |
|
|
logits_processor=logits_processor, |
|
|
tokenizer=tokenizer, |
|
|
**kwargs, |
|
|
) |
|
|
block_pbar.set_postfix( |
|
|
NFEs=total_NFEs, |
|
|
block_NFEs=block_NFEs, |
|
|
) |
|
|
|
|
|
if ( |
|
|
not torch.allclose(xs, denoiser_inputs.xt) |
|
|
or not generation_config.use_model_output_cache |
|
|
): |
|
|
model_output_cache = None |
|
|
if not generation_config.use_cache: |
|
|
xt[..., -block_size:] = xs[..., -block_size:] |
|
|
else: |
|
|
xt = xs |
|
|
if ( |
|
|
xt == self.mask_token_id |
|
|
).sum().item() == 0 and self.config.diffusion_type == "absorbing": |
|
|
break |
|
|
accumulated_samples[ |
|
|
:, |
|
|
inputs_offset + (block_id * block_size) : inputs_offset |
|
|
+ ((block_id + 1) * block_size), |
|
|
] = xt |
|
|
if tokenizer is not None: |
|
|
print(tokenizer.batch_decode(accumulated_samples)) |
|
|
if stopping_criteria is not None: |
|
|
is_done = stopping_criteria( |
|
|
input_ids=accumulated_samples[ |
|
|
:, |
|
|
inputs_offset : inputs_offset + ((block_id + 1) * block_size), |
|
|
], |
|
|
scores=None, |
|
|
) |
|
|
if torch.any(is_done): |
|
|
accumulated_samples = accumulated_samples[ |
|
|
:, |
|
|
: inputs_offset + ((block_id + 1) * block_size), |
|
|
] |
|
|
break |
|
|
if generation_config.use_cache: |
|
|
cache = self.update_cache( |
|
|
inputs=xt, |
|
|
cache=cache, |
|
|
) |
|
|
return accumulated_samples |
|
|
|
|
|
|
|
|
class MDLMConfig(D3PMConfig): |
|
|
"""Configuration class for MDLM models.""" |
|
|
|
|
|
model_type = "mdlm" |
|
|
auto_map = { |
|
|
"AutoConfig": "diffusion.MDLMConfig", |
|
|
"AutoModel": "diffusion.MDLM", |
|
|
"AutoModelForMaskedLM": "diffusion.MDLM", |
|
|
} |
|
|
|
|
|
|
|
|
class MDLM(D3PM): |
|
|
"""Denoiser class for MDLM models.""" |
|
|
|
|
|
config_class = MDLMConfig |
|
|
|
|
|
def __init__(self, config: MDLMConfig, **kwargs): |
|
|
super().__init__(config, **kwargs) |
|
|
self.neg_infinity = -1e12 |
|
|
|
|
|
def _forward( |
|
|
self, |
|
|
backbone_output: torch.FloatTensor, |
|
|
denoiser_inputs: DenoiserInput, |
|
|
**kwargs, |
|
|
) -> torch.FloatTensor: |
|
|
|
|
|
backbone_output[..., self.mask_token_id] = self.neg_infinity |
|
|
log_probs = backbone_output - torch.logsumexp( |
|
|
backbone_output, dim=-1, keepdim=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
xt = denoiser_inputs.xt |
|
|
unmasked_indices = xt != self.mask_token_id |
|
|
log_probs[unmasked_indices] = self.neg_infinity |
|
|
log_probs[unmasked_indices, xt[unmasked_indices]] = 0 |
|
|
return log_probs |
|
|
|
|
|
def _compute_loss( |
|
|
self, |
|
|
model_output: torch.FloatTensor, |
|
|
denoiser_inputs: DenoiserInput, |
|
|
**kwargs: Any, |
|
|
) -> LossAndNllOutput: |
|
|
log_p_theta = torch.gather( |
|
|
input=model_output, dim=-1, index=denoiser_inputs.x0[:, :, None] |
|
|
).squeeze(-1) |
|
|
nlls = ( |
|
|
log_p_theta |
|
|
* denoiser_inputs.alpha_t_prime |
|
|
/ (1 - denoiser_inputs.alpha_t) |
|
|
* denoiser_inputs.tokens_mask |
|
|
) |
|
|
if self.training: |
|
|
batch_nll = -(log_p_theta * denoiser_inputs.tokens_mask).sum(dim=-1) |
|
|
else: |
|
|
batch_nll = nlls.sum(dim=-1) |
|
|
count = denoiser_inputs.tokens_mask.sum(dim=-1) |
|
|
token_nll = (batch_nll / count).mean() |
|
|
return LossAndNllOutput( |
|
|
loss=token_nll, |
|
|
nlls=nlls, |
|
|
other_loss_terms={ |
|
|
"masked_tokens": (denoiser_inputs.xt == self.mask_token_id).int() |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
class BD3LMConfig(MDLMConfig): |
|
|
"""Configuration class for BD3LM models.""" |
|
|
|
|
|
model_type = "bd3lm" |
|
|
auto_map = { |
|
|
"AutoConfig": "diffusion.BD3LMConfig", |
|
|
"AutoModel": "diffusion.BD3LM", |
|
|
"AutoModelForMaskedLM": "diffusion.BD3LM", |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
block_size: Optional[int] = None, |
|
|
eval_block_size: Optional[int] = None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.block_size = block_size |
|
|
self.eval_block_size = ( |
|
|
eval_block_size if eval_block_size is not None else block_size |
|
|
) |
|
|
|
|
|
|
|
|
class BD3LM(MDLM): |
|
|
"""Denoiser class for BD3LM models.""" |
|
|
|
|
|
config_class = BD3LMConfig |
|
|
|
|
|
def __init__(self, config: BD3LMConfig, **kwargs): |
|
|
super().__init__(config, **kwargs) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _block_mask( |
|
|
b, |
|
|
h, |
|
|
q_idx, |
|
|
kv_idx, |
|
|
block_size: Optional[int] = None, |
|
|
seq_length: Optional[int] = None, |
|
|
) -> torch.Tensor: |
|
|
del b, h |
|
|
|
|
|
|
|
|
xt_flag_q = (q_idx >= seq_length).bool() |
|
|
xt_flag_kv = (kv_idx >= seq_length).bool() |
|
|
|
|
|
|
|
|
block_q = torch.where( |
|
|
xt_flag_q, (q_idx - seq_length) // block_size, q_idx // block_size |
|
|
) |
|
|
block_kv = torch.where( |
|
|
xt_flag_kv, (kv_idx - seq_length) // block_size, kv_idx // block_size |
|
|
) |
|
|
|
|
|
offset_block_causal = (block_q > block_kv) & ~xt_flag_kv & xt_flag_q |
|
|
|
|
|
|
|
|
block_diagonal = (block_q == block_kv) & (xt_flag_q == xt_flag_kv) |
|
|
|
|
|
|
|
|
block_causal = (block_q >= block_kv) & ~xt_flag_kv & ~xt_flag_q |
|
|
|
|
|
|
|
|
return block_diagonal | offset_block_causal | block_causal |
|
|
|
|
|
def _create_static_mask(self) -> None: |
|
|
if self.config.attn_backend == "sdpa": |
|
|
static_mask = self._block_mask( |
|
|
b=None, |
|
|
h=None, |
|
|
q_idx=torch.arange(self.config.length * 2)[:, None], |
|
|
kv_idx=torch.arange(self.config.length * 2)[None, :], |
|
|
block_size=self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
seq_length=self.config.length, |
|
|
) |
|
|
self.register_buffer( |
|
|
"static_attention_mask", |
|
|
static_mask, |
|
|
) |
|
|
self.skip_params_for_push.append("static_attention_mask") |
|
|
elif self.config.attn_backend == "flex_attention": |
|
|
mask = partial( |
|
|
self._block_mask, |
|
|
block_size=self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
seq_length=self.config.length, |
|
|
) |
|
|
self.static_attention_mask = create_block_mask( |
|
|
mask, |
|
|
B=None, |
|
|
H=None, |
|
|
Q_LEN=self.config.length * 2, |
|
|
KV_LEN=self.config.length * 2, |
|
|
) |
|
|
|
|
|
def _ensure_no_unmasked_blocks( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
xt: torch.LongTensor, |
|
|
context_mask: Optional[torch.FloatTensor] = None, |
|
|
) -> torch.Tensor: |
|
|
n_blocks = xt.shape[1] // self.config.block_size |
|
|
|
|
|
blocks_without_masks = ((xt == self.mask_token_id) + context_mask).reshape( |
|
|
-1, n_blocks, self.config.block_size |
|
|
).sum(dim=-1) == 0 |
|
|
if blocks_without_masks.sum() > 0: |
|
|
num_remasks_per_block = torch.randint( |
|
|
0, |
|
|
self.config.block_size, |
|
|
blocks_without_masks.shape, |
|
|
device=xt.device, |
|
|
) |
|
|
rand = torch.rand(xt.shape[0], xt.shape[1], device=xt.device) |
|
|
perm_indices = torch.argsort( |
|
|
rand.view(xt.shape[0], n_blocks, self.config.block_size), |
|
|
stable=True, |
|
|
dim=-1, |
|
|
) |
|
|
remask_indices = perm_indices <= num_remasks_per_block[..., None] |
|
|
xt = torch.where( |
|
|
remask_indices.view(xt.shape[0], xt.shape[1]) |
|
|
* blocks_without_masks.repeat_interleave(self.config.block_size, dim=1), |
|
|
self.mask_token_id, |
|
|
xt, |
|
|
) |
|
|
if self.config.keep_clean_bos: |
|
|
xt[..., 0] = input_ids[..., 0] |
|
|
return xt |
|
|
|
|
|
def _prepare_inputs( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
context_mask: Optional[torch.FloatTensor] = None, |
|
|
t: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
): |
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
if context_mask is None: |
|
|
context_mask = torch.zeros_like(attention_mask) |
|
|
|
|
|
if torch.is_floating_point(attention_mask): |
|
|
attention_mask = attention_mask.to(torch.int) |
|
|
context_mask = context_mask.to(torch.int) |
|
|
|
|
|
if t is None: |
|
|
t = torch.rand( |
|
|
input_ids.shape[0], |
|
|
input_ids.shape[1] // self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
device=input_ids.device, |
|
|
).repeat_interleave( |
|
|
self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
dim=-1, |
|
|
) |
|
|
alpha_t, alpha_t_prime = self.noise_schedule(t) |
|
|
while alpha_t.ndim < 2: |
|
|
alpha_t = alpha_t[..., None] |
|
|
alpha_t_prime = alpha_t_prime[..., None] |
|
|
xt = self._sample_q_xt(x0=input_ids, alpha_t=alpha_t, context_mask=context_mask) |
|
|
|
|
|
if self.training: |
|
|
xt = self._ensure_no_unmasked_blocks( |
|
|
input_ids, |
|
|
xt, |
|
|
context_mask, |
|
|
) |
|
|
if self.config.attn_backend == "sdpa": |
|
|
decoder_attention_mask = ( |
|
|
self.static_attention_mask[None, ...] |
|
|
& attention_mask.repeat(1, 2)[:, None, :] |
|
|
& attention_mask.repeat(1, 2)[..., None] |
|
|
)[:, None, ...] |
|
|
decoder_attention_mask = self._preprocess_attention_mask( |
|
|
decoder_attention_mask, dtype=torch.float |
|
|
) |
|
|
elif self.config.attn_backend == "flex_attention": |
|
|
if context_mask.any(): |
|
|
raise NotImplementedError( |
|
|
"flex_attention with context_mask not implemented yet." |
|
|
) |
|
|
elif attention_mask is not None and (attention_mask != 1).any(): |
|
|
padding_mask = create_attn_mask( |
|
|
attention_mask.bool().repeat(2, 2).bool() |
|
|
) |
|
|
dec_masks = [ |
|
|
partial( |
|
|
self._block_mask, |
|
|
block_size=self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
seq_length=self.config.length, |
|
|
), |
|
|
padding_mask, |
|
|
] |
|
|
decoder_attention_mask = create_block_mask( |
|
|
and_masks(*dec_masks), |
|
|
B=input_ids.shape[0], |
|
|
H=None, |
|
|
Q_LEN=input_ids.shape[1] * 2, |
|
|
KV_LEN=input_ids.shape[1] * 2, |
|
|
) |
|
|
else: |
|
|
decoder_attention_mask = self.static_attention_mask |
|
|
else: |
|
|
raise ValueError("Unknown backbone backend") |
|
|
backbone_input_ids = torch.cat((input_ids, xt), dim=-1) |
|
|
position_ids = ( |
|
|
torch.arange(input_ids.shape[1]).repeat(2).to(input_ids.device)[None, :] |
|
|
) |
|
|
if self.training and self.config.train_on_context: |
|
|
tokens_mask = attention_mask |
|
|
else: |
|
|
tokens_mask = attention_mask * (1 - context_mask) |
|
|
return DenoiserInput( |
|
|
xt=backbone_input_ids, |
|
|
x0=input_ids, |
|
|
attention_mask=decoder_attention_mask, |
|
|
tokens_mask=tokens_mask, |
|
|
t=t, |
|
|
alpha_t=alpha_t, |
|
|
alpha_t_prime=alpha_t_prime, |
|
|
backbone_kwargs={ |
|
|
"cache_position": position_ids[0], |
|
|
"position_ids": position_ids, |
|
|
}, |
|
|
) |
|
|
|
|
|
def _prepare_inputs_inference( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
context: Optional[torch.LongTensor] = None, |
|
|
context_mask: Optional[torch.FloatTensor] = None, |
|
|
cache: Optional[Dict[str, Any]] = None, |
|
|
return_updated_cache: bool = False, |
|
|
**backbone_kwargs: Dict[str, Any], |
|
|
) -> Tuple[DenoiserInput, Union[Dict[str, Any], None]]: |
|
|
device = input_ids.device if input_ids is not None else context.device |
|
|
assert input_ids is not None or context is not None, ( |
|
|
"Must provide either input_ids or context." |
|
|
) |
|
|
cache = cache if cache is not None else {} |
|
|
past_key_values = cache.pop("past_key_values", DynamicCache()) |
|
|
if context is not None: |
|
|
if input_ids is not None: |
|
|
input_ids = torch.cat([context, input_ids], dim=-1) |
|
|
else: |
|
|
input_ids = context |
|
|
cache_length = self._get_past_key_values_seq_length(past_key_values) |
|
|
full_seq_length = cache_length + input_ids.shape[-1] |
|
|
decoder_attention_mask = self.static_attention_mask[ |
|
|
None, |
|
|
None, |
|
|
cache_length:full_seq_length, |
|
|
:full_seq_length, |
|
|
] |
|
|
decoder_attention_mask = self._preprocess_attention_mask( |
|
|
decoder_attention_mask, dtype=torch.float |
|
|
) |
|
|
position_ids = torch.arange(cache_length, full_seq_length).to(device)[None, :] |
|
|
return DenoiserInput( |
|
|
xt=input_ids, |
|
|
attention_mask=decoder_attention_mask, |
|
|
context_mask=context_mask, |
|
|
past_key_values=past_key_values, |
|
|
backbone_kwargs={ |
|
|
"position_ids": position_ids, |
|
|
} |
|
|
| backbone_kwargs, |
|
|
), cache |
|
|
|
|
|
def _compute_loss( |
|
|
self, |
|
|
model_output: torch.FloatTensor, |
|
|
denoiser_inputs: DenoiserInput, |
|
|
**kwargs: Any, |
|
|
) -> LossAndNllOutput: |
|
|
input_length = denoiser_inputs.xt.shape[1] // 2 |
|
|
model_output = model_output[:, input_length:, ...] |
|
|
return super()._compute_loss( |
|
|
model_output=model_output, |
|
|
denoiser_inputs=denoiser_inputs, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
class E2D2Config(BD3LMConfig): |
|
|
"""Configuration class for E2D2 models.""" |
|
|
|
|
|
model_type = "e2d2" |
|
|
auto_map = { |
|
|
"AutoConfig": "diffusion.E2D2Config", |
|
|
"AutoModel": "diffusion.E2D2", |
|
|
"AutoModelForMaskedLM": "diffusion.E2D2", |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
class E2D2(BD3LM): |
|
|
"""Denoiser class for E2D2 models.""" |
|
|
|
|
|
config_class = E2D2Config |
|
|
|
|
|
def __init__(self, config: E2D2Config, **kwargs): |
|
|
super().__init__(config, **kwargs) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _encoder_block_mask( |
|
|
b, |
|
|
h, |
|
|
q_idx, |
|
|
kv_idx, |
|
|
block_size: Optional[int] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
q_idx (Tensor): Query indices. |
|
|
kv_idx (Tensor): Key indices |
|
|
b (Optional: int): batch size |
|
|
h (Optional: int): number of heads |
|
|
block_size (Optional: int): Defines the block structure. |
|
|
|
|
|
Returns: |
|
|
Encoder block-causal attention mask. |
|
|
""" |
|
|
|
|
|
|
|
|
block_q = q_idx // block_size |
|
|
block_kv = kv_idx // block_size |
|
|
|
|
|
|
|
|
return block_q >= block_kv |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _decoder_block_mask( |
|
|
b, |
|
|
h, |
|
|
q_idx, |
|
|
kv_idx, |
|
|
block_size: Optional[int] = None, |
|
|
seq_length: Optional[int] = None, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
xt_flag_kv = (kv_idx >= seq_length).bool() |
|
|
|
|
|
|
|
|
block_q = q_idx // block_size |
|
|
block_kv = torch.where( |
|
|
xt_flag_kv, (kv_idx - seq_length) // block_size, kv_idx // block_size |
|
|
) |
|
|
|
|
|
offset_block_causal = (block_q > block_kv) & ~xt_flag_kv |
|
|
|
|
|
|
|
|
block_diagonal = (block_q == block_kv) & xt_flag_kv |
|
|
|
|
|
|
|
|
return block_diagonal | offset_block_causal |
|
|
|
|
|
def _create_static_mask(self) -> None: |
|
|
if self.config.attn_backend == "flex_attention": |
|
|
enc_mask = partial( |
|
|
self._encoder_block_mask, |
|
|
block_size=self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
) |
|
|
encoder_attention_mask = create_block_mask( |
|
|
enc_mask, |
|
|
B=None, |
|
|
H=None, |
|
|
Q_LEN=self.config.length, |
|
|
KV_LEN=self.config.length, |
|
|
) |
|
|
dec_mask = partial( |
|
|
self._decoder_block_mask, |
|
|
block_size=self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
seq_length=self.config.length, |
|
|
) |
|
|
decoder_attention_mask = create_block_mask( |
|
|
dec_mask, |
|
|
B=None, |
|
|
H=None, |
|
|
Q_LEN=self.config.length, |
|
|
KV_LEN=self.config.length * 2, |
|
|
) |
|
|
self.encoder_static_attention_mask = encoder_attention_mask |
|
|
self.static_attention_mask = decoder_attention_mask |
|
|
else: |
|
|
encoder_static_mask = self._encoder_block_mask( |
|
|
b=None, |
|
|
h=None, |
|
|
q_idx=torch.arange(self.config.length)[:, None], |
|
|
kv_idx=torch.arange(self.config.length)[None, :], |
|
|
block_size=self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
) |
|
|
decoder_static_mask = self._decoder_block_mask( |
|
|
b=None, |
|
|
h=None, |
|
|
q_idx=torch.arange(self.config.length)[:, None], |
|
|
kv_idx=torch.arange(self.config.length * 2)[None, :], |
|
|
block_size=self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
seq_length=self.config.length, |
|
|
) |
|
|
self.register_buffer( |
|
|
"encoder_static_attention_mask", |
|
|
encoder_static_mask, |
|
|
) |
|
|
self.register_buffer( |
|
|
"static_attention_mask", |
|
|
decoder_static_mask, |
|
|
) |
|
|
self.skip_params_for_push.append("encoder_static_attention_mask") |
|
|
self.skip_params_for_push.append("static_attention_mask") |
|
|
|
|
|
def _prepare_inputs( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
context_mask: Optional[torch.FloatTensor] = None, |
|
|
t: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
): |
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
if context_mask is None: |
|
|
context_mask = torch.zeros_like(attention_mask) |
|
|
|
|
|
if torch.is_floating_point(attention_mask): |
|
|
attention_mask = attention_mask.to(torch.int) |
|
|
context_mask = context_mask.to(torch.int) |
|
|
|
|
|
if t is None: |
|
|
t = torch.rand( |
|
|
input_ids.shape[0], |
|
|
input_ids.shape[1] // self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
device=input_ids.device, |
|
|
).repeat_interleave( |
|
|
self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
dim=-1, |
|
|
) |
|
|
alpha_t, alpha_t_prime = self.noise_schedule(t) |
|
|
while alpha_t.ndim < 2: |
|
|
alpha_t = alpha_t[..., None] |
|
|
alpha_t_prime = alpha_t_prime[..., None] |
|
|
xt = self._sample_q_xt(x0=input_ids, alpha_t=alpha_t, context_mask=context_mask) |
|
|
|
|
|
if self.training: |
|
|
xt = self._ensure_no_unmasked_blocks( |
|
|
input_ids, |
|
|
xt, |
|
|
context_mask, |
|
|
) |
|
|
if self.config.attn_backend == "sdpa": |
|
|
decoder_attention_mask = ( |
|
|
self.static_attention_mask[None, ...] |
|
|
& attention_mask.repeat(1, 2)[:, None, :] |
|
|
& attention_mask[..., None] |
|
|
)[:, None, ...] |
|
|
encoder_attention_mask = ( |
|
|
( |
|
|
self.encoder_static_attention_mask[None, ...] |
|
|
| context_mask[:, None, :] |
|
|
) |
|
|
& attention_mask[:, None, :] |
|
|
& attention_mask[..., None] |
|
|
)[:, None, ...] |
|
|
encoder_attention_mask = self._preprocess_attention_mask( |
|
|
encoder_attention_mask, dtype=torch.float |
|
|
) |
|
|
decoder_attention_mask = self._preprocess_attention_mask( |
|
|
decoder_attention_mask, dtype=torch.float |
|
|
) |
|
|
elif self.config.attn_backend == "flex_attention": |
|
|
|
|
|
if context_mask.any(): |
|
|
raise NotImplementedError( |
|
|
"flex_attention with context_mask not implemented yet." |
|
|
) |
|
|
elif attention_mask is not None and (attention_mask != 1).any(): |
|
|
padding_mask = create_attn_mask(attention_mask.bool()) |
|
|
dec_padding_mask = create_attn_mask(attention_mask.repeat(1, 2).bool()) |
|
|
enc_masks = [ |
|
|
partial( |
|
|
self._encoder_block_mask, |
|
|
block_size=self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
), |
|
|
padding_mask, |
|
|
] |
|
|
encoder_attention_mask = create_block_mask( |
|
|
and_masks(*enc_masks), |
|
|
B=input_ids.shape[0], |
|
|
H=None, |
|
|
Q_LEN=input_ids.shape[1], |
|
|
KV_LEN=input_ids.shape[1], |
|
|
) |
|
|
dec_masks = [ |
|
|
partial( |
|
|
self._decoder_block_mask, |
|
|
block_size=self.config.block_size |
|
|
if self.training |
|
|
else self.config.eval_block_size, |
|
|
seq_length=input_ids.shape[1], |
|
|
), |
|
|
dec_padding_mask, |
|
|
] |
|
|
decoder_attention_mask = create_block_mask( |
|
|
and_masks(*dec_masks), |
|
|
B=input_ids.shape[0], |
|
|
H=None, |
|
|
Q_LEN=input_ids.shape[1], |
|
|
KV_LEN=input_ids.shape[1] * 2, |
|
|
) |
|
|
else: |
|
|
encoder_attention_mask = self.encoder_static_attention_mask |
|
|
decoder_attention_mask = self.static_attention_mask |
|
|
else: |
|
|
raise ValueError("Unknown backbone backend") |
|
|
position_ids = torch.arange(input_ids.shape[1]).to(input_ids.device)[None, :] |
|
|
if self.training and self.config.train_on_context: |
|
|
tokens_mask = attention_mask |
|
|
else: |
|
|
tokens_mask = attention_mask * (1 - context_mask) |
|
|
return DenoiserInput( |
|
|
xt=xt, |
|
|
x0=input_ids, |
|
|
attention_mask=decoder_attention_mask, |
|
|
tokens_mask=tokens_mask, |
|
|
t=t, |
|
|
alpha_t=alpha_t, |
|
|
alpha_t_prime=alpha_t_prime, |
|
|
backbone_kwargs={ |
|
|
"encoder_input_ids": input_ids, |
|
|
"encoder_attention_mask": encoder_attention_mask, |
|
|
"encoder_position_ids": position_ids, |
|
|
"encoder_cache_position": position_ids[0], |
|
|
}, |
|
|
) |
|
|
|
|
|
def _prepare_inputs_inference( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
context: Optional[torch.LongTensor] = None, |
|
|
context_mask: Optional[torch.FloatTensor] = None, |
|
|
cache: Optional[Dict[str, Any]] = None, |
|
|
return_updated_cache: bool = False, |
|
|
**backbone_kwargs: Dict[str, Any], |
|
|
) -> Tuple[DenoiserInput, Union[Dict[str, Any], None]]: |
|
|
device = input_ids.device if input_ids is not None else context.device |
|
|
batch_size = input_ids.shape[0] if input_ids is not None else context.shape[0] |
|
|
assert input_ids is not None or context is not None, ( |
|
|
"Must provide either input_ids or context." |
|
|
) |
|
|
if return_updated_cache: |
|
|
context = input_ids |
|
|
input_ids = None |
|
|
position_ids, encoder_position_ids = None, None |
|
|
if cache is not None: |
|
|
past_key_values = cache.pop("past_key_values", DynamicCache()) |
|
|
encoder_past_key_values = cache.pop( |
|
|
"encoder_past_key_values", DynamicCache() |
|
|
) |
|
|
encoder_last_hidden_state = cache.pop("encoder_last_hidden_state", None) |
|
|
if input_ids is not None: |
|
|
cache_length = self._get_past_key_values_seq_length(past_key_values) |
|
|
if encoder_last_hidden_state is not None: |
|
|
full_seq_length = ( |
|
|
cache_length |
|
|
+ encoder_last_hidden_state.shape[1] |
|
|
+ input_ids.shape[-1] |
|
|
) |
|
|
else: |
|
|
full_seq_length = cache_length + input_ids.shape[-1] |
|
|
encoder_attention_mask = None |
|
|
position_ids = torch.arange( |
|
|
cache_length, full_seq_length, device=device |
|
|
)[None, :] |
|
|
else: |
|
|
encoder_cache_length = self._get_past_key_values_seq_length( |
|
|
encoder_past_key_values |
|
|
if len(encoder_past_key_values) > 0 |
|
|
else past_key_values |
|
|
) |
|
|
encoder_full_seq_length = encoder_cache_length + context.shape[-1] |
|
|
encoder_attention_mask = torch.ones( |
|
|
( |
|
|
1, |
|
|
1, |
|
|
encoder_full_seq_length - encoder_cache_length, |
|
|
encoder_full_seq_length, |
|
|
), |
|
|
device=context.device, |
|
|
) |
|
|
encoder_position_ids = torch.arange( |
|
|
encoder_cache_length, encoder_full_seq_length |
|
|
).to(device)[None, :] |
|
|
encoder_attention_mask = self._preprocess_attention_mask( |
|
|
encoder_attention_mask, dtype=torch.float |
|
|
) |
|
|
full_seq_length = -1 |
|
|
else: |
|
|
past_key_values = None |
|
|
encoder_past_key_values, encoder_last_hidden_state = None, None |
|
|
if context is not None: |
|
|
context_len = context.shape[1] |
|
|
encoder_attention_mask = torch.ones( |
|
|
(1, 1, context_len, context_len), device=context.device |
|
|
) |
|
|
encoder_attention_mask = self._preprocess_attention_mask( |
|
|
encoder_attention_mask, dtype=torch.float |
|
|
) |
|
|
encoder_position_ids = torch.arange(context_len).to(device)[None, :] |
|
|
else: |
|
|
context_len = 0 |
|
|
encoder_attention_mask = None |
|
|
if input_ids is not None: |
|
|
full_seq_length = context_len + input_ids.shape[1] |
|
|
else: |
|
|
full_seq_length = context_len |
|
|
position_ids = torch.arange(context_len, full_seq_length).to(device)[ |
|
|
None, : |
|
|
] |
|
|
if input_ids is not None: |
|
|
decoder_attention_mask = torch.ones( |
|
|
(batch_size, 1, input_ids.shape[1], full_seq_length), |
|
|
device=device, |
|
|
) |
|
|
decoder_attention_mask = self._preprocess_attention_mask( |
|
|
decoder_attention_mask, dtype=torch.float |
|
|
) |
|
|
else: |
|
|
decoder_attention_mask = None |
|
|
return DenoiserInput( |
|
|
xt=input_ids, |
|
|
attention_mask=decoder_attention_mask, |
|
|
context_mask=context_mask, |
|
|
past_key_values=past_key_values, |
|
|
backbone_kwargs={ |
|
|
"position_ids": position_ids, |
|
|
"encoder_input_ids": context, |
|
|
"encoder_position_ids": encoder_position_ids, |
|
|
"encoder_attention_mask": encoder_attention_mask, |
|
|
"encoder_past_key_values": encoder_past_key_values, |
|
|
"encoder_last_hidden_state": encoder_last_hidden_state, |
|
|
} |
|
|
| backbone_kwargs, |
|
|
), cache |
|
|
|
|
|
def _compute_loss( |
|
|
self, |
|
|
model_output: torch.FloatTensor, |
|
|
denoiser_inputs: DenoiserInput, |
|
|
**kwargs: Any, |
|
|
) -> LossAndNllOutput: |
|
|
|
|
|
return super(BD3LM, self)._compute_loss( |
|
|
model_output=model_output, |
|
|
denoiser_inputs=denoiser_inputs, |
|
|
**kwargs, |
|
|
) |
|
|
|