Feature Extraction
Transformers
PyTorch
e2d2
custom_code
yairschiff commited on
Commit
b978d15
·
verified ·
1 Parent(s): d0425e9

Update pytorch.bin; Add model and code

Browse files
Files changed (3) hide show
  1. config.json +59 -0
  2. diffusion.py +1462 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "T": 0,
3
+ "architectures": [
4
+ "E2D2"
5
+ ],
6
+ "attn_backend": "sdpa",
7
+ "auto_map": {
8
+ "AutoConfig": "diffusion.E2D2Config",
9
+ "AutoModel": "diffusion.E2D2",
10
+ "AutoModelForMaskedLM": "diffusion.E2D2"
11
+ },
12
+ "backbone_config": {
13
+ "_target_": "backbone_encoder_decoder.LLMasEncoderDecoder",
14
+ "attn_backend": "sdpa",
15
+ "freeze_encoder": false,
16
+ "hidden_size": 512,
17
+ "intermediate_size": 1536,
18
+ "keep_top_decoder_layers": false,
19
+ "keep_top_encoder_layers": false,
20
+ "max_length": 256,
21
+ "num_decoder_layers": 4,
22
+ "num_encoder_layers": 28,
23
+ "pretrained_model_name_or_path": "Qwen/Qwen3-0.6B-Base",
24
+ "reinit_decoder": true,
25
+ "reinit_encoder": true,
26
+ "tie_encoder_decoder_weights": false,
27
+ "use_encoder_causal_mask": false,
28
+ "use_gradient_checkpointing": false
29
+ },
30
+ "block_size": 4,
31
+ "bos_token_id": 151643,
32
+ "diffusion_type": "absorbing",
33
+ "eos_token_id": 151643,
34
+ "eval_block_size": 4,
35
+ "keep_clean_bos": true,
36
+ "length": 256,
37
+ "mask_token_id": 151660,
38
+ "model_type": "e2d2",
39
+ "noise_config": {
40
+ "_target_": "noise_schedule_noise_schedules.LinearNoise"
41
+ },
42
+ "pad_token_id": 151643,
43
+ "pad_vocab_size_multiple": 1,
44
+ "shift_logits": false,
45
+ "time_conditioned_backbone": false,
46
+ "tokenization_config": {
47
+ "bos_token_id": 151643,
48
+ "eos_token_id": 151643,
49
+ "mask_token_id": 151660,
50
+ "pad_token_id": 151643,
51
+ "pad_vocab_size_multiple": 1,
52
+ "vocab_size": 151669
53
+ },
54
+ "tokenizer_name": "Qwen/Qwen3-0.6B-Base",
55
+ "torch_dtype": "float32",
56
+ "train_on_context": false,
57
+ "transformers_version": "4.52.4",
58
+ "vocab_size": 151669
59
+ }
diffusion.py ADDED
@@ -0,0 +1,1462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Dict, Literal, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from tqdm.auto import tqdm
6
+ from transformers import (
7
+ GenerationConfig,
8
+ LogitsProcessorList,
9
+ PreTrainedTokenizer,
10
+ StoppingCriteriaList,
11
+ )
12
+ from transformers.cache_utils import Cache, DynamicCache
13
+
14
+ try:
15
+ from torch.nn.attention.flex_attention import (
16
+ BlockMask,
17
+ and_masks,
18
+ create_block_mask,
19
+ )
20
+ except ImportError:
21
+ BlockMask, and_masks, create_block_mask = None, None, None
22
+
23
+
24
+ from src.denoiser.base import (
25
+ Denoiser,
26
+ DenoiserConfig,
27
+ DenoiserInput,
28
+ LossAndNllOutput,
29
+ )
30
+
31
+
32
+ def create_attn_mask(attn_mask):
33
+ # noinspection PyUnusedLocal
34
+ def padding(b, h, q_idx, kv_idx):
35
+ return attn_mask[b, q_idx] & attn_mask[b, kv_idx]
36
+
37
+ return padding
38
+
39
+
40
+ class DiffusionGenerationConfig(GenerationConfig):
41
+ def __init__(
42
+ self,
43
+ num_steps: int = 1000,
44
+ min_t: float = 1e-5,
45
+ block_size: Optional[int] = None,
46
+ first_hitting: bool = False,
47
+ sampling_strategy: Literal["posterior", "predict_then_noise"] = "posterior",
48
+ confidence_based_noising: bool = False,
49
+ confidence_margin_based_noising: bool = False,
50
+ confidence_threshold: float = 1e6,
51
+ use_model_output_cache: bool = True,
52
+ align_inputs_to_blocks: bool = True,
53
+ **kwargs,
54
+ ):
55
+ """Generation config with additional parameters relevant for diffusion model
56
+ sampling.
57
+
58
+ Args:
59
+ num_steps (int): Number of diffusion / iterative refinement steps.
60
+ Defaults to 1000.
61
+ min_t (float): Minimum time to use.
62
+ Diffusion models use t=1 for noise and t=0 for signal.
63
+ Setting t=0 exactly can lead to certain numerical instabilities.
64
+ Defaults to 1e-5.
65
+ block_size (int): Block size to use for semi-autoregressive decoding.
66
+ Defaults to None (in which case block_size is set to max_new_tokens).
67
+ first_hitting (bool): Whether to use first hitting sampler.
68
+ When set to true, rather than following the diffusion time and sampling
69
+ from posterior, which can result in no tokens changing between steps,
70
+ e.g., for masked diffusion, we explicitly determine the next time step
71
+ at which a token will be decoded / generated.
72
+ Note: this will negate the `num_steps` parameter, as we will decode one
73
+ token at a time, hence, when True, num_steps = seq_length
74
+ (or block_size, for semi-autoregressive).
75
+ See https://arxiv.org/abs/2409.02908 for details.
76
+ Defaults to False.
77
+ sampling_strategy (str): Method for transitioning between latents.
78
+ Options:
79
+ - "posterior" - Compute and sample from the posterior
80
+ q(x_s | x_t, x_theta).
81
+ - "predict_then_noise" - Sample from the denoising model x_theta,
82
+ then add back noise to produce x_s.
83
+ Only implemented for absorbing diffusion.
84
+ Defaults to "posterior".
85
+ confidence_based_noising (bool): When using the "predict_then_noise"
86
+ strategy, whether to add noise to random positions or to those that have
87
+ the lowest probability under x_theta.
88
+ Cannot be used in conjunction with confidence_margin_based_noising.
89
+ Defaults to False.
90
+ confidence_margin_based_noising (bool): When using the "predict_then_noise"
91
+ strategy, whether to add noise to random positions or to those that have
92
+ the lowest probability margins under x_theta, where margin is defined as
93
+ the absolute difference between the top two probabilities at a given
94
+ position.
95
+ See https://arxiv.org/abs/2502.06768 for details.
96
+ Cannot be used in conjunction with confidence_based_noising.
97
+ Defaults to False.
98
+ confidence_threshold (float): Confidence threshold to use for sampling.
99
+ Any tokens that exceed threshold are decoded.
100
+ See https://arxiv.org/abs/2505.22618 for details.
101
+ Defaults to 1e6.
102
+ use_model_output_cache (bool): Whether to re-use model's output, if sequence
103
+ is unchanged, because if xt == xs, we can simply re-use the denoising
104
+ model's outputs and save a function evaluation.
105
+ Relevant if model.backbone is not time/noise-conditioned.
106
+ Defaults to True.
107
+ align_inputs_to_blocks (bool): Whether to align input tokens to block size,
108
+ e.g., for an input of length C and block size S, context will be C // S,
109
+ and generation will begin with a block whose first C % S tokens come
110
+ from the input.
111
+ kwargs: Keyword arguments passed to `GenerationConfig`.
112
+ """
113
+ super().__init__(**kwargs)
114
+ self.num_steps = num_steps
115
+ self.min_t = min_t
116
+ # TODO: assumes we are setting max_new_tokens, which may not be the case!
117
+ self.block_size = block_size if block_size is not None else self.max_new_tokens
118
+ self.first_hitting = first_hitting
119
+ if self.first_hitting:
120
+ # TODO: log.warn that this is being overridden
121
+ self.num_steps = min(num_steps, self.block_size)
122
+ self.sampling_strategy = sampling_strategy
123
+ assert not confidence_based_noising or not confidence_margin_based_noising, (
124
+ "Cannot use both `confidence_based_noising` and"
125
+ " `confidence_margin_based_noising`."
126
+ )
127
+ self.confidence_based_noising = confidence_based_noising
128
+ self.confidence_margin_based_noising = confidence_margin_based_noising
129
+ self.confidence_threshold = confidence_threshold
130
+ self.use_model_output_cache = use_model_output_cache
131
+ self.align_inputs_to_blocks = align_inputs_to_blocks
132
+
133
+
134
+ class D3PMConfig(DenoiserConfig):
135
+ """Configuration class for D3PM models."""
136
+
137
+ model_type = "d3pm"
138
+ auto_map = {
139
+ "AutoConfig": "diffusion.D3PMConfig",
140
+ "AutoModel": "diffusion.D3PM",
141
+ "AutoModelForMaskedLM": "diffusion.D3PM",
142
+ }
143
+
144
+ def __init__(
145
+ self,
146
+ keep_clean_bos: Optional[bool] = None, # Whether to enforce un-noised BOS token
147
+ T: int = 1000,
148
+ diffusion_type: Literal["absorbing", "uniform"] = "absorbing",
149
+ **kwargs,
150
+ ):
151
+ super().__init__(**kwargs)
152
+ self.keep_clean_bos = keep_clean_bos
153
+ self.diffusion_type = diffusion_type
154
+ self.T = T
155
+
156
+
157
+ class D3PM(Denoiser):
158
+ """Denoiser class for D3PM models.
159
+
160
+ This class implements the Denoiser interface for D3PM models.
161
+ """
162
+
163
+ config_class = D3PMConfig
164
+
165
+ def __init__(self, config: D3PMConfig, **kwargs):
166
+ super().__init__(config, **kwargs)
167
+ self.T = config.T
168
+ self.diffusion_type = config.diffusion_type
169
+ self._create_static_mask()
170
+
171
+ def _create_static_mask(self) -> None:
172
+ static_mask = torch.ones(
173
+ self.config.length, self.config.length, dtype=torch.bool
174
+ )
175
+ self.register_buffer(
176
+ "static_attention_mask",
177
+ static_mask,
178
+ )
179
+ self.skip_params_for_push.append("static_attention_mask")
180
+
181
+ def _sample_q_xt(
182
+ self,
183
+ x0: torch.LongTensor,
184
+ alpha_t: torch.FloatTensor,
185
+ context_mask: torch.FloatTensor,
186
+ ) -> torch.LongTensor:
187
+ """Sample from the pre-defined forward / noising process.
188
+
189
+ Parameters:
190
+ x0 (Tensor): Signal / data sample;
191
+ can potentially include context tokens.
192
+ alpha_t (Tensor): Amount of signal to retain.
193
+ context_mask (Tensor): Indicator of context tokens (to remain
194
+ unchanged).
195
+ """
196
+ move_indices = torch.rand(*x0.shape, device=x0.device) < (1.0 - alpha_t)
197
+ if self.diffusion_type == "absorbing":
198
+ xt = torch.where(
199
+ (move_indices * (1 - context_mask)).bool(), self.mask_token_id, x0
200
+ )
201
+ if self.config.keep_clean_bos:
202
+ xt[..., 0] = x0[..., 0]
203
+ return xt # type: ignore
204
+ if self.diffusion_type == "uniform":
205
+ xt = torch.randint(0, self.vocab_size, x0.shape, device=x0.device)
206
+ xt = torch.where(context_mask.bool(), x0, xt)
207
+ if self.config.keep_clean_bos:
208
+ xt[..., 0] = x0[..., 0]
209
+ return xt # type: ignore
210
+ raise NotImplementedError(
211
+ f"Diffusion type '{self.diffusion_type}' not implemented."
212
+ )
213
+
214
+ def _prepare_inputs(
215
+ self,
216
+ input_ids: torch.LongTensor,
217
+ attention_mask: Optional[torch.FloatTensor] = None,
218
+ context_mask: Optional[torch.FloatTensor] = None,
219
+ t: Optional[torch.FloatTensor] = None,
220
+ past_key_values: Optional[Cache] = None,
221
+ ):
222
+ # Prepare inputs for D3PM model
223
+ if attention_mask is None:
224
+ attention_mask = torch.ones_like(input_ids)
225
+ if context_mask is None:
226
+ context_mask = torch.zeros_like(attention_mask)
227
+
228
+ if torch.is_floating_point(attention_mask):
229
+ attention_mask = attention_mask.to(torch.int)
230
+ context_mask = context_mask.to(torch.int)
231
+
232
+ if t is None:
233
+ t = torch.rand(input_ids.shape[0], device=input_ids.device)
234
+ alpha_t, alpha_t_prime = self.noise_schedule(t)
235
+ while alpha_t.ndim < 2:
236
+ alpha_t = alpha_t[..., None]
237
+ alpha_t_prime = alpha_t_prime[..., None]
238
+ xt = self._sample_q_xt(
239
+ x0=input_ids,
240
+ alpha_t=alpha_t,
241
+ context_mask=context_mask,
242
+ )
243
+ if (
244
+ context_mask is not None
245
+ and context_mask.sum() == 0
246
+ and (attention_mask == 1).all()
247
+ ):
248
+ processed_attention_mask = None
249
+ else:
250
+ processed_attention_mask = (
251
+ self.static_attention_mask[None, ...]
252
+ & attention_mask[:, None, :]
253
+ & attention_mask[..., None]
254
+ )[:, None, ...] # Make attention mask 4D
255
+ processed_attention_mask = self._preprocess_attention_mask(
256
+ processed_attention_mask, dtype=torch.float
257
+ )
258
+ if self.training and self.config.train_on_context:
259
+ tokens_mask = attention_mask
260
+ else:
261
+ tokens_mask = attention_mask * (1 - context_mask)
262
+ return DenoiserInput(
263
+ xt=xt,
264
+ x0=input_ids,
265
+ attention_mask=processed_attention_mask,
266
+ context_mask=context_mask,
267
+ tokens_mask=tokens_mask,
268
+ t=t,
269
+ alpha_t=alpha_t,
270
+ alpha_t_prime=alpha_t_prime,
271
+ )
272
+
273
+ def _prepare_inputs_inference(
274
+ self,
275
+ input_ids: Optional[torch.LongTensor] = None,
276
+ attention_mask: Optional[torch.FloatTensor] = None,
277
+ context: Optional[torch.LongTensor] = None,
278
+ context_mask: Optional[torch.FloatTensor] = None,
279
+ cache: Optional[Dict[str, Any]] = None,
280
+ **backbone_kwargs: Any,
281
+ ) -> Tuple[DenoiserInput, Dict[str, Any]]:
282
+ assert input_ids is not None or context is not None, (
283
+ "Must provide either input_ids or context."
284
+ )
285
+ cache = cache if cache is not None else {}
286
+ past_key_values = cache.pop("past_key_values", DynamicCache())
287
+ if context is not None:
288
+ if input_ids is not None:
289
+ if context_mask is None:
290
+ context_mask = torch.cat(
291
+ [torch.ones_like(context), torch.zeros_like(input_ids)], dim=-1
292
+ )
293
+ input_ids = torch.cat([context, input_ids], dim=-1)
294
+ else:
295
+ input_ids = context
296
+ context_mask = torch.ones_like(input_ids)
297
+ if attention_mask is None:
298
+ cache_length = self._get_past_key_values_seq_length(past_key_values)
299
+ full_seq_length = cache_length + input_ids.shape[-1]
300
+ attention_mask = torch.ones(
301
+ (input_ids.shape[0], 1, input_ids.shape[1], full_seq_length),
302
+ device=input_ids.device,
303
+ ) # Make attention mask 4D
304
+ attention_mask = self._preprocess_attention_mask(
305
+ attention_mask, dtype=torch.float
306
+ )
307
+ return DenoiserInput(
308
+ xt=input_ids,
309
+ attention_mask=attention_mask,
310
+ past_key_values=past_key_values,
311
+ context_mask=context_mask,
312
+ backbone_kwargs=backbone_kwargs | {"use_cache": False},
313
+ ), cache
314
+
315
+ def _forward(
316
+ self,
317
+ backbone_output: torch.FloatTensor,
318
+ denoiser_inputs: DenoiserInput,
319
+ **kwargs,
320
+ ) -> torch.FloatTensor:
321
+ return torch.log_softmax(backbone_output, dim=-1) # type: ignore
322
+
323
+ def _compute_loss(
324
+ self,
325
+ model_output: torch.FloatTensor,
326
+ denoiser_inputs: DenoiserInput,
327
+ **kwargs: Any,
328
+ ) -> LossAndNllOutput:
329
+ raise NotImplementedError
330
+
331
+ def _sample_prior(self, device, batch_size, length):
332
+ """Samples from prior / limiting distribution."""
333
+ if self.diffusion_type == "absorbing":
334
+ return self.mask_token_id * torch.ones(
335
+ (batch_size, length), dtype=torch.int64, device=device
336
+ )
337
+ if self.diffusion_type == "uniform":
338
+ return torch.randint(
339
+ 0,
340
+ self.vocab_size,
341
+ (batch_size, length),
342
+ device=device,
343
+ dtype=torch.int64,
344
+ )
345
+ raise NotImplementedError(
346
+ f"Diffusion type '{self.diffusion_type}' not implemented."
347
+ )
348
+
349
+ def _compute_posterior(
350
+ self,
351
+ x: Union[torch.FloatTensor, torch.LongTensor],
352
+ xt: torch.LongTensor,
353
+ alpha_t: torch.FloatTensor,
354
+ alpha_s: torch.FloatTensor,
355
+ ) -> torch.FloatTensor:
356
+ """Computes posterior / approximate posterior q(x_s | x_t, x),
357
+ where x represents clean sequence (as one-hots) or the output of the
358
+ denoising model.
359
+
360
+ Args:
361
+ x (Tensor): True (one-hot) / predicted clean signal (B, L, V).
362
+ xt (Tensor): Noised signal at time t (B, L).
363
+ alpha_t (Tensor): Noise schedule parameter at time t (B, 1, 1).
364
+ alpha_s (Tensor): Noise schedule parameter at time s (B, 1, 1).
365
+ """
366
+ if self.diffusion_type == "absorbing":
367
+ q_xs = x * (alpha_s - alpha_t)
368
+ q_xs[..., self.mask_token_id] = 1 - alpha_s[..., 0]
369
+ q_xs /= 1 - alpha_t
370
+ return q_xs # type: ignore
371
+
372
+ alpha_ts = alpha_t / alpha_s
373
+ d_alpha = alpha_s - alpha_t
374
+ xt_one_hot = torch.nn.functional.one_hot(x, self.vocab_size)
375
+ limiting_distribution = torch.ones_like(xt_one_hot) / self.vocab_size
376
+ if self.diffusion_type == "uniform":
377
+ return (
378
+ alpha_t * self.vocab_size * x * xt_one_hot
379
+ + (alpha_ts - alpha_t) * xt_one_hot
380
+ + d_alpha * x
381
+ + (1 - alpha_ts) * (1 - alpha_s) * limiting_distribution
382
+ ) / (
383
+ alpha_t * self.vocab_size * torch.gather(x, -1, xt[..., None])
384
+ + (1 - alpha_t)
385
+ )
386
+ raise NotImplementedError(
387
+ f"Diffusion type {self.diffusion_type} not implemented."
388
+ )
389
+
390
+ @staticmethod
391
+ def _sample_generation_timesteps(
392
+ generation_config: DiffusionGenerationConfig,
393
+ max_length: Optional[int] = None,
394
+ device: Optional[str] = None,
395
+ ) -> torch.FloatTensor:
396
+ """Sample timesteps for diffusion generation process."""
397
+ if device is None:
398
+ device = "cuda" if torch.cuda.is_available() else "cpu"
399
+ if max_length is None:
400
+ max_length = generation_config.max_new_tokens
401
+
402
+ if (
403
+ generation_config.first_hitting
404
+ # TODO: first-hitting does not work with posterior
405
+ and generation_config.sampling_strategy == "posterior"
406
+ ):
407
+ timesteps = torch.FloatTensor([1.0])
408
+ for i in range(max_length, 0, -1):
409
+ u = torch.rand(1)
410
+ next_t = timesteps[-1] * u ** (1 / i)
411
+ timesteps = torch.cat((timesteps, next_t), dim=0)
412
+ return timesteps[1:].to(device) # type: ignore
413
+ return torch.linspace( # type: ignore
414
+ 1.0,
415
+ generation_config.min_t,
416
+ generation_config.num_steps + 1,
417
+ device=device,
418
+ )[:-1]
419
+
420
+ def _generate_unconditional(
421
+ self,
422
+ generation_config: DiffusionGenerationConfig,
423
+ alpha_t: torch.FloatTensor,
424
+ alpha_s: torch.FloatTensor,
425
+ denoiser_inputs: Optional[DenoiserInput] = None,
426
+ model_output_cache: Optional[Dict[str, torch.FloatTensor]] = None,
427
+ cache: Optional[Dict[str, Any]] = None,
428
+ running_generation: Optional[torch.LongTensor] = None,
429
+ logits_processor: Optional[LogitsProcessorList] = None,
430
+ **kwargs: Any,
431
+ ) -> Tuple[torch.LongTensor, Dict[str, torch.FloatTensor], Dict[str, Any]]:
432
+ cache = cache if cache is not None else {}
433
+ if model_output_cache is None: # execute function evaluation
434
+ backbone_output = self._backbone_forward(
435
+ denoiser_inputs,
436
+ fix_cache_length=True, # Do not let kv cache grow on each forward call
437
+ **cache,
438
+ **kwargs,
439
+ )
440
+ backbone_output = {k: v for k, v in backbone_output.items()}
441
+ logits = backbone_output.pop("logits")
442
+ cache = cache | backbone_output
443
+ log_x_theta = self._forward(logits, denoiser_inputs, **kwargs)
444
+ if logits_processor is not None:
445
+ for token_idx in range(log_x_theta.shape[1]):
446
+ # TODO: Looping over token positions like this does not allow for
447
+ # some processors, e.g. length penalty which could be applied all
448
+ # at once to the entire block, to be applied in parallel.
449
+ log_x_theta[:, token_idx] = logits_processor(
450
+ input_ids=running_generation,
451
+ scores=log_x_theta[:, token_idx], # type: ignore
452
+ )
453
+ log_x_theta = torch.log_softmax(log_x_theta, dim=-1) # re-normalize
454
+ x_theta = log_x_theta.exp()
455
+ else:
456
+ x_theta = model_output_cache["x_theta"]
457
+ model_output_cache = {"x_theta": x_theta}
458
+ prob_check_denom = denoiser_inputs.xt.numel()
459
+ if generation_config.sampling_strategy == "posterior":
460
+ q_xs = self._compute_posterior(
461
+ x_theta, denoiser_inputs.xt, alpha_t, alpha_s
462
+ )
463
+
464
+ assert abs((q_xs.sum() / prob_check_denom).item() - 1.0) < 1e-6, (
465
+ "Posterior probabilities not summing to 1."
466
+ )
467
+ assert q_xs.isnan().sum().item() == 0, "NaN found in the posterior."
468
+ xs = self._sample_categorical(q_xs, generation_config.do_sample)
469
+ output = torch.where(
470
+ (denoiser_inputs.xt != self.mask_token_id).bool(), # type: ignore
471
+ denoiser_inputs.xt,
472
+ xs,
473
+ )
474
+ elif generation_config.sampling_strategy == "predict_and_noise":
475
+ assert self.config.diffusion_type == "absorbing", (
476
+ "predict_and_noise decoding strategy only supports absorbing diffusion."
477
+ )
478
+ # assert (
479
+ # abs((x_theta.sum() / prob_check_denom).item() - 1.0) < 1e-6
480
+ # ), "Denoising output probabilities not summing to 1."
481
+ # assert x_theta.isnan().sum().item() == 0, (
482
+ # "NaN found in the denoising output."
483
+ # )
484
+
485
+ # Predict
486
+ xs = self._sample_categorical(x_theta, generation_config.do_sample)
487
+ xs_probs = x_theta.gather(-1, xs[..., None]).squeeze(dim=-1)
488
+ output = xs.clone()
489
+
490
+ # Noise
491
+ num_noise_indices = torch.minimum(
492
+ ((1 - alpha_s) * generation_config.block_size).to(torch.int),
493
+ (denoiser_inputs.xt == self.mask_token_id).sum() - 1, # type: ignore
494
+ )
495
+ if generation_config.confidence_based_noising:
496
+ conf = x_theta.gather(-1, xs[..., None]).squeeze(-1)
497
+ conf = torch.where( # already decoded tokens have 'inf' confidence
498
+ (denoiser_inputs.xt == self.mask_token_id).bool(), # type: ignore
499
+ conf,
500
+ torch.inf,
501
+ )
502
+ noise_indices = conf.argsort(dim=-1)[..., :num_noise_indices]
503
+ elif generation_config.confidence_margin_based_noising:
504
+ top2 = torch.topk(x_theta, k=2, dim=-1).values # shape: (B, L, 2)
505
+ conf = (top2[..., 0] - top2[..., 1]).abs()
506
+ conf = torch.where( # already decoded tokens have 'inf' confidence
507
+ (denoiser_inputs.xt == self.mask_token_id).bool(), # type: ignore
508
+ conf,
509
+ torch.inf,
510
+ )
511
+ noise_indices = conf.argsort(dim=-1)[..., :num_noise_indices]
512
+ else:
513
+ # TODO: implement random noise indices selection
514
+ raise NotImplementedError
515
+ output[..., noise_indices] = self.mask_token_id
516
+ output = torch.where(
517
+ xs_probs >= generation_config.confidence_threshold, xs, output
518
+ )
519
+ else:
520
+ raise NotImplementedError(
521
+ f"Sampling strategy {generation_config.sampling_strategy} not"
522
+ " implemented."
523
+ )
524
+ return output, model_output_cache, cache # type: ignore
525
+
526
+ @torch.no_grad()
527
+ def generate(
528
+ self,
529
+ inputs: Optional[torch.LongTensor] = None,
530
+ generation_config: Optional[DiffusionGenerationConfig] = None,
531
+ logits_processor: Optional[LogitsProcessorList] = None,
532
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
533
+ max_length: Optional[int] = None,
534
+ max_new_tokens: Optional[int] = None,
535
+ batch_size: Optional[int] = None,
536
+ device: Optional[str] = None,
537
+ tokenizer: Optional[PreTrainedTokenizer] = None,
538
+ disable_pbar: bool = False,
539
+ **kwargs: Any,
540
+ ) -> torch.LongTensor:
541
+ # Setup sampling variables
542
+ if generation_config is None:
543
+ assert getattr(self, "generation_config", None) is not None, (
544
+ "Generation config must be provided if not present in the model."
545
+ )
546
+ generation_config = self.generation_config
547
+ if inputs is None:
548
+ inputs = torch.ones((batch_size, 1), device=device) * self.bos_token_id
549
+ if max_length is None:
550
+ if hasattr(generation_config, "max_length"):
551
+ max_length = generation_config.max_length
552
+ else:
553
+ max_length = self.max_length
554
+ if max_new_tokens is None:
555
+ if hasattr(generation_config, "max_new_tokens"):
556
+ max_new_tokens = generation_config.max_new_tokens
557
+ else:
558
+ max_new_tokens = max_length - inputs.shape[-1]
559
+ batch_size = batch_size if batch_size is not None else inputs.shape[0]
560
+ assert batch_size == 1, "Batched sampling not supported yet"
561
+ if device is None:
562
+ device = "cuda" if torch.cuda.is_available() else "cpu"
563
+ block_size = generation_config.block_size
564
+ max_blocks = max_new_tokens // block_size
565
+
566
+ # Sample max generation length tensor from prior
567
+ accumulated_samples = self._sample_prior(
568
+ device=device,
569
+ batch_size=batch_size,
570
+ length=max_blocks * block_size,
571
+ )
572
+ accumulated_samples = torch.cat([inputs, accumulated_samples], dim=-1)
573
+ if generation_config.use_cache and inputs.numel() > 0:
574
+ cache = self.update_cache(
575
+ inputs=inputs[:, : block_size * (inputs.shape[-1] // block_size)]
576
+ if generation_config.align_inputs_to_blocks
577
+ else inputs,
578
+ cache={},
579
+ )
580
+ else:
581
+ cache = None
582
+
583
+ if generation_config.align_inputs_to_blocks:
584
+ inputs_offset = (
585
+ block_size * (inputs.shape[-1] // block_size)
586
+ if inputs.numel() > 0
587
+ else 0
588
+ )
589
+ else:
590
+ inputs_offset = inputs.shape[-1] if inputs.numel() > 0 else 0
591
+
592
+ total_NFEs = 0
593
+ timesteps = self._sample_generation_timesteps( # Re-use in every block
594
+ generation_config, max_length=block_size, device=device
595
+ )
596
+ dt = (1 - generation_config.min_t) / len(timesteps)
597
+ block_pbar = tqdm(
598
+ range(max_blocks),
599
+ desc="Blocks",
600
+ leave=True,
601
+ disable=disable_pbar,
602
+ )
603
+ for block_id in block_pbar:
604
+ block_NFEs = 0
605
+ xt = accumulated_samples[
606
+ :,
607
+ inputs_offset + (block_id * block_size) : inputs_offset
608
+ + ((block_id + 1) * block_size),
609
+ ]
610
+ if self.mask_token_id not in xt:
611
+ continue
612
+ step_pbar = tqdm(
613
+ timesteps,
614
+ desc="T",
615
+ total=timesteps.shape[0],
616
+ leave=False,
617
+ disable=disable_pbar,
618
+ )
619
+ model_output_cache = None
620
+ context = (
621
+ accumulated_samples[:, : (block_id * block_size) + inputs_offset]
622
+ if not generation_config.use_cache
623
+ else None
624
+ )
625
+ # Used for logit processing
626
+ running_generation = accumulated_samples[
627
+ :,
628
+ inputs_offset : inputs_offset + (block_id * block_size),
629
+ ]
630
+ for t in step_pbar:
631
+ if model_output_cache is None:
632
+ block_NFEs += 1
633
+ total_NFEs += 1
634
+ # t is 0-dim tensor, reshape to (1, 1, 1) for broadcasting
635
+ alpha_t, _ = self.noise_schedule(t)
636
+ alpha_s, _ = self.noise_schedule(t - dt)
637
+ alpha_t = alpha_t[None, None, None]
638
+ alpha_s = alpha_s[None, None, None]
639
+ denoiser_inputs, cache = self._prepare_inputs_inference(
640
+ input_ids=xt,
641
+ context=context,
642
+ cache=cache if generation_config.use_cache else None,
643
+ )
644
+ xs, model_output_cache, cache = self._generate_unconditional(
645
+ generation_config=generation_config,
646
+ alpha_t=alpha_t,
647
+ alpha_s=alpha_s,
648
+ denoiser_inputs=denoiser_inputs,
649
+ model_output_cache=model_output_cache,
650
+ cache=cache,
651
+ running_generation=running_generation, # type: ignore
652
+ logits_processor=logits_processor,
653
+ tokenizer=tokenizer,
654
+ **kwargs,
655
+ )
656
+ block_pbar.set_postfix(
657
+ NFEs=total_NFEs,
658
+ block_NFEs=block_NFEs,
659
+ )
660
+
661
+ if (
662
+ not torch.allclose(xs, denoiser_inputs.xt)
663
+ or not generation_config.use_model_output_cache
664
+ ):
665
+ model_output_cache = None
666
+ if not generation_config.use_cache:
667
+ xt[..., -block_size:] = xs[..., -block_size:]
668
+ else:
669
+ xt = xs
670
+ if (
671
+ xt == self.mask_token_id
672
+ ).sum().item() == 0 and self.config.diffusion_type == "absorbing":
673
+ break
674
+ accumulated_samples[
675
+ :,
676
+ inputs_offset + (block_id * block_size) : inputs_offset
677
+ + ((block_id + 1) * block_size),
678
+ ] = xt
679
+ if tokenizer is not None: # Useful for debugging
680
+ print(tokenizer.batch_decode(accumulated_samples))
681
+ if stopping_criteria is not None:
682
+ is_done = stopping_criteria(
683
+ input_ids=accumulated_samples[ # type: ignore
684
+ :,
685
+ inputs_offset : inputs_offset + ((block_id + 1) * block_size),
686
+ ],
687
+ scores=None, # type: ignore
688
+ )
689
+ if torch.any(is_done):
690
+ accumulated_samples = accumulated_samples[
691
+ :,
692
+ : inputs_offset + ((block_id + 1) * block_size),
693
+ ]
694
+ break
695
+ if generation_config.use_cache:
696
+ cache = self.update_cache(
697
+ inputs=xt,
698
+ cache=cache,
699
+ )
700
+ return accumulated_samples # type: ignore
701
+
702
+
703
+ class MDLMConfig(D3PMConfig):
704
+ """Configuration class for MDLM models."""
705
+
706
+ model_type = "mdlm"
707
+ auto_map = {
708
+ "AutoConfig": "diffusion.MDLMConfig",
709
+ "AutoModel": "diffusion.MDLM",
710
+ "AutoModelForMaskedLM": "diffusion.MDLM",
711
+ }
712
+
713
+
714
+ class MDLM(D3PM):
715
+ """Denoiser class for MDLM models."""
716
+
717
+ config_class = MDLMConfig
718
+
719
+ def __init__(self, config: MDLMConfig, **kwargs):
720
+ super().__init__(config, **kwargs)
721
+ self.neg_infinity = -1e12
722
+
723
+ def _forward(
724
+ self,
725
+ backbone_output: torch.FloatTensor,
726
+ denoiser_inputs: DenoiserInput,
727
+ **kwargs,
728
+ ) -> torch.FloatTensor:
729
+ # Zero-mask probability
730
+ backbone_output[..., self.mask_token_id] = self.neg_infinity
731
+ log_probs = backbone_output - torch.logsumexp(
732
+ backbone_output, dim=-1, keepdim=True
733
+ )
734
+ # Copy-over unmasked: For the log_probs of the unmasked tokens, set all values
735
+ # to -infinity except for the indices corresponding to
736
+ # the unmasked tokens.
737
+ xt = denoiser_inputs.xt
738
+ unmasked_indices = xt != self.mask_token_id
739
+ log_probs[unmasked_indices] = self.neg_infinity
740
+ log_probs[unmasked_indices, xt[unmasked_indices]] = 0
741
+ return log_probs # type: ignore
742
+
743
+ def _compute_loss(
744
+ self,
745
+ model_output: torch.FloatTensor,
746
+ denoiser_inputs: DenoiserInput,
747
+ **kwargs: Any,
748
+ ) -> LossAndNllOutput:
749
+ log_p_theta = torch.gather(
750
+ input=model_output, dim=-1, index=denoiser_inputs.x0[:, :, None]
751
+ ).squeeze(-1)
752
+ nlls = (
753
+ log_p_theta
754
+ * denoiser_inputs.alpha_t_prime
755
+ / (1 - denoiser_inputs.alpha_t)
756
+ * denoiser_inputs.tokens_mask
757
+ )
758
+ if self.training:
759
+ batch_nll = -(log_p_theta * denoiser_inputs.tokens_mask).sum(dim=-1)
760
+ else:
761
+ batch_nll = nlls.sum(dim=-1)
762
+ count = denoiser_inputs.tokens_mask.sum(dim=-1)
763
+ token_nll = (batch_nll / count).mean()
764
+ return LossAndNllOutput(
765
+ loss=token_nll, # type: ignore
766
+ nlls=nlls,
767
+ other_loss_terms={
768
+ "masked_tokens": (denoiser_inputs.xt == self.mask_token_id).int()
769
+ },
770
+ )
771
+
772
+
773
+ class BD3LMConfig(MDLMConfig):
774
+ """Configuration class for BD3LM models."""
775
+
776
+ model_type = "bd3lm"
777
+ auto_map = {
778
+ "AutoConfig": "diffusion.BD3LMConfig",
779
+ "AutoModel": "diffusion.BD3LM",
780
+ "AutoModelForMaskedLM": "diffusion.BD3LM",
781
+ }
782
+
783
+ def __init__(
784
+ self,
785
+ block_size: Optional[int] = None,
786
+ eval_block_size: Optional[int] = None,
787
+ **kwargs,
788
+ ):
789
+ super().__init__(**kwargs)
790
+ self.block_size = block_size
791
+ self.eval_block_size = (
792
+ eval_block_size if eval_block_size is not None else block_size
793
+ )
794
+
795
+
796
+ class BD3LM(MDLM):
797
+ """Denoiser class for BD3LM models."""
798
+
799
+ config_class = BD3LMConfig
800
+
801
+ def __init__(self, config: BD3LMConfig, **kwargs):
802
+ super().__init__(config, **kwargs)
803
+
804
+ # noinspection PyUnusedLocal
805
+ @staticmethod
806
+ def _block_mask(
807
+ b,
808
+ h,
809
+ q_idx,
810
+ kv_idx,
811
+ block_size: Optional[int] = None,
812
+ seq_length: Optional[int] = None,
813
+ ) -> torch.Tensor:
814
+ del b, h
815
+
816
+ # Indicate whether token belongs to xt or x0:
817
+ xt_flag_q = (q_idx >= seq_length).bool()
818
+ xt_flag_kv = (kv_idx >= seq_length).bool()
819
+
820
+ # Compute block indices
821
+ block_q = torch.where(
822
+ xt_flag_q, (q_idx - seq_length) // block_size, q_idx // block_size
823
+ )
824
+ block_kv = torch.where(
825
+ xt_flag_kv, (kv_idx - seq_length) // block_size, kv_idx // block_size
826
+ )
827
+ # **1. Offset Block-Causal Mask (M_OBC) **
828
+ offset_block_causal = (block_q > block_kv) & ~xt_flag_kv & xt_flag_q
829
+
830
+ # **2. Block Diagonal Mask (M_BD) **
831
+ block_diagonal = (block_q == block_kv) & (xt_flag_q == xt_flag_kv)
832
+
833
+ # **3. Block-Causal Mask (M_BC) **
834
+ block_causal = (block_q >= block_kv) & ~xt_flag_kv & ~xt_flag_q
835
+
836
+ # **3. Combine Masks **
837
+ return block_diagonal | offset_block_causal | block_causal
838
+
839
+ def _create_static_mask(self) -> None:
840
+ if self.config.attn_backend == "sdpa":
841
+ static_mask = self._block_mask(
842
+ b=None,
843
+ h=None,
844
+ q_idx=torch.arange(self.config.length * 2)[:, None],
845
+ kv_idx=torch.arange(self.config.length * 2)[None, :],
846
+ block_size=self.config.block_size
847
+ if self.training
848
+ else self.config.eval_block_size,
849
+ seq_length=self.config.length,
850
+ )
851
+ self.register_buffer(
852
+ "static_attention_mask",
853
+ static_mask,
854
+ )
855
+ self.skip_params_for_push.append("static_attention_mask")
856
+ elif self.config.attn_backend == "flex_attention":
857
+ mask = partial(
858
+ self._block_mask,
859
+ block_size=self.config.block_size
860
+ if self.training
861
+ else self.config.eval_block_size,
862
+ seq_length=self.config.length,
863
+ )
864
+ self.static_attention_mask = create_block_mask(
865
+ mask,
866
+ B=None,
867
+ H=None,
868
+ Q_LEN=self.config.length * 2,
869
+ KV_LEN=self.config.length * 2,
870
+ )
871
+
872
+ def _ensure_no_unmasked_blocks(
873
+ self,
874
+ input_ids: torch.LongTensor,
875
+ xt: torch.LongTensor,
876
+ context_mask: Optional[torch.FloatTensor] = None,
877
+ ) -> torch.Tensor:
878
+ n_blocks = xt.shape[1] // self.config.block_size
879
+ # If context overlaps w/block, ignore it
880
+ blocks_without_masks = ((xt == self.mask_token_id) + context_mask).reshape(
881
+ -1, n_blocks, self.config.block_size
882
+ ).sum(dim=-1) == 0
883
+ if blocks_without_masks.sum() > 0:
884
+ num_remasks_per_block = torch.randint(
885
+ 0,
886
+ self.config.block_size,
887
+ blocks_without_masks.shape,
888
+ device=xt.device,
889
+ )
890
+ rand = torch.rand(xt.shape[0], xt.shape[1], device=xt.device)
891
+ perm_indices = torch.argsort(
892
+ rand.view(xt.shape[0], n_blocks, self.config.block_size),
893
+ stable=True,
894
+ dim=-1,
895
+ )
896
+ remask_indices = perm_indices <= num_remasks_per_block[..., None]
897
+ xt = torch.where(
898
+ remask_indices.view(xt.shape[0], xt.shape[1])
899
+ * blocks_without_masks.repeat_interleave(self.config.block_size, dim=1),
900
+ self.mask_token_id,
901
+ xt,
902
+ )
903
+ if self.config.keep_clean_bos:
904
+ xt[..., 0] = input_ids[..., 0]
905
+ return xt
906
+
907
+ def _prepare_inputs(
908
+ self,
909
+ input_ids: torch.LongTensor,
910
+ attention_mask: Optional[torch.FloatTensor] = None,
911
+ context_mask: Optional[torch.FloatTensor] = None,
912
+ t: Optional[torch.FloatTensor] = None,
913
+ past_key_values: Optional[Cache] = None,
914
+ ):
915
+ if attention_mask is None:
916
+ attention_mask = torch.ones_like(input_ids)
917
+ if context_mask is None:
918
+ context_mask = torch.zeros_like(attention_mask)
919
+
920
+ if torch.is_floating_point(attention_mask):
921
+ attention_mask = attention_mask.to(torch.int)
922
+ context_mask = context_mask.to(torch.int)
923
+
924
+ if t is None:
925
+ t = torch.rand(
926
+ input_ids.shape[0],
927
+ input_ids.shape[1] // self.config.block_size
928
+ if self.training
929
+ else self.config.eval_block_size,
930
+ device=input_ids.device,
931
+ ).repeat_interleave(
932
+ self.config.block_size
933
+ if self.training
934
+ else self.config.eval_block_size,
935
+ dim=-1,
936
+ )
937
+ alpha_t, alpha_t_prime = self.noise_schedule(t)
938
+ while alpha_t.ndim < 2:
939
+ alpha_t = alpha_t[..., None]
940
+ alpha_t_prime = alpha_t_prime[..., None]
941
+ xt = self._sample_q_xt(x0=input_ids, alpha_t=alpha_t, context_mask=context_mask)
942
+ # Ensure each block has at least 1 masked token
943
+ if self.training:
944
+ xt = self._ensure_no_unmasked_blocks(
945
+ input_ids,
946
+ xt,
947
+ context_mask,
948
+ )
949
+ if self.config.attn_backend == "sdpa":
950
+ decoder_attention_mask = (
951
+ self.static_attention_mask[None, ...]
952
+ & attention_mask.repeat(1, 2)[:, None, :]
953
+ & attention_mask.repeat(1, 2)[..., None]
954
+ )[:, None, ...] # Make attention mask 4D
955
+ decoder_attention_mask = self._preprocess_attention_mask(
956
+ decoder_attention_mask, dtype=torch.float
957
+ )
958
+ elif self.config.attn_backend == "flex_attention":
959
+ if context_mask.any():
960
+ raise NotImplementedError(
961
+ "flex_attention with context_mask not implemented yet."
962
+ )
963
+ elif attention_mask is not None and (attention_mask != 1).any():
964
+ padding_mask = create_attn_mask(
965
+ attention_mask.bool().repeat(2, 2).bool()
966
+ )
967
+ dec_masks = [
968
+ partial(
969
+ self._block_mask,
970
+ block_size=self.config.block_size
971
+ if self.training
972
+ else self.config.eval_block_size,
973
+ seq_length=self.config.length,
974
+ ),
975
+ padding_mask,
976
+ ]
977
+ decoder_attention_mask = create_block_mask(
978
+ and_masks(*dec_masks),
979
+ B=input_ids.shape[0],
980
+ H=None,
981
+ Q_LEN=input_ids.shape[1] * 2,
982
+ KV_LEN=input_ids.shape[1] * 2,
983
+ )
984
+ else:
985
+ decoder_attention_mask = self.static_attention_mask
986
+ else:
987
+ raise ValueError("Unknown backbone backend")
988
+ backbone_input_ids = torch.cat((input_ids, xt), dim=-1)
989
+ position_ids = (
990
+ torch.arange(input_ids.shape[1]).repeat(2).to(input_ids.device)[None, :]
991
+ )
992
+ if self.training and self.config.train_on_context:
993
+ tokens_mask = attention_mask
994
+ else:
995
+ tokens_mask = attention_mask * (1 - context_mask)
996
+ return DenoiserInput(
997
+ xt=backbone_input_ids, # type: ignore
998
+ x0=input_ids,
999
+ attention_mask=decoder_attention_mask, # type: ignore
1000
+ tokens_mask=tokens_mask,
1001
+ t=t,
1002
+ alpha_t=alpha_t,
1003
+ alpha_t_prime=alpha_t_prime,
1004
+ backbone_kwargs={
1005
+ "cache_position": position_ids[0],
1006
+ "position_ids": position_ids,
1007
+ },
1008
+ )
1009
+
1010
+ def _prepare_inputs_inference(
1011
+ self,
1012
+ input_ids: Optional[torch.LongTensor] = None,
1013
+ attention_mask: Optional[torch.FloatTensor] = None,
1014
+ context: Optional[torch.LongTensor] = None,
1015
+ context_mask: Optional[torch.FloatTensor] = None,
1016
+ cache: Optional[Dict[str, Any]] = None,
1017
+ return_updated_cache: bool = False,
1018
+ **backbone_kwargs: Dict[str, Any],
1019
+ ) -> Tuple[DenoiserInput, Union[Dict[str, Any], None]]:
1020
+ device = input_ids.device if input_ids is not None else context.device
1021
+ assert input_ids is not None or context is not None, (
1022
+ "Must provide either input_ids or context."
1023
+ )
1024
+ cache = cache if cache is not None else {}
1025
+ past_key_values = cache.pop("past_key_values", DynamicCache())
1026
+ if context is not None:
1027
+ if input_ids is not None:
1028
+ input_ids = torch.cat([context, input_ids], dim=-1)
1029
+ else:
1030
+ input_ids = context
1031
+ cache_length = self._get_past_key_values_seq_length(past_key_values)
1032
+ full_seq_length = cache_length + input_ids.shape[-1]
1033
+ decoder_attention_mask = self.static_attention_mask[
1034
+ None,
1035
+ None,
1036
+ cache_length:full_seq_length,
1037
+ :full_seq_length,
1038
+ ] # Make attention mask 4D
1039
+ decoder_attention_mask = self._preprocess_attention_mask(
1040
+ decoder_attention_mask, dtype=torch.float
1041
+ )
1042
+ position_ids = torch.arange(cache_length, full_seq_length).to(device)[None, :]
1043
+ return DenoiserInput(
1044
+ xt=input_ids,
1045
+ attention_mask=decoder_attention_mask,
1046
+ context_mask=context_mask,
1047
+ past_key_values=past_key_values,
1048
+ backbone_kwargs={
1049
+ "position_ids": position_ids,
1050
+ }
1051
+ | backbone_kwargs,
1052
+ ), cache
1053
+
1054
+ def _compute_loss(
1055
+ self,
1056
+ model_output: torch.FloatTensor,
1057
+ denoiser_inputs: DenoiserInput,
1058
+ **kwargs: Any,
1059
+ ) -> LossAndNllOutput:
1060
+ input_length = denoiser_inputs.xt.shape[1] // 2
1061
+ model_output = model_output[:, input_length:, ...]
1062
+ return super()._compute_loss(
1063
+ model_output=model_output, # type: ignore
1064
+ denoiser_inputs=denoiser_inputs,
1065
+ **kwargs,
1066
+ )
1067
+
1068
+
1069
+ class E2D2Config(BD3LMConfig):
1070
+ """Configuration class for E2D2 models."""
1071
+
1072
+ model_type = "e2d2"
1073
+ auto_map = {
1074
+ "AutoConfig": "diffusion.E2D2Config",
1075
+ "AutoModel": "diffusion.E2D2",
1076
+ "AutoModelForMaskedLM": "diffusion.E2D2",
1077
+ }
1078
+
1079
+ def __init__(
1080
+ self,
1081
+ **kwargs,
1082
+ ):
1083
+ super().__init__(**kwargs)
1084
+
1085
+
1086
+ class E2D2(BD3LM):
1087
+ """Denoiser class for E2D2 models."""
1088
+
1089
+ config_class = E2D2Config
1090
+
1091
+ def __init__(self, config: E2D2Config, **kwargs):
1092
+ super().__init__(config, **kwargs)
1093
+
1094
+ # noinspection PyUnusedLocal
1095
+ @staticmethod
1096
+ def _encoder_block_mask(
1097
+ b,
1098
+ h,
1099
+ q_idx,
1100
+ kv_idx,
1101
+ block_size: Optional[int] = None,
1102
+ ) -> torch.Tensor:
1103
+ """
1104
+ Args:
1105
+ q_idx (Tensor): Query indices.
1106
+ kv_idx (Tensor): Key indices
1107
+ b (Optional: int): batch size
1108
+ h (Optional: int): number of heads
1109
+ block_size (Optional: int): Defines the block structure.
1110
+
1111
+ Returns:
1112
+ Encoder block-causal attention mask.
1113
+ """
1114
+
1115
+ # Compute block indices
1116
+ block_q = q_idx // block_size
1117
+ block_kv = kv_idx // block_size
1118
+
1119
+ # ** Block-Causal Mask **
1120
+ return block_q >= block_kv
1121
+
1122
+ # noinspection PyUnusedLocal
1123
+ @staticmethod
1124
+ def _decoder_block_mask(
1125
+ b,
1126
+ h,
1127
+ q_idx,
1128
+ kv_idx,
1129
+ block_size: Optional[int] = None,
1130
+ seq_length: Optional[int] = None,
1131
+ ) -> torch.Tensor:
1132
+ # Indicate whether token belongs to xt or x0:
1133
+ xt_flag_kv = (kv_idx >= seq_length).bool()
1134
+
1135
+ # Compute block indices
1136
+ block_q = q_idx // block_size
1137
+ block_kv = torch.where(
1138
+ xt_flag_kv, (kv_idx - seq_length) // block_size, kv_idx // block_size
1139
+ )
1140
+ # **1. Offset Block-Causal Mask (M_OBC) **
1141
+ offset_block_causal = (block_q > block_kv) & ~xt_flag_kv
1142
+
1143
+ # **2. Block Diagonal Mask (M_BD) **
1144
+ block_diagonal = (block_q == block_kv) & xt_flag_kv
1145
+
1146
+ # **3. Combine Masks **
1147
+ return block_diagonal | offset_block_causal
1148
+
1149
+ def _create_static_mask(self) -> None:
1150
+ if self.config.attn_backend == "flex_attention":
1151
+ enc_mask = partial(
1152
+ self._encoder_block_mask,
1153
+ block_size=self.config.block_size
1154
+ if self.training
1155
+ else self.config.eval_block_size,
1156
+ )
1157
+ encoder_attention_mask = create_block_mask(
1158
+ enc_mask,
1159
+ B=None,
1160
+ H=None,
1161
+ Q_LEN=self.config.length,
1162
+ KV_LEN=self.config.length,
1163
+ )
1164
+ dec_mask = partial(
1165
+ self._decoder_block_mask,
1166
+ block_size=self.config.block_size
1167
+ if self.training
1168
+ else self.config.eval_block_size,
1169
+ seq_length=self.config.length,
1170
+ )
1171
+ decoder_attention_mask = create_block_mask(
1172
+ dec_mask,
1173
+ B=None,
1174
+ H=None,
1175
+ Q_LEN=self.config.length,
1176
+ KV_LEN=self.config.length * 2,
1177
+ )
1178
+ self.encoder_static_attention_mask = encoder_attention_mask
1179
+ self.static_attention_mask = decoder_attention_mask
1180
+ else:
1181
+ encoder_static_mask = self._encoder_block_mask(
1182
+ b=None, # type: ignore
1183
+ h=None, # type: ignore
1184
+ q_idx=torch.arange(self.config.length)[:, None],
1185
+ kv_idx=torch.arange(self.config.length)[None, :],
1186
+ block_size=self.config.block_size
1187
+ if self.training
1188
+ else self.config.eval_block_size,
1189
+ )
1190
+ decoder_static_mask = self._decoder_block_mask(
1191
+ b=None,
1192
+ h=None,
1193
+ q_idx=torch.arange(self.config.length)[:, None],
1194
+ kv_idx=torch.arange(self.config.length * 2)[None, :],
1195
+ block_size=self.config.block_size
1196
+ if self.training
1197
+ else self.config.eval_block_size,
1198
+ seq_length=self.config.length,
1199
+ )
1200
+ self.register_buffer(
1201
+ "encoder_static_attention_mask",
1202
+ encoder_static_mask,
1203
+ )
1204
+ self.register_buffer(
1205
+ "static_attention_mask",
1206
+ decoder_static_mask,
1207
+ )
1208
+ self.skip_params_for_push.append("encoder_static_attention_mask")
1209
+ self.skip_params_for_push.append("static_attention_mask")
1210
+
1211
+ def _prepare_inputs(
1212
+ self,
1213
+ input_ids: torch.LongTensor,
1214
+ attention_mask: Optional[torch.FloatTensor] = None,
1215
+ context_mask: Optional[torch.FloatTensor] = None,
1216
+ t: Optional[torch.FloatTensor] = None,
1217
+ past_key_values: Optional[Cache] = None,
1218
+ ):
1219
+ if attention_mask is None:
1220
+ attention_mask = torch.ones_like(input_ids)
1221
+ if context_mask is None:
1222
+ context_mask = torch.zeros_like(attention_mask)
1223
+
1224
+ if torch.is_floating_point(attention_mask):
1225
+ attention_mask = attention_mask.to(torch.int)
1226
+ context_mask = context_mask.to(torch.int)
1227
+
1228
+ if t is None:
1229
+ t = torch.rand(
1230
+ input_ids.shape[0],
1231
+ input_ids.shape[1] // self.config.block_size
1232
+ if self.training
1233
+ else self.config.eval_block_size,
1234
+ device=input_ids.device,
1235
+ ).repeat_interleave(
1236
+ self.config.block_size
1237
+ if self.training
1238
+ else self.config.eval_block_size,
1239
+ dim=-1,
1240
+ )
1241
+ alpha_t, alpha_t_prime = self.noise_schedule(t)
1242
+ while alpha_t.ndim < 2:
1243
+ alpha_t = alpha_t[..., None]
1244
+ alpha_t_prime = alpha_t_prime[..., None]
1245
+ xt = self._sample_q_xt(x0=input_ids, alpha_t=alpha_t, context_mask=context_mask)
1246
+ # Ensure each block has at least 1 masked token
1247
+ if self.training:
1248
+ xt = self._ensure_no_unmasked_blocks(
1249
+ input_ids,
1250
+ xt,
1251
+ context_mask,
1252
+ )
1253
+ if self.config.attn_backend == "sdpa":
1254
+ decoder_attention_mask = (
1255
+ self.static_attention_mask[None, ...]
1256
+ & attention_mask.repeat(1, 2)[:, None, :]
1257
+ & attention_mask[..., None]
1258
+ )[:, None, ...] # Make attention mask 4D
1259
+ encoder_attention_mask = (
1260
+ (
1261
+ self.encoder_static_attention_mask[None, ...]
1262
+ | context_mask[:, None, :]
1263
+ )
1264
+ & attention_mask[:, None, :]
1265
+ & attention_mask[..., None]
1266
+ )[:, None, ...] # Make attention mask 4D
1267
+ encoder_attention_mask = self._preprocess_attention_mask(
1268
+ encoder_attention_mask, dtype=torch.float
1269
+ )
1270
+ decoder_attention_mask = self._preprocess_attention_mask(
1271
+ decoder_attention_mask, dtype=torch.float
1272
+ )
1273
+ elif self.config.attn_backend == "flex_attention":
1274
+ # TODO enable bidirectional attention on context for seq2seq tasks
1275
+ if context_mask.any():
1276
+ raise NotImplementedError(
1277
+ "flex_attention with context_mask not implemented yet."
1278
+ )
1279
+ elif attention_mask is not None and (attention_mask != 1).any():
1280
+ padding_mask = create_attn_mask(attention_mask.bool())
1281
+ dec_padding_mask = create_attn_mask(attention_mask.repeat(1, 2).bool())
1282
+ enc_masks = [
1283
+ partial(
1284
+ self._encoder_block_mask,
1285
+ block_size=self.config.block_size
1286
+ if self.training
1287
+ else self.config.eval_block_size,
1288
+ ),
1289
+ padding_mask,
1290
+ ]
1291
+ encoder_attention_mask = create_block_mask(
1292
+ and_masks(*enc_masks),
1293
+ B=input_ids.shape[0],
1294
+ H=None,
1295
+ Q_LEN=input_ids.shape[1],
1296
+ KV_LEN=input_ids.shape[1],
1297
+ )
1298
+ dec_masks = [
1299
+ partial(
1300
+ self._decoder_block_mask,
1301
+ block_size=self.config.block_size
1302
+ if self.training
1303
+ else self.config.eval_block_size,
1304
+ seq_length=input_ids.shape[1],
1305
+ ),
1306
+ dec_padding_mask,
1307
+ ]
1308
+ decoder_attention_mask = create_block_mask(
1309
+ and_masks(*dec_masks),
1310
+ B=input_ids.shape[0],
1311
+ H=None,
1312
+ Q_LEN=input_ids.shape[1],
1313
+ KV_LEN=input_ids.shape[1] * 2,
1314
+ )
1315
+ else:
1316
+ encoder_attention_mask = self.encoder_static_attention_mask
1317
+ decoder_attention_mask = self.static_attention_mask
1318
+ else:
1319
+ raise ValueError("Unknown backbone backend")
1320
+ position_ids = torch.arange(input_ids.shape[1]).to(input_ids.device)[None, :]
1321
+ if self.training and self.config.train_on_context:
1322
+ tokens_mask = attention_mask
1323
+ else:
1324
+ tokens_mask = attention_mask * (1 - context_mask)
1325
+ return DenoiserInput(
1326
+ xt=xt,
1327
+ x0=input_ids,
1328
+ attention_mask=decoder_attention_mask,
1329
+ tokens_mask=tokens_mask,
1330
+ t=t,
1331
+ alpha_t=alpha_t,
1332
+ alpha_t_prime=alpha_t_prime,
1333
+ backbone_kwargs={
1334
+ "encoder_input_ids": input_ids,
1335
+ "encoder_attention_mask": encoder_attention_mask,
1336
+ "encoder_position_ids": position_ids,
1337
+ "encoder_cache_position": position_ids[0],
1338
+ },
1339
+ )
1340
+
1341
+ def _prepare_inputs_inference(
1342
+ self,
1343
+ input_ids: Optional[torch.LongTensor] = None,
1344
+ attention_mask: Optional[torch.FloatTensor] = None,
1345
+ context: Optional[torch.LongTensor] = None,
1346
+ context_mask: Optional[torch.FloatTensor] = None,
1347
+ cache: Optional[Dict[str, Any]] = None,
1348
+ return_updated_cache: bool = False,
1349
+ **backbone_kwargs: Dict[str, Any],
1350
+ ) -> Tuple[DenoiserInput, Union[Dict[str, Any], None]]:
1351
+ device = input_ids.device if input_ids is not None else context.device
1352
+ batch_size = input_ids.shape[0] if input_ids is not None else context.shape[0]
1353
+ assert input_ids is not None or context is not None, (
1354
+ "Must provide either input_ids or context."
1355
+ )
1356
+ if return_updated_cache: # Indicates this is a cache update step
1357
+ context = input_ids
1358
+ input_ids = None
1359
+ position_ids, encoder_position_ids = None, None
1360
+ if cache is not None:
1361
+ past_key_values = cache.pop("past_key_values", DynamicCache())
1362
+ encoder_past_key_values = cache.pop(
1363
+ "encoder_past_key_values", DynamicCache()
1364
+ )
1365
+ encoder_last_hidden_state = cache.pop("encoder_last_hidden_state", None)
1366
+ if input_ids is not None: # Skip enc: nothing new to cache
1367
+ cache_length = self._get_past_key_values_seq_length(past_key_values)
1368
+ if encoder_last_hidden_state is not None:
1369
+ full_seq_length = (
1370
+ cache_length
1371
+ + encoder_last_hidden_state.shape[1] # type: ignore
1372
+ + input_ids.shape[-1]
1373
+ )
1374
+ else:
1375
+ full_seq_length = cache_length + input_ids.shape[-1]
1376
+ encoder_attention_mask = None
1377
+ position_ids = torch.arange(
1378
+ cache_length, full_seq_length, device=device
1379
+ )[None, :]
1380
+ else: # Caching new tokens in the enc
1381
+ encoder_cache_length = self._get_past_key_values_seq_length(
1382
+ encoder_past_key_values
1383
+ if len(encoder_past_key_values) > 0
1384
+ else past_key_values
1385
+ )
1386
+ encoder_full_seq_length = encoder_cache_length + context.shape[-1]
1387
+ encoder_attention_mask = torch.ones(
1388
+ (
1389
+ 1,
1390
+ 1,
1391
+ encoder_full_seq_length - encoder_cache_length,
1392
+ encoder_full_seq_length,
1393
+ ),
1394
+ device=context.device,
1395
+ )
1396
+ encoder_position_ids = torch.arange(
1397
+ encoder_cache_length, encoder_full_seq_length
1398
+ ).to(device)[None, :]
1399
+ encoder_attention_mask = self._preprocess_attention_mask(
1400
+ encoder_attention_mask, dtype=torch.float
1401
+ )
1402
+ full_seq_length = -1 # Not used
1403
+ else: # Not using kv-cache
1404
+ past_key_values = None
1405
+ encoder_past_key_values, encoder_last_hidden_state = None, None
1406
+ if context is not None:
1407
+ context_len = context.shape[1]
1408
+ encoder_attention_mask = torch.ones(
1409
+ (1, 1, context_len, context_len), device=context.device
1410
+ )
1411
+ encoder_attention_mask = self._preprocess_attention_mask(
1412
+ encoder_attention_mask, dtype=torch.float
1413
+ )
1414
+ encoder_position_ids = torch.arange(context_len).to(device)[None, :]
1415
+ else:
1416
+ context_len = 0
1417
+ encoder_attention_mask = None
1418
+ if input_ids is not None:
1419
+ full_seq_length = context_len + input_ids.shape[1]
1420
+ else:
1421
+ full_seq_length = context_len
1422
+ position_ids = torch.arange(context_len, full_seq_length).to(device)[
1423
+ None, :
1424
+ ]
1425
+ if input_ids is not None:
1426
+ decoder_attention_mask = torch.ones(
1427
+ (batch_size, 1, input_ids.shape[1], full_seq_length),
1428
+ device=device,
1429
+ ) # Make attention mask 4D
1430
+ decoder_attention_mask = self._preprocess_attention_mask(
1431
+ decoder_attention_mask, dtype=torch.float
1432
+ )
1433
+ else:
1434
+ decoder_attention_mask = None
1435
+ return DenoiserInput(
1436
+ xt=input_ids,
1437
+ attention_mask=decoder_attention_mask,
1438
+ context_mask=context_mask,
1439
+ past_key_values=past_key_values,
1440
+ backbone_kwargs={
1441
+ "position_ids": position_ids,
1442
+ "encoder_input_ids": context,
1443
+ "encoder_position_ids": encoder_position_ids,
1444
+ "encoder_attention_mask": encoder_attention_mask,
1445
+ "encoder_past_key_values": encoder_past_key_values,
1446
+ "encoder_last_hidden_state": encoder_last_hidden_state,
1447
+ }
1448
+ | backbone_kwargs,
1449
+ ), cache # TODO: potentially returning cache None, violates return type
1450
+
1451
+ def _compute_loss(
1452
+ self,
1453
+ model_output: torch.FloatTensor,
1454
+ denoiser_inputs: DenoiserInput,
1455
+ **kwargs: Any,
1456
+ ) -> LossAndNllOutput:
1457
+ # Use MDLM `_compute_loss`, since BD3LM method splits model_output
1458
+ return super(BD3LM, self)._compute_loss(
1459
+ model_output=model_output,
1460
+ denoiser_inputs=denoiser_inputs,
1461
+ **kwargs,
1462
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29daae41ab012a61516d316e7cd54de035044e6c8890395ccc542ba161c07aa1
3
+ size 1016097199