README.md CHANGED
@@ -8,6 +8,7 @@ tags:
8
  - text-to-video
9
  - video-to-video
10
  - realtime
 
11
  ---
12
  Krea Realtime 14B is distilled from the [Wan 2.1 14B text-to-video model](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) using Self-Forcing, a technique for converting regular video diffusion models into autoregressive models. It achieves a text-to-video inference speed of **11fps** using 4 inference steps on a single NVIDIA B200 GPU. For more details on our training methodology and sampling innovations, refer to our [technical blog post](https://www.krea.ai/blog/krea-realtime-14b).
13
 
@@ -97,5 +98,80 @@ Krea realtime allows users to generate videos in a streaming fashion with ~1s ti
97
  </table>
98
  </div>
99
 
 
100
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  - text-to-video
9
  - video-to-video
10
  - realtime
11
+ library_name: diffusers
12
  ---
13
  Krea Realtime 14B is distilled from the [Wan 2.1 14B text-to-video model](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) using Self-Forcing, a technique for converting regular video diffusion models into autoregressive models. It achieves a text-to-video inference speed of **11fps** using 4 inference steps on a single NVIDIA B200 GPU. For more details on our training methodology and sampling innovations, refer to our [technical blog post](https://www.krea.ai/blog/krea-realtime-14b).
14
 
 
98
  </table>
99
  </div>
100
 
101
+ # Use it with our inference code
102
 
103
+ Set up
104
+ ```bash
105
+ sudo apt install ffmpeg # install if you haven't already
106
+ git clone https://github.com/krea-ai/realtime-video
107
+ cd realtime-video
108
+ uv sync
109
+ uv pip install flash_attn --no-build-isolation
110
+ huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B
111
+ huggingface-cli download krea/krea-realtime-video krea-realtime-video-14b.safetensors --local-dir-use-symlinks False --local-dir checkpoints/krea-realtime-video-14b.safetensors
112
+ ```
113
 
114
+ Run
115
+ ```bash
116
+ export MODEL_FOLDER=Wan-AI
117
+ export CUDA_VISIBLE_DEVICES=0 # pick the GPU you want to serve on
118
+ export DO_COMPILE=true
119
+
120
+ uvicorn release_server:app --host 0.0.0.0 --port 8000
121
+ ```
122
+
123
+ And use the web app at http://localhost:8000/ in your browser
124
+ (for more advanced use-cases and custom pipeline check out our GitHub repository: https://github.com/krea-ai/realtime-video)
125
+
126
+ # Use it with 🧨 diffusers
127
+
128
+ Krea Realtime 14B can be used with the `diffusers` library utilizing the new Modular Diffusers structure (for now supporting text-to-video, video-to-video coming soon)
129
+
130
+ ```bash
131
+ # Install diffusers from main
132
+ pip install git+github.com/huggingface/diffusers.git
133
+ ```
134
+
135
+ ```py
136
+ import torch
137
+ from collections import deque
138
+ from diffusers.utils import export_to_video
139
+ from diffusers import ModularPipelineBlocks
140
+ from diffusers.modular_pipelines import PipelineState, WanModularPipeline
141
+
142
+ repo_id = "krea/krea-realtime-video"
143
+ blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
144
+ pipe = WanModularPipeline(blocks, repo_id)
145
+
146
+ pipe.load_components(
147
+ trust_remote_code=True,
148
+ device_map="cuda",
149
+ torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
150
+ )
151
+
152
+ num_frames_per_block = 3
153
+ num_blocks = 9
154
+
155
+ frames = []
156
+ state = PipelineState()
157
+ state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len))
158
+
159
+ prompt = ["a cat sitting on a boat"]
160
+
161
+ for block in pipe.transformer.blocks:
162
+ block.self_attn.fuse_projections()
163
+
164
+ for block_idx in range(num_blocks):
165
+ state = pipe(
166
+ state,
167
+ prompt=prompt,
168
+ num_inference_steps=6,
169
+ num_blocks=num_blocks,
170
+ num_frames_per_block=num_frames_per_block,
171
+ block_idx=block_idx,
172
+ generator=torch.Generator("cuda").manual_seed(42),
173
+ )
174
+ frames.extend(state.values["videos"][0])
175
+
176
+ export_to_video(frames, "output.mp4", fps=16)
177
+ ```
__init__.py ADDED
File without changes
before_denoise.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import List, Optional, Union, Dict
17
+
18
+ import torch
19
+
20
+ from diffusers import AutoencoderKLWan
21
+ from diffusers.schedulers import UniPCMultistepScheduler
22
+ from diffusers.utils import logging
23
+ from diffusers.utils.torch_utils import randn_tensor
24
+ from diffusers.modular_pipelines import (
25
+ ModularPipeline,
26
+ ModularPipelineBlocks,
27
+ SequentialPipelineBlocks,
28
+ PipelineState,
29
+ )
30
+ from diffusers.modular_pipelines.modular_pipeline_utils import (
31
+ ComponentSpec,
32
+ ConfigSpec,
33
+ InputParam,
34
+ OutputParam,
35
+ )
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
41
+ def retrieve_timesteps(
42
+ scheduler,
43
+ num_inference_steps: Optional[int] = None,
44
+ device: Optional[Union[str, torch.device]] = None,
45
+ timesteps: Optional[List[int]] = None,
46
+ sigmas: Optional[List[float]] = None,
47
+ **kwargs,
48
+ ):
49
+ r"""
50
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
51
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
52
+
53
+ Args:
54
+ scheduler (`SchedulerMixin`):
55
+ The scheduler to get timesteps from.
56
+ num_inference_steps (`int`):
57
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
58
+ must be `None`.
59
+ device (`str` or `torch.device`, *optional*):
60
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
61
+ timesteps (`List[int]`, *optional*):
62
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
63
+ `num_inference_steps` and `sigmas` must be `None`.
64
+ sigmas (`List[float]`, *optional*):
65
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
66
+ `num_inference_steps` and `timesteps` must be `None`.
67
+
68
+ Returns:
69
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
70
+ second element is the number of inference steps.
71
+ """
72
+ if timesteps is not None and sigmas is not None:
73
+ raise ValueError(
74
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
75
+ )
76
+ if timesteps is not None:
77
+ accepts_timesteps = "timesteps" in set(
78
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
79
+ )
80
+ if not accepts_timesteps:
81
+ raise ValueError(
82
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
83
+ f" timestep schedules. Please check whether you are using the correct scheduler."
84
+ )
85
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
86
+ timesteps = scheduler.timesteps
87
+ num_inference_steps = len(timesteps)
88
+ elif sigmas is not None:
89
+ accept_sigmas = "sigmas" in set(
90
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
91
+ )
92
+ if not accept_sigmas:
93
+ raise ValueError(
94
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
95
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
96
+ )
97
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
98
+ timesteps = scheduler.timesteps
99
+ num_inference_steps = len(timesteps)
100
+ else:
101
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
102
+ timesteps = scheduler.timesteps
103
+ return timesteps, num_inference_steps
104
+
105
+
106
+ def retrieve_latents(
107
+ encoder_output: torch.Tensor,
108
+ generator: Optional[torch.Generator] = None,
109
+ sample_mode: str = "sample",
110
+ ):
111
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
112
+ return encoder_output.latent_dist.sample(generator)
113
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
114
+ return encoder_output.latent_dist.mode()
115
+ elif hasattr(encoder_output, "latents"):
116
+ return encoder_output.latents
117
+ else:
118
+ raise AttributeError("Could not access latents of provided encoder_output")
119
+
120
+
121
+ def _initialize_kv_cache(
122
+ components: ModularPipeline,
123
+ kv_cache_existing: Optional[List[Dict]],
124
+ batch_size: int,
125
+ dtype: torch.dtype,
126
+ device: torch.device,
127
+ local_attn_size: int,
128
+ frame_seq_length: int,
129
+ ):
130
+ """
131
+ Initialize a Per-GPU KV cache for the Wan model.
132
+ Mirrors causal_inference.py:279-313
133
+ """
134
+ kv_cache = []
135
+
136
+ # Calculate KV cache size
137
+ if local_attn_size != -1:
138
+ # Use the local attention size to compute the KV cache size
139
+ kv_cache_size = local_attn_size * frame_seq_length
140
+ else:
141
+ # Use the default KV cache size
142
+ kv_cache_size = 32760
143
+
144
+ # Get transformer config
145
+ num_transformer_blocks = len(components.transformer.blocks)
146
+ num_heads = components.transformer.config.num_heads
147
+ dim = components.transformer.config.dim
148
+ k_shape = [batch_size, kv_cache_size, num_heads, dim // num_heads]
149
+ v_shape = [batch_size, kv_cache_size, num_heads, dim // num_heads]
150
+
151
+ # Check if we can reuse existing cache
152
+ if (
153
+ kv_cache_existing
154
+ and len(kv_cache_existing) > 0
155
+ and list(kv_cache_existing[0]["k"].shape) == k_shape
156
+ and list(kv_cache_existing[0]["v"].shape) == v_shape
157
+ ):
158
+ for i in range(num_transformer_blocks):
159
+ kv_cache_existing[i]["k"].zero_()
160
+ kv_cache_existing[i]["v"].zero_()
161
+ kv_cache_existing[i]["global_end_index"] = 0
162
+ kv_cache_existing[i]["local_end_index"] = 0
163
+ return kv_cache_existing
164
+ else:
165
+ # Create new cache
166
+ for _ in range(num_transformer_blocks):
167
+ kv_cache.append(
168
+ {
169
+ "k": torch.zeros(k_shape, dtype=dtype, device=device).contiguous(),
170
+ "v": torch.zeros(v_shape, dtype=dtype, device=device).contiguous(),
171
+ "global_end_index": 0,
172
+ "local_end_index": 0,
173
+ }
174
+ )
175
+ return kv_cache
176
+
177
+
178
+ def _initialize_crossattn_cache(
179
+ components: ModularPipeline,
180
+ crossattn_cache_existing: Optional[List[Dict]],
181
+ batch_size: int,
182
+ dtype: torch.dtype,
183
+ device: torch.device,
184
+ ):
185
+ """
186
+ Initialize a Per-GPU cross-attention cache for the Wan model.
187
+ Mirrors causal_inference.py:315-338
188
+ """
189
+ crossattn_cache = []
190
+
191
+ # Get transformer config
192
+ num_transformer_blocks = len(components.transformer.blocks)
193
+ num_heads = components.transformer.config.num_heads
194
+ dim = components.transformer.config.dim
195
+ k_shape = [batch_size, 512, num_heads, dim // num_heads]
196
+ v_shape = [batch_size, 512, num_heads, dim // num_heads]
197
+
198
+ # Check if we can reuse existing cache
199
+ if (
200
+ crossattn_cache_existing
201
+ and len(crossattn_cache_existing) > 0
202
+ and list(crossattn_cache_existing[0]["k"].shape) == k_shape
203
+ and list(crossattn_cache_existing[0]["v"].shape) == v_shape
204
+ ):
205
+ for i in range(num_transformer_blocks):
206
+ crossattn_cache_existing[i]["k"].zero_()
207
+ crossattn_cache_existing[i]["v"].zero_()
208
+ crossattn_cache_existing[i]["is_init"] = False
209
+ return crossattn_cache_existing
210
+ else:
211
+ # Create new cache
212
+ for _ in range(num_transformer_blocks):
213
+ crossattn_cache.append(
214
+ {
215
+ "k": torch.zeros(k_shape, dtype=dtype, device=device).contiguous(),
216
+ "v": torch.zeros(v_shape, dtype=dtype, device=device).contiguous(),
217
+ "is_init": False,
218
+ }
219
+ )
220
+ return crossattn_cache
221
+
222
+
223
+ class WanInputStep(ModularPipelineBlocks):
224
+ model_name = "WanRT"
225
+
226
+ @property
227
+ def description(self) -> str:
228
+ return (
229
+ "Input processing step that:\n"
230
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
231
+ " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n"
232
+ "All input tensors are expected to have either batch_size=1 or match the batch_size\n"
233
+ "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
234
+ "have a final batch_size of batch_size * num_videos_per_prompt."
235
+ )
236
+
237
+ @property
238
+ def inputs(self) -> List[InputParam]:
239
+ return [
240
+ InputParam("num_videos_per_prompt", default=1),
241
+ InputParam(
242
+ "prompt_embeds",
243
+ required=True,
244
+ type_hint=torch.Tensor,
245
+ description="Pre-generated text embeddings. Can be generated from text_encoder step.",
246
+ ),
247
+ InputParam(
248
+ "negative_prompt_embeds",
249
+ type_hint=torch.Tensor,
250
+ description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
251
+ ),
252
+ ]
253
+
254
+ @property
255
+ def intermediate_outputs(self) -> List[str]:
256
+ return [
257
+ OutputParam(
258
+ "batch_size",
259
+ type_hint=int,
260
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt",
261
+ ),
262
+ OutputParam(
263
+ "dtype",
264
+ type_hint=torch.dtype,
265
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
266
+ ),
267
+ OutputParam(
268
+ "prompt_embeds",
269
+ type_hint=torch.Tensor,
270
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
271
+ description="text embeddings used to guide the image generation",
272
+ ),
273
+ OutputParam(
274
+ "negative_prompt_embeds",
275
+ type_hint=torch.Tensor,
276
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
277
+ description="negative text embeddings used to guide the image generation",
278
+ ),
279
+ ]
280
+
281
+ def check_inputs(self, components, block_state):
282
+ if (
283
+ block_state.prompt_embeds is not None
284
+ and block_state.negative_prompt_embeds is not None
285
+ ):
286
+ if (
287
+ block_state.prompt_embeds.shape
288
+ != block_state.negative_prompt_embeds.shape
289
+ ):
290
+ raise ValueError(
291
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
292
+ f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`"
293
+ f" {block_state.negative_prompt_embeds.shape}."
294
+ )
295
+
296
+ @torch.no_grad()
297
+ def __call__(
298
+ self, components: ModularPipeline, state: PipelineState
299
+ ) -> PipelineState:
300
+ block_state = self.get_block_state(state)
301
+ self.check_inputs(components, block_state)
302
+
303
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
304
+ block_state.dtype = block_state.prompt_embeds.dtype
305
+
306
+ _, seq_len, _ = block_state.prompt_embeds.shape
307
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(
308
+ 1, block_state.num_videos_per_prompt, 1
309
+ )
310
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
311
+ block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
312
+ )
313
+
314
+ if block_state.negative_prompt_embeds is not None:
315
+ _, seq_len, _ = block_state.negative_prompt_embeds.shape
316
+ block_state.negative_prompt_embeds = (
317
+ block_state.negative_prompt_embeds.repeat(
318
+ 1, block_state.num_videos_per_prompt, 1
319
+ )
320
+ )
321
+ block_state.negative_prompt_embeds = (
322
+ block_state.negative_prompt_embeds.view(
323
+ block_state.batch_size * block_state.num_videos_per_prompt,
324
+ seq_len,
325
+ -1,
326
+ )
327
+ )
328
+
329
+ self.set_block_state(state, block_state)
330
+
331
+ return components, state
332
+
333
+
334
+ class WanRTStreamingSetTimestepsStep(ModularPipelineBlocks):
335
+ model_name = "WanRT"
336
+
337
+ @property
338
+ def expected_components(self) -> List[ComponentSpec]:
339
+ return [
340
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
341
+ ]
342
+
343
+ @property
344
+ def description(self) -> str:
345
+ return "Step that sets the scheduler's timesteps for inference"
346
+
347
+ @property
348
+ def inputs(self) -> List[InputParam]:
349
+ return [
350
+ InputParam("num_inference_steps", default=4),
351
+ InputParam("timesteps"),
352
+ InputParam("sigmas"),
353
+ ]
354
+
355
+ @property
356
+ def intermediate_outputs(self) -> List[OutputParam]:
357
+ return [
358
+ OutputParam(
359
+ "timesteps",
360
+ type_hint=torch.Tensor,
361
+ description="The timesteps to use for inference",
362
+ ),
363
+ OutputParam(
364
+ "all_timesteps",
365
+ type_hint=torch.Tensor,
366
+ description="The timesteps to use for inference",
367
+ ),
368
+ OutputParam(
369
+ "num_inference_steps",
370
+ type_hint=int,
371
+ description="The number of denoising steps to perform at inference time",
372
+ ),
373
+ ]
374
+
375
+ @torch.no_grad()
376
+ def __call__(
377
+ self, components: ModularPipeline, state: PipelineState
378
+ ) -> PipelineState:
379
+ block_state = self.get_block_state(state)
380
+ block_state.device = components._execution_device
381
+
382
+ shift = 5.0
383
+ sigmas = torch.linspace(1.0, 0.0, 1001)[:-1]
384
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
385
+
386
+ timesteps = sigmas.to(components.transformer.device) * 1000.0
387
+ zero_padded_timesteps = torch.cat(
388
+ [
389
+ timesteps,
390
+ torch.tensor([0], device=components.transformer.device),
391
+ ]
392
+ )
393
+ denoising_steps = torch.linspace(
394
+ 1000, 0, block_state.num_inference_steps, dtype=torch.float32
395
+ ).to(torch.long)
396
+
397
+ block_state.timesteps = zero_padded_timesteps[1000 - denoising_steps]
398
+ block_state.all_timesteps = timesteps
399
+ block_state.sigmas = sigmas
400
+
401
+ self.set_block_state(state, block_state)
402
+
403
+ return components, state
404
+
405
+
406
+ class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks):
407
+ model_name = "WanRT"
408
+
409
+ @property
410
+ def expected_components(self) -> List[ComponentSpec]:
411
+ return [
412
+ ComponentSpec("vae", AutoencoderKLWan),
413
+ ]
414
+
415
+ @property
416
+ def expected_configs(self) -> List[ConfigSpec]:
417
+ return [ConfigSpec("num_frames_per_block", 3)]
418
+
419
+ @property
420
+ def description(self) -> str:
421
+ return "Prepare latents step that prepares the latents for the text-to-video generation process"
422
+
423
+ @property
424
+ def inputs(self) -> List[InputParam]:
425
+ return [
426
+ InputParam("height", type_hint=int),
427
+ InputParam("width", type_hint=int),
428
+ InputParam("num_blocks", type_hint=int),
429
+ InputParam("num_frames_per_block", type_hint=int),
430
+ InputParam("latents", type_hint=Optional[torch.Tensor]),
431
+ InputParam("init_latents", type_hint=Optional[torch.Tensor]),
432
+ InputParam("final_latents", type_hint=Optional[torch.Tensor]),
433
+ InputParam("num_videos_per_prompt", type_hint=int, default=1),
434
+ InputParam("generator"),
435
+ InputParam(
436
+ "dtype",
437
+ type_hint=torch.dtype,
438
+ description="The dtype of the model inputs",
439
+ ),
440
+ ]
441
+
442
+ @property
443
+ def intermediate_outputs(self) -> List[OutputParam]:
444
+ return [
445
+ OutputParam(
446
+ "latents",
447
+ type_hint=torch.Tensor,
448
+ description="The initial latents to use for the denoising process",
449
+ ),
450
+ OutputParam(
451
+ "init_latents",
452
+ type_hint=torch.Tensor,
453
+ description="The initial latents to use for the denoising process",
454
+ ),
455
+ OutputParam(
456
+ "final_latents",
457
+ type_hint=torch.Tensor,
458
+ ),
459
+ ]
460
+
461
+ @staticmethod
462
+ def check_inputs(components, block_state):
463
+ if (
464
+ block_state.height is not None
465
+ and block_state.height % components.vae_scale_factor_spatial != 0
466
+ ) or (
467
+ block_state.width is not None
468
+ and block_state.width % components.vae_scale_factor_spatial != 0
469
+ ):
470
+ raise ValueError(
471
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
472
+ )
473
+
474
+ @staticmethod
475
+ def prepare_latents(
476
+ components,
477
+ batch_size: int,
478
+ num_channels_latents: int = 16,
479
+ height: int = 352,
480
+ width: int = 640,
481
+ num_blocks: int = 9,
482
+ num_frames_per_block: int = 3,
483
+ dtype: Optional[torch.dtype] = None,
484
+ device: Optional[torch.device] = None,
485
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
486
+ latents: Optional[torch.Tensor] = None,
487
+ ) -> torch.Tensor:
488
+ if latents is not None:
489
+ return latents.to(device=device, dtype=dtype)
490
+
491
+ num_latent_frames = num_blocks * num_frames_per_block
492
+ shape = (
493
+ batch_size,
494
+ num_channels_latents,
495
+ num_latent_frames,
496
+ int(height) // components.vae_scale_factor_spatial,
497
+ int(width) // components.vae_scale_factor_spatial,
498
+ )
499
+ if isinstance(generator, list) and len(generator) != batch_size:
500
+ raise ValueError(
501
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
502
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
503
+ )
504
+
505
+ latents = randn_tensor(
506
+ shape,
507
+ generator=generator,
508
+ device=components.transformer.device,
509
+ dtype=dtype,
510
+ )
511
+ return latents
512
+
513
+ @torch.no_grad()
514
+ def __call__(
515
+ self, components: ModularPipeline, state: PipelineState
516
+ ) -> PipelineState:
517
+ block_state = self.get_block_state(state)
518
+
519
+ block_state.height = block_state.height or components.default_height
520
+ block_state.width = block_state.width or components.default_width
521
+ block_state.device = components._execution_device
522
+ block_state.num_channels_latents = components.num_channels_latents
523
+
524
+ self.check_inputs(components, block_state)
525
+
526
+ block_state.init_latents = self.prepare_latents(
527
+ components,
528
+ 1,
529
+ block_state.num_channels_latents,
530
+ block_state.height,
531
+ block_state.width,
532
+ block_state.num_blocks,
533
+ components.config.num_frames_per_block,
534
+ components.transformer.dtype,
535
+ block_state.device,
536
+ block_state.generator,
537
+ block_state.init_latents,
538
+ )
539
+ if block_state.final_latents is None:
540
+ block_state.final_latents = torch.zeros_like(
541
+ block_state.init_latents, device=components.transformer.device
542
+ )
543
+ self.set_block_state(state, block_state)
544
+
545
+ return components, state
546
+
547
+
548
+ class WanRTStreamingExtractBlockLatentsStep(ModularPipelineBlocks):
549
+ """
550
+ Extracts a single block of latents from the full video buffer for streaming generation.
551
+
552
+ This block simply slices the final_latents buffer to get the current block's latents.
553
+ The final_latents buffer should be created beforehand using WanRTStreamingPrepareAllLatents.
554
+ """
555
+
556
+ model_name = "WanRT"
557
+
558
+ @property
559
+ def expected_components(self) -> List[ComponentSpec]:
560
+ return []
561
+
562
+ @property
563
+ def description(self) -> str:
564
+ return (
565
+ "Extracts a single block from the full latent buffer for streaming generation. "
566
+ "Slices final_latents based on block_idx to get current block's latents."
567
+ )
568
+
569
+ @property
570
+ def inputs(self) -> List[InputParam]:
571
+ return [
572
+ InputParam(
573
+ "final_latents",
574
+ required=True,
575
+ type_hint=torch.Tensor,
576
+ description="Full latent buffer [B, C, total_frames, H, W]",
577
+ ),
578
+ InputParam(
579
+ "init_latents",
580
+ required=True,
581
+ type_hint=torch.Tensor,
582
+ description="Full latent buffer [B, C, total_frames, H, W]",
583
+ ),
584
+ InputParam(
585
+ "latents",
586
+ type_hint=torch.Tensor,
587
+ description="Full latent buffer [B, C, total_frames, H, W]",
588
+ ),
589
+ InputParam(
590
+ "block_idx",
591
+ required=True,
592
+ type_hint=int,
593
+ default=0,
594
+ description="Current block index to process",
595
+ ),
596
+ InputParam(
597
+ "num_frames_per_block",
598
+ required=True,
599
+ type_hint=int,
600
+ default=3,
601
+ description="Number of frames per block",
602
+ ),
603
+ ]
604
+
605
+ @property
606
+ def intermediate_outputs(self) -> List[OutputParam]:
607
+ return [
608
+ OutputParam(
609
+ "latents",
610
+ type_hint=torch.Tensor,
611
+ description="Latents for current block [B, C, num_frames_per_block, H, W]",
612
+ ),
613
+ OutputParam(
614
+ "current_start_frame",
615
+ type_hint=int,
616
+ description="Starting frame index for current block",
617
+ ),
618
+ ]
619
+
620
+ @torch.no_grad()
621
+ def __call__(
622
+ self, components: ModularPipeline, state: PipelineState
623
+ ) -> PipelineState:
624
+ block_state = self.get_block_state(state)
625
+
626
+ num_frames_per_block = block_state.num_frames_per_block
627
+ block_idx = block_state.block_idx
628
+
629
+ # Calculate frame range for current block
630
+ start_frame = block_idx * num_frames_per_block
631
+ end_frame = start_frame + num_frames_per_block
632
+
633
+ # Extract single block from full latent buffer
634
+ # final_latents shape: [B, C, total_frames, H, W]
635
+ # Extract frames along the time dimension (dim=2)
636
+ block_state.latents = block_state.init_latents[
637
+ :, :, start_frame:end_frame, :, :
638
+ ]
639
+ block_state.current_start_frame = start_frame
640
+
641
+ self.set_block_state(state, block_state)
642
+ return components, state
643
+
644
+
645
+ class WanRTStreamingSetupKVCache(ModularPipelineBlocks):
646
+ """
647
+ Initializes KV cache and cross-attention cache for streaming generation.
648
+
649
+ This block sets up the persistent caches used across all blocks in streaming
650
+ generation. Mirrors the cache initialization logic from causal_inference.py.
651
+ Should be called once at the start of streaming generation.
652
+ """
653
+
654
+ model_name = "WanRT"
655
+
656
+ @property
657
+ def expected_components(self) -> List[ComponentSpec]:
658
+ return [
659
+ ComponentSpec("transformer", torch.nn.Module),
660
+ ]
661
+
662
+ @property
663
+ def expected_configs(self) -> List[ConfigSpec]:
664
+ return [
665
+ ConfigSpec("kv_cache_num_frames", 3),
666
+ ConfigSpec("num_frames_per_block", 3),
667
+ ConfigSpec("frame_seq_length", 1560),
668
+ ConfigSpec("frame_cache_len", 9),
669
+ ]
670
+
671
+ @property
672
+ def description(self) -> str:
673
+ return (
674
+ "Initializes KV cache and cross-attention cache for streaming generation. "
675
+ "Creates persistent caches that will be reused across all blocks."
676
+ )
677
+
678
+ @property
679
+ def inputs(self) -> List[InputParam]:
680
+ return [
681
+ InputParam(
682
+ "kv_cache",
683
+ required=False,
684
+ type_hint=Optional[List[Dict]],
685
+ description="Existing KV cache. If provided and shape matches, will be zeroed instead of recreated.",
686
+ ),
687
+ InputParam(
688
+ "crossattn_cache",
689
+ required=False,
690
+ type_hint=Optional[List[Dict]],
691
+ description="Existing cross-attention cache. If provided and shape matches, will be zeroed.",
692
+ ),
693
+ InputParam(
694
+ "local_attn_size",
695
+ required=False,
696
+ type_hint=int,
697
+ default=-1,
698
+ description="Local attention size for computing KV cache size. -1 uses default (32760).",
699
+ ),
700
+ InputParam(
701
+ "dtype",
702
+ required=False,
703
+ type_hint=torch.dtype,
704
+ description="Data type for caches (defaults to bfloat16)",
705
+ ),
706
+ InputParam(
707
+ "update_prompt_embeds",
708
+ required=False,
709
+ description="Flag to reinitialize prompt embeds if they are updated.",
710
+ default=False,
711
+ ),
712
+ ]
713
+
714
+ @property
715
+ def outputs(self) -> List[OutputParam]:
716
+ return [
717
+ OutputParam(
718
+ "kv_cache",
719
+ type_hint=List[Dict],
720
+ description="Initialized KV cache (list of dicts per transformer block)",
721
+ ),
722
+ OutputParam(
723
+ "crossattn_cache",
724
+ type_hint=List[Dict],
725
+ description="Initialized cross-attention cache",
726
+ ),
727
+ OutputParam(
728
+ "local_attn_size",
729
+ ),
730
+ ]
731
+
732
+ @torch.no_grad()
733
+ def __call__(
734
+ self, components: ModularPipeline, state: PipelineState
735
+ ) -> PipelineState:
736
+ block_state = self.get_block_state(state)
737
+ batch_size = 1 # Streaming always uses batch_size=1
738
+
739
+ # Get existing caches if they exist
740
+ kv_cache = block_state.kv_cache
741
+ crossattn_cache = block_state.crossattn_cache
742
+
743
+ if block_state.crossattn_cache is None or block_state.update_prompt_embeds:
744
+ block_state.crossattn_cache = _initialize_crossattn_cache(
745
+ components,
746
+ crossattn_cache,
747
+ batch_size,
748
+ components.transformer.dtype,
749
+ components.transformer.device,
750
+ )
751
+
752
+ block_state.local_attn_size = (
753
+ components.config.kv_cache_num_frames
754
+ + components.config.num_frames_per_block
755
+ )
756
+ for block in components.transformer.blocks:
757
+ block.self_attn.local_attn_size = -1
758
+ for block in components.transformer.blocks:
759
+ block.self_attn.num_frame_per_block = components.config.num_frames_per_block
760
+
761
+ block_state.kv_cache = _initialize_kv_cache(
762
+ components,
763
+ kv_cache,
764
+ batch_size,
765
+ components.transformer.dtype,
766
+ components.transformer.device,
767
+ block_state.local_attn_size,
768
+ components.config.frame_seq_length,
769
+ )
770
+
771
+ self.set_block_state(state, block_state)
772
+ return components, state
773
+
774
+
775
+ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
776
+ @property
777
+ def inputs(self) -> List[InputParam]:
778
+ return [
779
+ InputParam(
780
+ "latents",
781
+ type_hint=torch.Tensor,
782
+ description="Current block latents [B, C, num_frames_per_block, H, W]",
783
+ ),
784
+ InputParam(
785
+ "num_frames_per_block",
786
+ type_hint=int,
787
+ description="Number of frames per block",
788
+ ),
789
+ InputParam(
790
+ "block_idx",
791
+ type_hint=int,
792
+ description="Current block index to process",
793
+ ),
794
+ InputParam(
795
+ "block_mask",
796
+ description="Block-wise causal attention mask",
797
+ ),
798
+ InputParam(
799
+ "current_start_frame",
800
+ type_hint=int,
801
+ description="Starting frame index for current block",
802
+ ),
803
+ InputParam(
804
+ "videos",
805
+ type_hint=torch.Tensor,
806
+ description="Video frames for context encoding",
807
+ ),
808
+ InputParam(
809
+ "final_latents",
810
+ type_hint=torch.Tensor,
811
+ description="Full latent buffer [B, C, total_frames, H, W]",
812
+ ),
813
+ InputParam(
814
+ "prompt_embeds",
815
+ type_hint=torch.Tensor,
816
+ description="Text embeddings to guide generation",
817
+ ),
818
+ InputParam(
819
+ "kv_cache",
820
+ type_hint=torch.Tensor,
821
+ description="Key-value cache for attention",
822
+ ),
823
+ InputParam(
824
+ "crossattn_cache",
825
+ type_hint=torch.Tensor,
826
+ description="Cross-attention cache",
827
+ ),
828
+ InputParam(
829
+ "encoder_cache",
830
+ description="Encoder feature cache",
831
+ ),
832
+ InputParam(
833
+ "frame_cache_context",
834
+ description="Cached context frames for reencoding",
835
+ ),
836
+ InputParam(
837
+ "local_attn_size",
838
+ ),
839
+ ]
840
+
841
+ @property
842
+ def expected_configs(self) -> List[ConfigSpec]:
843
+ return [ConfigSpec("seq_length", 32760)]
844
+
845
+ def prepare_latents(self, components, block_state):
846
+ frames = block_state.frame_cache_context[0].half()
847
+
848
+ components.vae._enc_feat_map = [None] * 55
849
+ latents = retrieve_latents(components.vae.encode(frames), sample_mode="argmax")
850
+ latents_mean = (
851
+ torch.tensor(components.vae.config.latents_mean)
852
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
853
+ .to(latents.device, latents.dtype)
854
+ )
855
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
856
+ 1, components.vae.config.z_dim, 1, 1, 1
857
+ ).to(latents.device, latents.dtype)
858
+ latents = (latents - latents_mean) * latents_std
859
+
860
+ return latents.to(components.transformer.dtype)
861
+
862
+ def get_context_frames(self, components, block_state):
863
+ current_kv_cache_num_frames = components.config.kv_cache_num_frames
864
+ context_frames = block_state.final_latents[
865
+ :, :, : block_state.current_start_frame
866
+ ]
867
+
868
+ if (
869
+ block_state.block_idx - 1
870
+ ) * block_state.num_frames_per_block < current_kv_cache_num_frames:
871
+ if current_kv_cache_num_frames == 1:
872
+ context_frames = context_frames[:, :, :1]
873
+ else:
874
+ context_frames = torch.cat(
875
+ (
876
+ context_frames[:, :, :1],
877
+ context_frames[:, :, 1:][
878
+ :, :, -current_kv_cache_num_frames + 1 :
879
+ ],
880
+ ),
881
+ dim=2,
882
+ )
883
+ else:
884
+ context_frames = context_frames[:, :, 1:][
885
+ :, :, -current_kv_cache_num_frames + 1 :
886
+ ]
887
+ first_frame_latent = self.prepare_latents(components, block_state)
888
+ first_frame_latent = first_frame_latent.to(block_state.latents)
889
+ context_frames = torch.cat((first_frame_latent, context_frames), dim=2)
890
+
891
+ return context_frames
892
+
893
+ def __call__(self, components, state):
894
+ block_state = self.get_block_state(state)
895
+ if block_state.block_idx == 0:
896
+ return components, state
897
+
898
+ start_frame = min(
899
+ block_state.current_start_frame, components.config.kv_cache_num_frames
900
+ )
901
+ context_frames = self.get_context_frames(components, block_state)
902
+ block_state.block_mask = (
903
+ components.transformer._prepare_blockwise_causal_attn_mask(
904
+ components.transformer.device,
905
+ num_frames=context_frames.shape[2],
906
+ frame_seqlen=components.config.frame_seq_length,
907
+ num_frame_per_block=block_state.num_frames_per_block,
908
+ local_attn_size=-1,
909
+ )
910
+ )
911
+ components.transformer.block_mask = block_state.block_mask
912
+ context_timestep = torch.zeros(
913
+ (context_frames.shape[0], context_frames.shape[2]),
914
+ device=components.transformer.device,
915
+ dtype=torch.int64,
916
+ )
917
+ components.transformer(
918
+ x=context_frames.to(components.transformer.dtype),
919
+ t=context_timestep,
920
+ context=block_state.prompt_embeds.to(components.transformer.dtype),
921
+ kv_cache=block_state.kv_cache,
922
+ seq_len=components.config.seq_length,
923
+ crossattn_cache=block_state.crossattn_cache,
924
+ current_start=start_frame * components.config.frame_seq_length,
925
+ cache_start=None,
926
+ )
927
+ components.transformer.block_mask = None
928
+
929
+ return components, state
930
+
931
+
932
+ class WanRTStreamingBeforeDenoiseStep(SequentialPipelineBlocks):
933
+ block_classes = [
934
+ WanRTStreamingSetTimestepsStep,
935
+ WanRTStreamingPrepareLatentsStep,
936
+ WanRTStreamingExtractBlockLatentsStep,
937
+ WanRTStreamingSetupKVCache,
938
+ WanRTStreamingRecomputeKVCache,
939
+ ]
940
+ block_names = [
941
+ "set_timesteps",
942
+ "prepare_latents",
943
+ "extract_block_init_latents",
944
+ "setup_kv_cache",
945
+ "recompute_kv_cache",
946
+ ]
947
+
948
+ @property
949
+ def description(self):
950
+ return (
951
+ "Before denoise step that prepare the inputs for the denoise step.\n"
952
+ + "This is a sequential pipeline blocks:\n"
953
+ + " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n"
954
+ + " - `WanRTSetTimestepsStep` is used to set the timesteps\n"
955
+ + " - `WanRTPrepareLatentsStep` is used to prepare the latents\n"
956
+ )
decoders.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List, Tuple, Union
16
+
17
+ import numpy as np
18
+ import PIL
19
+ import torch
20
+
21
+ from diffusers.configuration_utils import FrozenDict
22
+ from diffusers.models import AutoencoderKLWan
23
+ from diffusers.utils import logging
24
+ from diffusers.video_processor import VideoProcessor
25
+ from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState
26
+ from diffusers.modular_pipelines.modular_pipeline_utils import (
27
+ ComponentSpec,
28
+ InputParam,
29
+ OutputParam,
30
+ )
31
+ import types
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ class WanRTDecodeStep(ModularPipelineBlocks):
38
+ model_name = "WanRT"
39
+ decoder_cache = []
40
+
41
+ @property
42
+ def expected_components(self) -> List[ComponentSpec]:
43
+ return [
44
+ ComponentSpec(
45
+ "vae",
46
+ AutoencoderKLWan,
47
+ repo="Wan-AI/Wan2.1-T2V-14B-Diffusers",
48
+ subfolder="vae",
49
+ ),
50
+ ComponentSpec(
51
+ "video_processor",
52
+ VideoProcessor,
53
+ config=FrozenDict({"vae_scale_factor": 8}),
54
+ default_creation_method="from_config",
55
+ ),
56
+ ]
57
+
58
+ @property
59
+ def description(self) -> str:
60
+ return "Step that decodes the denoised latents into images"
61
+
62
+ @property
63
+ def inputs(self) -> List[Tuple[str, Any]]:
64
+ return [
65
+ InputParam("output_type", default="pil"),
66
+ InputParam(
67
+ "latents",
68
+ required=True,
69
+ type_hint=torch.Tensor,
70
+ description="The denoised latents from the denoising step",
71
+ ),
72
+ InputParam(
73
+ "frame_cache_context",
74
+ description="The denoised latents from the denoising step",
75
+ ),
76
+ InputParam(
77
+ "block_idx",
78
+ description="The denoised latents from the denoising step",
79
+ ),
80
+ InputParam(
81
+ "decoder_cache",
82
+ description="The denoised latents from the denoising step",
83
+ ),
84
+ ]
85
+
86
+ @property
87
+ def intermediate_outputs(self) -> List[str]:
88
+ return [
89
+ OutputParam(
90
+ "videos",
91
+ type_hint=Union[
92
+ List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]
93
+ ],
94
+ description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
95
+ )
96
+ ]
97
+
98
+ @torch.no_grad()
99
+ def __call__(self, components, state: PipelineState) -> PipelineState:
100
+ block_state = self.get_block_state(state)
101
+ vae_dtype = components.vae.dtype
102
+
103
+ # Disable clearing cache
104
+ if block_state.block_idx == 0:
105
+ components.vae.clear_cache()
106
+ components.vae.clear_cache = lambda: None
107
+ components.vae._feat_map = [None] * 55
108
+
109
+ if block_state.block_idx != 0:
110
+ components.vae._feat_map = block_state.decoder_cache
111
+
112
+ if not block_state.output_type == "latent":
113
+ latents = block_state.latents.to(components.vae.device)
114
+
115
+ # Create tensors directly on target device and dtype to avoid redundant conversions
116
+ latents_mean = torch.tensor(
117
+ components.vae.config.latents_mean,
118
+ device=latents.device,
119
+ dtype=latents.dtype,
120
+ ).view(1, components.vae.config.z_dim, 1, 1, 1)
121
+ latents_std = 1.0 / torch.tensor(
122
+ components.vae.config.latents_std,
123
+ device=latents.device,
124
+ dtype=latents.dtype,
125
+ ).view(1, components.vae.config.z_dim, 1, 1, 1)
126
+
127
+ latents = latents / latents_std + latents_mean
128
+ latents = latents.to(vae_dtype)
129
+
130
+ videos = components.vae.decode(latents, return_dict=False)[0]
131
+
132
+ else:
133
+ block_state.videos = block_state.latents
134
+
135
+ block_state.decoder_cache = components.vae._feat_map
136
+ block_state.frame_cache_context.extend(videos.split(1, dim=2))
137
+
138
+ videos = components.video_processor.postprocess_video(
139
+ videos, output_type=block_state.output_type
140
+ )
141
+ block_state.videos = videos
142
+
143
+ self.set_block_state(state, block_state)
144
+
145
+ return components, state
denoise.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List, Tuple
16
+
17
+ import torch
18
+
19
+ from diffusers.configuration_utils import FrozenDict
20
+ from diffusers.guiders import ClassifierFreeGuidance
21
+ from diffusers.models import AutoModel
22
+ from diffusers.schedulers import UniPCMultistepScheduler
23
+ from diffusers.utils import logging
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.modular_pipelines import (
26
+ BlockState,
27
+ LoopSequentialPipelineBlocks,
28
+ ModularPipelineBlocks,
29
+ PipelineState,
30
+ ModularPipeline,
31
+ )
32
+ from diffusers.modular_pipelines.modular_pipeline_utils import (
33
+ ComponentSpec,
34
+ InputParam,
35
+ OutputParam,
36
+ )
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ class WanRTStreamingLoopDenoiser(ModularPipelineBlocks):
43
+ model_name = "WanRTStreaming"
44
+
45
+ @property
46
+ def expected_components(self) -> List[ComponentSpec]:
47
+ return [ComponentSpec("transformer", AutoModel)]
48
+
49
+ @property
50
+ def description(self) -> str:
51
+ return (
52
+ "Step within the denoising loop that denoise the latents with guidance. "
53
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
54
+ "object (e.g. `WanRTStreamingDenoiseLoopWrapper`)"
55
+ )
56
+
57
+ @property
58
+ def inputs(self) -> List[Tuple[str, Any]]:
59
+ return [
60
+ InputParam("attention_kwargs"),
61
+ InputParam("block_idx"),
62
+ InputParam(
63
+ "latents",
64
+ required=True,
65
+ type_hint=torch.Tensor,
66
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
67
+ ),
68
+ InputParam(
69
+ "prompt_embeds",
70
+ required=True,
71
+ type_hint=torch.Tensor,
72
+ ),
73
+ InputParam(
74
+ "kv_cache",
75
+ required=True,
76
+ type_hint=torch.Tensor,
77
+ ),
78
+ InputParam(
79
+ "crossattn_cache",
80
+ required=True,
81
+ type_hint=torch.Tensor,
82
+ ),
83
+ InputParam(
84
+ "current_start_frame",
85
+ required=True,
86
+ type_hint=torch.Tensor,
87
+ ),
88
+ InputParam(
89
+ "num_inference_steps",
90
+ required=True,
91
+ type_hint=int,
92
+ default=4,
93
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
94
+ ),
95
+ InputParam(
96
+ kwargs_type="guider_input_fields",
97
+ description=(
98
+ "All conditional model inputs that need to be prepared with guider. "
99
+ "It should contain prompt_embeds/negative_prompt_embeds. "
100
+ "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
101
+ ),
102
+ ),
103
+ ]
104
+
105
+ @torch.no_grad()
106
+ def __call__(
107
+ self,
108
+ components: ModularPipeline,
109
+ block_state: BlockState,
110
+ i: int,
111
+ t: torch.Tensor,
112
+ ) -> PipelineState:
113
+ start_frame = min(
114
+ block_state.current_start_frame, components.config.kv_cache_num_frames
115
+ )
116
+
117
+ block_state.noise_pred = components.transformer(
118
+ x=block_state.latents,
119
+ t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block),
120
+ context=block_state.prompt_embeds,
121
+ kv_cache=block_state.kv_cache,
122
+ seq_len=components.config.seq_length,
123
+ crossattn_cache=block_state.crossattn_cache,
124
+ current_start=start_frame * components.config.frame_seq_length,
125
+ cache_start=start_frame * components.config.frame_seq_length,
126
+ )
127
+
128
+ return components, block_state
129
+
130
+
131
+ class WanRTStreamingLoopAfterDenoiser(ModularPipelineBlocks):
132
+ model_name = "WanRTStreaming"
133
+
134
+ @property
135
+ def expected_components(self) -> List[ComponentSpec]:
136
+ return [
137
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
138
+ ]
139
+
140
+ @property
141
+ def description(self) -> str:
142
+ return (
143
+ "step within the denoising loop that update the latents. "
144
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
145
+ "object (e.g. `WanRTStreamingDenoiseLoopWrapper`)"
146
+ )
147
+
148
+ @property
149
+ def inputs(self) -> List[Tuple[str, Any]]:
150
+ return []
151
+
152
+ @property
153
+ def intermediate_inputs(self) -> List[str]:
154
+ return [
155
+ InputParam("generator"),
156
+ InputParam("block_id"),
157
+ ]
158
+
159
+ @property
160
+ def intermediate_outputs(self) -> List[OutputParam]:
161
+ return [
162
+ OutputParam(
163
+ "latents", type_hint=torch.Tensor, description="The denoised latents"
164
+ )
165
+ ]
166
+
167
+ @torch.no_grad()
168
+ def __call__(
169
+ self,
170
+ components: ModularPipeline,
171
+ block_state: BlockState,
172
+ i: int,
173
+ t: torch.Tensor,
174
+ ):
175
+ # Perform scheduler step using the predicted output
176
+ latents_dtype = block_state.latents.dtype
177
+ timesteps = block_state.all_timesteps
178
+ sigmas = block_state.sigmas
179
+
180
+ timestep_id = torch.argmin((timesteps - t).abs())
181
+ sigma_t = sigmas[timestep_id]
182
+
183
+ # Perform computation in double precision, then convert back once
184
+ latents = (
185
+ block_state.latents.double()
186
+ - sigma_t.double() * block_state.noise_pred.double()
187
+ ).to(latents_dtype)
188
+
189
+ block_state.latents = latents
190
+
191
+ return components, block_state
192
+
193
+
194
+ class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
195
+ model_name = "WanRTStreaming"
196
+
197
+ @property
198
+ def description(self) -> str:
199
+ return (
200
+ "Streaming denoising loop that processes a single block with persistent KV cache. "
201
+ "Recomputes cache from context frames, denoises current block, and updates cache."
202
+ )
203
+
204
+ def add_noise(self, components, block_state, sample, noise, timestep, index):
205
+ timesteps = block_state.all_timesteps
206
+ sigmas = block_state.sigmas.to(timesteps.device)
207
+
208
+ if timestep.ndim == 2:
209
+ timestep = timestep.flatten(0, 1)
210
+
211
+ timestep_id = torch.argmin(
212
+ (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1
213
+ )
214
+ sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1)
215
+ sample = (
216
+ 1 - sigma.double()
217
+ ) * sample.double() + sigma.double() * noise.double()
218
+ sample = sample.type_as(noise)
219
+
220
+ return sample
221
+
222
+ @property
223
+ def loop_inputs(self) -> List[InputParam]:
224
+ return [
225
+ InputParam(
226
+ "timesteps",
227
+ required=True,
228
+ type_hint=torch.Tensor,
229
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
230
+ ),
231
+ InputParam(
232
+ "all_timesteps",
233
+ required=True,
234
+ type_hint=torch.Tensor,
235
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
236
+ ),
237
+ InputParam(
238
+ "sigmas",
239
+ required=True,
240
+ type_hint=torch.Tensor,
241
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
242
+ ),
243
+ InputParam("final_latents", type_hint=torch.Tensor),
244
+ InputParam(
245
+ "num_inference_steps",
246
+ required=True,
247
+ type_hint=int,
248
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
249
+ ),
250
+ InputParam(
251
+ "num_frames_per_block",
252
+ required=True,
253
+ type_hint=int,
254
+ default=3,
255
+ ),
256
+ InputParam(
257
+ "current_start_frame",
258
+ required=True,
259
+ type_hint=int,
260
+ ),
261
+ InputParam(
262
+ "block_idx",
263
+ ),
264
+ InputParam(
265
+ "generator",
266
+ ),
267
+ ]
268
+
269
+ @torch.no_grad()
270
+ def __call__(
271
+ self, components: ModularPipeline, state: PipelineState
272
+ ) -> PipelineState:
273
+ block_state = self.get_block_state(state)
274
+
275
+ for i, t in enumerate(block_state.timesteps):
276
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
277
+ if i < (block_state.num_inference_steps - 1):
278
+ t1 = block_state.timesteps[i + 1]
279
+
280
+ block_state.latents = (
281
+ self.add_noise(
282
+ components,
283
+ block_state,
284
+ block_state.latents.transpose(1, 2).squeeze(0),
285
+ randn_tensor(
286
+ block_state.latents.transpose(1, 2).squeeze(0).shape,
287
+ device=block_state.latents.device,
288
+ dtype=block_state.latents.dtype,
289
+ generator=block_state.generator,
290
+ ),
291
+ t1.expand(
292
+ block_state.latents.shape[0],
293
+ block_state.num_frames_per_block,
294
+ ),
295
+ i,
296
+ )
297
+ .unsqueeze(0)
298
+ .transpose(1, 2)
299
+ )
300
+
301
+ # Update the state
302
+ block_state.final_latents[
303
+ :,
304
+ :,
305
+ block_state.current_start_frame : block_state.current_start_frame
306
+ + block_state.num_frames_per_block,
307
+ ] = block_state.latents
308
+
309
+ self.set_block_state(state, block_state)
310
+
311
+ return components, state
312
+
313
+
314
+ class WanRTStreamingDenoiseStep(WanRTStreamingDenoiseLoopWrapper):
315
+ block_classes = [
316
+ WanRTStreamingLoopDenoiser,
317
+ WanRTStreamingLoopAfterDenoiser,
318
+ ]
319
+ block_names = ["denoiser", "after_denoiser"]
320
+
321
+ @property
322
+ def description(self) -> str:
323
+ return (
324
+ "Denoise step that iteratively denoise the latents. \n"
325
+ "Its loop logic is defined in `WanRTStreamingDenoiseLoopWrapper.__call__` method \n"
326
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
327
+ " - `WanRTStreamingLoopDenoiser`\n"
328
+ " - `WanRTStreamingLoopAfterDenoiser`\n"
329
+ "This block supports both text2vid tasks."
330
+ )
encoders.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ from typing import List, Optional, Union
17
+
18
+ import regex as re
19
+ import torch
20
+ from transformers import AutoTokenizer, UMT5EncoderModel
21
+
22
+ from diffusers.configuration_utils import FrozenDict
23
+ from diffusers.guiders import ClassifierFreeGuidance
24
+ from diffusers.utils import is_ftfy_available, logging
25
+ from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState
26
+ from diffusers.modular_pipelines.modular_pipeline_utils import (
27
+ ComponentSpec,
28
+ ConfigSpec,
29
+ InputParam,
30
+ OutputParam,
31
+ )
32
+ from diffusers.modular_pipelines import WanModularPipeline
33
+
34
+
35
+ if is_ftfy_available():
36
+ import ftfy
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ def basic_clean(text):
43
+ text = ftfy.fix_text(text)
44
+ text = html.unescape(html.unescape(text))
45
+ return text.strip()
46
+
47
+
48
+ def whitespace_clean(text):
49
+ text = re.sub(r"\s+", " ", text)
50
+ text = text.strip()
51
+ return text
52
+
53
+
54
+ def prompt_clean(text):
55
+ text = whitespace_clean(basic_clean(text))
56
+ return text
57
+
58
+
59
+ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
60
+ model_name = "WanRTStreaming"
61
+
62
+ @property
63
+ def description(self) -> str:
64
+ return "Text Encoder step that generate text_embeddings to guide the video generation"
65
+
66
+ @property
67
+ def expected_components(self) -> List[ComponentSpec]:
68
+ return [
69
+ ComponentSpec("text_encoder", UMT5EncoderModel),
70
+ ComponentSpec("tokenizer", AutoTokenizer),
71
+ ComponentSpec(
72
+ "guider",
73
+ ClassifierFreeGuidance,
74
+ config=FrozenDict({"guidance_scale": 5.0}),
75
+ default_creation_method="from_config",
76
+ ),
77
+ ]
78
+
79
+ @property
80
+ def expected_configs(self) -> List[ConfigSpec]:
81
+ return []
82
+
83
+ @property
84
+ def inputs(self) -> List[InputParam]:
85
+ return [
86
+ InputParam("prompt"),
87
+ InputParam("negative_prompt"),
88
+ InputParam(
89
+ "prompt_embeds",
90
+ type_hint=torch.Tensor,
91
+ description="text embeddings used to guide the image generation",
92
+ ),
93
+ InputParam(
94
+ "negative_prompt_embeds",
95
+ type_hint=torch.Tensor,
96
+ description="negative text embeddings used to guide the image generation",
97
+ ),
98
+ InputParam("attention_kwargs"),
99
+ ]
100
+
101
+ @property
102
+ def intermediate_outputs(self) -> List[OutputParam]:
103
+ return [
104
+ OutputParam(
105
+ "prompt_embeds",
106
+ type_hint=torch.Tensor,
107
+ kwargs_type="denoiser_input_fields",
108
+ description="text embeddings used to guide the image generation",
109
+ ),
110
+ OutputParam(
111
+ "negative_prompt_embeds",
112
+ type_hint=torch.Tensor,
113
+ kwargs_type="denoiser_input_fields",
114
+ description="negative text embeddings used to guide the image generation",
115
+ ),
116
+ ]
117
+
118
+ @staticmethod
119
+ def check_inputs(block_state):
120
+ if block_state.prompt is not None and (
121
+ not isinstance(block_state.prompt, str)
122
+ and not isinstance(block_state.prompt, list)
123
+ ):
124
+ raise ValueError(
125
+ f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}"
126
+ )
127
+
128
+ @staticmethod
129
+ def _get_t5_prompt_embeds(
130
+ components,
131
+ prompt: Union[str, List[str]],
132
+ max_sequence_length: int,
133
+ device: torch.device,
134
+ ):
135
+ dtype = components.text_encoder.dtype
136
+ prompt = [prompt] if isinstance(prompt, str) else prompt
137
+ prompt = [prompt_clean(u) for u in prompt]
138
+
139
+ text_inputs = components.tokenizer(
140
+ prompt,
141
+ padding="max_length",
142
+ max_length=max_sequence_length,
143
+ truncation=True,
144
+ add_special_tokens=True,
145
+ return_attention_mask=True,
146
+ return_tensors="pt",
147
+ )
148
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
149
+ seq_lens = mask.gt(0).sum(dim=1).long()
150
+ prompt_embeds = components.text_encoder(
151
+ text_input_ids.to(device), mask.to(device)
152
+ ).last_hidden_state
153
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
154
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
155
+ prompt_embeds = torch.stack(
156
+ [
157
+ torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))])
158
+ for u in prompt_embeds
159
+ ],
160
+ dim=0,
161
+ )
162
+
163
+ return prompt_embeds
164
+
165
+ @staticmethod
166
+ def encode_prompt(
167
+ components,
168
+ prompt: str,
169
+ device: Optional[torch.device] = None,
170
+ num_videos_per_prompt: int = 1,
171
+ prepare_unconditional_embeds: bool = True,
172
+ negative_prompt: Optional[str] = None,
173
+ prompt_embeds: Optional[torch.Tensor] = None,
174
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
175
+ max_sequence_length: int = 512,
176
+ ):
177
+ r"""
178
+ Encodes the prompt into text encoder hidden states.
179
+
180
+ Args:
181
+ prompt (`str` or `List[str]`, *optional*):
182
+ prompt to be encoded
183
+ device: (`torch.device`):
184
+ torch device
185
+ num_videos_per_prompt (`int`):
186
+ number of videos that should be generated per prompt
187
+ prepare_unconditional_embeds (`bool`):
188
+ whether to use prepare unconditional embeddings or not
189
+ negative_prompt (`str` or `List[str]`, *optional*):
190
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
191
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
192
+ less than `1`).
193
+ prompt_embeds (`torch.Tensor`, *optional*):
194
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
195
+ provided, text embeddings will be generated from `prompt` input argument.
196
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
197
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
198
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
199
+ argument.
200
+ max_sequence_length (`int`, defaults to `512`):
201
+ The maximum number of text tokens to be used for the generation process.
202
+ """
203
+ device = device or components._execution_device
204
+ prompt = [prompt] if isinstance(prompt, str) else prompt
205
+ batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
206
+
207
+ if prompt_embeds is None:
208
+ prompt_embeds = WanRTStreamingTextEncoderStep._get_t5_prompt_embeds(
209
+ components, prompt, max_sequence_length, device
210
+ )
211
+
212
+ if prepare_unconditional_embeds and negative_prompt_embeds is None:
213
+ negative_prompt = negative_prompt or ""
214
+ negative_prompt = (
215
+ batch_size * [negative_prompt]
216
+ if isinstance(negative_prompt, str)
217
+ else negative_prompt
218
+ )
219
+
220
+ if prompt is not None and type(prompt) is not type(negative_prompt):
221
+ raise TypeError(
222
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
223
+ f" {type(prompt)}."
224
+ )
225
+ elif batch_size != len(negative_prompt):
226
+ raise ValueError(
227
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
228
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
229
+ " the batch size of `prompt`."
230
+ )
231
+
232
+ negative_prompt_embeds = (
233
+ WanRTStreamingTextEncoderStep._get_t5_prompt_embeds(
234
+ components, negative_prompt, max_sequence_length, device
235
+ )
236
+ )
237
+
238
+ bs_embed, seq_len, _ = prompt_embeds.shape
239
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
240
+ prompt_embeds = prompt_embeds.view(
241
+ bs_embed * num_videos_per_prompt, seq_len, -1
242
+ )
243
+
244
+ if prepare_unconditional_embeds:
245
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
246
+ 1, num_videos_per_prompt, 1
247
+ )
248
+ negative_prompt_embeds = negative_prompt_embeds.view(
249
+ batch_size * num_videos_per_prompt, seq_len, -1
250
+ )
251
+
252
+ return prompt_embeds, negative_prompt_embeds
253
+
254
+ @torch.no_grad()
255
+ def __call__(
256
+ self, components: WanModularPipeline, state: PipelineState
257
+ ) -> PipelineState:
258
+ # Get inputs and intermediates
259
+ block_state = self.get_block_state(state)
260
+ self.check_inputs(block_state)
261
+
262
+ block_state.prepare_unconditional_embeds = False
263
+ block_state.device = components._execution_device
264
+
265
+ # Encode input prompt
266
+ (
267
+ block_state.prompt_embeds,
268
+ block_state.negative_prompt_embeds,
269
+ ) = WanRTStreamingTextEncoderStep.encode_prompt(
270
+ components,
271
+ block_state.prompt,
272
+ block_state.device,
273
+ 1,
274
+ block_state.prepare_unconditional_embeds,
275
+ block_state.negative_prompt,
276
+ prompt_embeds=block_state.prompt_embeds,
277
+ negative_prompt_embeds=block_state.negative_prompt_embeds,
278
+ )
279
+
280
+ # Add outputs
281
+ self.set_block_state(state, block_state)
282
+ return components, state
modular_blocks.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers.utils import logging
16
+ from diffusers.modular_pipelines import SequentialPipelineBlocks
17
+ from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict
18
+
19
+ from .before_denoise import WanRTStreamingBeforeDenoiseStep
20
+ from .decoders import WanRTDecodeStep
21
+ from .encoders import WanRTStreamingTextEncoderStep
22
+ from .denoise import WanRTStreamingDenoiseStep
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+ TEXT2VIDEO_BLOCKS = InsertableDict(
27
+ [
28
+ ("text_encoder", WanRTStreamingTextEncoderStep),
29
+ ("before_denoise", WanRTStreamingBeforeDenoiseStep),
30
+ ("denoise", WanRTStreamingDenoiseStep),
31
+ ("decode", WanRTDecodeStep),
32
+ ]
33
+ )
34
+
35
+ ALL_BLOCKS = {
36
+ "text2video": TEXT2VIDEO_BLOCKS,
37
+ }
38
+
39
+
40
+ class WanStreamingRTBlocks(SequentialPipelineBlocks):
41
+ block_classes = list(TEXT2VIDEO_BLOCKS.copy().values())
42
+ block_names = list(TEXT2VIDEO_BLOCKS.copy().keys())
modular_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "WanRTBlocks",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "auto_map": {
5
+ "ModularPipelineBlocks": "modular_blocks.WanStreamingRTBlocks"
6
+ }
7
+ }
modular_model_index.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_blocks_class_name": "WanStreamingRTBlocks",
3
+ "_class_name": "WanRTStreamingPipeline",
4
+ "_diffusers_version": "0.36.0.dev0",
5
+ "frame_seq_length": 1560,
6
+ "kv_cache_num_frames": 3,
7
+ "num_frames_per_block": 3,
8
+ "scheduler": [
9
+ null,
10
+ null,
11
+ {
12
+ "repo": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
13
+ "revision": null,
14
+ "subfolder": "scheduler",
15
+ "type_hint": [
16
+ "diffusers",
17
+ "UniPCMultistepScheduler"
18
+ ],
19
+ "variant": null
20
+ }
21
+ ],
22
+ "seq_length": 32760,
23
+ "text_encoder": [
24
+ null,
25
+ null,
26
+ {
27
+ "repo": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
28
+ "revision": null,
29
+ "subfolder": "text_encoder",
30
+ "type_hint": [
31
+ "transformers",
32
+ "UMT5EncoderModel"
33
+ ],
34
+ "variant": null
35
+ }
36
+ ],
37
+ "tokenizer": [
38
+ null,
39
+ null,
40
+ {
41
+ "repo": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
42
+ "revision": null,
43
+ "subfolder": "tokenizer",
44
+ "type_hint": [
45
+ "transformers",
46
+ "T5TokenizerFast"
47
+ ],
48
+ "variant": null
49
+ }
50
+ ],
51
+ "transformer": [
52
+ null,
53
+ null,
54
+ {
55
+ "repo": "diffusers-internal-dev/krt",
56
+ "revision": null,
57
+ "subfolder": "transformer",
58
+ "type_hint": [
59
+ "diffusers",
60
+ "AutoModel"
61
+ ],
62
+ "variant": null
63
+ }
64
+ ],
65
+ "vae": [
66
+ null,
67
+ null,
68
+ {
69
+ "repo": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
70
+ "revision": null,
71
+ "subfolder": "vae",
72
+ "type_hint": [
73
+ "diffusers",
74
+ "AutoencoderKLWan"
75
+ ],
76
+ "variant": null
77
+ }
78
+ ]
79
+ }
transformer/__init__.py ADDED
File without changes
transformer/attention.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from typing import Optional
4
+ import os
5
+ import warnings
6
+
7
+ # Global state for lazy initialization
8
+ _SAGEATTN_AVAILABLE = None
9
+ _FLASH_ATTN_3_AVAILABLE = None
10
+ _FLASH_ATTN_2_AVAILABLE = None
11
+ _sageattn_func = None
12
+ _flash_attn_func = None
13
+ _flash_attn_interface = None
14
+ _flash_attn = None
15
+
16
+
17
+ def _init_sageattention():
18
+ """Lazy initialization for SageAttention."""
19
+ global _SAGEATTN_AVAILABLE, _sageattn_func
20
+
21
+ if _SAGEATTN_AVAILABLE is not None:
22
+ return _SAGEATTN_AVAILABLE
23
+
24
+ _SAGEATTN_AVAILABLE = False
25
+ try:
26
+ if os.getenv("DISABLE_SAGEATTENTION", "0") != "0":
27
+ raise Exception("DISABLE_SAGEATTENTION is set")
28
+
29
+ from sageattention import sageattn
30
+
31
+ @torch.library.custom_op(
32
+ "mylib::sageattn", mutates_args={"q", "k", "v"}, device_types="cuda"
33
+ )
34
+ def sageattn_func(
35
+ q: torch.Tensor,
36
+ k: torch.Tensor,
37
+ v: torch.Tensor,
38
+ attn_mask: Optional[torch.Tensor] = None,
39
+ dropout_p: float = 0,
40
+ is_causal: bool = False,
41
+ ) -> torch.Tensor:
42
+ return sageattn(
43
+ q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
44
+ )
45
+
46
+ @sageattn_func.register_fake
47
+ def _sageattn_fake(q, k, v, attn_mask=None, dropout_p=0, is_causal=False):
48
+ return torch.empty(*q.shape, device=q.device, dtype=q.dtype)
49
+
50
+ print("SageAttention loaded successfully")
51
+ _sageattn_func = sageattn_func
52
+ _SAGEATTN_AVAILABLE = True
53
+
54
+ except Exception as e:
55
+ print(f"Warning: Could not load sageattention: {str(e)}")
56
+ if isinstance(e, ModuleNotFoundError):
57
+ print("sageattention package is not installed")
58
+ elif isinstance(e, ImportError) and "DLL" in str(e):
59
+ print("sageattention DLL loading error")
60
+ _sageattn_func = None
61
+
62
+ return _SAGEATTN_AVAILABLE
63
+
64
+
65
+ def _is_hopper_gpu():
66
+ """Check if the current GPU is a Hopper architecture."""
67
+ if not torch.cuda.is_available():
68
+ return False
69
+ device_name = torch.cuda.get_device_name(0).lower()
70
+ return "h100" in device_name or "hopper" in device_name
71
+
72
+
73
+ def _init_flash_attention_3():
74
+ """Lazy initialization for Flash Attention 3."""
75
+ global _FLASH_ATTN_3_AVAILABLE, _flash_attn_func, _flash_attn_interface
76
+
77
+ if _FLASH_ATTN_3_AVAILABLE is not None:
78
+ return _FLASH_ATTN_3_AVAILABLE
79
+
80
+ _FLASH_ATTN_3_AVAILABLE = False
81
+ try:
82
+ from flash_attn import flash_attn_func
83
+ import flash_attn_interface
84
+
85
+ # Always set the function reference if flash_attn is available
86
+ _flash_attn_func = flash_attn_func
87
+ _flash_attn_interface = flash_attn_interface
88
+ # FA3 optimizations only available on Hopper GPUs
89
+ _FLASH_ATTN_3_AVAILABLE = _is_hopper_gpu()
90
+ except ModuleNotFoundError:
91
+ _FLASH_ATTN_3_AVAILABLE = False
92
+ _flash_attn_func = None
93
+ _flash_attn_interface = None
94
+
95
+ return _FLASH_ATTN_3_AVAILABLE
96
+
97
+
98
+ def _init_flash_attention_2():
99
+ """Lazy initialization for Flash Attention 2."""
100
+ global _FLASH_ATTN_2_AVAILABLE, _flash_attn
101
+
102
+ if _FLASH_ATTN_2_AVAILABLE is not None:
103
+ return _FLASH_ATTN_2_AVAILABLE
104
+
105
+ _FLASH_ATTN_2_AVAILABLE = False
106
+ try:
107
+ import flash_attn
108
+
109
+ _flash_attn = flash_attn
110
+ _FLASH_ATTN_2_AVAILABLE = True
111
+ except ModuleNotFoundError:
112
+ _FLASH_ATTN_2_AVAILABLE = False
113
+
114
+ return _FLASH_ATTN_2_AVAILABLE
115
+
116
+ __all__ = ["flash_attention", "attention"]
117
+
118
+
119
+ # Compatibility getters for external code
120
+ def sageattn_func():
121
+ """Getter for sageattn_func - initializes if needed."""
122
+ _init_sageattention()
123
+ return _sageattn_func
124
+
125
+
126
+ def SAGEATTN_AVAILABLE():
127
+ """Getter for SAGEATTN_AVAILABLE - initializes if needed."""
128
+ return _init_sageattention()
129
+
130
+
131
+ def flash_attention(
132
+ q,
133
+ k,
134
+ v,
135
+ q_lens=None,
136
+ k_lens=None,
137
+ dropout_p=0.0,
138
+ softmax_scale=None,
139
+ q_scale=None,
140
+ causal=False,
141
+ window_size=(-1, -1),
142
+ deterministic=False,
143
+ dtype=torch.bfloat16,
144
+ version=None,
145
+ ):
146
+ """
147
+ q: [B, Lq, Nq, C1].
148
+ k: [B, Lk, Nk, C1].
149
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
150
+ q_lens: [B].
151
+ k_lens: [B].
152
+ dropout_p: float. Dropout probability.
153
+ softmax_scale: float. The scaling of QK^T before applying softmax.
154
+ causal: bool. Whether to apply causal attention mask.
155
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
156
+ deterministic: bool. If True, slightly slower and uses more memory.
157
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
158
+ """
159
+ # Initialize flash attention modules
160
+ flash_attn_3_available = _init_flash_attention_3()
161
+ flash_attn_2_available = _init_flash_attention_2()
162
+
163
+ # Early fallback for simple cases when advanced features aren't needed
164
+ # Only use this path if flash_attn is available but we're not using FA3 features
165
+ if not flash_attn_3_available and _flash_attn_func is not None and q_lens is None and k_lens is None:
166
+ return _flash_attn_func(
167
+ q,
168
+ k,
169
+ v,
170
+ )
171
+
172
+ half_dtypes = (torch.float16, torch.bfloat16)
173
+ assert dtype in half_dtypes
174
+ assert q.device.type == "cuda" and q.size(-1) <= 256
175
+
176
+ # params
177
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
178
+
179
+ def half(x):
180
+ return x if x.dtype in half_dtypes else x.to(dtype)
181
+
182
+ # preprocess query
183
+ if q_lens is None:
184
+ q = half(q.flatten(0, 1))
185
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(
186
+ device=q.device, non_blocking=True
187
+ )
188
+ else:
189
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
190
+
191
+ # preprocess key, value
192
+ if k_lens is None:
193
+ k = half(k.flatten(0, 1))
194
+ v = half(v.flatten(0, 1))
195
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(
196
+ device=k.device, non_blocking=True
197
+ )
198
+ else:
199
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
200
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
201
+
202
+ q = q.to(v.dtype)
203
+ k = k.to(v.dtype)
204
+
205
+ if q_scale is not None:
206
+ q = q * q_scale
207
+
208
+ if version is not None and version == 3 and not flash_attn_3_available:
209
+ warnings.warn(
210
+ "Flash attention 3 is not available, use flash attention 2 instead."
211
+ )
212
+
213
+ # apply attention
214
+ if (version is None or version == 3) and flash_attn_3_available:
215
+ # Note: dropout_p, window_size are not supported in FA3 now.
216
+ x = _flash_attn_interface.flash_attn_varlen_func(
217
+ q=q,
218
+ k=k,
219
+ v=v,
220
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
221
+ .cumsum(0, dtype=torch.int32)
222
+ .to(q.device, non_blocking=True),
223
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
224
+ .cumsum(0, dtype=torch.int32)
225
+ .to(q.device, non_blocking=True),
226
+ max_seqlen_q=lq,
227
+ max_seqlen_k=lk,
228
+ softmax_scale=softmax_scale,
229
+ causal=causal,
230
+ deterministic=deterministic,
231
+ ).unflatten(0, (b, lq))
232
+ else:
233
+ assert flash_attn_2_available
234
+ x = _flash_attn.flash_attn_varlen_func(
235
+ q=q,
236
+ k=k,
237
+ v=v,
238
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
239
+ .cumsum(0, dtype=torch.int32)
240
+ .to(q.device, non_blocking=True),
241
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
242
+ .cumsum(0, dtype=torch.int32)
243
+ .to(q.device, non_blocking=True),
244
+ max_seqlen_q=lq,
245
+ max_seqlen_k=lk,
246
+ dropout_p=dropout_p,
247
+ softmax_scale=softmax_scale,
248
+ causal=causal,
249
+ window_size=window_size,
250
+ deterministic=deterministic,
251
+ ).unflatten(0, (b, lq))
252
+
253
+ # output
254
+ return x.type(out_dtype)
255
+
256
+
257
+ def attention(
258
+ q: torch.Tensor,
259
+ k: torch.Tensor,
260
+ v: torch.Tensor,
261
+ q_lens=None,
262
+ k_lens=None,
263
+ dropout_p=0.0,
264
+ softmax_scale=None,
265
+ q_scale=None,
266
+ causal=False,
267
+ window_size=(-1, -1),
268
+ deterministic=False,
269
+ dtype=torch.bfloat16,
270
+ fa_version=None,
271
+ # og_dtype=torch.bfloat16,
272
+ ):
273
+ # Initialize attention modules
274
+ sageattn_available = _init_sageattention()
275
+ flash_attn_2_available = _init_flash_attention_2()
276
+ flash_attn_3_available = _init_flash_attention_3()
277
+
278
+ if sageattn_available:
279
+ # print("Using sageattention")
280
+ attn_mask = None
281
+
282
+ og_dtype = q.dtype
283
+ q = q.transpose(1, 2).to(dtype)
284
+ k = k.transpose(1, 2).to(dtype)
285
+ v = v.transpose(1, 2).to(dtype)
286
+
287
+ out = _sageattn_func(
288
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
289
+ )
290
+
291
+ out = out.transpose(1, 2).contiguous().to(og_dtype)
292
+ return out
293
+
294
+ elif flash_attn_2_available or flash_attn_3_available:
295
+ return flash_attention(
296
+ q=q,
297
+ k=k,
298
+ v=v,
299
+ q_lens=q_lens,
300
+ k_lens=k_lens,
301
+ dropout_p=dropout_p,
302
+ softmax_scale=softmax_scale,
303
+ q_scale=q_scale,
304
+ causal=causal,
305
+ window_size=window_size,
306
+ deterministic=deterministic,
307
+ dtype=dtype,
308
+ version=fa_version,
309
+ )
310
+ else:
311
+ if q_lens is not None or k_lens is not None:
312
+ warnings.warn(
313
+ "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
314
+ )
315
+ attn_mask = None
316
+
317
+ q = q.transpose(1, 2).to(dtype)
318
+ k = k.transpose(1, 2).to(dtype)
319
+ v = v.transpose(1, 2).to(dtype)
320
+
321
+ out = torch.nn.functional.scaled_dot_product_attention(
322
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
323
+ )
324
+
325
+ out = out.transpose(1, 2).contiguous()
326
+ return out
transformer/causal_model.py ADDED
@@ -0,0 +1,1402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+
4
+ from .attention import attention
5
+ from .model import (
6
+ WanRMSNorm,
7
+ rope_apply,
8
+ WanLayerNorm,
9
+ WAN_CROSSATTENTION_CLASSES,
10
+ rope_params,
11
+ MLPProj,
12
+ sinusoidal_embedding_1d,
13
+ )
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
17
+ from torch.nn.attention.flex_attention import BlockMask
18
+
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.models.modeling_utils import ModelMixin
21
+
22
+ flex_attention = torch.compile(
23
+ flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
24
+ )
25
+
26
+
27
+ def rope_params_riflex(max_seq_len, dim, theta=10000, k=0, L_test=None):
28
+ assert dim % 2 == 0
29
+ omega = 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))
30
+ if k is not None:
31
+ print("Doing riflex w/ ltest", L_test)
32
+ omega[k - 1] = 0.9 * 2 * torch.pi / L_test
33
+ freqs = torch.outer(torch.arange(max_seq_len), omega)
34
+
35
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
36
+ return freqs
37
+
38
+
39
+ @functools.lru_cache(maxsize=32)
40
+ def get_sdpa_mask(
41
+ device: str,
42
+ num_frames: int = 21,
43
+ frame_seqlen: int = 1560,
44
+ num_frame_per_block: int = 1,
45
+ local_attn_size: int = -1,
46
+ dtype: torch.dtype = torch.bool,
47
+ ):
48
+ """
49
+ Create an attention mask tensor for torch.nn.functional.scaled_dot_product_attention
50
+
51
+ Args:
52
+ device: Device to create the mask on
53
+ num_frames: Number of frames
54
+ frame_seqlen: Sequence length per frame
55
+ num_frame_per_block: Number of frames per block
56
+ local_attn_size: Local attention window size (-1 for global)
57
+ dtype: Data type for the mask (torch.bool for masking, torch.float for additive)
58
+
59
+ Returns:
60
+ torch.Tensor: Attention mask of shape (seq_len, seq_len)
61
+ - True/1.0 for allowed attention
62
+ - False/-inf for masked attention
63
+ """
64
+ print("Generating SDPA attention mask")
65
+ total_length = num_frames * frame_seqlen
66
+
67
+ # Right padding to get to a multiple of 128
68
+ padded_length = math.ceil(total_length / 128) * 128 - total_length
69
+ full_length = total_length + padded_length
70
+
71
+ # Create the ends array (same logic as original)
72
+ ends = torch.zeros(full_length, device=device, dtype=torch.long)
73
+
74
+ frame_indices = torch.arange(
75
+ start=0,
76
+ end=total_length,
77
+ step=frame_seqlen * num_frame_per_block,
78
+ device=device,
79
+ )
80
+
81
+ for tmp in frame_indices:
82
+ end_idx = min(tmp + frame_seqlen * num_frame_per_block, full_length)
83
+ ends[tmp:end_idx] = end_idx
84
+
85
+ # Create q_idx and kv_idx coordinate matrices
86
+ q_indices = torch.arange(full_length, device=device).unsqueeze(
87
+ 1
88
+ ) # Shape: (seq_len, 1)
89
+ kv_indices = torch.arange(full_length, device=device).unsqueeze(
90
+ 0
91
+ ) # Shape: (1, seq_len)
92
+
93
+ # Apply the attention logic
94
+ if local_attn_size == -1:
95
+ # Global attention within blocks + diagonal
96
+ mask = (kv_indices < ends[q_indices]) | (q_indices == kv_indices)
97
+ else:
98
+ # Local attention within blocks + diagonal
99
+ local_window_start = ends[q_indices] - local_attn_size * frame_seqlen
100
+ mask = ((kv_indices < ends[q_indices]) & (kv_indices >= local_window_start)) | (
101
+ q_indices == kv_indices
102
+ )
103
+
104
+ if dtype == torch.bool:
105
+ return mask
106
+ elif dtype == torch.float32 or dtype == torch.float16:
107
+ # Convert to additive mask (0.0 for attend, -inf for mask)
108
+ return mask.float() * 0.0 + (~mask).float() * float("-inf")
109
+ else:
110
+ raise ValueError(f"Unsupported dtype: {dtype}")
111
+
112
+
113
+ @functools.lru_cache(maxsize=32)
114
+ def get_block_mask(
115
+ device: str,
116
+ num_frames: int = 21,
117
+ frame_seqlen: int = 1560,
118
+ num_frame_per_block=3,
119
+ local_attn_size=-1,
120
+ ):
121
+ print("Generating block mask")
122
+ total_length = num_frames * frame_seqlen
123
+
124
+ # we do right padding to get to a multiple of 128
125
+ padded_length = math.ceil(total_length / 128) * 128 - total_length
126
+
127
+ ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
128
+
129
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
130
+ frame_indices = torch.arange(
131
+ start=0,
132
+ end=total_length,
133
+ step=frame_seqlen * num_frame_per_block,
134
+ device=device,
135
+ )
136
+
137
+ for tmp in frame_indices:
138
+ ends[tmp : tmp + frame_seqlen * num_frame_per_block] = (
139
+ tmp + frame_seqlen * num_frame_per_block
140
+ )
141
+
142
+ def attention_mask(b, h, q_idx, kv_idx):
143
+ if local_attn_size == -1:
144
+ return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
145
+ else:
146
+ return (
147
+ (kv_idx < ends[q_idx])
148
+ & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen))
149
+ ) | (q_idx == kv_idx)
150
+
151
+ block_mask = create_block_mask(
152
+ attention_mask,
153
+ B=None,
154
+ H=None,
155
+ Q_LEN=total_length + padded_length,
156
+ KV_LEN=total_length + padded_length,
157
+ _compile=False,
158
+ device=device,
159
+ )
160
+ return block_mask
161
+
162
+
163
+ def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
164
+ n, c = x.size(2), x.size(3) // 2
165
+
166
+ # split freqs
167
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
168
+
169
+ # loop over samples
170
+ output = []
171
+
172
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
173
+ seq_len = f * h * w
174
+
175
+ # precompute multipliers
176
+ x_i = torch.view_as_complex(
177
+ x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
178
+ )
179
+ freqs_i = torch.cat(
180
+ [
181
+ freqs[0][start_frame : start_frame + f]
182
+ .view(f, 1, 1, -1)
183
+ .expand(f, h, w, -1),
184
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
185
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
186
+ ],
187
+ dim=-1,
188
+ ).reshape(seq_len, 1, -1)
189
+
190
+ # apply rotary embedding
191
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
192
+ x_i = torch.cat([x_i, x[i, seq_len:]])
193
+
194
+ # append to collection
195
+ output.append(x_i)
196
+ return torch.stack(output).type_as(x)
197
+
198
+
199
+ class CausalWanSelfAttention(nn.Module):
200
+ def __init__(
201
+ self, dim, num_heads, local_attn_size=-1, sink_size=0, qk_norm=True, eps=1e-6
202
+ ):
203
+ assert dim % num_heads == 0
204
+ super().__init__()
205
+ self.dim = dim
206
+ self.num_heads = num_heads
207
+ self.head_dim = dim // num_heads
208
+ self.local_attn_size = local_attn_size
209
+ self.sink_size = sink_size
210
+ self.qk_norm = qk_norm
211
+ self.eps = eps
212
+ self.max_attention_size = (
213
+ 32760 if local_attn_size == -1 else local_attn_size * 1560
214
+ )
215
+ self.fused_projections = False
216
+
217
+ # layers
218
+ self.q = nn.Linear(dim, dim)
219
+ self.k = nn.Linear(dim, dim)
220
+ self.v = nn.Linear(dim, dim)
221
+ self.o = nn.Linear(dim, dim)
222
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
223
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
224
+
225
+ @torch.no_grad()
226
+ def fuse_projections(self):
227
+ # if not self.is_cross_attention:
228
+ if self.fused_projections:
229
+ return
230
+ concatenated_weights = torch.cat(
231
+ [self.q.weight.data, self.k.weight.data, self.v.weight.data]
232
+ )
233
+ concatenated_bias = torch.cat(
234
+ [self.q.bias.data, self.k.bias.data, self.v.bias.data]
235
+ )
236
+ out_features, in_features = concatenated_weights.shape
237
+ with torch.device("meta"):
238
+ self.to_qkv = torch.nn.Linear(in_features, out_features, bias=True)
239
+ self.to_qkv.load_state_dict(
240
+ {"weight": concatenated_weights, "bias": concatenated_bias},
241
+ strict=True,
242
+ assign=True,
243
+ )
244
+ self.fused_projections = True
245
+
246
+ def forward(
247
+ self,
248
+ x,
249
+ seq_lens,
250
+ grid_sizes,
251
+ freqs,
252
+ block_mask,
253
+ kv_cache=None,
254
+ current_start=0,
255
+ cache_start=None,
256
+ ):
257
+ r"""
258
+ Args:
259
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
260
+ seq_lens(Tensor): Shape [B]
261
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
262
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
263
+ block_mask (BlockMask)
264
+ """
265
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
266
+ if cache_start is None:
267
+ cache_start = current_start
268
+
269
+ # query, key, value function
270
+ # @torch.compile(dynamic=True, mode="max-autotune-no-cudagraphs")
271
+ def qkv_fn(x):
272
+ if self.fused_projections:
273
+ # print("Using fused projections")
274
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
275
+ q = self.norm_q(q).view(b, s, n, d)
276
+ k = self.norm_k(k).view(b, s, n, d)
277
+ v = v.view(b, s, n, d)
278
+ else:
279
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
280
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
281
+ v = self.v(x).view(b, s, n, d)
282
+ return q, k, v
283
+
284
+ q, k, v = qkv_fn(x)
285
+
286
+ if kv_cache is None or block_mask is not None:
287
+ # if it is teacher forcing training?
288
+ # is_tf = (s == seq_lens[0].item() * 2)
289
+ is_tf = False
290
+ if is_tf:
291
+ print("Teacher forcing training")
292
+ q_chunk = torch.chunk(q, 2, dim=1)
293
+ k_chunk = torch.chunk(k, 2, dim=1)
294
+ roped_query = []
295
+ roped_key = []
296
+ # rope should be same for clean and noisy parts
297
+ for ii in range(2):
298
+ rq = rope_apply(q_chunk[ii], grid_sizes, freqs).type_as(v)
299
+ rk = rope_apply(k_chunk[ii], grid_sizes, freqs).type_as(v)
300
+ roped_query.append(rq)
301
+ roped_key.append(rk)
302
+
303
+ roped_query = torch.cat(roped_query, dim=1)
304
+ roped_key = torch.cat(roped_key, dim=1)
305
+
306
+ padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1]
307
+ padded_roped_query = torch.cat(
308
+ [
309
+ roped_query,
310
+ torch.zeros(
311
+ [q.shape[0], padded_length, q.shape[2], q.shape[3]],
312
+ device=q.device,
313
+ dtype=v.dtype,
314
+ ),
315
+ ],
316
+ dim=1,
317
+ )
318
+
319
+ padded_roped_key = torch.cat(
320
+ [
321
+ roped_key,
322
+ torch.zeros(
323
+ [k.shape[0], padded_length, k.shape[2], k.shape[3]],
324
+ device=k.device,
325
+ dtype=v.dtype,
326
+ ),
327
+ ],
328
+ dim=1,
329
+ )
330
+
331
+ padded_v = torch.cat(
332
+ [
333
+ v,
334
+ torch.zeros(
335
+ [v.shape[0], padded_length, v.shape[2], v.shape[3]],
336
+ device=v.device,
337
+ dtype=v.dtype,
338
+ ),
339
+ ],
340
+ dim=1,
341
+ )
342
+
343
+ x = flex_attention(
344
+ query=padded_roped_query.transpose(2, 1),
345
+ key=padded_roped_key.transpose(2, 1),
346
+ value=padded_v.transpose(2, 1),
347
+ block_mask=block_mask,
348
+ )[:, :, :-padded_length].transpose(2, 1)
349
+
350
+ else:
351
+ roped_query = rope_apply(q, grid_sizes, freqs).type_as(v)
352
+ roped_key = rope_apply(k, grid_sizes, freqs).type_as(v)
353
+ local_end_index = roped_key.shape[1]
354
+ kv_cache["k"][:, :local_end_index] = roped_key
355
+ kv_cache["v"][:, :local_end_index] = v
356
+
357
+ kv_cache["global_end_index"] = local_end_index
358
+ kv_cache["local_end_index"] = local_end_index
359
+
360
+ padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1]
361
+ padded_roped_query = torch.cat(
362
+ [
363
+ roped_query,
364
+ torch.zeros(
365
+ [q.shape[0], padded_length, q.shape[2], q.shape[3]],
366
+ device=q.device,
367
+ dtype=v.dtype,
368
+ ),
369
+ ],
370
+ dim=1,
371
+ )
372
+
373
+ padded_roped_key = torch.cat(
374
+ [
375
+ roped_key,
376
+ torch.zeros(
377
+ [k.shape[0], padded_length, k.shape[2], k.shape[3]],
378
+ device=k.device,
379
+ dtype=v.dtype,
380
+ ),
381
+ ],
382
+ dim=1,
383
+ )
384
+ # print("shape of padded_roped_query", padded_roped_query.shape)
385
+ # print("shape of padded_roped_key", padded_roped_key.shape)
386
+
387
+ padded_v = torch.cat(
388
+ [
389
+ v,
390
+ torch.zeros(
391
+ [v.shape[0], padded_length, v.shape[2], v.shape[3]],
392
+ device=v.device,
393
+ dtype=v.dtype,
394
+ ),
395
+ ],
396
+ dim=1,
397
+ )
398
+
399
+ x = flex_attention(
400
+ query=padded_roped_query.transpose(2, 1).contiguous(),
401
+ key=padded_roped_key.transpose(2, 1).contiguous(),
402
+ value=padded_v.transpose(2, 1).contiguous(),
403
+ block_mask=block_mask,
404
+ kernel_options={
405
+ "BLOCKS_ARE_CONTIGUOUS": True,
406
+ },
407
+ )[:, :, :-padded_length].transpose(2, 1)
408
+ else:
409
+ # frame_seqlen = math.prod(grid_sizes[0][1:]).item() # torch compile doesn't like this
410
+ frame_seqlen = 1560
411
+ current_start_frame = current_start // frame_seqlen
412
+ roped_query = causal_rope_apply(
413
+ q, grid_sizes, freqs, start_frame=current_start_frame
414
+ ).type_as(v)
415
+ roped_key = causal_rope_apply(
416
+ k, grid_sizes, freqs, start_frame=current_start_frame
417
+ ).type_as(v)
418
+
419
+ current_end = current_start + roped_query.shape[1]
420
+ sink_tokens = self.sink_size * frame_seqlen
421
+ # If we are using local attention and the current KV cache size is larger than the local attention size, we need to truncate the KV cache
422
+ kv_cache_size = kv_cache["k"].shape[1]
423
+ num_new_tokens = roped_query.shape[1]
424
+ if (
425
+ self.local_attn_size != -1
426
+ and (current_end > kv_cache["global_end_index"])
427
+ and (num_new_tokens + kv_cache["local_end_index"] > kv_cache_size)
428
+ ):
429
+ # Calculate the number of new tokens added in this step
430
+ # Shift existing cache content left to discard oldest tokens
431
+ # Clone the source slice to avoid overlapping memory error
432
+ num_evicted_tokens = (
433
+ num_new_tokens + kv_cache["local_end_index"] - kv_cache_size
434
+ )
435
+ num_rolled_tokens = (
436
+ kv_cache["local_end_index"] - num_evicted_tokens - sink_tokens
437
+ )
438
+ kv_cache["k"][:, sink_tokens : sink_tokens + num_rolled_tokens] = (
439
+ kv_cache["k"][
440
+ :,
441
+ sink_tokens + num_evicted_tokens : sink_tokens
442
+ + num_evicted_tokens
443
+ + num_rolled_tokens,
444
+ ].clone()
445
+ )
446
+ kv_cache["v"][:, sink_tokens : sink_tokens + num_rolled_tokens] = (
447
+ kv_cache["v"][
448
+ :,
449
+ sink_tokens + num_evicted_tokens : sink_tokens
450
+ + num_evicted_tokens
451
+ + num_rolled_tokens,
452
+ ].clone()
453
+ )
454
+ # Insert the new keys/values at the end
455
+ local_end_index = (
456
+ kv_cache["local_end_index"]
457
+ + current_end
458
+ - kv_cache["global_end_index"]
459
+ - num_evicted_tokens
460
+ )
461
+ local_start_index = local_end_index - num_new_tokens
462
+ kv_cache["k"][:, local_start_index:local_end_index] = roped_key
463
+ kv_cache["v"][:, local_start_index:local_end_index] = v
464
+ else:
465
+ # Assign new keys/values directly up to current_end
466
+ local_end_index = (
467
+ kv_cache["local_end_index"]
468
+ + current_end
469
+ - kv_cache["global_end_index"]
470
+ )
471
+ local_start_index = local_end_index - num_new_tokens
472
+ kv_cache["k"][:, local_start_index:local_end_index] = roped_key
473
+ kv_cache["v"][:, local_start_index:local_end_index] = v
474
+
475
+ x = attention(
476
+ roped_query,
477
+ kv_cache["k"][
478
+ :,
479
+ max(0, local_end_index - self.max_attention_size) : local_end_index,
480
+ ],
481
+ kv_cache["v"][
482
+ :,
483
+ max(0, local_end_index - self.max_attention_size) : local_end_index,
484
+ ],
485
+ )
486
+ kv_cache["global_end_index"] = current_end
487
+ kv_cache["local_end_index"] = local_end_index
488
+
489
+ # output
490
+ x = x.flatten(2)
491
+ x = self.o(x)
492
+ return x
493
+
494
+
495
+ class CausalWanAttentionBlock(nn.Module):
496
+ def __init__(
497
+ self,
498
+ cross_attn_type,
499
+ dim,
500
+ ffn_dim,
501
+ num_heads,
502
+ local_attn_size=-1,
503
+ sink_size=0,
504
+ qk_norm=True,
505
+ cross_attn_norm=False,
506
+ eps=1e-6,
507
+ ):
508
+ super().__init__()
509
+ self.dim = dim
510
+ self.ffn_dim = ffn_dim
511
+ self.num_heads = num_heads
512
+ self.local_attn_size = local_attn_size
513
+ self.qk_norm = qk_norm
514
+ self.cross_attn_norm = cross_attn_norm
515
+ self.eps = eps
516
+
517
+ # layers
518
+ self.norm1 = WanLayerNorm(dim, eps)
519
+ self.self_attn = CausalWanSelfAttention(
520
+ dim, num_heads, local_attn_size, sink_size, qk_norm, eps
521
+ )
522
+ self.norm3 = (
523
+ WanLayerNorm(dim, eps, elementwise_affine=True)
524
+ if cross_attn_norm
525
+ else nn.Identity()
526
+ )
527
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](
528
+ dim, num_heads, (-1, -1), qk_norm, eps
529
+ )
530
+ self.norm2 = WanLayerNorm(dim, eps)
531
+ self.ffn = nn.Sequential(
532
+ nn.Linear(dim, ffn_dim),
533
+ nn.GELU(approximate="tanh"),
534
+ nn.Linear(ffn_dim, dim),
535
+ )
536
+
537
+ # modulation
538
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
539
+
540
+ def forward(
541
+ self,
542
+ x,
543
+ e,
544
+ seq_lens,
545
+ grid_sizes,
546
+ freqs,
547
+ context,
548
+ context_lens,
549
+ block_mask,
550
+ kv_cache=None,
551
+ crossattn_cache=None,
552
+ current_start=0,
553
+ cache_start=None,
554
+ ):
555
+ r"""
556
+ Args:
557
+ x(Tensor): Shape [B, L, C]
558
+ e(Tensor): Shape [B, F, 6, C]
559
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
560
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
561
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
562
+ """
563
+ num_frames, frame_seqlen = e.shape[1], x.shape[1] // e.shape[1]
564
+ # assert e.dtype == torch.float32
565
+ # with amp.autocast(dtype=torch.float32):
566
+ e = (self.modulation.unsqueeze(1) + e).chunk(6, dim=2)
567
+ # assert e[0].dtype == torch.float32
568
+
569
+ # self-attention
570
+ y = self.self_attn(
571
+ (
572
+ self.norm1(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen))
573
+ * (1 + e[1])
574
+ + e[0]
575
+ ).flatten(1, 2),
576
+ seq_lens,
577
+ grid_sizes,
578
+ freqs,
579
+ block_mask,
580
+ kv_cache,
581
+ current_start,
582
+ cache_start,
583
+ )
584
+
585
+ # with amp.autocast(dtype=torch.float32):
586
+ x = x + (y.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * e[2]).flatten(
587
+ 1, 2
588
+ )
589
+
590
+ # cross-attention & ffn function
591
+ def cross_attn_ffn(x, context, context_lens, e, crossattn_cache=None):
592
+ x = x + self.cross_attn(
593
+ self.norm3(x), context, context_lens, crossattn_cache=crossattn_cache
594
+ )
595
+ y = self.ffn(
596
+ (
597
+ self.norm2(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen))
598
+ * (1 + e[4])
599
+ + e[3]
600
+ ).flatten(1, 2)
601
+ )
602
+ # with amp.autocast(dtype=torch.float32):
603
+ x = x + (
604
+ y.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * e[5]
605
+ ).flatten(1, 2)
606
+ return x
607
+
608
+ x = cross_attn_ffn(x, context, context_lens, e, crossattn_cache)
609
+ return x
610
+
611
+
612
+ class CausalHead(nn.Module):
613
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
614
+ super().__init__()
615
+ self.dim = dim
616
+ self.out_dim = out_dim
617
+ self.patch_size = patch_size
618
+ self.eps = eps
619
+
620
+ # layers
621
+ out_dim = math.prod(patch_size) * out_dim
622
+ self.norm = WanLayerNorm(dim, eps)
623
+ self.head = nn.Linear(dim, out_dim)
624
+
625
+ # modulation
626
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
627
+
628
+ def forward(self, x, e):
629
+ r"""
630
+ Args:
631
+ x(Tensor): Shape [B, L1, C]
632
+ e(Tensor): Shape [B, F, 1, C]
633
+ """
634
+ # assert e.dtype == torch.float32
635
+ # with amp.autocast(dtype=torch.float32):
636
+ num_frames, frame_seqlen = e.shape[1], x.shape[1] // e.shape[1]
637
+ e = (self.modulation.unsqueeze(1) + e).chunk(2, dim=2)
638
+ x = self.head(
639
+ self.norm(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * (1 + e[1])
640
+ + e[0]
641
+ )
642
+ return x
643
+
644
+
645
+ class CausalWanModel(ModelMixin, ConfigMixin):
646
+ r"""
647
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
648
+ """
649
+
650
+ ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim"]
651
+ _no_split_modules = ["WanAttentionBlock"]
652
+ _supports_gradient_checkpointing = True
653
+
654
+ @register_to_config
655
+ def __init__(
656
+ self,
657
+ model_type="t2v",
658
+ patch_size=(1, 2, 2),
659
+ text_len=512,
660
+ in_dim=16,
661
+ dim=2048,
662
+ ffn_dim=8192,
663
+ freq_dim=256,
664
+ text_dim=4096,
665
+ out_dim=16,
666
+ num_heads=16,
667
+ num_layers=32,
668
+ local_attn_size=-1,
669
+ sink_size=0,
670
+ qk_norm=True,
671
+ cross_attn_norm=True,
672
+ eps=1e-6,
673
+ ):
674
+ r"""
675
+ Initialize the diffusion model backbone.
676
+
677
+ Args:
678
+ model_type (`str`, *optional*, defaults to 't2v'):
679
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
680
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
681
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
682
+ text_len (`int`, *optional*, defaults to 512):
683
+ Fixed length for text embeddings
684
+ in_dim (`int`, *optional*, defaults to 16):
685
+ Input video channels (C_in)
686
+ dim (`int`, *optional*, defaults to 2048):
687
+ Hidden dimension of the transformer
688
+ ffn_dim (`int`, *optional*, defaults to 8192):
689
+ Intermediate dimension in feed-forward network
690
+ freq_dim (`int`, *optional*, defaults to 256):
691
+ Dimension for sinusoidal time embeddings
692
+ text_dim (`int`, *optional*, defaults to 4096):
693
+ Input dimension for text embeddings
694
+ out_dim (`int`, *optional*, defaults to 16):
695
+ Output video channels (C_out)
696
+ num_heads (`int`, *optional*, defaults to 16):
697
+ Number of attention heads
698
+ num_layers (`int`, *optional*, defaults to 32):
699
+ Number of transformer blocks
700
+ local_attn_size (`int`, *optional*, defaults to -1):
701
+ Window size for temporal local attention (-1 indicates global attention)
702
+ sink_size (`int`, *optional*, defaults to 0):
703
+ Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache
704
+ qk_norm (`bool`, *optional*, defaults to True):
705
+ Enable query/key normalization
706
+ cross_attn_norm (`bool`, *optional*, defaults to False):
707
+ Enable cross-attention normalization
708
+ eps (`float`, *optional*, defaults to 1e-6):
709
+ Epsilon value for normalization layers
710
+ """
711
+
712
+ super().__init__()
713
+
714
+ assert model_type in ["t2v", "i2v"]
715
+ self.model_type = model_type
716
+
717
+ self.patch_size = patch_size
718
+ self.text_len = text_len
719
+ self.in_dim = in_dim
720
+ self.dim = dim
721
+ self.ffn_dim = ffn_dim
722
+ self.freq_dim = freq_dim
723
+ self.text_dim = text_dim
724
+ self.out_dim = out_dim
725
+ self.num_heads = num_heads
726
+ self.num_layers = num_layers
727
+ self.local_attn_size = local_attn_size
728
+ self.qk_norm = qk_norm
729
+ self.cross_attn_norm = cross_attn_norm
730
+ self.eps = eps
731
+
732
+ # embeddings
733
+ self.patch_embedding = nn.Conv3d(
734
+ in_dim, dim, kernel_size=patch_size, stride=patch_size
735
+ )
736
+ self.text_embedding = nn.Sequential(
737
+ nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
738
+ )
739
+
740
+ self.time_embedding = nn.Sequential(
741
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
742
+ )
743
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
744
+
745
+ # blocks
746
+ cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn"
747
+ self.blocks = nn.ModuleList(
748
+ [
749
+ CausalWanAttentionBlock(
750
+ cross_attn_type,
751
+ dim,
752
+ ffn_dim,
753
+ num_heads,
754
+ local_attn_size,
755
+ sink_size,
756
+ qk_norm,
757
+ cross_attn_norm,
758
+ eps,
759
+ )
760
+ for _ in range(num_layers)
761
+ ]
762
+ )
763
+
764
+ # head
765
+ self.head = CausalHead(dim, out_dim, patch_size, eps)
766
+
767
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
768
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
769
+ d = dim // num_heads
770
+ self.freqs = torch.cat(
771
+ [
772
+ rope_params(1024, d - 4 * (d // 6)),
773
+ # rope_params_riflex(1024, d - 4 * (d // 6), ),
774
+ rope_params(1024, 2 * (d // 6)),
775
+ rope_params(1024, 2 * (d // 6)),
776
+ ],
777
+ dim=1,
778
+ )
779
+
780
+ if model_type == "i2v":
781
+ self.img_emb = MLPProj(1280, dim)
782
+
783
+ # initialize weights
784
+ self.init_weights()
785
+
786
+ self.gradient_checkpointing = False
787
+
788
+ self.block_mask = None
789
+
790
+ self.num_frame_per_block = 1
791
+ self.independent_first_frame = False
792
+
793
+ def _set_gradient_checkpointing(self, module, value=False):
794
+ self.gradient_checkpointing = value
795
+
796
+ @staticmethod
797
+ def _prepare_blockwise_causal_attn_mask(
798
+ device,
799
+ num_frames: int = 21,
800
+ frame_seqlen: int = 1560,
801
+ num_frame_per_block=1,
802
+ local_attn_size=-1,
803
+ ) -> BlockMask:
804
+ """
805
+ we will divide the token sequence into the following format
806
+ [1 latent frame] [1 latent frame] ... [1 latent frame]
807
+ We use flexattention to construct the attention mask
808
+ """
809
+ block_mask = get_block_mask(
810
+ str(device), num_frames, frame_seqlen, num_frame_per_block, local_attn_size
811
+ )
812
+ return block_mask
813
+
814
+ @staticmethod
815
+ def _prepare_teacher_forcing_mask(
816
+ device: torch.device | str,
817
+ num_frames: int = 21,
818
+ frame_seqlen: int = 1560,
819
+ num_frame_per_block=1,
820
+ ) -> BlockMask:
821
+ """
822
+ we will divide the token sequence into the following format
823
+ [1 latent frame] [1 latent frame] ... [1 latent frame]
824
+ We use flexattention to construct the attention mask
825
+ """
826
+ # debug
827
+ DEBUG = False
828
+ if DEBUG:
829
+ num_frames = 9
830
+ frame_seqlen = 256
831
+
832
+ total_length = num_frames * frame_seqlen * 2
833
+
834
+ # we do right padding to get to a multiple of 128
835
+ padded_length = math.ceil(total_length / 128) * 128 - total_length
836
+
837
+ clean_ends = num_frames * frame_seqlen
838
+ # for clean context frames, we can construct their flex attention mask based on a [start, end] interval
839
+ context_ends = torch.zeros(
840
+ total_length + padded_length, device=device, dtype=torch.long
841
+ )
842
+ # for noisy frames, we need two intervals to construct the flex attention mask [context_start, context_end] [noisy_start, noisy_end]
843
+ noise_context_starts = torch.zeros(
844
+ total_length + padded_length, device=device, dtype=torch.long
845
+ )
846
+ noise_context_ends = torch.zeros(
847
+ total_length + padded_length, device=device, dtype=torch.long
848
+ )
849
+ noise_noise_starts = torch.zeros(
850
+ total_length + padded_length, device=device, dtype=torch.long
851
+ )
852
+ noise_noise_ends = torch.zeros(
853
+ total_length + padded_length, device=device, dtype=torch.long
854
+ )
855
+
856
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
857
+ attention_block_size = frame_seqlen * num_frame_per_block
858
+ frame_indices = torch.arange(
859
+ start=0,
860
+ end=num_frames * frame_seqlen,
861
+ step=attention_block_size,
862
+ device=device,
863
+ dtype=torch.long,
864
+ )
865
+
866
+ # attention for clean context frames
867
+ for start in frame_indices:
868
+ context_ends[start : start + attention_block_size] = (
869
+ start + attention_block_size
870
+ )
871
+
872
+ noisy_image_start_list = torch.arange(
873
+ num_frames * frame_seqlen,
874
+ total_length,
875
+ step=attention_block_size,
876
+ device=device,
877
+ dtype=torch.long,
878
+ )
879
+ noisy_image_end_list = noisy_image_start_list + attention_block_size
880
+
881
+ # attention for noisy frames
882
+ for block_index, (start, end) in enumerate(
883
+ zip(noisy_image_start_list, noisy_image_end_list)
884
+ ):
885
+ # attend to noisy tokens within the same block
886
+ noise_noise_starts[start:end] = start
887
+ noise_noise_ends[start:end] = end
888
+ # attend to context tokens in previous blocks
889
+ # noise_context_starts[start:end] = 0
890
+ noise_context_ends[start:end] = block_index * attention_block_size
891
+
892
+ def attention_mask(b, h, q_idx, kv_idx):
893
+ # first design the mask for clean frames
894
+ clean_mask = (q_idx < clean_ends) & (kv_idx < context_ends[q_idx])
895
+ # then design the mask for noisy frames
896
+ # noisy frames will attend to all clean preceeding clean frames + itself
897
+ C1 = (kv_idx < noise_noise_ends[q_idx]) & (
898
+ kv_idx >= noise_noise_starts[q_idx]
899
+ )
900
+ C2 = (kv_idx < noise_context_ends[q_idx]) & (
901
+ kv_idx >= noise_context_starts[q_idx]
902
+ )
903
+ noise_mask = (q_idx >= clean_ends) & (C1 | C2)
904
+
905
+ eye_mask = q_idx == kv_idx
906
+ return eye_mask | clean_mask | noise_mask
907
+
908
+ block_mask = create_block_mask(
909
+ attention_mask,
910
+ B=None,
911
+ H=None,
912
+ Q_LEN=total_length + padded_length,
913
+ KV_LEN=total_length + padded_length,
914
+ _compile=False,
915
+ device=device,
916
+ )
917
+
918
+ if DEBUG:
919
+ print(block_mask)
920
+ import imageio
921
+ import numpy as np
922
+ from torch.nn.attention.flex_attention import create_mask
923
+
924
+ mask = create_mask(
925
+ attention_mask,
926
+ B=None,
927
+ H=None,
928
+ Q_LEN=total_length + padded_length,
929
+ KV_LEN=total_length + padded_length,
930
+ device=device,
931
+ )
932
+ import cv2
933
+
934
+ mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024))
935
+ imageio.imwrite("mask_%d.jpg" % (0), np.uint8(255.0 * mask))
936
+
937
+ return block_mask
938
+
939
+ @staticmethod
940
+ def _prepare_blockwise_causal_attn_mask_i2v(
941
+ device: torch.device | str,
942
+ num_frames: int = 21,
943
+ frame_seqlen: int = 1560,
944
+ num_frame_per_block=4,
945
+ local_attn_size=-1,
946
+ ) -> BlockMask:
947
+ """
948
+ we will divide the token sequence into the following format
949
+ [1 latent frame] [N latent frame] ... [N latent frame]
950
+ The first frame is separated out to support I2V generation
951
+ We use flexattention to construct the attention mask
952
+ """
953
+ total_length = num_frames * frame_seqlen
954
+
955
+ # we do right padding to get to a multiple of 128
956
+ padded_length = math.ceil(total_length / 128) * 128 - total_length
957
+
958
+ ends = torch.zeros(
959
+ total_length + padded_length, device=device, dtype=torch.long
960
+ )
961
+
962
+ # special handling for the first frame
963
+ ends[:frame_seqlen] = frame_seqlen
964
+
965
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
966
+ frame_indices = torch.arange(
967
+ start=frame_seqlen,
968
+ end=total_length,
969
+ step=frame_seqlen * num_frame_per_block,
970
+ device=device,
971
+ )
972
+
973
+ for idx, tmp in enumerate(frame_indices):
974
+ ends[tmp : tmp + frame_seqlen * num_frame_per_block] = (
975
+ tmp + frame_seqlen * num_frame_per_block
976
+ )
977
+
978
+ def attention_mask(b, h, q_idx, kv_idx):
979
+ if local_attn_size == -1:
980
+ return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
981
+ else:
982
+ return (
983
+ (kv_idx < ends[q_idx])
984
+ & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen))
985
+ ) | (q_idx == kv_idx)
986
+
987
+ block_mask = create_block_mask(
988
+ attention_mask,
989
+ B=None,
990
+ H=None,
991
+ Q_LEN=total_length + padded_length,
992
+ KV_LEN=total_length + padded_length,
993
+ _compile=False,
994
+ device=device,
995
+ )
996
+
997
+ # if not dist.is_initialized() or dist.get_rank() == 0:
998
+ # print(
999
+ # f" cache a block wise causal mask with block size of {num_frame_per_block} frames")
1000
+ # print(block_mask)
1001
+
1002
+ # import imageio
1003
+ # import numpy as np
1004
+ # from torch.nn.attention.flex_attention import create_mask
1005
+
1006
+ # mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length +
1007
+ # padded_length, KV_LEN=total_length + padded_length, device=device)
1008
+ # import cv2
1009
+ # mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024))
1010
+ # imageio.imwrite("mask_%d.jpg" % (0), np.uint8(255. * mask))
1011
+
1012
+ return block_mask
1013
+
1014
+ def _forward_inference(
1015
+ self,
1016
+ x,
1017
+ t,
1018
+ context,
1019
+ seq_len,
1020
+ clip_fea=None,
1021
+ y=None,
1022
+ kv_cache: dict = None,
1023
+ crossattn_cache: dict = None,
1024
+ current_start: int = 0,
1025
+ cache_start: int = 0,
1026
+ ):
1027
+ r"""
1028
+ Run the diffusion model with kv caching.
1029
+ See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details.
1030
+ This function will be run for num_frame times.
1031
+ Process the latent frames one by one (1560 tokens each)
1032
+
1033
+ Args:
1034
+ x (List[Tensor]):
1035
+ List of input video tensors, each with shape [C_in, F, H, W]
1036
+ t (Tensor):
1037
+ Diffusion timesteps tensor of shape [B]
1038
+ context (List[Tensor]):
1039
+ List of text embeddings each with shape [L, C]
1040
+ seq_len (`int`):
1041
+ Maximum sequence length for positional encoding
1042
+ clip_fea (Tensor, *optional*):
1043
+ CLIP image features for image-to-video mode
1044
+ y (List[Tensor], *optional*):
1045
+ Conditional video inputs for image-to-video mode, same shape as x
1046
+
1047
+ Returns:
1048
+ List[Tensor]:
1049
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
1050
+ """
1051
+ if self.model_type == "i2v":
1052
+ assert clip_fea is not None and y is not None
1053
+ # params
1054
+ device = self.patch_embedding.weight.device
1055
+ if self.freqs.device != device:
1056
+ self.freqs = self.freqs.to(device)
1057
+
1058
+ if y is not None:
1059
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
1060
+
1061
+ # embeddings
1062
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
1063
+ grid_sizes = torch.stack(
1064
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
1065
+ )
1066
+ x = [u.flatten(2).transpose(1, 2) for u in x]
1067
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
1068
+ assert seq_lens.max() <= seq_len
1069
+ x = torch.cat(x)
1070
+ """
1071
+ torch.cat([
1072
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
1073
+ dim=1) for u in x
1074
+ ])
1075
+ """
1076
+
1077
+ # time embeddings
1078
+ # with amp.autocast(dtype=torch.float32):
1079
+ e = self.time_embedding(
1080
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x)
1081
+ )
1082
+ e0 = (
1083
+ self.time_projection(e)
1084
+ .unflatten(1, (6, self.dim))
1085
+ .unflatten(dim=0, sizes=t.shape)
1086
+ )
1087
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
1088
+
1089
+ # context
1090
+ context_lens = None
1091
+ context = self.text_embedding(
1092
+ torch.stack(
1093
+ [
1094
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
1095
+ for u in context
1096
+ ]
1097
+ )
1098
+ )
1099
+
1100
+ if clip_fea is not None:
1101
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
1102
+ context = torch.concat([context_clip, context], dim=1)
1103
+
1104
+ # arguments
1105
+ kwargs = dict(
1106
+ e=e0,
1107
+ seq_lens=seq_lens,
1108
+ grid_sizes=grid_sizes,
1109
+ freqs=self.freqs,
1110
+ context=context,
1111
+ context_lens=context_lens,
1112
+ block_mask=self.block_mask,
1113
+ )
1114
+ # print("Block mask in forward : ", self.block_mask)
1115
+
1116
+ def create_custom_forward(module):
1117
+ def custom_forward(*inputs, **kwargs):
1118
+ return module(*inputs, **kwargs)
1119
+
1120
+ return custom_forward
1121
+
1122
+ for block_index, block in enumerate(self.blocks):
1123
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1124
+ kwargs.update(
1125
+ {
1126
+ "kv_cache": kv_cache[block_index],
1127
+ "current_start": current_start,
1128
+ "cache_start": cache_start,
1129
+ }
1130
+ )
1131
+ x = torch.utils.checkpoint.checkpoint(
1132
+ create_custom_forward(block),
1133
+ x,
1134
+ **kwargs,
1135
+ use_reentrant=False,
1136
+ )
1137
+ else:
1138
+ kwargs.update(
1139
+ {
1140
+ "kv_cache": kv_cache[block_index],
1141
+ "crossattn_cache": crossattn_cache[block_index],
1142
+ "current_start": current_start,
1143
+ "cache_start": cache_start,
1144
+ }
1145
+ )
1146
+ x = block(x, **kwargs)
1147
+
1148
+ # head
1149
+ x = self.head(x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2))
1150
+ # unpatchify
1151
+ x = self.unpatchify(x, grid_sizes)
1152
+ return torch.stack(x)
1153
+
1154
+ def _forward_train(
1155
+ self,
1156
+ x,
1157
+ t,
1158
+ context,
1159
+ seq_len,
1160
+ clean_x=None,
1161
+ aug_t=None,
1162
+ clip_fea=None,
1163
+ y=None,
1164
+ ):
1165
+ r"""
1166
+ Forward pass through the diffusion model
1167
+
1168
+ Args:
1169
+ x (List[Tensor]):
1170
+ List of input video tensors, each with shape [C_in, F, H, W]
1171
+ t (Tensor):
1172
+ Diffusion timesteps tensor of shape [B]
1173
+ context (List[Tensor]):
1174
+ List of text embeddings each with shape [L, C]
1175
+ seq_len (`int`):
1176
+ Maximum sequence length for positional encoding
1177
+ clip_fea (Tensor, *optional*):
1178
+ CLIP image features for image-to-video mode
1179
+ y (List[Tensor], *optional*):
1180
+ Conditional video inputs for image-to-video mode, same shape as x
1181
+
1182
+ Returns:
1183
+ List[Tensor]:
1184
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
1185
+ """
1186
+ if self.model_type == "i2v":
1187
+ assert clip_fea is not None and y is not None
1188
+ # params
1189
+ device = self.patch_embedding.weight.device
1190
+ if self.freqs.device != device:
1191
+ self.freqs = self.freqs.to(device)
1192
+
1193
+ # Construct blockwise causal attn mask
1194
+ if self.block_mask is None:
1195
+ if clean_x is not None:
1196
+ if self.independent_first_frame:
1197
+ raise NotImplementedError()
1198
+ else:
1199
+ self.block_mask = self._prepare_teacher_forcing_mask(
1200
+ device,
1201
+ num_frames=x.shape[2],
1202
+ frame_seqlen=x.shape[-2]
1203
+ * x.shape[-1]
1204
+ // (self.patch_size[1] * self.patch_size[2]),
1205
+ num_frame_per_block=self.num_frame_per_block,
1206
+ )
1207
+ else:
1208
+ if self.independent_first_frame:
1209
+ self.block_mask = self._prepare_blockwise_causal_attn_mask_i2v(
1210
+ device,
1211
+ num_frames=x.shape[2],
1212
+ frame_seqlen=x.shape[-2]
1213
+ * x.shape[-1]
1214
+ // (self.patch_size[1] * self.patch_size[2]),
1215
+ num_frame_per_block=self.num_frame_per_block,
1216
+ local_attn_size=self.local_attn_size,
1217
+ )
1218
+ else:
1219
+ self.block_mask = self._prepare_blockwise_causal_attn_mask(
1220
+ device,
1221
+ num_frames=x.shape[2],
1222
+ frame_seqlen=x.shape[-2]
1223
+ * x.shape[-1]
1224
+ // (self.patch_size[1] * self.patch_size[2]),
1225
+ num_frame_per_block=self.num_frame_per_block,
1226
+ local_attn_size=self.local_attn_size,
1227
+ )
1228
+
1229
+ if y is not None:
1230
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
1231
+
1232
+ # embeddings
1233
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
1234
+
1235
+ grid_sizes = torch.stack(
1236
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
1237
+ )
1238
+ x = [u.flatten(2).transpose(1, 2) for u in x]
1239
+
1240
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
1241
+ assert seq_lens.max() <= seq_len
1242
+ x = torch.cat(
1243
+ [
1244
+ torch.cat(
1245
+ [u, u.new_zeros(1, seq_lens[0] - u.size(1), u.size(2))], dim=1
1246
+ )
1247
+ for u in x
1248
+ ]
1249
+ )
1250
+
1251
+ # time embeddings
1252
+ # with amp.autocast(dtype=torch.float32):
1253
+ e = self.time_embedding(
1254
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x)
1255
+ )
1256
+ e0 = (
1257
+ self.time_projection(e)
1258
+ .unflatten(1, (6, self.dim))
1259
+ .unflatten(dim=0, sizes=t.shape)
1260
+ )
1261
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
1262
+
1263
+ # context
1264
+ context_lens = None
1265
+ context = self.text_embedding(
1266
+ torch.stack(
1267
+ [
1268
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
1269
+ for u in context
1270
+ ]
1271
+ )
1272
+ )
1273
+
1274
+ if clip_fea is not None:
1275
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
1276
+ context = torch.concat([context_clip, context], dim=1)
1277
+
1278
+ if clean_x is not None:
1279
+ clean_x = [self.patch_embedding(u.unsqueeze(0)) for u in clean_x]
1280
+ clean_x = [u.flatten(2).transpose(1, 2) for u in clean_x]
1281
+
1282
+ seq_lens_clean = torch.tensor(
1283
+ [u.size(1) for u in clean_x], dtype=torch.long
1284
+ )
1285
+ assert seq_lens_clean.max() <= seq_len
1286
+ clean_x = torch.cat(
1287
+ [
1288
+ torch.cat(
1289
+ [u, u.new_zeros(1, seq_lens_clean[0] - u.size(1), u.size(2))],
1290
+ dim=1,
1291
+ )
1292
+ for u in clean_x
1293
+ ]
1294
+ )
1295
+
1296
+ x = torch.cat([clean_x, x], dim=1)
1297
+ if aug_t is None:
1298
+ aug_t = torch.zeros_like(t)
1299
+ e_clean = self.time_embedding(
1300
+ sinusoidal_embedding_1d(self.freq_dim, aug_t.flatten()).type_as(x)
1301
+ )
1302
+ e0_clean = (
1303
+ self.time_projection(e_clean)
1304
+ .unflatten(1, (6, self.dim))
1305
+ .unflatten(dim=0, sizes=t.shape)
1306
+ )
1307
+ e0 = torch.cat([e0_clean, e0], dim=1)
1308
+
1309
+ # arguments
1310
+ kwargs = dict(
1311
+ e=e0,
1312
+ seq_lens=seq_lens,
1313
+ grid_sizes=grid_sizes,
1314
+ freqs=self.freqs,
1315
+ context=context,
1316
+ context_lens=context_lens,
1317
+ block_mask=self.block_mask,
1318
+ )
1319
+
1320
+ def create_custom_forward(module):
1321
+ def custom_forward(*inputs, **kwargs):
1322
+ return module(*inputs, **kwargs)
1323
+
1324
+ return custom_forward
1325
+
1326
+ for block in self.blocks:
1327
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1328
+ x = torch.utils.checkpoint.checkpoint(
1329
+ create_custom_forward(block),
1330
+ x,
1331
+ **kwargs,
1332
+ use_reentrant=False,
1333
+ )
1334
+ else:
1335
+ x = block(x, **kwargs)
1336
+
1337
+ if clean_x is not None:
1338
+ x = x[:, x.shape[1] // 2 :]
1339
+
1340
+ # head
1341
+ x = self.head(x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2))
1342
+
1343
+ # unpatchify
1344
+ x = self.unpatchify(x, grid_sizes)
1345
+ return torch.stack(x)
1346
+
1347
+ def forward(self, *args, **kwargs):
1348
+ result = self._forward_inference(*args, **kwargs)
1349
+ # if kwargs.get('kv_cache', None) is not None:
1350
+ # else:
1351
+ # result = self._forward_train(*args, **kwargs)
1352
+
1353
+ return result
1354
+
1355
+ def unpatchify(self, x, grid_sizes):
1356
+ r"""
1357
+ Reconstruct video tensors from patch embeddings.
1358
+
1359
+ Args:
1360
+ x (List[Tensor]):
1361
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
1362
+ grid_sizes (Tensor):
1363
+ Original spatial-temporal grid dimensions before patching,
1364
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
1365
+
1366
+ Returns:
1367
+ List[Tensor]:
1368
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
1369
+ """
1370
+
1371
+ c = self.out_dim
1372
+ out = []
1373
+ for u, v in zip(x, grid_sizes.tolist()):
1374
+ u = u[: math.prod(v)].view(*v, *self.patch_size, c)
1375
+ u = torch.einsum("fhwpqrc->cfphqwr", u)
1376
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
1377
+ out.append(u)
1378
+ return out
1379
+
1380
+ def init_weights(self):
1381
+ r"""
1382
+ Initialize model parameters using Xavier initialization.
1383
+ """
1384
+
1385
+ # basic init
1386
+ for m in self.modules():
1387
+ if isinstance(m, nn.Linear):
1388
+ nn.init.xavier_uniform_(m.weight)
1389
+ if m.bias is not None:
1390
+ nn.init.zeros_(m.bias)
1391
+
1392
+ # init embeddings
1393
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
1394
+ for m in self.text_embedding.modules():
1395
+ if isinstance(m, nn.Linear):
1396
+ nn.init.normal_(m.weight, std=0.02)
1397
+ for m in self.time_embedding.modules():
1398
+ if isinstance(m, nn.Linear):
1399
+ nn.init.normal_(m.weight, std=0.02)
1400
+
1401
+ # init output layer
1402
+ nn.init.zeros_(self.head.head.weight)
transformer/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "CausalWanModel",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "causal_model.CausalWanModel"
6
+ },
7
+ "dim": 5120,
8
+ "eps": 1e-06,
9
+ "ffn_dim": 13824,
10
+ "freq_dim": 256,
11
+ "in_dim": 16,
12
+ "local_attn_size": -1,
13
+ "model_type": "t2v",
14
+ "num_heads": 40,
15
+ "num_layers": 40,
16
+ "out_dim": 16,
17
+ "sink_size": 0,
18
+ "text_len": 512
19
+ }
transformer/diffusion_pytorch_model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c179cb7e91005fe6e009bfb42df4ed70316f03bbc35d33e303021b33b564791
3
+ size 9968228976
transformer/diffusion_pytorch_model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3e0f19e177dd8d83e244281ffdada88c150e080a9503dfd0f78f5acfe63563a
3
+ size 9891538864
transformer/diffusion_pytorch_model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15c15fb4b8b8a181ff23b142566389344fb34973dddc8c49b4ff0dee29db2735
3
+ size 8717326272
transformer/diffusion_pytorch_model.safetensors.index.json ADDED
@@ -0,0 +1,1102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 28576983168
4
+ },
5
+ "weight_map": {
6
+ "blocks.0.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
7
+ "blocks.0.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
8
+ "blocks.0.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
9
+ "blocks.0.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
10
+ "blocks.0.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
11
+ "blocks.0.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
12
+ "blocks.0.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
13
+ "blocks.0.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
14
+ "blocks.0.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
15
+ "blocks.0.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
16
+ "blocks.0.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
17
+ "blocks.0.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
18
+ "blocks.0.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
19
+ "blocks.0.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
20
+ "blocks.0.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
21
+ "blocks.0.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
22
+ "blocks.0.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
23
+ "blocks.0.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
24
+ "blocks.0.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
25
+ "blocks.0.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
26
+ "blocks.0.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
27
+ "blocks.0.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
28
+ "blocks.0.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
29
+ "blocks.0.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
30
+ "blocks.0.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
31
+ "blocks.0.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
32
+ "blocks.0.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
33
+ "blocks.1.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
34
+ "blocks.1.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
35
+ "blocks.1.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
36
+ "blocks.1.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
37
+ "blocks.1.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
38
+ "blocks.1.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
39
+ "blocks.1.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
40
+ "blocks.1.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
41
+ "blocks.1.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
42
+ "blocks.1.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
43
+ "blocks.1.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
44
+ "blocks.1.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
45
+ "blocks.1.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
46
+ "blocks.1.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
47
+ "blocks.1.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
48
+ "blocks.1.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
49
+ "blocks.1.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
50
+ "blocks.1.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
51
+ "blocks.1.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
52
+ "blocks.1.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
53
+ "blocks.1.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
54
+ "blocks.1.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
55
+ "blocks.1.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
56
+ "blocks.1.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
57
+ "blocks.1.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
58
+ "blocks.1.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
59
+ "blocks.1.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
60
+ "blocks.10.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
61
+ "blocks.10.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
62
+ "blocks.10.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
63
+ "blocks.10.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
64
+ "blocks.10.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
65
+ "blocks.10.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
66
+ "blocks.10.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
67
+ "blocks.10.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
68
+ "blocks.10.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
69
+ "blocks.10.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
70
+ "blocks.10.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
71
+ "blocks.10.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
72
+ "blocks.10.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
73
+ "blocks.10.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
74
+ "blocks.10.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
75
+ "blocks.10.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
76
+ "blocks.10.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
77
+ "blocks.10.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
78
+ "blocks.10.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
79
+ "blocks.10.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
80
+ "blocks.10.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
81
+ "blocks.10.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
82
+ "blocks.10.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
83
+ "blocks.10.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
84
+ "blocks.10.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
85
+ "blocks.10.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
86
+ "blocks.10.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
87
+ "blocks.11.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
88
+ "blocks.11.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
89
+ "blocks.11.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
90
+ "blocks.11.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
91
+ "blocks.11.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
92
+ "blocks.11.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
93
+ "blocks.11.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
94
+ "blocks.11.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
95
+ "blocks.11.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
96
+ "blocks.11.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
97
+ "blocks.11.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
98
+ "blocks.11.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
99
+ "blocks.11.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
100
+ "blocks.11.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
101
+ "blocks.11.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
102
+ "blocks.11.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
103
+ "blocks.11.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
104
+ "blocks.11.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
105
+ "blocks.11.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
106
+ "blocks.11.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
107
+ "blocks.11.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
108
+ "blocks.11.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
109
+ "blocks.11.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
110
+ "blocks.11.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
111
+ "blocks.11.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
112
+ "blocks.11.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
113
+ "blocks.11.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
114
+ "blocks.12.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
115
+ "blocks.12.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
116
+ "blocks.12.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
117
+ "blocks.12.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
118
+ "blocks.12.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
119
+ "blocks.12.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
120
+ "blocks.12.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
121
+ "blocks.12.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
122
+ "blocks.12.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
123
+ "blocks.12.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
124
+ "blocks.12.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
125
+ "blocks.12.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
126
+ "blocks.12.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
127
+ "blocks.12.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
128
+ "blocks.12.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
129
+ "blocks.12.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
130
+ "blocks.12.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
131
+ "blocks.12.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
132
+ "blocks.12.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
133
+ "blocks.12.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
134
+ "blocks.12.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
135
+ "blocks.12.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
136
+ "blocks.12.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
137
+ "blocks.12.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
138
+ "blocks.12.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
139
+ "blocks.12.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
140
+ "blocks.12.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
141
+ "blocks.13.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
142
+ "blocks.13.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
143
+ "blocks.13.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
144
+ "blocks.13.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
145
+ "blocks.13.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
146
+ "blocks.13.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
147
+ "blocks.13.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
148
+ "blocks.13.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
149
+ "blocks.13.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
150
+ "blocks.13.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
151
+ "blocks.13.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
152
+ "blocks.13.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
153
+ "blocks.13.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
154
+ "blocks.13.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
155
+ "blocks.13.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
156
+ "blocks.13.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
157
+ "blocks.13.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
158
+ "blocks.13.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
159
+ "blocks.13.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
160
+ "blocks.13.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
161
+ "blocks.13.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
162
+ "blocks.13.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
163
+ "blocks.13.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
164
+ "blocks.13.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
165
+ "blocks.13.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
166
+ "blocks.13.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
167
+ "blocks.13.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
168
+ "blocks.14.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
169
+ "blocks.14.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
170
+ "blocks.14.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
171
+ "blocks.14.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
172
+ "blocks.14.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
173
+ "blocks.14.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
174
+ "blocks.14.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
175
+ "blocks.14.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
176
+ "blocks.14.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
177
+ "blocks.14.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
178
+ "blocks.14.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
179
+ "blocks.14.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
180
+ "blocks.14.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
181
+ "blocks.14.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
182
+ "blocks.14.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
183
+ "blocks.14.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
184
+ "blocks.14.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
185
+ "blocks.14.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
186
+ "blocks.14.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
187
+ "blocks.14.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
188
+ "blocks.14.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
189
+ "blocks.14.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
190
+ "blocks.14.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
191
+ "blocks.14.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
192
+ "blocks.14.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
193
+ "blocks.14.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
194
+ "blocks.14.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
195
+ "blocks.15.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
196
+ "blocks.15.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
197
+ "blocks.15.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
198
+ "blocks.15.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
199
+ "blocks.15.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
200
+ "blocks.15.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
201
+ "blocks.15.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
202
+ "blocks.15.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
203
+ "blocks.15.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
204
+ "blocks.15.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
205
+ "blocks.15.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
206
+ "blocks.15.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
207
+ "blocks.15.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
208
+ "blocks.15.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
209
+ "blocks.15.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
210
+ "blocks.15.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
211
+ "blocks.15.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
212
+ "blocks.15.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
213
+ "blocks.15.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
214
+ "blocks.15.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
215
+ "blocks.15.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
216
+ "blocks.15.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
217
+ "blocks.15.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
218
+ "blocks.15.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
219
+ "blocks.15.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
220
+ "blocks.15.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
221
+ "blocks.15.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
222
+ "blocks.16.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
223
+ "blocks.16.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
224
+ "blocks.16.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
225
+ "blocks.16.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
226
+ "blocks.16.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
227
+ "blocks.16.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
228
+ "blocks.16.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
229
+ "blocks.16.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
230
+ "blocks.16.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
231
+ "blocks.16.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
232
+ "blocks.16.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
233
+ "blocks.16.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
234
+ "blocks.16.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
235
+ "blocks.16.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
236
+ "blocks.16.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
237
+ "blocks.16.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
238
+ "blocks.16.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
239
+ "blocks.16.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
240
+ "blocks.16.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
241
+ "blocks.16.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
242
+ "blocks.16.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
243
+ "blocks.16.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
244
+ "blocks.16.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
245
+ "blocks.16.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
246
+ "blocks.16.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
247
+ "blocks.16.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
248
+ "blocks.16.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
249
+ "blocks.17.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
250
+ "blocks.17.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
251
+ "blocks.17.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
252
+ "blocks.17.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
253
+ "blocks.17.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
254
+ "blocks.17.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
255
+ "blocks.17.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
256
+ "blocks.17.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
257
+ "blocks.17.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
258
+ "blocks.17.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
259
+ "blocks.17.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
260
+ "blocks.17.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
261
+ "blocks.17.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
262
+ "blocks.17.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
263
+ "blocks.17.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
264
+ "blocks.17.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
265
+ "blocks.17.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
266
+ "blocks.17.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
267
+ "blocks.17.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
268
+ "blocks.17.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
269
+ "blocks.17.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
270
+ "blocks.17.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
271
+ "blocks.17.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
272
+ "blocks.17.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
273
+ "blocks.17.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
274
+ "blocks.17.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
275
+ "blocks.17.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
276
+ "blocks.18.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
277
+ "blocks.18.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
278
+ "blocks.18.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
279
+ "blocks.18.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
280
+ "blocks.18.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
281
+ "blocks.18.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
282
+ "blocks.18.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
283
+ "blocks.18.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
284
+ "blocks.18.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
285
+ "blocks.18.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
286
+ "blocks.18.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
287
+ "blocks.18.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
288
+ "blocks.18.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
289
+ "blocks.18.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
290
+ "blocks.18.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
291
+ "blocks.18.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
292
+ "blocks.18.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
293
+ "blocks.18.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
294
+ "blocks.18.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
295
+ "blocks.18.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
296
+ "blocks.18.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
297
+ "blocks.18.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
298
+ "blocks.18.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
299
+ "blocks.18.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
300
+ "blocks.18.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
301
+ "blocks.18.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
302
+ "blocks.18.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
303
+ "blocks.19.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
304
+ "blocks.19.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
305
+ "blocks.19.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
306
+ "blocks.19.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
307
+ "blocks.19.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
308
+ "blocks.19.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
309
+ "blocks.19.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
310
+ "blocks.19.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
311
+ "blocks.19.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
312
+ "blocks.19.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
313
+ "blocks.19.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
314
+ "blocks.19.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
315
+ "blocks.19.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
316
+ "blocks.19.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
317
+ "blocks.19.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
318
+ "blocks.19.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
319
+ "blocks.19.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
320
+ "blocks.19.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
321
+ "blocks.19.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
322
+ "blocks.19.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
323
+ "blocks.19.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
324
+ "blocks.19.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
325
+ "blocks.19.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
326
+ "blocks.19.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
327
+ "blocks.19.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
328
+ "blocks.19.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
329
+ "blocks.19.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
330
+ "blocks.2.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
331
+ "blocks.2.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
332
+ "blocks.2.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
333
+ "blocks.2.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
334
+ "blocks.2.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
335
+ "blocks.2.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
336
+ "blocks.2.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
337
+ "blocks.2.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
338
+ "blocks.2.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
339
+ "blocks.2.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
340
+ "blocks.2.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
341
+ "blocks.2.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
342
+ "blocks.2.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
343
+ "blocks.2.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
344
+ "blocks.2.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
345
+ "blocks.2.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
346
+ "blocks.2.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
347
+ "blocks.2.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
348
+ "blocks.2.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
349
+ "blocks.2.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
350
+ "blocks.2.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
351
+ "blocks.2.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
352
+ "blocks.2.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
353
+ "blocks.2.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
354
+ "blocks.2.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
355
+ "blocks.2.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
356
+ "blocks.2.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
357
+ "blocks.20.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
358
+ "blocks.20.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
359
+ "blocks.20.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
360
+ "blocks.20.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
361
+ "blocks.20.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
362
+ "blocks.20.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
363
+ "blocks.20.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
364
+ "blocks.20.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
365
+ "blocks.20.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
366
+ "blocks.20.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
367
+ "blocks.20.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
368
+ "blocks.20.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
369
+ "blocks.20.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
370
+ "blocks.20.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
371
+ "blocks.20.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
372
+ "blocks.20.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
373
+ "blocks.20.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
374
+ "blocks.20.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
375
+ "blocks.20.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
376
+ "blocks.20.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
377
+ "blocks.20.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
378
+ "blocks.20.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
379
+ "blocks.20.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
380
+ "blocks.20.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
381
+ "blocks.20.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
382
+ "blocks.20.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
383
+ "blocks.20.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
384
+ "blocks.21.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
385
+ "blocks.21.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
386
+ "blocks.21.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
387
+ "blocks.21.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
388
+ "blocks.21.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
389
+ "blocks.21.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
390
+ "blocks.21.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
391
+ "blocks.21.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
392
+ "blocks.21.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
393
+ "blocks.21.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
394
+ "blocks.21.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
395
+ "blocks.21.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
396
+ "blocks.21.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
397
+ "blocks.21.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
398
+ "blocks.21.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
399
+ "blocks.21.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
400
+ "blocks.21.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
401
+ "blocks.21.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
402
+ "blocks.21.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
403
+ "blocks.21.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
404
+ "blocks.21.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
405
+ "blocks.21.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
406
+ "blocks.21.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
407
+ "blocks.21.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
408
+ "blocks.21.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
409
+ "blocks.21.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
410
+ "blocks.21.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
411
+ "blocks.22.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
412
+ "blocks.22.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
413
+ "blocks.22.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
414
+ "blocks.22.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
415
+ "blocks.22.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
416
+ "blocks.22.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
417
+ "blocks.22.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
418
+ "blocks.22.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
419
+ "blocks.22.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
420
+ "blocks.22.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
421
+ "blocks.22.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
422
+ "blocks.22.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
423
+ "blocks.22.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
424
+ "blocks.22.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
425
+ "blocks.22.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
426
+ "blocks.22.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
427
+ "blocks.22.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
428
+ "blocks.22.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
429
+ "blocks.22.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
430
+ "blocks.22.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
431
+ "blocks.22.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
432
+ "blocks.22.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
433
+ "blocks.22.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
434
+ "blocks.22.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
435
+ "blocks.22.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
436
+ "blocks.22.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
437
+ "blocks.22.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
438
+ "blocks.23.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
439
+ "blocks.23.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
440
+ "blocks.23.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
441
+ "blocks.23.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
442
+ "blocks.23.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
443
+ "blocks.23.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
444
+ "blocks.23.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
445
+ "blocks.23.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
446
+ "blocks.23.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
447
+ "blocks.23.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
448
+ "blocks.23.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
449
+ "blocks.23.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
450
+ "blocks.23.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
451
+ "blocks.23.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
452
+ "blocks.23.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
453
+ "blocks.23.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
454
+ "blocks.23.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
455
+ "blocks.23.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
456
+ "blocks.23.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
457
+ "blocks.23.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
458
+ "blocks.23.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
459
+ "blocks.23.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
460
+ "blocks.23.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
461
+ "blocks.23.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
462
+ "blocks.23.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
463
+ "blocks.23.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
464
+ "blocks.23.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
465
+ "blocks.24.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
466
+ "blocks.24.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
467
+ "blocks.24.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
468
+ "blocks.24.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
469
+ "blocks.24.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
470
+ "blocks.24.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
471
+ "blocks.24.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
472
+ "blocks.24.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
473
+ "blocks.24.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
474
+ "blocks.24.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
475
+ "blocks.24.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
476
+ "blocks.24.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
477
+ "blocks.24.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
478
+ "blocks.24.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
479
+ "blocks.24.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
480
+ "blocks.24.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
481
+ "blocks.24.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
482
+ "blocks.24.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
483
+ "blocks.24.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
484
+ "blocks.24.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
485
+ "blocks.24.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
486
+ "blocks.24.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
487
+ "blocks.24.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
488
+ "blocks.24.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
489
+ "blocks.24.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
490
+ "blocks.24.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
491
+ "blocks.24.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
492
+ "blocks.25.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
493
+ "blocks.25.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
494
+ "blocks.25.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
495
+ "blocks.25.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
496
+ "blocks.25.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
497
+ "blocks.25.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
498
+ "blocks.25.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
499
+ "blocks.25.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
500
+ "blocks.25.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
501
+ "blocks.25.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
502
+ "blocks.25.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
503
+ "blocks.25.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
504
+ "blocks.25.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
505
+ "blocks.25.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
506
+ "blocks.25.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
507
+ "blocks.25.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
508
+ "blocks.25.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
509
+ "blocks.25.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
510
+ "blocks.25.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
511
+ "blocks.25.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
512
+ "blocks.25.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
513
+ "blocks.25.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
514
+ "blocks.25.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
515
+ "blocks.25.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
516
+ "blocks.25.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
517
+ "blocks.25.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
518
+ "blocks.25.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
519
+ "blocks.26.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
520
+ "blocks.26.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
521
+ "blocks.26.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
522
+ "blocks.26.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
523
+ "blocks.26.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
524
+ "blocks.26.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
525
+ "blocks.26.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
526
+ "blocks.26.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
527
+ "blocks.26.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
528
+ "blocks.26.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
529
+ "blocks.26.ffn.0.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
530
+ "blocks.26.ffn.0.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
531
+ "blocks.26.ffn.2.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
532
+ "blocks.26.ffn.2.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
533
+ "blocks.26.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
534
+ "blocks.26.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
535
+ "blocks.26.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
536
+ "blocks.26.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
537
+ "blocks.26.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
538
+ "blocks.26.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
539
+ "blocks.26.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
540
+ "blocks.26.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
541
+ "blocks.26.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
542
+ "blocks.26.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
543
+ "blocks.26.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
544
+ "blocks.26.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
545
+ "blocks.26.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
546
+ "blocks.27.cross_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
547
+ "blocks.27.cross_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
548
+ "blocks.27.cross_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
549
+ "blocks.27.cross_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
550
+ "blocks.27.cross_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
551
+ "blocks.27.cross_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
552
+ "blocks.27.cross_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
553
+ "blocks.27.cross_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
554
+ "blocks.27.cross_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
555
+ "blocks.27.cross_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
556
+ "blocks.27.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
557
+ "blocks.27.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
558
+ "blocks.27.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
559
+ "blocks.27.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
560
+ "blocks.27.modulation": "diffusion_pytorch_model-00002-of-00003.safetensors",
561
+ "blocks.27.norm3.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
562
+ "blocks.27.norm3.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
563
+ "blocks.27.self_attn.k.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
564
+ "blocks.27.self_attn.k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
565
+ "blocks.27.self_attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
566
+ "blocks.27.self_attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
567
+ "blocks.27.self_attn.o.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
568
+ "blocks.27.self_attn.o.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
569
+ "blocks.27.self_attn.q.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
570
+ "blocks.27.self_attn.q.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
571
+ "blocks.27.self_attn.v.bias": "diffusion_pytorch_model-00002-of-00003.safetensors",
572
+ "blocks.27.self_attn.v.weight": "diffusion_pytorch_model-00002-of-00003.safetensors",
573
+ "blocks.28.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
574
+ "blocks.28.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
575
+ "blocks.28.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
576
+ "blocks.28.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
577
+ "blocks.28.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
578
+ "blocks.28.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
579
+ "blocks.28.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
580
+ "blocks.28.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
581
+ "blocks.28.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
582
+ "blocks.28.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
583
+ "blocks.28.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
584
+ "blocks.28.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
585
+ "blocks.28.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
586
+ "blocks.28.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
587
+ "blocks.28.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
588
+ "blocks.28.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
589
+ "blocks.28.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
590
+ "blocks.28.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
591
+ "blocks.28.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
592
+ "blocks.28.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
593
+ "blocks.28.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
594
+ "blocks.28.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
595
+ "blocks.28.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
596
+ "blocks.28.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
597
+ "blocks.28.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
598
+ "blocks.28.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
599
+ "blocks.28.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
600
+ "blocks.29.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
601
+ "blocks.29.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
602
+ "blocks.29.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
603
+ "blocks.29.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
604
+ "blocks.29.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
605
+ "blocks.29.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
606
+ "blocks.29.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
607
+ "blocks.29.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
608
+ "blocks.29.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
609
+ "blocks.29.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
610
+ "blocks.29.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
611
+ "blocks.29.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
612
+ "blocks.29.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
613
+ "blocks.29.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
614
+ "blocks.29.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
615
+ "blocks.29.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
616
+ "blocks.29.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
617
+ "blocks.29.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
618
+ "blocks.29.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
619
+ "blocks.29.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
620
+ "blocks.29.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
621
+ "blocks.29.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
622
+ "blocks.29.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
623
+ "blocks.29.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
624
+ "blocks.29.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
625
+ "blocks.29.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
626
+ "blocks.29.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
627
+ "blocks.3.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
628
+ "blocks.3.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
629
+ "blocks.3.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
630
+ "blocks.3.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
631
+ "blocks.3.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
632
+ "blocks.3.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
633
+ "blocks.3.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
634
+ "blocks.3.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
635
+ "blocks.3.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
636
+ "blocks.3.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
637
+ "blocks.3.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
638
+ "blocks.3.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
639
+ "blocks.3.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
640
+ "blocks.3.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
641
+ "blocks.3.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
642
+ "blocks.3.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
643
+ "blocks.3.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
644
+ "blocks.3.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
645
+ "blocks.3.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
646
+ "blocks.3.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
647
+ "blocks.3.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
648
+ "blocks.3.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
649
+ "blocks.3.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
650
+ "blocks.3.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
651
+ "blocks.3.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
652
+ "blocks.3.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
653
+ "blocks.3.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
654
+ "blocks.30.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
655
+ "blocks.30.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
656
+ "blocks.30.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
657
+ "blocks.30.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
658
+ "blocks.30.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
659
+ "blocks.30.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
660
+ "blocks.30.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
661
+ "blocks.30.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
662
+ "blocks.30.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
663
+ "blocks.30.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
664
+ "blocks.30.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
665
+ "blocks.30.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
666
+ "blocks.30.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
667
+ "blocks.30.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
668
+ "blocks.30.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
669
+ "blocks.30.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
670
+ "blocks.30.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
671
+ "blocks.30.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
672
+ "blocks.30.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
673
+ "blocks.30.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
674
+ "blocks.30.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
675
+ "blocks.30.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
676
+ "blocks.30.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
677
+ "blocks.30.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
678
+ "blocks.30.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
679
+ "blocks.30.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
680
+ "blocks.30.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
681
+ "blocks.31.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
682
+ "blocks.31.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
683
+ "blocks.31.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
684
+ "blocks.31.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
685
+ "blocks.31.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
686
+ "blocks.31.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
687
+ "blocks.31.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
688
+ "blocks.31.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
689
+ "blocks.31.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
690
+ "blocks.31.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
691
+ "blocks.31.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
692
+ "blocks.31.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
693
+ "blocks.31.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
694
+ "blocks.31.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
695
+ "blocks.31.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
696
+ "blocks.31.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
697
+ "blocks.31.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
698
+ "blocks.31.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
699
+ "blocks.31.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
700
+ "blocks.31.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
701
+ "blocks.31.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
702
+ "blocks.31.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
703
+ "blocks.31.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
704
+ "blocks.31.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
705
+ "blocks.31.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
706
+ "blocks.31.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
707
+ "blocks.31.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
708
+ "blocks.32.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
709
+ "blocks.32.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
710
+ "blocks.32.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
711
+ "blocks.32.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
712
+ "blocks.32.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
713
+ "blocks.32.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
714
+ "blocks.32.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
715
+ "blocks.32.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
716
+ "blocks.32.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
717
+ "blocks.32.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
718
+ "blocks.32.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
719
+ "blocks.32.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
720
+ "blocks.32.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
721
+ "blocks.32.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
722
+ "blocks.32.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
723
+ "blocks.32.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
724
+ "blocks.32.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
725
+ "blocks.32.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
726
+ "blocks.32.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
727
+ "blocks.32.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
728
+ "blocks.32.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
729
+ "blocks.32.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
730
+ "blocks.32.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
731
+ "blocks.32.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
732
+ "blocks.32.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
733
+ "blocks.32.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
734
+ "blocks.32.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
735
+ "blocks.33.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
736
+ "blocks.33.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
737
+ "blocks.33.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
738
+ "blocks.33.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
739
+ "blocks.33.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
740
+ "blocks.33.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
741
+ "blocks.33.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
742
+ "blocks.33.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
743
+ "blocks.33.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
744
+ "blocks.33.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
745
+ "blocks.33.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
746
+ "blocks.33.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
747
+ "blocks.33.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
748
+ "blocks.33.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
749
+ "blocks.33.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
750
+ "blocks.33.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
751
+ "blocks.33.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
752
+ "blocks.33.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
753
+ "blocks.33.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
754
+ "blocks.33.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
755
+ "blocks.33.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
756
+ "blocks.33.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
757
+ "blocks.33.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
758
+ "blocks.33.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
759
+ "blocks.33.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
760
+ "blocks.33.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
761
+ "blocks.33.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
762
+ "blocks.34.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
763
+ "blocks.34.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
764
+ "blocks.34.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
765
+ "blocks.34.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
766
+ "blocks.34.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
767
+ "blocks.34.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
768
+ "blocks.34.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
769
+ "blocks.34.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
770
+ "blocks.34.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
771
+ "blocks.34.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
772
+ "blocks.34.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
773
+ "blocks.34.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
774
+ "blocks.34.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
775
+ "blocks.34.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
776
+ "blocks.34.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
777
+ "blocks.34.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
778
+ "blocks.34.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
779
+ "blocks.34.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
780
+ "blocks.34.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
781
+ "blocks.34.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
782
+ "blocks.34.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
783
+ "blocks.34.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
784
+ "blocks.34.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
785
+ "blocks.34.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
786
+ "blocks.34.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
787
+ "blocks.34.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
788
+ "blocks.34.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
789
+ "blocks.35.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
790
+ "blocks.35.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
791
+ "blocks.35.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
792
+ "blocks.35.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
793
+ "blocks.35.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
794
+ "blocks.35.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
795
+ "blocks.35.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
796
+ "blocks.35.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
797
+ "blocks.35.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
798
+ "blocks.35.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
799
+ "blocks.35.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
800
+ "blocks.35.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
801
+ "blocks.35.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
802
+ "blocks.35.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
803
+ "blocks.35.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
804
+ "blocks.35.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
805
+ "blocks.35.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
806
+ "blocks.35.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
807
+ "blocks.35.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
808
+ "blocks.35.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
809
+ "blocks.35.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
810
+ "blocks.35.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
811
+ "blocks.35.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
812
+ "blocks.35.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
813
+ "blocks.35.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
814
+ "blocks.35.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
815
+ "blocks.35.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
816
+ "blocks.36.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
817
+ "blocks.36.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
818
+ "blocks.36.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
819
+ "blocks.36.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
820
+ "blocks.36.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
821
+ "blocks.36.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
822
+ "blocks.36.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
823
+ "blocks.36.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
824
+ "blocks.36.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
825
+ "blocks.36.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
826
+ "blocks.36.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
827
+ "blocks.36.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
828
+ "blocks.36.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
829
+ "blocks.36.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
830
+ "blocks.36.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
831
+ "blocks.36.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
832
+ "blocks.36.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
833
+ "blocks.36.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
834
+ "blocks.36.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
835
+ "blocks.36.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
836
+ "blocks.36.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
837
+ "blocks.36.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
838
+ "blocks.36.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
839
+ "blocks.36.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
840
+ "blocks.36.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
841
+ "blocks.36.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
842
+ "blocks.36.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
843
+ "blocks.37.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
844
+ "blocks.37.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
845
+ "blocks.37.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
846
+ "blocks.37.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
847
+ "blocks.37.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
848
+ "blocks.37.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
849
+ "blocks.37.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
850
+ "blocks.37.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
851
+ "blocks.37.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
852
+ "blocks.37.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
853
+ "blocks.37.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
854
+ "blocks.37.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
855
+ "blocks.37.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
856
+ "blocks.37.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
857
+ "blocks.37.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
858
+ "blocks.37.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
859
+ "blocks.37.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
860
+ "blocks.37.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
861
+ "blocks.37.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
862
+ "blocks.37.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
863
+ "blocks.37.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
864
+ "blocks.37.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
865
+ "blocks.37.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
866
+ "blocks.37.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
867
+ "blocks.37.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
868
+ "blocks.37.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
869
+ "blocks.37.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
870
+ "blocks.38.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
871
+ "blocks.38.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
872
+ "blocks.38.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
873
+ "blocks.38.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
874
+ "blocks.38.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
875
+ "blocks.38.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
876
+ "blocks.38.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
877
+ "blocks.38.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
878
+ "blocks.38.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
879
+ "blocks.38.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
880
+ "blocks.38.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
881
+ "blocks.38.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
882
+ "blocks.38.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
883
+ "blocks.38.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
884
+ "blocks.38.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
885
+ "blocks.38.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
886
+ "blocks.38.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
887
+ "blocks.38.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
888
+ "blocks.38.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
889
+ "blocks.38.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
890
+ "blocks.38.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
891
+ "blocks.38.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
892
+ "blocks.38.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
893
+ "blocks.38.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
894
+ "blocks.38.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
895
+ "blocks.38.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
896
+ "blocks.38.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
897
+ "blocks.39.cross_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
898
+ "blocks.39.cross_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
899
+ "blocks.39.cross_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
900
+ "blocks.39.cross_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
901
+ "blocks.39.cross_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
902
+ "blocks.39.cross_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
903
+ "blocks.39.cross_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
904
+ "blocks.39.cross_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
905
+ "blocks.39.cross_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
906
+ "blocks.39.cross_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
907
+ "blocks.39.ffn.0.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
908
+ "blocks.39.ffn.0.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
909
+ "blocks.39.ffn.2.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
910
+ "blocks.39.ffn.2.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
911
+ "blocks.39.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
912
+ "blocks.39.norm3.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
913
+ "blocks.39.norm3.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
914
+ "blocks.39.self_attn.k.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
915
+ "blocks.39.self_attn.k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
916
+ "blocks.39.self_attn.norm_k.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
917
+ "blocks.39.self_attn.norm_q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
918
+ "blocks.39.self_attn.o.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
919
+ "blocks.39.self_attn.o.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
920
+ "blocks.39.self_attn.q.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
921
+ "blocks.39.self_attn.q.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
922
+ "blocks.39.self_attn.v.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
923
+ "blocks.39.self_attn.v.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
924
+ "blocks.4.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
925
+ "blocks.4.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
926
+ "blocks.4.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
927
+ "blocks.4.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
928
+ "blocks.4.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
929
+ "blocks.4.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
930
+ "blocks.4.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
931
+ "blocks.4.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
932
+ "blocks.4.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
933
+ "blocks.4.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
934
+ "blocks.4.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
935
+ "blocks.4.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
936
+ "blocks.4.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
937
+ "blocks.4.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
938
+ "blocks.4.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
939
+ "blocks.4.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
940
+ "blocks.4.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
941
+ "blocks.4.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
942
+ "blocks.4.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
943
+ "blocks.4.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
944
+ "blocks.4.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
945
+ "blocks.4.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
946
+ "blocks.4.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
947
+ "blocks.4.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
948
+ "blocks.4.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
949
+ "blocks.4.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
950
+ "blocks.4.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
951
+ "blocks.5.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
952
+ "blocks.5.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
953
+ "blocks.5.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
954
+ "blocks.5.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
955
+ "blocks.5.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
956
+ "blocks.5.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
957
+ "blocks.5.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
958
+ "blocks.5.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
959
+ "blocks.5.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
960
+ "blocks.5.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
961
+ "blocks.5.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
962
+ "blocks.5.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
963
+ "blocks.5.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
964
+ "blocks.5.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
965
+ "blocks.5.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
966
+ "blocks.5.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
967
+ "blocks.5.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
968
+ "blocks.5.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
969
+ "blocks.5.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
970
+ "blocks.5.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
971
+ "blocks.5.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
972
+ "blocks.5.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
973
+ "blocks.5.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
974
+ "blocks.5.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
975
+ "blocks.5.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
976
+ "blocks.5.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
977
+ "blocks.5.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
978
+ "blocks.6.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
979
+ "blocks.6.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
980
+ "blocks.6.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
981
+ "blocks.6.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
982
+ "blocks.6.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
983
+ "blocks.6.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
984
+ "blocks.6.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
985
+ "blocks.6.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
986
+ "blocks.6.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
987
+ "blocks.6.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
988
+ "blocks.6.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
989
+ "blocks.6.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
990
+ "blocks.6.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
991
+ "blocks.6.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
992
+ "blocks.6.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
993
+ "blocks.6.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
994
+ "blocks.6.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
995
+ "blocks.6.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
996
+ "blocks.6.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
997
+ "blocks.6.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
998
+ "blocks.6.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
999
+ "blocks.6.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1000
+ "blocks.6.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1001
+ "blocks.6.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1002
+ "blocks.6.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1003
+ "blocks.6.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1004
+ "blocks.6.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1005
+ "blocks.7.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1006
+ "blocks.7.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1007
+ "blocks.7.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1008
+ "blocks.7.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1009
+ "blocks.7.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1010
+ "blocks.7.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1011
+ "blocks.7.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1012
+ "blocks.7.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1013
+ "blocks.7.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1014
+ "blocks.7.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1015
+ "blocks.7.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1016
+ "blocks.7.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1017
+ "blocks.7.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1018
+ "blocks.7.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1019
+ "blocks.7.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
1020
+ "blocks.7.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1021
+ "blocks.7.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1022
+ "blocks.7.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1023
+ "blocks.7.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1024
+ "blocks.7.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1025
+ "blocks.7.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1026
+ "blocks.7.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1027
+ "blocks.7.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1028
+ "blocks.7.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1029
+ "blocks.7.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1030
+ "blocks.7.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1031
+ "blocks.7.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1032
+ "blocks.8.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1033
+ "blocks.8.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1034
+ "blocks.8.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1035
+ "blocks.8.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1036
+ "blocks.8.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1037
+ "blocks.8.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1038
+ "blocks.8.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1039
+ "blocks.8.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1040
+ "blocks.8.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1041
+ "blocks.8.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1042
+ "blocks.8.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1043
+ "blocks.8.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1044
+ "blocks.8.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1045
+ "blocks.8.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1046
+ "blocks.8.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
1047
+ "blocks.8.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1048
+ "blocks.8.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1049
+ "blocks.8.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1050
+ "blocks.8.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1051
+ "blocks.8.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1052
+ "blocks.8.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1053
+ "blocks.8.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1054
+ "blocks.8.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1055
+ "blocks.8.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1056
+ "blocks.8.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1057
+ "blocks.8.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1058
+ "blocks.8.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1059
+ "blocks.9.cross_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1060
+ "blocks.9.cross_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1061
+ "blocks.9.cross_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1062
+ "blocks.9.cross_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1063
+ "blocks.9.cross_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1064
+ "blocks.9.cross_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1065
+ "blocks.9.cross_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1066
+ "blocks.9.cross_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1067
+ "blocks.9.cross_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1068
+ "blocks.9.cross_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1069
+ "blocks.9.ffn.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1070
+ "blocks.9.ffn.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1071
+ "blocks.9.ffn.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1072
+ "blocks.9.ffn.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1073
+ "blocks.9.modulation": "diffusion_pytorch_model-00001-of-00003.safetensors",
1074
+ "blocks.9.norm3.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1075
+ "blocks.9.norm3.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1076
+ "blocks.9.self_attn.k.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1077
+ "blocks.9.self_attn.k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1078
+ "blocks.9.self_attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1079
+ "blocks.9.self_attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1080
+ "blocks.9.self_attn.o.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1081
+ "blocks.9.self_attn.o.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1082
+ "blocks.9.self_attn.q.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1083
+ "blocks.9.self_attn.q.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1084
+ "blocks.9.self_attn.v.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1085
+ "blocks.9.self_attn.v.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1086
+ "head.head.bias": "diffusion_pytorch_model-00003-of-00003.safetensors",
1087
+ "head.head.weight": "diffusion_pytorch_model-00003-of-00003.safetensors",
1088
+ "head.modulation": "diffusion_pytorch_model-00003-of-00003.safetensors",
1089
+ "patch_embedding.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1090
+ "patch_embedding.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1091
+ "text_embedding.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1092
+ "text_embedding.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1093
+ "text_embedding.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1094
+ "text_embedding.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1095
+ "time_embedding.0.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1096
+ "time_embedding.0.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1097
+ "time_embedding.2.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1098
+ "time_embedding.2.weight": "diffusion_pytorch_model-00001-of-00003.safetensors",
1099
+ "time_projection.1.bias": "diffusion_pytorch_model-00001-of-00003.safetensors",
1100
+ "time_projection.1.weight": "diffusion_pytorch_model-00001-of-00003.safetensors"
1101
+ }
1102
+ }
transformer/model.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+ from einops import repeat
9
+
10
+ from .attention import (
11
+ flash_attention,
12
+ sageattn_func,
13
+ _SAGEATTN_AVAILABLE,
14
+ _FLASH_ATTN_2_AVAILABLE,
15
+ _FLASH_ATTN_3_AVAILABLE,
16
+ )
17
+
18
+ print("SAGEATTN_AVAILABLE:", _SAGEATTN_AVAILABLE)
19
+
20
+ __all__ = ["WanModel"]
21
+
22
+
23
+ def sinusoidal_embedding_1d(dim, position):
24
+ # preprocess
25
+ assert dim % 2 == 0
26
+ half = dim // 2
27
+ position = position.type(torch.float64)
28
+
29
+ # calculation
30
+ sinusoid = torch.outer(
31
+ position,
32
+ torch.pow(
33
+ 10000,
34
+ -torch.arange(
35
+ half, device=torch.cuda.current_device(), dtype=torch.float64
36
+ ).div(half),
37
+ ),
38
+ )
39
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
40
+ return x
41
+
42
+
43
+ # @amp.autocast(enabled=False)
44
+ def rope_params(max_seq_len, dim, theta=10000):
45
+ assert dim % 2 == 0
46
+ freqs = torch.outer(
47
+ torch.arange(max_seq_len),
48
+ 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
49
+ )
50
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
51
+ return freqs
52
+
53
+
54
+ # @amp.autocast(enabled=False)
55
+ def rope_apply(x, grid_sizes, freqs):
56
+ n, c = x.size(2), x.size(3) // 2
57
+
58
+ # split freqs
59
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
60
+
61
+ # loop over samples
62
+ output = []
63
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
64
+ seq_len = f * h * w
65
+
66
+ # precompute multipliers
67
+ x_i = torch.view_as_complex(
68
+ x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
69
+ )
70
+ freqs_i = torch.cat(
71
+ [
72
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
73
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
74
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
75
+ ],
76
+ dim=-1,
77
+ ).reshape(seq_len, 1, -1)
78
+
79
+ # apply rotary embedding
80
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
81
+ x_i = torch.cat([x_i, x[i, seq_len:]])
82
+
83
+ # append to collection
84
+ output.append(x_i)
85
+ return torch.stack(output).type_as(x)
86
+
87
+
88
+ class WanRMSNorm(nn.Module):
89
+ def __init__(self, dim, eps=1e-5):
90
+ super().__init__()
91
+ self.dim = dim
92
+ self.eps = eps
93
+ self.weight = nn.Parameter(torch.ones(dim))
94
+
95
+ def forward(self, x):
96
+ r"""
97
+ Args:
98
+ x(Tensor): Shape [B, L, C]
99
+ """
100
+ return self._norm(x.float()).type_as(x) * self.weight
101
+
102
+ def _norm(self, x):
103
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
104
+
105
+
106
+ class WanLayerNorm(nn.LayerNorm):
107
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
108
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
109
+
110
+ def forward(self, x):
111
+ r"""
112
+ Args:
113
+ x(Tensor): Shape [B, L, C]
114
+ """
115
+ return super().forward(x).type_as(x)
116
+
117
+
118
+ class WanSelfAttention(nn.Module):
119
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
120
+ assert dim % num_heads == 0
121
+ super().__init__()
122
+ self.dim = dim
123
+ self.num_heads = num_heads
124
+ self.head_dim = dim // num_heads
125
+ self.window_size = window_size
126
+ self.qk_norm = qk_norm
127
+ self.eps = eps
128
+
129
+ # layers
130
+ self.q = nn.Linear(dim, dim)
131
+ self.k = nn.Linear(dim, dim)
132
+ self.v = nn.Linear(dim, dim)
133
+ self.o = nn.Linear(dim, dim)
134
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
135
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
136
+
137
+ def forward(self, x, seq_lens, grid_sizes, freqs):
138
+ r"""
139
+ Args:
140
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
141
+ seq_lens(Tensor): Shape [B]
142
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
143
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
144
+ """
145
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
146
+
147
+ # query, key, value function
148
+ def qkv_fn(x):
149
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
150
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
151
+ v = self.v(x).view(b, s, n, d)
152
+ return q, k, v
153
+
154
+ q, k, v = qkv_fn(x)
155
+
156
+ if _SAGEATTN_AVAILABLE:
157
+ # print("Using sageattention in crossattn")
158
+ og_dtype = q.dtype
159
+ q = q.transpose(1, 2).to(dtype)
160
+ k = k.transpose(1, 2).to(dtype)
161
+ v = v.transpose(1, 2).to(dtype)
162
+ x = sageattn_func(
163
+ q=rope_apply(q, grid_sizes, freqs),
164
+ k=rope_apply(k, grid_sizes, freqs),
165
+ v=v,
166
+ )
167
+ x = x.transpose(1, 2).contiguous().to(og_dtype)
168
+ else:
169
+ x = flash_attention(
170
+ q=rope_apply(q, grid_sizes, freqs),
171
+ k=rope_apply(k, grid_sizes, freqs),
172
+ v=v,
173
+ k_lens=seq_lens,
174
+ window_size=self.window_size,
175
+ )
176
+
177
+ # output
178
+ x = x.flatten(2)
179
+ x = self.o(x)
180
+ return x
181
+
182
+
183
+ class WanT2VCrossAttention(WanSelfAttention):
184
+ def forward(self, x, context, context_lens, crossattn_cache=None):
185
+ r"""
186
+ Args:
187
+ x(Tensor): Shape [B, L1, C]
188
+ context(Tensor): Shape [B, L2, C]
189
+ context_lens(Tensor): Shape [B]
190
+ crossattn_cache (List[dict], *optional*): Contains the cached key and value tensors for context embedding.
191
+ """
192
+ b, n, d = x.size(0), self.num_heads, self.head_dim
193
+
194
+ # compute query, key, value
195
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
196
+
197
+ if crossattn_cache is not None:
198
+ if not crossattn_cache["is_init"]:
199
+ crossattn_cache["is_init"] = True
200
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
201
+ v = self.v(context).view(b, -1, n, d)
202
+ crossattn_cache["k"] = k
203
+ crossattn_cache["v"] = v
204
+ else:
205
+ k = crossattn_cache["k"]
206
+ v = crossattn_cache["v"]
207
+ else:
208
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
209
+ v = self.v(context).view(b, -1, n, d)
210
+
211
+ # compute attention
212
+ if _SAGEATTN_AVAILABLE:
213
+ # print("Using sageattention in crossattn")
214
+ dtype = torch.bfloat16
215
+ og_dtype = q.dtype
216
+ q = q.transpose(1, 2).to(dtype)
217
+ k = k.transpose(1, 2).to(dtype)
218
+ v = v.transpose(1, 2).to(dtype)
219
+ x = sageattn_func(
220
+ q=q,
221
+ k=k,
222
+ v=v,
223
+ )
224
+ x = x.transpose(1, 2).contiguous().to(og_dtype)
225
+ elif _FLASH_ATTN_2_AVAILABLE or _FLASH_ATTN_3_AVAILABLE:
226
+ x = flash_attention(q, k, v, k_lens=context_lens)
227
+ else:
228
+ dtype = torch.bfloat16
229
+ q = q.transpose(1, 2).to(dtype)
230
+ k = k.transpose(1, 2).to(dtype)
231
+ v = v.transpose(1, 2).to(dtype)
232
+
233
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
234
+ x = x.transpose(1, 2).contiguous()
235
+
236
+ # output
237
+ x = x.flatten(2)
238
+ x = self.o(x)
239
+ return x
240
+
241
+
242
+ class WanGanCrossAttention(WanSelfAttention):
243
+ def forward(self, x, context, crossattn_cache=None):
244
+ r"""
245
+ Args:
246
+ x(Tensor): Shape [B, L1, C]
247
+ context(Tensor): Shape [B, L2, C]
248
+ context_lens(Tensor): Shape [B]
249
+ crossattn_cache (List[dict], *optional*): Contains the cached key and value tensors for context embedding.
250
+ """
251
+ b, n, d = x.size(0), self.num_heads, self.head_dim
252
+
253
+ # compute query, key, value
254
+ qq = self.norm_q(self.q(context)).view(b, 1, -1, d)
255
+
256
+ kk = self.norm_k(self.k(x)).view(b, -1, n, d)
257
+ vv = self.v(x).view(b, -1, n, d)
258
+
259
+ # compute attention
260
+ x = flash_attention(qq, kk, vv)
261
+
262
+ # output
263
+ x = x.flatten(2)
264
+ x = self.o(x)
265
+ return x
266
+
267
+
268
+ class WanI2VCrossAttention(WanSelfAttention):
269
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
270
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
271
+
272
+ self.k_img = nn.Linear(dim, dim)
273
+ self.v_img = nn.Linear(dim, dim)
274
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
275
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
276
+
277
+ def forward(self, x, context, context_lens):
278
+ r"""
279
+ Args:
280
+ x(Tensor): Shape [B, L1, C]
281
+ context(Tensor): Shape [B, L2, C]
282
+ context_lens(Tensor): Shape [B]
283
+ """
284
+ context_img = context[:, :257]
285
+ context = context[:, 257:]
286
+ b, n, d = x.size(0), self.num_heads, self.head_dim
287
+
288
+ # compute query, key, value
289
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
290
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
291
+ v = self.v(context).view(b, -1, n, d)
292
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
293
+ v_img = self.v_img(context_img).view(b, -1, n, d)
294
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
295
+ # compute attention
296
+ x = flash_attention(q, k, v, k_lens=context_lens)
297
+
298
+ # output
299
+ x = x.flatten(2)
300
+ img_x = img_x.flatten(2)
301
+ x = x + img_x
302
+ x = self.o(x)
303
+ return x
304
+
305
+
306
+ WAN_CROSSATTENTION_CLASSES = {
307
+ "t2v_cross_attn": WanT2VCrossAttention,
308
+ "i2v_cross_attn": WanI2VCrossAttention,
309
+ }
310
+
311
+
312
+ class WanAttentionBlock(nn.Module):
313
+ def __init__(
314
+ self,
315
+ cross_attn_type,
316
+ dim,
317
+ ffn_dim,
318
+ num_heads,
319
+ window_size=(-1, -1),
320
+ qk_norm=True,
321
+ cross_attn_norm=False,
322
+ eps=1e-6,
323
+ ):
324
+ super().__init__()
325
+ self.dim = dim
326
+ self.ffn_dim = ffn_dim
327
+ self.num_heads = num_heads
328
+ self.window_size = window_size
329
+ self.qk_norm = qk_norm
330
+ self.cross_attn_norm = cross_attn_norm
331
+ self.eps = eps
332
+
333
+ # layers
334
+ self.norm1 = WanLayerNorm(dim, eps)
335
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
336
+ self.norm3 = (
337
+ WanLayerNorm(dim, eps, elementwise_affine=True)
338
+ if cross_attn_norm
339
+ else nn.Identity()
340
+ )
341
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](
342
+ dim, num_heads, (-1, -1), qk_norm, eps
343
+ )
344
+ self.norm2 = WanLayerNorm(dim, eps)
345
+ self.ffn = nn.Sequential(
346
+ nn.Linear(dim, ffn_dim),
347
+ nn.GELU(approximate="tanh"),
348
+ nn.Linear(ffn_dim, dim),
349
+ )
350
+
351
+ # modulation
352
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
353
+
354
+ def forward(
355
+ self,
356
+ x,
357
+ e,
358
+ seq_lens,
359
+ grid_sizes,
360
+ freqs,
361
+ context,
362
+ context_lens,
363
+ ):
364
+ r"""
365
+ Args:
366
+ x(Tensor): Shape [B, L, C]
367
+ e(Tensor): Shape [B, 6, C]
368
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
369
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
370
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
371
+ """
372
+ # assert e.dtype == torch.float32
373
+ # with amp.autocast(dtype=torch.float32):
374
+ e = (self.modulation + e).chunk(6, dim=1)
375
+ # assert e[0].dtype == torch.float32
376
+
377
+ # self-attention
378
+ y = self.self_attn(
379
+ self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs
380
+ )
381
+ # with amp.autocast(dtype=torch.float32):
382
+ x = x + y * e[2]
383
+
384
+ # cross-attention & ffn function
385
+ def cross_attn_ffn(x, context, context_lens, e):
386
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
387
+ y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
388
+ # with amp.autocast(dtype=torch.float32):
389
+ x = x + y * e[5]
390
+ return x
391
+
392
+ x = cross_attn_ffn(x, context, context_lens, e)
393
+ return x
394
+
395
+
396
+ class GanAttentionBlock(nn.Module):
397
+ def __init__(
398
+ self,
399
+ dim=1536,
400
+ ffn_dim=8192,
401
+ num_heads=12,
402
+ window_size=(-1, -1),
403
+ qk_norm=True,
404
+ cross_attn_norm=True,
405
+ eps=1e-6,
406
+ ):
407
+ super().__init__()
408
+ self.dim = dim
409
+ self.ffn_dim = ffn_dim
410
+ self.num_heads = num_heads
411
+ self.window_size = window_size
412
+ self.qk_norm = qk_norm
413
+ self.cross_attn_norm = cross_attn_norm
414
+ self.eps = eps
415
+
416
+ # layers
417
+ # self.norm1 = WanLayerNorm(dim, eps)
418
+ # self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
419
+ # eps)
420
+ self.norm3 = (
421
+ WanLayerNorm(dim, eps, elementwise_affine=True)
422
+ if cross_attn_norm
423
+ else nn.Identity()
424
+ )
425
+
426
+ self.norm2 = WanLayerNorm(dim, eps)
427
+ self.ffn = nn.Sequential(
428
+ nn.Linear(dim, ffn_dim),
429
+ nn.GELU(approximate="tanh"),
430
+ nn.Linear(ffn_dim, dim),
431
+ )
432
+
433
+ self.cross_attn = WanGanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps)
434
+
435
+ # modulation
436
+ # self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
437
+
438
+ def forward(
439
+ self,
440
+ x,
441
+ context,
442
+ # seq_lens,
443
+ # grid_sizes,
444
+ # freqs,
445
+ # context,
446
+ # context_lens,
447
+ ):
448
+ r"""
449
+ Args:
450
+ x(Tensor): Shape [B, L, C]
451
+ e(Tensor): Shape [B, 6, C]
452
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
453
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
454
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
455
+ """
456
+ # assert e.dtype == torch.float32
457
+ # with amp.autocast(dtype=torch.float32):
458
+ # e = (self.modulation + e).chunk(6, dim=1)
459
+ # assert e[0].dtype == torch.float32
460
+
461
+ # # self-attention
462
+ # y = self.self_attn(
463
+ # self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes,
464
+ # freqs)
465
+ # # with amp.autocast(dtype=torch.float32):
466
+ # x = x + y * e[2]
467
+
468
+ # cross-attention & ffn function
469
+ def cross_attn_ffn(x, context):
470
+ token = context + self.cross_attn(self.norm3(x), context)
471
+ y = self.ffn(self.norm2(token)) + token # * (1 + e[4]) + e[3])
472
+ # with amp.autocast(dtype=torch.float32):
473
+ # x = x + y * e[5]
474
+ return y
475
+
476
+ x = cross_attn_ffn(x, context)
477
+ return x
478
+
479
+
480
+ class Head(nn.Module):
481
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
482
+ super().__init__()
483
+ self.dim = dim
484
+ self.out_dim = out_dim
485
+ self.patch_size = patch_size
486
+ self.eps = eps
487
+
488
+ # layers
489
+ out_dim = math.prod(patch_size) * out_dim
490
+ self.norm = WanLayerNorm(dim, eps)
491
+ self.head = nn.Linear(dim, out_dim)
492
+
493
+ # modulation
494
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
495
+
496
+ def forward(self, x, e):
497
+ r"""
498
+ Args:
499
+ x(Tensor): Shape [B, L1, C]
500
+ e(Tensor): Shape [B, C]
501
+ """
502
+ # assert e.dtype == torch.float32
503
+ # with amp.autocast(dtype=torch.float32):
504
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
505
+ x = self.head(self.norm(x) * (1 + e[1]) + e[0])
506
+ return x
507
+
508
+
509
+ class MLPProj(torch.nn.Module):
510
+ def __init__(self, in_dim, out_dim):
511
+ super().__init__()
512
+
513
+ self.proj = torch.nn.Sequential(
514
+ torch.nn.LayerNorm(in_dim),
515
+ torch.nn.Linear(in_dim, in_dim),
516
+ torch.nn.GELU(),
517
+ torch.nn.Linear(in_dim, out_dim),
518
+ torch.nn.LayerNorm(out_dim),
519
+ )
520
+
521
+ def forward(self, image_embeds):
522
+ clip_extra_context_tokens = self.proj(image_embeds)
523
+ return clip_extra_context_tokens
524
+
525
+
526
+ class RegisterTokens(nn.Module):
527
+ def __init__(self, num_registers: int, dim: int):
528
+ super().__init__()
529
+ self.register_tokens = nn.Parameter(torch.randn(num_registers, dim) * 0.02)
530
+ self.rms_norm = WanRMSNorm(dim, eps=1e-6)
531
+
532
+ def forward(self):
533
+ return self.rms_norm(self.register_tokens)
534
+
535
+ def reset_parameters(self):
536
+ nn.init.normal_(self.register_tokens, std=0.02)
537
+
538
+
539
+ class WanModel(ModelMixin, ConfigMixin):
540
+ r"""
541
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
542
+ """
543
+
544
+ ignore_for_config = [
545
+ "patch_size",
546
+ "cross_attn_norm",
547
+ "qk_norm",
548
+ "text_dim",
549
+ "window_size",
550
+ ]
551
+ _no_split_modules = ["WanAttentionBlock"]
552
+ _supports_gradient_checkpointing = True
553
+
554
+ @register_to_config
555
+ def __init__(
556
+ self,
557
+ model_type="t2v",
558
+ patch_size=(1, 2, 2),
559
+ text_len=512,
560
+ in_dim=16,
561
+ dim=2048,
562
+ ffn_dim=8192,
563
+ freq_dim=256,
564
+ text_dim=4096,
565
+ out_dim=16,
566
+ num_heads=16,
567
+ num_layers=32,
568
+ window_size=(-1, -1),
569
+ qk_norm=True,
570
+ cross_attn_norm=True,
571
+ eps=1e-6,
572
+ ):
573
+ r"""
574
+ Initialize the diffusion model backbone.
575
+
576
+ Args:
577
+ model_type (`str`, *optional*, defaults to 't2v'):
578
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
579
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
580
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
581
+ text_len (`int`, *optional*, defaults to 512):
582
+ Fixed length for text embeddings
583
+ in_dim (`int`, *optional*, defaults to 16):
584
+ Input video channels (C_in)
585
+ dim (`int`, *optional*, defaults to 2048):
586
+ Hidden dimension of the transformer
587
+ ffn_dim (`int`, *optional*, defaults to 8192):
588
+ Intermediate dimension in feed-forward network
589
+ freq_dim (`int`, *optional*, defaults to 256):
590
+ Dimension for sinusoidal time embeddings
591
+ text_dim (`int`, *optional*, defaults to 4096):
592
+ Input dimension for text embeddings
593
+ out_dim (`int`, *optional*, defaults to 16):
594
+ Output video channels (C_out)
595
+ num_heads (`int`, *optional*, defaults to 16):
596
+ Number of attention heads
597
+ num_layers (`int`, *optional*, defaults to 32):
598
+ Number of transformer blocks
599
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
600
+ Window size for local attention (-1 indicates global attention)
601
+ qk_norm (`bool`, *optional*, defaults to True):
602
+ Enable query/key normalization
603
+ cross_attn_norm (`bool`, *optional*, defaults to False):
604
+ Enable cross-attention normalization
605
+ eps (`float`, *optional*, defaults to 1e-6):
606
+ Epsilon value for normalization layers
607
+ """
608
+
609
+ super().__init__()
610
+
611
+ assert model_type in ["t2v", "i2v"]
612
+ self.model_type = model_type
613
+
614
+ self.patch_size = patch_size
615
+ self.text_len = text_len
616
+ self.in_dim = in_dim
617
+ self.dim = dim
618
+ self.ffn_dim = ffn_dim
619
+ self.freq_dim = freq_dim
620
+ self.text_dim = text_dim
621
+ self.out_dim = out_dim
622
+ self.num_heads = num_heads
623
+ self.num_layers = num_layers
624
+ self.window_size = window_size
625
+ self.qk_norm = qk_norm
626
+ self.cross_attn_norm = cross_attn_norm
627
+ self.eps = eps
628
+ self.local_attn_size = 21
629
+
630
+ # embeddings
631
+ self.patch_embedding = nn.Conv3d(
632
+ in_dim, dim, kernel_size=patch_size, stride=patch_size
633
+ )
634
+ self.text_embedding = nn.Sequential(
635
+ nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
636
+ )
637
+
638
+ self.time_embedding = nn.Sequential(
639
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
640
+ )
641
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
642
+
643
+ # blocks
644
+ cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn"
645
+ self.blocks = nn.ModuleList(
646
+ [
647
+ WanAttentionBlock(
648
+ cross_attn_type,
649
+ dim,
650
+ ffn_dim,
651
+ num_heads,
652
+ window_size,
653
+ qk_norm,
654
+ cross_attn_norm,
655
+ eps,
656
+ )
657
+ for _ in range(num_layers)
658
+ ]
659
+ )
660
+
661
+ # head
662
+ self.head = Head(dim, out_dim, patch_size, eps)
663
+
664
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
665
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
666
+ d = dim // num_heads
667
+ self.freqs = torch.cat(
668
+ [
669
+ # rope_params(1024, d - 4 * (d // 6)),
670
+ rope_params_riflex(
671
+ 1024,
672
+ d - 4 * (d // 6),
673
+ ),
674
+ rope_params(1024, 2 * (d // 6)),
675
+ rope_params(1024, 2 * (d // 6)),
676
+ ],
677
+ dim=1,
678
+ )
679
+
680
+ if model_type == "i2v":
681
+ self.img_emb = MLPProj(1280, dim)
682
+
683
+ # initialize weights
684
+ self.init_weights()
685
+
686
+ self.gradient_checkpointing = False
687
+
688
+ def _set_gradient_checkpointing(self, module, value=False):
689
+ self.gradient_checkpointing = value
690
+
691
+ def forward(self, *args, **kwargs):
692
+ # if kwargs.get('classify_mode', False) is True:
693
+ # kwargs.pop('classify_mode')
694
+ # return self._forward_classify(*args, **kwargs)
695
+ # else:
696
+ return self._forward(*args, **kwargs)
697
+
698
+ def _forward(
699
+ self,
700
+ x,
701
+ t,
702
+ context,
703
+ seq_len,
704
+ classify_mode=False,
705
+ concat_time_embeddings=False,
706
+ register_tokens=None,
707
+ cls_pred_branch=None,
708
+ gan_ca_blocks=None,
709
+ clip_fea=None,
710
+ y=None,
711
+ ):
712
+ r"""
713
+ Forward pass through the diffusion model
714
+
715
+ Args:
716
+ x (List[Tensor]):
717
+ List of input video tensors, each with shape [C_in, F, H, W]
718
+ t (Tensor):
719
+ Diffusion timesteps tensor of shape [B]
720
+ context (List[Tensor]):
721
+ List of text embeddings each with shape [L, C]
722
+ seq_len (`int`):
723
+ Maximum sequence length for positional encoding
724
+ clip_fea (Tensor, *optional*):
725
+ CLIP image features for image-to-video mode
726
+ y (List[Tensor], *optional*):
727
+ Conditional video inputs for image-to-video mode, same shape as x
728
+
729
+ Returns:
730
+ List[Tensor]:
731
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
732
+ """
733
+ if self.model_type == "i2v":
734
+ assert clip_fea is not None and y is not None
735
+ # params
736
+ device = self.patch_embedding.weight.device
737
+ if self.freqs.device != device:
738
+ self.freqs = self.freqs.to(device)
739
+
740
+ if y is not None:
741
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
742
+
743
+ # embeddings
744
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
745
+ grid_sizes = torch.stack(
746
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
747
+ )
748
+ x = [u.flatten(2).transpose(1, 2) for u in x]
749
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
750
+ assert seq_lens.max() <= seq_len
751
+ x = torch.cat(
752
+ [
753
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
754
+ for u in x
755
+ ]
756
+ )
757
+
758
+ # time embeddings
759
+ # with amp.autocast(dtype=torch.float32):
760
+ e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).type_as(x))
761
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
762
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
763
+
764
+ # context
765
+ context_lens = None
766
+ context = self.text_embedding(
767
+ torch.stack(
768
+ [
769
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
770
+ for u in context
771
+ ]
772
+ )
773
+ )
774
+
775
+ if clip_fea is not None:
776
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
777
+ context = torch.concat([context_clip, context], dim=1)
778
+
779
+ # arguments
780
+ kwargs = dict(
781
+ e=e0,
782
+ seq_lens=seq_lens,
783
+ grid_sizes=grid_sizes,
784
+ freqs=self.freqs,
785
+ context=context,
786
+ context_lens=context_lens,
787
+ )
788
+
789
+ def create_custom_forward(module):
790
+ def custom_forward(*inputs, **kwargs):
791
+ return module(*inputs, **kwargs)
792
+
793
+ return custom_forward
794
+
795
+ # TODO: Tune the number of blocks for feature extraction
796
+ final_x = None
797
+ if classify_mode:
798
+ assert register_tokens is not None
799
+ assert gan_ca_blocks is not None
800
+ assert cls_pred_branch is not None
801
+
802
+ final_x = []
803
+ registers = repeat(register_tokens(), "n d -> b n d", b=x.shape[0])
804
+ # x = torch.cat([registers, x], dim=1)
805
+
806
+ gan_idx = 0
807
+ for ii, block in enumerate(self.blocks):
808
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
809
+ x = torch.utils.checkpoint.checkpoint(
810
+ create_custom_forward(block),
811
+ x,
812
+ **kwargs,
813
+ use_reentrant=False,
814
+ )
815
+ else:
816
+ x = block(x, **kwargs)
817
+
818
+ if classify_mode and ii in [13, 21, 29]:
819
+ gan_token = registers[:, gan_idx : gan_idx + 1]
820
+ final_x.append(gan_ca_blocks[gan_idx](x, gan_token))
821
+ gan_idx += 1
822
+
823
+ if classify_mode:
824
+ final_x = torch.cat(final_x, dim=1)
825
+ if concat_time_embeddings:
826
+ final_x = cls_pred_branch(
827
+ torch.cat([final_x, 10 * e[:, None, :]], dim=1).view(
828
+ final_x.shape[0], -1
829
+ )
830
+ )
831
+ else:
832
+ final_x = cls_pred_branch(final_x.view(final_x.shape[0], -1))
833
+
834
+ # head
835
+ x = self.head(x, e)
836
+
837
+ # unpatchify
838
+ x = self.unpatchify(x, grid_sizes)
839
+
840
+ if classify_mode:
841
+ return torch.stack(x), final_x
842
+
843
+ return torch.stack(x)
844
+
845
+ def _forward_classify(
846
+ self,
847
+ x,
848
+ t,
849
+ context,
850
+ seq_len,
851
+ register_tokens,
852
+ cls_pred_branch,
853
+ clip_fea=None,
854
+ y=None,
855
+ ):
856
+ r"""
857
+ Feature extraction through the diffusion model
858
+
859
+ Args:
860
+ x (List[Tensor]):
861
+ List of input video tensors, each with shape [C_in, F, H, W]
862
+ t (Tensor):
863
+ Diffusion timesteps tensor of shape [B]
864
+ context (List[Tensor]):
865
+ List of text embeddings each with shape [L, C]
866
+ seq_len (`int`):
867
+ Maximum sequence length for positional encoding
868
+ clip_fea (Tensor, *optional*):
869
+ CLIP image features for image-to-video mode
870
+ y (List[Tensor], *optional*):
871
+ Conditional video inputs for image-to-video mode, same shape as x
872
+
873
+ Returns:
874
+ List[Tensor]:
875
+ List of video features with original input shapes [C_block, F, H / 8, W / 8]
876
+ """
877
+ if self.model_type == "i2v":
878
+ assert clip_fea is not None and y is not None
879
+ # params
880
+ device = self.patch_embedding.weight.device
881
+ if self.freqs.device != device:
882
+ self.freqs = self.freqs.to(device)
883
+
884
+ if y is not None:
885
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
886
+
887
+ # embeddings
888
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
889
+ grid_sizes = torch.stack(
890
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
891
+ )
892
+ x = [u.flatten(2).transpose(1, 2) for u in x]
893
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
894
+ assert seq_lens.max() <= seq_len
895
+ x = torch.cat(
896
+ [
897
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
898
+ for u in x
899
+ ]
900
+ )
901
+
902
+ # time embeddings
903
+ # with amp.autocast(dtype=torch.float32):
904
+ e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).type_as(x))
905
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
906
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
907
+
908
+ # context
909
+ context_lens = None
910
+ context = self.text_embedding(
911
+ torch.stack(
912
+ [
913
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
914
+ for u in context
915
+ ]
916
+ )
917
+ )
918
+
919
+ if clip_fea is not None:
920
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
921
+ context = torch.concat([context_clip, context], dim=1)
922
+
923
+ # arguments
924
+ kwargs = dict(
925
+ e=e0,
926
+ seq_lens=seq_lens,
927
+ grid_sizes=grid_sizes,
928
+ freqs=self.freqs,
929
+ context=context,
930
+ context_lens=context_lens,
931
+ )
932
+
933
+ def create_custom_forward(module):
934
+ def custom_forward(*inputs, **kwargs):
935
+ return module(*inputs, **kwargs)
936
+
937
+ return custom_forward
938
+
939
+ # TODO: Tune the number of blocks for feature extraction
940
+ for block in self.blocks[:16]:
941
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
942
+ x = torch.utils.checkpoint.checkpoint(
943
+ create_custom_forward(block),
944
+ x,
945
+ **kwargs,
946
+ use_reentrant=False,
947
+ )
948
+ else:
949
+ x = block(x, **kwargs)
950
+
951
+ # unpatchify
952
+ x = self.unpatchify(x, grid_sizes, c=self.dim // 4)
953
+ return torch.stack(x)
954
+
955
+ def unpatchify(self, x, grid_sizes, c=None):
956
+ r"""
957
+ Reconstruct video tensors from patch embeddings.
958
+
959
+ Args:
960
+ x (List[Tensor]):
961
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
962
+ grid_sizes (Tensor):
963
+ Original spatial-temporal grid dimensions before patching,
964
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
965
+
966
+ Returns:
967
+ List[Tensor]:
968
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
969
+ """
970
+
971
+ c = self.out_dim if c is None else c
972
+ out = []
973
+ for u, v in zip(x, grid_sizes.tolist()):
974
+ u = u[: math.prod(v)].view(*v, *self.patch_size, c)
975
+ u = torch.einsum("fhwpqrc->cfphqwr", u)
976
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
977
+ out.append(u)
978
+ return out
979
+
980
+ def init_weights(self):
981
+ r"""
982
+ Initialize model parameters using Xavier initialization.
983
+ """
984
+
985
+ # basic init
986
+ for m in self.modules():
987
+ if isinstance(m, nn.Linear):
988
+ nn.init.xavier_uniform_(m.weight)
989
+ if m.bias is not None:
990
+ nn.init.zeros_(m.bias)
991
+
992
+ # init embeddings
993
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
994
+ for m in self.text_embedding.modules():
995
+ if isinstance(m, nn.Linear):
996
+ nn.init.normal_(m.weight, std=0.02)
997
+ for m in self.time_embedding.modules():
998
+ if isinstance(m, nn.Linear):
999
+ nn.init.normal_(m.weight, std=0.02)
1000
+
1001
+ # init output layer
1002
+ nn.init.zeros_(self.head.head.weight)