CARLEXsX commited on
Commit
f8aad81
·
verified ·
1 Parent(s): 0e05fab

Upload ai_studio_code - 2025-08-16T134813.673.py

Browse files
ltx_video/models/transformers/ai_studio_code - 2025-08-16T134813.673.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --- START OF MODIFIED FILE app_fluxContext_Ltx/ltx_video/models/transformers/transformer3d.py ---
2
+ # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Optional, Union
6
+ import os
7
+ import json
8
+ import glob
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ import numpy as np
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.models.embeddings import PixArtAlphaTextProjection
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.models.normalization import AdaLayerNormSingle
17
+ from diffusers.utils import BaseOutput, is_torch_version
18
+ from diffusers.utils import logging
19
+ from torch import nn
20
+ from safetensors import safe_open
21
+
22
+
23
+ from ltx_video.models.transformers.attention import BasicTransformerBlock
24
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
25
+
26
+ from ltx_video.utils.diffusers_config_mapping import (
27
+ diffusers_and_ours_config_mapping,
28
+ make_hashable_key,
29
+ TRANSFORMER_KEYS_RENAME_DICT,
30
+ )
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class Transformer3DModelOutput(BaseOutput):
38
+ """
39
+ The output of [`Transformer2DModel`].
40
+
41
+ Args:
42
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
43
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
44
+ distributions for the unnoised latent pixels.
45
+ """
46
+
47
+ sample: torch.FloatTensor
48
+
49
+
50
+ class Transformer3DModel(ModelMixin, ConfigMixin):
51
+ _supports_gradient_checkpointing = True
52
+
53
+ @register_to_config
54
+ def __init__(
55
+ self,
56
+ num_attention_heads: int = 16,
57
+ attention_head_dim: int = 88,
58
+ in_channels: Optional[int] = None,
59
+ out_channels: Optional[int] = None,
60
+ num_layers: int = 1,
61
+ dropout: float = 0.0,
62
+ norm_num_groups: int = 32,
63
+ cross_attention_dim: Optional[int] = None,
64
+ attention_bias: bool = False,
65
+ num_vector_embeds: Optional[int] = None,
66
+ activation_fn: str = "geglu",
67
+ num_embeds_ada_norm: Optional[int] = None,
68
+ use_linear_projection: bool = False,
69
+ only_cross_attention: bool = False,
70
+ double_self_attention: bool = False,
71
+ upcast_attention: bool = False,
72
+ adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale'
73
+ standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
74
+ norm_elementwise_affine: bool = True,
75
+ norm_eps: float = 1e-5,
76
+ attention_type: str = "default",
77
+ caption_channels: int = None,
78
+ use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention')
79
+ qk_norm: Optional[str] = None,
80
+ positional_embedding_type: str = "rope",
81
+ positional_embedding_theta: Optional[float] = None,
82
+ positional_embedding_max_pos: Optional[List[int]] = None,
83
+ timestep_scale_multiplier: Optional[float] = None,
84
+ causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated
85
+ ):
86
+ super().__init__()
87
+ self.use_tpu_flash_attention = (
88
+ use_tpu_flash_attention # FIXME: push config down to the attention modules
89
+ )
90
+ self.use_linear_projection = use_linear_projection
91
+ self.num_attention_heads = num_attention_heads
92
+ self.attention_head_dim = attention_head_dim
93
+ inner_dim = num_attention_heads * attention_head_dim
94
+ self.inner_dim = inner_dim
95
+ self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True)
96
+ self.positional_embedding_type = positional_embedding_type
97
+ self.positional_embedding_theta = positional_embedding_theta
98
+ self.positional_embedding_max_pos = positional_embedding_max_pos
99
+ self.use_rope = self.positional_embedding_type == "rope"
100
+ self.timestep_scale_multiplier = timestep_scale_multiplier
101
+
102
+ if self.positional_embedding_type == "absolute":
103
+ raise ValueError("Absolute positional embedding is no longer supported")
104
+ elif self.positional_embedding_type == "rope":
105
+ if positional_embedding_theta is None:
106
+ raise ValueError(
107
+ "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined"
108
+ )
109
+ if positional_embedding_max_pos is None:
110
+ raise ValueError(
111
+ "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined"
112
+ )
113
+
114
+ # 3. Define transformers blocks
115
+ self.transformer_blocks = nn.ModuleList(
116
+ [
117
+ BasicTransformerBlock(
118
+ inner_dim,
119
+ num_attention_heads,
120
+ attention_head_dim,
121
+ dropout=dropout,
122
+ cross_attention_dim=cross_attention_dim,
123
+ activation_fn=activation_fn,
124
+ num_embeds_ada_norm=num_embeds_ada_norm,
125
+ attention_bias=attention_bias,
126
+ only_cross_attention=only_cross_attention,
127
+ double_self_attention=double_self_attention,
128
+ upcast_attention=upcast_attention,
129
+ adaptive_norm=adaptive_norm,
130
+ standardization_norm=standardization_norm,
131
+ norm_elementwise_affine=norm_elementwise_affine,
132
+ norm_eps=norm_eps,
133
+ attention_type=attention_type,
134
+ use_tpu_flash_attention=use_tpu_flash_attention,
135
+ qk_norm=qk_norm,
136
+ use_rope=self.use_rope,
137
+ )
138
+ for d in range(num_layers)
139
+ ]
140
+ )
141
+
142
+ # 4. Define output layers
143
+ self.out_channels = in_channels if out_channels is None else out_channels
144
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
145
+ self.scale_shift_table = nn.Parameter(
146
+ torch.randn(2, inner_dim) / inner_dim**0.5
147
+ )
148
+ self.proj_out = nn.Linear(inner_dim, self.out_channels)
149
+
150
+ self.adaln_single = AdaLayerNormSingle(
151
+ inner_dim, use_additional_conditions=False
152
+ )
153
+ if adaptive_norm == "single_scale":
154
+ self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
155
+
156
+ self.caption_projection = None
157
+ if caption_channels is not None:
158
+ self.caption_projection = PixArtAlphaTextProjection(
159
+ in_features=caption_channels, hidden_size=inner_dim
160
+ )
161
+
162
+ self.gradient_checkpointing = False
163
+
164
+ def set_use_tpu_flash_attention(self):
165
+ r"""
166
+ Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
167
+ attention kernel.
168
+ """
169
+ logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
170
+ self.use_tpu_flash_attention = True
171
+ # push config down to the attention modules
172
+ for block in self.transformer_blocks:
173
+ block.set_use_tpu_flash_attention()
174
+
175
+ def create_skip_layer_mask(
176
+ self,
177
+ batch_size: int,
178
+ num_conds: int,
179
+ ptb_index: int,
180
+ skip_block_list: Optional[List[int]] = None,
181
+ ):
182
+ if skip_block_list is None or len(skip_block_list) == 0:
183
+ return None
184
+ num_layers = len(self.transformer_blocks)
185
+ mask = torch.ones(
186
+ (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype
187
+ )
188
+ for block_idx in skip_block_list:
189
+ mask[block_idx, ptb_index::num_conds] = 0
190
+ return mask
191
+
192
+ def _set_gradient_checkpointing(self, module, value=False):
193
+ if hasattr(module, "gradient_checkpointing"):
194
+ module.gradient_checkpointing = value
195
+
196
+ def get_fractional_positions(self, indices_grid):
197
+ fractional_positions = torch.stack(
198
+ [
199
+ indices_grid[:, i] / self.positional_embedding_max_pos[i]
200
+ for i in range(3)
201
+ ],
202
+ dim=-1,
203
+ )
204
+ return fractional_positions
205
+
206
+ def precompute_freqs_cis(self, indices_grid, spacing="exp"):
207
+ dtype = torch.float32 # We need full precision in the freqs_cis computation.
208
+ dim = self.inner_dim
209
+ theta = self.positional_embedding_theta
210
+
211
+ fractional_positions = self.get_fractional_positions(indices_grid)
212
+
213
+ start = 1
214
+ end = theta
215
+ device = fractional_positions.device
216
+ if spacing == "exp":
217
+ indices = theta ** (
218
+ torch.linspace(
219
+ math.log(start, theta),
220
+ math.log(end, theta),
221
+ dim // 6,
222
+ device=device,
223
+ dtype=dtype,
224
+ )
225
+ )
226
+ indices = indices.to(dtype=dtype)
227
+ elif spacing == "exp_2":
228
+ indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim)
229
+ indices = indices.to(dtype=dtype)
230
+ elif spacing == "linear":
231
+ indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
232
+ elif spacing == "sqrt":
233
+ indices = torch.linspace(
234
+ start**2, end**2, dim // 6, device=device, dtype=dtype
235
+ ).sqrt()
236
+
237
+ indices = indices * math.pi / 2
238
+
239
+ if spacing == "exp_2":
240
+ freqs = (
241
+ (indices * fractional_positions.unsqueeze(-1))
242
+ .transpose(-1, -2)
243
+ .flatten(2)
244
+ )
245
+ else:
246
+ freqs = (
247
+ (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
248
+ .transpose(-1, -2)
249
+ .flatten(2)
250
+ )
251
+
252
+ cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
253
+ sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
254
+ if dim % 6 != 0:
255
+ cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
256
+ sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
257
+ cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
258
+ sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
259
+ return cos_freq.to(self.dtype), sin_freq.to(self.dtype)
260
+
261
+ def load_state_dict(
262
+ self,
263
+ state_dict: Dict,
264
+ *args,
265
+ **kwargs,
266
+ ):
267
+ if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]):
268
+ state_dict = {
269
+ key.replace("model.diffusion_model.", ""): value
270
+ for key, value in state_dict.items()
271
+ if key.startswith("model.diffusion_model.")
272
+ }
273
+ super().load_state_dict(state_dict, *args, **kwargs)
274
+
275
+ @classmethod
276
+ def from_pretrained(
277
+ cls,
278
+ pretrained_model_path: Optional[Union[str, os.PathLike]],
279
+ *args,
280
+ **kwargs,
281
+ ):
282
+ pretrained_model_path = Path(pretrained_model_path)
283
+ if pretrained_model_path.is_dir():
284
+ config_path = pretrained_model_path / "transformer" / "config.json"
285
+ with open(config_path, "r") as f:
286
+ config = make_hashable_key(json.load(f))
287
+
288
+ assert config in diffusers_and_ours_config_mapping, (
289
+ "Provided diffusers checkpoint config for transformer is not suppported. "
290
+ "We only support diffusers configs found in Lightricks/LTX-Video."
291
+ )
292
+
293
+ config = diffusers_and_ours_config_mapping[config]
294
+ state_dict = {}
295
+ ckpt_paths = (
296
+ pretrained_model_path
297
+ / "transformer"
298
+ / "diffusion_pytorch_model*.safetensors"
299
+ )
300
+ dict_list = glob.glob(str(ckpt_paths))
301
+ for dict_path in dict_list:
302
+ part_dict = {}
303
+ with safe_open(dict_path, framework="pt", device="cpu") as f:
304
+ for k in f.keys():
305
+ part_dict[k] = f.get_tensor(k)
306
+ state_dict.update(part_dict)
307
+
308
+ for key in list(state_dict.keys()):
309
+ new_key = key
310
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
311
+ new_key = new_key.replace(replace_key, rename_key)
312
+ state_dict[new_key] = state_dict.pop(key)
313
+
314
+ with torch.device("meta"):
315
+ transformer = cls.from_config(config)
316
+ transformer.load_state_dict(state_dict, assign=True, strict=True)
317
+ elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
318
+ ".safetensors"
319
+ ):
320
+ comfy_single_file_state_dict = {}
321
+ with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
322
+ metadata = f.metadata()
323
+ for k in f.keys():
324
+ comfy_single_file_state_dict[k] = f.get_tensor(k)
325
+ configs = json.loads(metadata["config"])
326
+ transformer_config = configs["transformer"]
327
+ with torch.device("meta"):
328
+ transformer = Transformer3DModel.from_config(transformer_config)
329
+ transformer.load_state_dict(comfy_single_file_state_dict, assign=True)
330
+ return transformer
331
+
332
+ def forward(
333
+ self,
334
+ hidden_states: torch.Tensor,
335
+ indices_grid: torch.Tensor,
336
+ encoder_hidden_states: Optional[torch.Tensor] = None,
337
+ timestep: Optional[torch.LongTensor] = None,
338
+ class_labels: Optional[torch.LongTensor] = None,
339
+ cross_attention_kwargs: Dict[str, Any] = None,
340
+ attention_mask: Optional[torch.Tensor] = None,
341
+ encoder_attention_mask: Optional[torch.Tensor] = None,
342
+ skip_layer_mask: Optional[torch.Tensor] = None,
343
+ skip_layer_strategy: Optional[SkipLayerStrategy] = None,
344
+ return_dict: bool = True,
345
+ ):
346
+ if not self.use_tpu_flash_attention:
347
+ if attention_mask is not None and attention_mask.ndim == 2:
348
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
349
+ attention_mask = attention_mask.unsqueeze(1)
350
+
351
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
352
+ encoder_attention_mask = (
353
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
354
+ ) * -10000.0
355
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
356
+
357
+ # 1. Input
358
+ hidden_states = self.patchify_proj(hidden_states)
359
+
360
+ if self.timestep_scale_multiplier:
361
+ timestep = self.timestep_scale_multiplier * timestep
362
+
363
+ freqs_cis = self.precompute_freqs_cis(indices_grid)
364
+
365
+ batch_size = hidden_states.shape[0]
366
+ timestep, embedded_timestep = self.adaln_single(
367
+ timestep.flatten(),
368
+ {"resolution": None, "aspect_ratio": None},
369
+ batch_size=batch_size,
370
+ hidden_dtype=hidden_states.dtype,
371
+ )
372
+ timestep = timestep.view(batch_size, -1, timestep.shape[-1])
373
+ embedded_timestep = embedded_timestep.view(
374
+ batch_size, -1, embedded_timestep.shape[-1]
375
+ )
376
+
377
+ if self.caption_projection is not None:
378
+ batch_size = hidden_states.shape[0]
379
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
380
+ encoder_hidden_states = encoder_hidden_states.view(
381
+ batch_size, -1, hidden_states.shape[-1]
382
+ )
383
+
384
+ # TeaCache Integration
385
+ if hasattr(self, 'enable_teacache') and self.enable_teacache:
386
+ ori_hidden_states = hidden_states.clone()
387
+ temb_ = embedded_timestep.clone()
388
+ inp = self.transformer_blocks[0].norm1(hidden_states.clone())
389
+
390
+ first_block = self.transformer_blocks[0]
391
+ modulated_inp = inp
392
+ if first_block.adaptive_norm in ["single_scale_shift", "single_scale"]:
393
+ num_ada_params = first_block.scale_shift_table.shape[0]
394
+ ada_values = first_block.scale_shift_table[None, None] + temb_.reshape(
395
+ batch_size, temb_.shape[1], num_ada_params, -1
396
+ )
397
+ if first_block.adaptive_norm == "single_scale_shift":
398
+ shift_msa, scale_msa, _, _, _, _ = ada_values.unbind(dim=2)
399
+ modulated_inp = inp * (1 + scale_msa) + shift_msa
400
+ else:
401
+ scale_msa, _, _, _ = ada_values.unbind(dim=2)
402
+ modulated_inp = inp * (1 + scale_msa)
403
+
404
+ should_calc = False
405
+ if self.cnt == 0 or self.cnt == self.num_steps - 1 or self.previous_modulated_input is None:
406
+ should_calc = True
407
+ self.accumulated_rel_l1_distance = 0
408
+ else:
409
+ coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
410
+ rescale_func = np.poly1d(coefficients)
411
+ rel_l1_dist = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
412
+ self.accumulated_rel_l1_distance += rescale_func(rel_l1_dist)
413
+
414
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
415
+ should_calc = False
416
+ else:
417
+ should_calc = True
418
+ self.accumulated_rel_l1_distance = 0
419
+
420
+ self.previous_modulated_input = modulated_inp
421
+ self.cnt += 1
422
+ if self.cnt == self.num_steps:
423
+ self.cnt = 0
424
+
425
+ if not should_calc and self.previous_residual is not None:
426
+ hidden_states = ori_hidden_states + self.previous_residual
427
+ else:
428
+ # Execute original logic if cache is missed
429
+ temp_hidden_states = hidden_states
430
+ for block_idx, block in enumerate(self.transformer_blocks):
431
+ temp_hidden_states = block(
432
+ temp_hidden_states,
433
+ freqs_cis=freqs_cis,
434
+ attention_mask=attention_mask,
435
+ encoder_hidden_states=encoder_hidden_states,
436
+ encoder_attention_mask=encoder_attention_mask,
437
+ timestep=timestep,
438
+ cross_attention_kwargs=cross_attention_kwargs,
439
+ class_labels=class_labels,
440
+ skip_layer_mask=(skip_layer_mask[block_idx] if skip_layer_mask is not None else None),
441
+ skip_layer_strategy=skip_layer_strategy,
442
+ )
443
+ self.previous_residual = temp_hidden_states - ori_hidden_states
444
+ hidden_states = temp_hidden_states
445
+ else:
446
+ # Original path if TeaCache is disabled
447
+ for block_idx, block in enumerate(self.transformer_blocks):
448
+ hidden_states = block(
449
+ hidden_states,
450
+ freqs_cis=freqs_cis,
451
+ attention_mask=attention_mask,
452
+ encoder_hidden_states=encoder_hidden_states,
453
+ encoder_attention_mask=encoder_attention_mask,
454
+ timestep=timestep,
455
+ cross_attention_kwargs=cross_attention_kwargs,
456
+ class_labels=class_labels,
457
+ skip_layer_mask=(skip_layer_mask[block_idx] if skip_layer_mask is not None else None),
458
+ skip_layer_strategy=skip_layer_strategy,
459
+ )
460
+
461
+ # Final modulation and output
462
+ scale_shift_values = (self.scale_shift_table[None, None] + embedded_timestep[:, :, None])
463
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
464
+ hidden_states = self.norm_out(hidden_states)
465
+ hidden_states = hidden_states * (1 + scale) + shift
466
+ hidden_states = self.proj_out(hidden_states)
467
+
468
+ if not return_dict:
469
+ return (hidden_states,)
470
+
471
+ return Transformer3DModelOutput(sample=hidden_states)
472
+ --- END OF MODIFIED FILE app_fluxContext_Ltx/ltx_video/models/transformers/transformer3d.py ---