CARLEXsX commited on
Commit
5faf292
·
verified ·
1 Parent(s): 0f55e23

Delete teacache_helpers.py

Browse files
Files changed (1) hide show
  1. teacache_helpers.py +0 -153
teacache_helpers.py DELETED
@@ -1,153 +0,0 @@
1
- # teacache_helpers.py
2
-
3
- import torch
4
- import numpy as np
5
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
6
- from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
7
- from typing import Any, Dict, Optional, Tuple
8
-
9
- logger = logging.get_logger(__name__)
10
-
11
- def teacache_forward(
12
- self,
13
- hidden_states: torch.Tensor,
14
- indices_grid: torch.Tensor,
15
- encoder_hidden_states: Optional[torch.Tensor] = None,
16
- timestep: Optional[torch.LongTensor] = None,
17
- class_labels: Optional[torch.LongTensor] = None,
18
- cross_attention_kwargs: Dict[str, Any] = None,
19
- attention_mask: Optional[torch.Tensor] = None,
20
- encoder_attention_mask: Optional[torch.Tensor] = None,
21
- skip_layer_mask: Optional[torch.Tensor] = None,
22
- skip_layer_strategy: Optional[Any] = None, # Usando Any para compatibilidade
23
- return_dict: bool = True,
24
- ) -> torch.Tensor:
25
-
26
- # Lógica de controle do TeaCache
27
- if not hasattr(self, 'enable_teacache') or not self.enable_teacache:
28
- # Se TeaCache estiver desabilitado, chama a função forward original
29
- # (Para simplicidade aqui, replicamos a lógica padrão. Em um cenário real, você poderia
30
- # ter salvo a função original antes de fazer o patch).
31
- # Esta parte replica a lógica de 'ltx_video/models/transformers/transformer3d.py'
32
- if attention_mask is not None and attention_mask.ndim == 2:
33
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
34
- attention_mask = attention_mask.unsqueeze(1)
35
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
36
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
37
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
38
-
39
- hidden_states = self.patchify_proj(hidden_states)
40
- if self.timestep_scale_multiplier:
41
- timestep = self.timestep_scale_multiplier * timestep
42
-
43
- freqs_cis = self.precompute_freqs_cis(indices_grid)
44
- batch_size = hidden_states.shape[0]
45
- timestep, embedded_timestep = self.adaln_single(
46
- timestep.flatten(), {"resolution": None, "aspect_ratio": None},
47
- batch_size=batch_size, hidden_dtype=hidden_states.dtype,
48
- )
49
- timestep = timestep.view(batch_size, -1, timestep.shape[-1])
50
- embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
51
-
52
- if self.caption_projection is not None:
53
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
54
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
55
-
56
- for block_idx, block in enumerate(self.transformer_blocks):
57
- hidden_states = block(
58
- hidden_states, freqs_cis=freqs_cis, attention_mask=attention_mask,
59
- encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask,
60
- timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels,
61
- skip_layer_mask=(skip_layer_mask[block_idx] if skip_layer_mask is not None else None),
62
- skip_layer_strategy=skip_layer_strategy,
63
- )
64
-
65
- scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
66
- shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
67
- hidden_states = self.norm_out(hidden_states)
68
- hidden_states = hidden_states * (1 + scale) + shift
69
- hidden_states = self.proj_out(hidden_states)
70
-
71
- if not return_dict: return (hidden_states,)
72
- return Transformer2DModelOutput(sample=hidden_states)
73
-
74
- # Lógica principal do TeaCache
75
- lora_scale = 1.0
76
-
77
- # Preparação dos embeddings e máscaras
78
- image_rotary_emb = self.precompute_freqs_cis(indices_grid)
79
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
80
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
81
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
82
-
83
- batch_size = hidden_states.size(0)
84
- hidden_states = self.patchify_proj(hidden_states)
85
-
86
- temb, embedded_timestep = self.adaln_single(
87
- timestep.flatten(), {"resolution": None, "aspect_ratio": None},
88
- batch_size=batch_size, hidden_dtype=hidden_states.dtype,
89
- )
90
- temb = temb.view(batch_size, -1, temb.size(-1))
91
- embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
92
-
93
- if self.caption_projection is not None:
94
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
95
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
96
-
97
- # Lógica de decisão do TeaCache (calcular ou reusar)
98
- inp = hidden_states.clone()
99
- temb_ = temb.clone()
100
- inp = self.transformer_blocks[0].norm1(inp)
101
- num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
102
- ada_values = self.transformer_blocks[0].scale_shift_table[None, None] + temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
103
-
104
- if self.transformer_blocks[0].adaptive_norm == "single_scale_shift":
105
- shift_msa, scale_msa, _, _, _, _ = ada_values.unbind(dim=2)
106
- modulated_inp = inp * (1 + scale_msa) + shift_msa
107
- else: # single_scale
108
- scale_msa, _, _, _ = ada_values.unbind(dim=2)
109
- modulated_inp = inp * (1 + scale_msa)
110
-
111
- if self.cnt == 0 or self.cnt == self.num_steps - 1:
112
- should_calc = True
113
- self.accumulated_rel_l1_distance = 0
114
- else:
115
- coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
116
- rescale_func = np.poly1d(coefficients)
117
- self.accumulated_rel_l1_distance += rescale_func(((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
118
- if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
119
- should_calc = False
120
- else:
121
- should_calc = True
122
- self.accumulated_rel_l1_distance = 0
123
-
124
- self.previous_modulated_input = modulated_inp
125
- self.cnt += 1
126
- if self.cnt == self.num_steps:
127
- self.cnt = 0
128
-
129
- # Execução do transformer
130
- if not should_calc:
131
- hidden_states += self.previous_residual
132
- else:
133
- ori_hidden_states = hidden_states.clone()
134
- for block_idx, block in enumerate(self.transformer_blocks):
135
- hidden_states = block(
136
- hidden_states=hidden_states,
137
- freqs_cis=image_rotary_emb,
138
- encoder_hidden_states=encoder_hidden_states,
139
- timestep=temb,
140
- encoder_attention_mask=encoder_attention_mask
141
- )
142
-
143
- scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
144
- shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
145
- hidden_states = self.norm_out(hidden_states)
146
- hidden_states = hidden_states * (1 + scale) + shift
147
- self.previous_residual = hidden_states - ori_hidden_states
148
-
149
- output = self.proj_out(hidden_states)
150
-
151
- if not return_dict:
152
- return (output,)
153
- return Transformer2DModelOutput(sample=output)