CARLEXsX commited on
Commit
708d36e
·
verified ·
1 Parent(s): 5faf292

Upload teacache_ltx.py

Browse files
Files changed (1) hide show
  1. teacache_ltx.py +204 -0
teacache_ltx.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import LTXPipeline
3
+ from diffusers.models.transformers import LTXVideoTransformer3DModel
4
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
5
+ from diffusers.utils import export_to_video
6
+ from typing import Any, Dict, Optional, Tuple
7
+ import numpy as np
8
+
9
+
10
+ def teacache_forward(
11
+ self,
12
+ hidden_states: torch.Tensor,
13
+ encoder_hidden_states: torch.Tensor,
14
+ timestep: torch.LongTensor,
15
+ encoder_attention_mask: torch.Tensor,
16
+ num_frames: int,
17
+ height: int,
18
+ width: int,
19
+ rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
20
+ attention_kwargs: Optional[Dict[str, Any]] = None,
21
+ return_dict: bool = True,
22
+ ) -> torch.Tensor:
23
+ if attention_kwargs is not None:
24
+ attention_kwargs = attention_kwargs.copy()
25
+ lora_scale = attention_kwargs.pop("scale", 1.0)
26
+ else:
27
+ lora_scale = 1.0
28
+
29
+ if USE_PEFT_BACKEND:
30
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
31
+ scale_lora_layers(self, lora_scale)
32
+ else:
33
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
34
+ logger.warning(
35
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
36
+ )
37
+
38
+ image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
39
+
40
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
41
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
42
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
43
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
44
+
45
+ batch_size = hidden_states.size(0)
46
+ hidden_states = self.proj_in(hidden_states)
47
+
48
+ temb, embedded_timestep = self.time_embed(
49
+ timestep.flatten(),
50
+ batch_size=batch_size,
51
+ hidden_dtype=hidden_states.dtype,
52
+ )
53
+
54
+ temb = temb.view(batch_size, -1, temb.size(-1))
55
+ embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
56
+
57
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
58
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
59
+
60
+ if self.enable_teacache:
61
+ inp = hidden_states.clone()
62
+ temb_ = temb.clone()
63
+ inp = self.transformer_blocks[0].norm1(inp)
64
+ num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
65
+ ada_values = self.transformer_blocks[0].scale_shift_table[None, None] + temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
66
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
67
+ modulated_inp = inp * (1 + scale_msa) + shift_msa
68
+ if self.cnt == 0 or self.cnt == self.num_steps-1:
69
+ should_calc = True
70
+ self.accumulated_rel_l1_distance = 0
71
+ else:
72
+ coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
73
+ rescale_func = np.poly1d(coefficients)
74
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
75
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
76
+ should_calc = False
77
+ else:
78
+ should_calc = True
79
+ self.accumulated_rel_l1_distance = 0
80
+ self.previous_modulated_input = modulated_inp
81
+ self.cnt += 1
82
+ if self.cnt == self.num_steps:
83
+ self.cnt = 0
84
+
85
+ if self.enable_teacache:
86
+ if not should_calc:
87
+ hidden_states += self.previous_residual
88
+ else:
89
+ ori_hidden_states = hidden_states.clone()
90
+ for block in self.transformer_blocks:
91
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
92
+
93
+ def create_custom_forward(module, return_dict=None):
94
+ def custom_forward(*inputs):
95
+ if return_dict is not None:
96
+ return module(*inputs, return_dict=return_dict)
97
+ else:
98
+ return module(*inputs)
99
+
100
+ return custom_forward
101
+
102
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
103
+ hidden_states = torch.utils.checkpoint.checkpoint(
104
+ create_custom_forward(block),
105
+ hidden_states,
106
+ encoder_hidden_states,
107
+ temb,
108
+ image_rotary_emb,
109
+ encoder_attention_mask,
110
+ **ckpt_kwargs,
111
+ )
112
+ else:
113
+ hidden_states = block(
114
+ hidden_states=hidden_states,
115
+ encoder_hidden_states=encoder_hidden_states,
116
+ temb=temb,
117
+ image_rotary_emb=image_rotary_emb,
118
+ encoder_attention_mask=encoder_attention_mask,
119
+ )
120
+
121
+ scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
122
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
123
+
124
+ hidden_states = self.norm_out(hidden_states)
125
+ hidden_states = hidden_states * (1 + scale) + shift
126
+ self.previous_residual = hidden_states - ori_hidden_states
127
+ else:
128
+ for block in self.transformer_blocks:
129
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
130
+
131
+ def create_custom_forward(module, return_dict=None):
132
+ def custom_forward(*inputs):
133
+ if return_dict is not None:
134
+ return module(*inputs, return_dict=return_dict)
135
+ else:
136
+ return module(*inputs)
137
+
138
+ return custom_forward
139
+
140
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
141
+ hidden_states = torch.utils.checkpoint.checkpoint(
142
+ create_custom_forward(block),
143
+ hidden_states,
144
+ encoder_hidden_states,
145
+ temb,
146
+ image_rotary_emb,
147
+ encoder_attention_mask,
148
+ **ckpt_kwargs,
149
+ )
150
+ else:
151
+ hidden_states = block(
152
+ hidden_states=hidden_states,
153
+ encoder_hidden_states=encoder_hidden_states,
154
+ temb=temb,
155
+ image_rotary_emb=image_rotary_emb,
156
+ encoder_attention_mask=encoder_attention_mask,
157
+ )
158
+
159
+ scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
160
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
161
+
162
+ hidden_states = self.norm_out(hidden_states)
163
+ hidden_states = hidden_states * (1 + scale) + shift
164
+
165
+
166
+ output = self.proj_out(hidden_states)
167
+
168
+ if USE_PEFT_BACKEND:
169
+ # remove `lora_scale` from each PEFT layer
170
+ unscale_lora_layers(self, lora_scale)
171
+
172
+ if not return_dict:
173
+ return (output,)
174
+ return Transformer2DModelOutput(sample=output)
175
+
176
+ LTXVideoTransformer3DModel.forward = teacache_forward
177
+ prompt = "A clear, turquoise river flows through a rocky canyon, cascading over a small waterfall and forming a pool of water at the bottom.The river is the main focus of the scene, with its clear water reflecting the surrounding trees and rocks. The canyon walls are steep and rocky, with some vegetation growing on them. The trees are mostly pine trees, with their green needles contrasting with the brown and gray rocks. The overall tone of the scene is one of peace and tranquility."
178
+ negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
179
+ seed = 42
180
+ num_inference_steps = 50
181
+ pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
182
+
183
+ # TeaCache
184
+ pipe.transformer.__class__.enable_teacache = True
185
+ pipe.transformer.__class__.cnt = 0
186
+ pipe.transformer.__class__.num_steps = num_inference_steps
187
+ pipe.transformer.__class__.rel_l1_thresh = 0.05 # 0.03 for 1.6x speedup, 0.05 for 2.1x speedup
188
+ pipe.transformer.__class__.accumulated_rel_l1_distance = 0
189
+ pipe.transformer.__class__.previous_modulated_input = None
190
+ pipe.transformer.__class__.previous_residual = None
191
+
192
+ pipe.to("cuda")
193
+ video = pipe(
194
+ prompt=prompt,
195
+ negative_prompt=negative_prompt,
196
+ width=768,
197
+ height=512,
198
+ num_frames=161,
199
+ decode_timestep=0.03,
200
+ decode_noise_scale=0.025,
201
+ num_inference_steps=num_inference_steps,
202
+ generator=torch.Generator("cuda").manual_seed(seed)
203
+ ).frames[0]
204
+ export_to_video(video, "teacache_ltx_{}.mp4".format(prompt[:50]), fps=24)