CARLEXsX commited on
Commit
0e05fab
·
verified ·
1 Parent(s): 03b1898

Delete ltx_video/models/transformers/transformer3d.py

Browse files
ltx_video/models/transformers/transformer3d.py DELETED
@@ -1,472 +0,0 @@
1
- #-- START OF MODIFIED FILE app_fluxContext_Ltx/ ---
2
- # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
3
- import math
4
- from dataclasses import dataclass
5
- from typing import Any, Dict, List, Optional, Union
6
- import os
7
- import json
8
- import glob
9
- from pathlib import Path
10
-
11
- import torch
12
- import numpy as np
13
- from diffusers.configuration_utils import ConfigMixin, register_to_config
14
- from diffusers.models.embeddings import PixArtAlphaTextProjection
15
- from diffusers.models.modeling_utils import ModelMixin
16
- from diffusers.models.normalization import AdaLayerNormSingle
17
- from diffusers.utils import BaseOutput, is_torch_version
18
- from diffusers.utils import logging
19
- from torch import nn
20
- from safetensors import safe_open
21
-
22
-
23
- from ltx_video.models.transformers.attention import BasicTransformerBlock
24
- from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
25
-
26
- from ltx_video.utils.diffusers_config_mapping import (
27
- diffusers_and_ours_config_mapping,
28
- make_hash_key,
29
- TRANSFORMER_KEYS_RENAME_DICT,
30
- )
31
-
32
-
33
- logger = logging.get_logger(__name__)
34
-
35
-
36
- @dataclass
37
- class Transformer3DModelOutput(BaseOutput):
38
- """
39
- The output of [`Transformer2DModel`].
40
-
41
- Args:
42
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
43
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
44
- distributions for the unnoised latent pixels.
45
- """
46
-
47
- sample: torch.FloatTensor
48
-
49
-
50
- class Transformer3DModel(ModelMixin, ConfigMixin):
51
- _supports_gradient_checkpointing = True
52
-
53
- @register_to_config
54
- def __init__(
55
- self,
56
- num_attention_heads: int = 16,
57
- attention_head_dim: int = 88,
58
- in_channels: Optional[int] = None,
59
- out_channels: Optional[int] = None,
60
- num_layers: int = 1,
61
- dropout: float = 0.0,
62
- norm_num_groups: int = 32,
63
- cross_attention_dim: Optional[int] = None,
64
- attention_bias: bool = False,
65
- num_vector_embeds: Optional[int] = None,
66
- activation_fn: str = "geglu",
67
- num_embeds_ada_norm: Optional[int] = None,
68
- use_linear_projection: bool = False,
69
- only_cross_attention: bool = False,
70
- double_self_attention: bool = False,
71
- upcast_attention: bool = False,
72
- adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale'
73
- standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
74
- norm_elementwise_affine: bool = True,
75
- norm_eps: float = 1e-5,
76
- attention_type: str = "default",
77
- caption_channels: int = None,
78
- use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention')
79
- qk_norm: Optional[str] = None,
80
- positional_embedding_type: str = "rope",
81
- positional_embedding_theta: Optional[float] = None,
82
- positional_embedding_max_pos: Optional[List[int]] = None,
83
- timestep_scale_multiplier: Optional[float] = None,
84
- causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated
85
- ):
86
- super().__init__()
87
- self.use_tpu_flash_attention = (
88
- use_tpu_flash_attention # FIXME: push config down to the attention modules
89
- )
90
- self.use_linear_projection = use_linear_projection
91
- self.num_attention_heads = num_attention_heads
92
- self.attention_head_dim = attention_head_dim
93
- inner_dim = num_attention_heads * attention_head_dim
94
- self.inner_dim = inner_dim
95
- self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True)
96
- self.positional_embedding_type = positional_embedding_type
97
- self.positional_embedding_theta = positional_embedding_theta
98
- self.positional_embedding_max_pos = positional_embedding_max_pos
99
- self.use_rope = self.positional_embedding_type == "rope"
100
- self.timestep_scale_multiplier = timestep_scale_multiplier
101
-
102
- if self.positional_embedding_type == "absolute":
103
- raise ValueError("Absolute positional embedding is no longer supported")
104
- elif self.positional_embedding_type == "rope":
105
- if positional_embedding_theta is None:
106
- raise ValueError(
107
- "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined"
108
- )
109
- if positional_embedding_max_pos is None:
110
- raise ValueError(
111
- "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined"
112
- )
113
-
114
- # 3. Define transformers blocks
115
- self.transformer_blocks = nn.ModuleList(
116
- [
117
- BasicTransformerBlock(
118
- inner_dim,
119
- num_attention_heads,
120
- attention_head_dim,
121
- dropout=dropout,
122
- cross_attention_dim=cross_attention_dim,
123
- activation_fn=activation_fn,
124
- num_embeds_ada_norm=num_embeds_ada_norm,
125
- attention_bias=attention_bias,
126
- only_cross_attention=only_cross_attention,
127
- double_self_attention=double_self_attention,
128
- upcast_attention=upcast_attention,
129
- adaptive_norm=adaptive_norm,
130
- standardization_norm=standardization_norm,
131
- norm_elementwise_affine=norm_elementwise_affine,
132
- norm_eps=norm_eps,
133
- attention_type=attention_type,
134
- use_tpu_flash_attention=use_tpu_flash_attention,
135
- qk_norm=qk_norm,
136
- use_rope=self.use_rope,
137
- )
138
- for d in range(num_layers)
139
- ]
140
- )
141
-
142
- # 4. Define output layers
143
- self.out_channels = in_channels if out_channels is None else out_channels
144
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
145
- self.scale_shift_table = nn.Parameter(
146
- torch.randn(2, inner_dim) / inner_dim**0.5
147
- )
148
- self.proj_out = nn.Linear(inner_dim, self.out_channels)
149
-
150
- self.adaln_single = AdaLayerNormSingle(
151
- inner_dim, use_additional_conditions=False
152
- )
153
- if adaptive_norm == "single_scale":
154
- self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
155
-
156
- self.caption_projection = None
157
- if caption_channels is not None:
158
- self.caption_projection = PixArtAlphaTextProjection(
159
- in_features=caption_channels, hidden_size=inner_dim
160
- )
161
-
162
- self.gradient_checkpointing = False
163
-
164
- def set_use_tpu_flash_attention(self):
165
- r"""
166
- Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
167
- attention kernel.
168
- """
169
- logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
170
- self.use_tpu_flash_attention = True
171
- # push config down to the attention modules
172
- for block in self.transformer_blocks:
173
- block.set_use_tpu_flash_attention()
174
-
175
- def create_skip_layer_mask(
176
- self,
177
- batch_size: int,
178
- num_conds: int,
179
- ptb_index: int,
180
- skip_block_list: Optional[List[int]] = None,
181
- ):
182
- if skip_block_list is None or len(skip_block_list) == 0:
183
- return None
184
- num_layers = len(self.transformer_blocks)
185
- mask = torch.ones(
186
- (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype
187
- )
188
- for block_idx in skip_block_list:
189
- mask[block_idx, ptb_index::num_conds] = 0
190
- return mask
191
-
192
- def _set_gradient_checkpointing(self, module, value=False):
193
- if hasattr(module, "gradient_checkpointing"):
194
- module.gradient_checkpointing = value
195
-
196
- def get_fractional_positions(self, indices_grid):
197
- fractional_positions = torch.stack(
198
- [
199
- indices_grid[:, i] / self.positional_embedding_max_pos[i]
200
- for i in range(3)
201
- ],
202
- dim=-1,
203
- )
204
- return fractional_positions
205
-
206
- def precompute_freqs_cis(self, indices_grid, spacing="exp"):
207
- dtype = torch.float32 # We need full precision in the freqs_cis computation.
208
- dim = self.inner_dim
209
- theta = self.positional_embedding_theta
210
-
211
- fractional_positions = self.get_fractional_positions(indices_grid)
212
-
213
- start = 1
214
- end = theta
215
- device = fractional_positions.device
216
- if spacing == "exp":
217
- indices = theta ** (
218
- torch.linspace(
219
- math.log(start, theta),
220
- math.log(end, theta),
221
- dim // 6,
222
- device=device,
223
- dtype=dtype,
224
- )
225
- )
226
- indices = indices.to(dtype=dtype)
227
- elif spacing == "exp_2":
228
- indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim)
229
- indices = indices.to(dtype=dtype)
230
- elif spacing == "linear":
231
- indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
232
- elif spacing == "sqrt":
233
- indices = torch.linspace(
234
- start**2, end**2, dim // 6, device=device, dtype=dtype
235
- ).sqrt()
236
-
237
- indices = indices * math.pi / 2
238
-
239
- if spacing == "exp_2":
240
- freqs = (
241
- (indices * fractional_positions.unsqueeze(-1))
242
- .transpose(-1, -2)
243
- .flatten(2)
244
- )
245
- else:
246
- freqs = (
247
- (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
248
- .transpose(-1, -2)
249
- .flatten(2)
250
- )
251
-
252
- cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
253
- sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
254
- if dim % 6 != 0:
255
- cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
256
- sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
257
- cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
258
- sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
259
- return cos_freq.to(self.dtype), sin_freq.to(self.dtype)
260
-
261
- def load_state_dict(
262
- self,
263
- state_dict: Dict,
264
- *args,
265
- **kwargs,
266
- ):
267
- if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]):
268
- state_dict = {
269
- key.replace("model.diffusion_model.", ""): value
270
- for key, value in state_dict.items()
271
- if key.startswith("model.diffusion_model.")
272
- }
273
- super().load_state_dict(state_dict, *args, **kwargs)
274
-
275
- @classmethod
276
- def from_pretrained(
277
- cls,
278
- pretrained_model_path: Optional[Union[str, os.PathLike]],
279
- *args,
280
- **kwargs,
281
- ):
282
- pretrained_model_path = Path(pretrained_model_path)
283
- if pretrained_model_path.is_dir():
284
- config_path = pretrained_model_path / "transformer" / "config.json"
285
- with open(config_path, "r") as f:
286
- config = make_hashable_key(json.load(f))
287
-
288
- assert config in diffusers_and_ours_config_mapping, (
289
- "Provided diffusers checkpoint config for transformer is not suppported. "
290
- "We only support diffusers configs found in Lightricks/LTX-Video."
291
- )
292
-
293
- config = diffusers_and_ours_config_mapping[config]
294
- state_dict = {}
295
- ckpt_paths = (
296
- pretrained_model_path
297
- / "transformer"
298
- / "diffusion_pytorch_model*.safetensors"
299
- )
300
- dict_list = glob.glob(str(ckpt_paths))
301
- for dict_path in dict_list:
302
- part_dict = {}
303
- with safe_open(dict_path, framework="pt", device="cpu") as f:
304
- for k in f.keys():
305
- part_dict[k] = f.get_tensor(k)
306
- state_dict.update(part_dict)
307
-
308
- for key in list(state_dict.keys()):
309
- new_key = key
310
- for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
311
- new_key = new_key.replace(replace_key, rename_key)
312
- state_dict[new_key] = state_dict.pop(key)
313
-
314
- with torch.device("meta"):
315
- transformer = cls.from_config(config)
316
- transformer.load_state_dict(state_dict, assign=True, strict=True)
317
- elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
318
- ".safetensors"
319
- ):
320
- comfy_single_file_state_dict = {}
321
- with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
322
- metadata = f.metadata()
323
- for k in f.keys():
324
- comfy_single_file_state_dict[k] = f.get_tensor(k)
325
- configs = json.loads(metadata["config"])
326
- transformer_config = configs["transformer"]
327
- with torch.device("meta"):
328
- transformer = Transformer3DModel.from_config(transformer_config)
329
- transformer.load_state_dict(comfy_single_file_state_dict, assign=True)
330
- return transformer
331
-
332
- def forward(
333
- self,
334
- hidden_states: torch.Tensor,
335
- indices_grid: torch.Tensor,
336
- encoder_hidden_states: Optional[torch.Tensor] = None,
337
- timestep: Optional[torch.LongTensor] = None,
338
- class_labels: Optional[torch.LongTensor] = None,
339
- cross_attention_kwargs: Dict[str, Any] = None,
340
- attention_mask: Optional[torch.Tensor] = None,
341
- encoder_attention_mask: Optional[torch.Tensor] = None,
342
- skip_layer_mask: Optional[torch.Tensor] = None,
343
- skip_layer_strategy: Optional[SkipLayerStrategy] = None,
344
- return_dict: bool = True,
345
- ):
346
- if not self.use_tpu_flash_attention:
347
- if attention_mask is not None and attention_mask.ndim == 2:
348
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
349
- attention_mask = attention_mask.unsqueeze(1)
350
-
351
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
352
- encoder_attention_mask = (
353
- 1 - encoder_attention_mask.to(hidden_states.dtype)
354
- ) * -10000.0
355
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
356
-
357
- # 1. Input
358
- hidden_states = self.patchify_proj(hidden_states)
359
-
360
- if self.timestep_scale_multiplier:
361
- timestep = self.timestep_scale_multiplier * timestep
362
-
363
- freqs_cis = self.precompute_freqs_cis(indices_grid)
364
-
365
- batch_size = hidden_states.shape[0]
366
- timestep, embedded_timestep = self.adaln_single(
367
- timestep.flatten(),
368
- {"resolution": None, "aspect_ratio": None},
369
- batch_size=batch_size,
370
- hidden_dtype=hidden_states.dtype,
371
- )
372
- timestep = timestep.view(batch_size, -1, timestep.shape[-1])
373
- embedded_timestep = embedded_timestep.view(
374
- batch_size, -1, embedded_timestep.shape[-1]
375
- )
376
-
377
- if self.caption_projection is not None:
378
- batch_size = hidden_states.shape[0]
379
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
380
- encoder_hidden_states = encoder_hidden_states.view(
381
- batch_size, -1, hidden_states.shape[-1]
382
- )
383
-
384
- # TeaCache Integration
385
- if hasattr(self, 'enable_teacache') and self.enable_teacache:
386
- ori_hidden_states = hidden_states.clone()
387
- temb_ = embedded_timestep.clone()
388
- inp = self.transformer_blocks[0].norm1(hidden_states.clone())
389
-
390
- first_block = self.transformer_blocks[0]
391
- modulated_inp = inp
392
- if first_block.adaptive_norm in ["single_scale_shift", "single_scale"]:
393
- num_ada_params = first_block.scale_shift_table.shape[0]
394
- ada_values = first_block.scale_shift_table[None, None] + temb_.reshape(
395
- batch_size, temb_.shape[1], num_ada_params, -1
396
- )
397
- if first_block.adaptive_norm == "single_scale_shift":
398
- shift_msa, scale_msa, _, _, _, _ = ada_values.unbind(dim=2)
399
- modulated_inp = inp * (1 + scale_msa) + shift_msa
400
- else:
401
- scale_msa, _, _, _ = ada_values.unbind(dim=2)
402
- modulated_inp = inp * (1 + scale_msa)
403
-
404
- should_calc = False
405
- if self.cnt == 0 or self.cnt == self.num_steps - 1 or self.previous_modulated_input is None:
406
- should_calc = True
407
- self.accumulated_rel_l1_distance = 0
408
- else:
409
- coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
410
- rescale_func = np.poly1d(coefficients)
411
- rel_l1_dist = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
412
- self.accumulated_rel_l1_distance += rescale_func(rel_l1_dist)
413
-
414
- if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
415
- should_calc = False
416
- else:
417
- should_calc = True
418
- self.accumulated_rel_l1_distance = 0
419
-
420
- self.previous_modulated_input = modulated_inp
421
- self.cnt += 1
422
- if self.cnt == self.num_steps:
423
- self.cnt = 0
424
-
425
- if not should_calc and self.previous_residual is not None:
426
- hidden_states = ori_hidden_states + self.previous_residual
427
- else:
428
- # Execute original logic if cache is missed
429
- temp_hidden_states = hidden_states
430
- for block_idx, block in enumerate(self.transformer_blocks):
431
- temp_hidden_states = block(
432
- temp_hidden_states,
433
- freqs_cis=freqs_cis,
434
- attention_mask=attention_mask,
435
- encoder_hidden_states=encoder_hidden_states,
436
- encoder_attention_mask=encoder_attention_mask,
437
- timestep=timestep,
438
- cross_attention_kwargs=cross_attention_kwargs,
439
- class_labels=class_labels,
440
- skip_layer_mask=(skip_layer_mask[block_idx] if skip_layer_mask is not None else None),
441
- skip_layer_strategy=skip_layer_strategy,
442
- )
443
- self.previous_residual = temp_hidden_states - ori_hidden_states
444
- hidden_states = temp_hidden_states
445
- else:
446
- # Original path if TeaCache is disabled
447
- for block_idx, block in enumerate(self.transformer_blocks):
448
- hidden_states = block(
449
- hidden_states,
450
- freqs_cis=freqs_cis,
451
- attention_mask=attention_mask,
452
- encoder_hidden_states=encoder_hidden_states,
453
- encoder_attention_mask=encoder_attention_mask,
454
- timestep=timestep,
455
- cross_attention_kwargs=cross_attention_kwargs,
456
- class_labels=class_labels,
457
- skip_layer_mask=(skip_layer_mask[block_idx] if skip_layer_mask is not None else None),
458
- skip_layer_strategy=skip_layer_strategy,
459
- )
460
-
461
- # Final modulation and output
462
- scale_shift_values = (self.scale_shift_table[None, None] + embedded_timestep[:, :, None])
463
- shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
464
- hidden_states = self.norm_out(hidden_states)
465
- hidden_states = hidden_states * (1 + scale) + shift
466
- hidden_states = self.proj_out(hidden_states)
467
-
468
- if not return_dict:
469
- return (hidden_states,)
470
-
471
- return Transformer3DModelOutput(sample=hidden_states)
472
- #-- END OF MODIFIED FILE app_fluxContext_Ltx/ltx_video/models/transformers/transformer3d.py ---