grnr9730 commited on
Commit
dde113c
·
verified ·
1 Parent(s): 7eb61f7

Update custom_pipeline/AutoencoderKLWan.py

Browse files
Files changed (1) hide show
  1. custom_pipeline/AutoencoderKLWan.py +13 -220
custom_pipeline/AutoencoderKLWan.py CHANGED
@@ -1,17 +1,3 @@
1
- # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
  from typing import List, Optional, Tuple, Union
16
 
17
  import torch
@@ -19,28 +5,24 @@ import torch.nn as nn
19
  import torch.nn.functional as F
20
  import torch.utils.checkpoint
21
 
22
- from ...configuration_utils import ConfigMixin, register_to_config
23
- from ...loaders import FromOriginalModelMixin
24
- from ...utils import logging
25
- from ...utils.accelerate_utils import apply_forward_hook
26
- from ..activations import get_activation
27
- from ..modeling_outputs import AutoencoderKLOutput
28
- from ..modeling_utils import ModelMixin
29
- from .vae import DecoderOutput, DiagonalGaussianDistribution
30
-
31
 
32
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
 
34
  CACHE_T = 2
35
 
36
-
37
  class WanCausalConv3d(nn.Conv3d):
38
  r"""
39
  A custom 3D causal convolution layer with feature caching support.
40
-
41
  This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
42
  caching for efficient inference.
43
-
44
  Args:
45
  in_channels (int): Number of channels in the input image
46
  out_channels (int): Number of channels produced by the convolution
@@ -48,7 +30,6 @@ class WanCausalConv3d(nn.Conv3d):
48
  stride (int or tuple, optional): Stride of the convolution. Default: 1
49
  padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
50
  """
51
-
52
  def __init__(
53
  self,
54
  in_channels: int,
@@ -64,8 +45,6 @@ class WanCausalConv3d(nn.Conv3d):
64
  stride=stride,
65
  padding=padding,
66
  )
67
-
68
- # Set up causal padding
69
  self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
70
  self.padding = (0, 0, 0)
71
 
@@ -78,11 +57,9 @@ class WanCausalConv3d(nn.Conv3d):
78
  x = F.pad(x, padding)
79
  return super().forward(x)
80
 
81
-
82
  class WanRMS_norm(nn.Module):
83
  r"""
84
  A custom RMS normalization layer.
85
-
86
  Args:
87
  dim (int): The number of dimensions to normalize over.
88
  channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
@@ -90,12 +67,10 @@ class WanRMS_norm(nn.Module):
90
  images (bool, optional): Whether the input represents image data. Default is True.
91
  bias (bool, optional): Whether to include a learnable bias term. Default is False.
92
  """
93
-
94
  def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
95
  super().__init__()
96
  broadcastable_dims = (1, 1, 1) if not images else (1, 1)
97
  shape = (dim, *broadcastable_dims) if channel_first else (dim,)
98
-
99
  self.channel_first = channel_first
100
  self.scale = dim**0.5
101
  self.gamma = nn.Parameter(torch.ones(shape))
@@ -104,26 +79,20 @@ class WanRMS_norm(nn.Module):
104
  def forward(self, x):
105
  return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
106
 
107
-
108
  class WanUpsample(nn.Upsample):
109
  r"""
110
  Perform upsampling while ensuring the output tensor has the same data type as the input.
111
-
112
  Args:
113
  x (torch.Tensor): Input tensor to be upsampled.
114
-
115
  Returns:
116
  torch.Tensor: Upsampled tensor with the same data type as the input.
117
  """
118
-
119
  def forward(self, x):
120
  return super().forward(x.float()).type_as(x)
121
 
122
-
123
  class WanResample(nn.Module):
124
  r"""
125
  A custom resampling module for 2D and 3D data.
126
-
127
  Args:
128
  dim (int): The number of input/output channels.
129
  mode (str): The resampling mode. Must be one of:
@@ -133,13 +102,10 @@ class WanResample(nn.Module):
133
  - 'downsample2d': 2D downsampling with zero-padding and convolution.
134
  - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
135
  """
136
-
137
  def __init__(self, dim: int, mode: str) -> None:
138
  super().__init__()
139
  self.dim = dim
140
  self.mode = mode
141
-
142
- # layers
143
  if mode == "upsample2d":
144
  self.resample = nn.Sequential(
145
  WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
@@ -149,13 +115,11 @@ class WanResample(nn.Module):
149
  WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
150
  )
151
  self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
152
-
153
  elif mode == "downsample2d":
154
  self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
155
  elif mode == "downsample3d":
156
  self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
157
  self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
158
-
159
  else:
160
  self.resample = nn.Identity()
161
 
@@ -170,7 +134,6 @@ class WanResample(nn.Module):
170
  else:
171
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
172
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
173
- # cache last frame of last two chunk
174
  cache_x = torch.cat(
175
  [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
176
  )
@@ -182,7 +145,6 @@ class WanResample(nn.Module):
182
  x = self.time_conv(x, feat_cache[idx])
183
  feat_cache[idx] = cache_x
184
  feat_idx[0] += 1
185
-
186
  x = x.reshape(b, 2, c, t, h, w)
187
  x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
188
  x = x.reshape(b, c, t * 2, h, w)
@@ -190,7 +152,6 @@ class WanResample(nn.Module):
190
  x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
191
  x = self.resample(x)
192
  x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
193
-
194
  if self.mode == "downsample3d":
195
  if feat_cache is not None:
196
  idx = feat_idx[0]
@@ -204,18 +165,15 @@ class WanResample(nn.Module):
204
  feat_idx[0] += 1
205
  return x
206
 
207
-
208
  class WanResidualBlock(nn.Module):
209
  r"""
210
  A custom residual block module.
211
-
212
  Args:
213
  in_dim (int): Number of input channels.
214
  out_dim (int): Number of output channels.
215
  dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
216
  non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
217
  """
218
-
219
  def __init__(
220
  self,
221
  in_dim: int,
@@ -227,8 +185,6 @@ class WanResidualBlock(nn.Module):
227
  self.in_dim = in_dim
228
  self.out_dim = out_dim
229
  self.nonlinearity = get_activation(non_linearity)
230
-
231
- # layers
232
  self.norm1 = WanRMS_norm(in_dim, images=False)
233
  self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
234
  self.norm2 = WanRMS_norm(out_dim, images=False)
@@ -237,61 +193,43 @@ class WanResidualBlock(nn.Module):
237
  self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
238
 
239
  def forward(self, x, feat_cache=None, feat_idx=[0]):
240
- # Apply shortcut connection
241
  h = self.conv_shortcut(x)
242
-
243
- # First normalization and activation
244
  x = self.norm1(x)
245
  x = self.nonlinearity(x)
246
-
247
  if feat_cache is not None:
248
  idx = feat_idx[0]
249
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
250
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
251
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
252
-
253
  x = self.conv1(x, feat_cache[idx])
254
  feat_cache[idx] = cache_x
255
  feat_idx[0] += 1
256
  else:
257
  x = self.conv1(x)
258
-
259
- # Second normalization and activation
260
  x = self.norm2(x)
261
  x = self.nonlinearity(x)
262
-
263
- # Dropout
264
  x = self.dropout(x)
265
-
266
  if feat_cache is not None:
267
  idx = feat_idx[0]
268
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
269
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
270
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
271
-
272
  x = self.conv2(x, feat_cache[idx])
273
  feat_cache[idx] = cache_x
274
  feat_idx[0] += 1
275
  else:
276
  x = self.conv2(x)
277
-
278
- # Add residual connection
279
  return x + h
280
 
281
-
282
  class WanAttentionBlock(nn.Module):
283
  r"""
284
  Causal self-attention with a single head.
285
-
286
  Args:
287
  dim (int): The number of channels in the input tensor.
288
  """
289
-
290
  def __init__(self, dim):
291
  super().__init__()
292
  self.dim = dim
293
-
294
- # layers
295
  self.norm = WanRMS_norm(dim)
296
  self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
297
  self.proj = nn.Conv2d(dim, dim, 1)
@@ -299,46 +237,30 @@ class WanAttentionBlock(nn.Module):
299
  def forward(self, x):
300
  identity = x
301
  batch_size, channels, time, height, width = x.size()
302
-
303
  x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
304
  x = self.norm(x)
305
-
306
- # compute query, key, value
307
  qkv = self.to_qkv(x)
308
  qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
309
  qkv = qkv.permute(0, 1, 3, 2).contiguous()
310
  q, k, v = qkv.chunk(3, dim=-1)
311
-
312
- # apply attention
313
  x = F.scaled_dot_product_attention(q, k, v)
314
-
315
  x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
316
-
317
- # output projection
318
  x = self.proj(x)
319
-
320
- # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
321
  x = x.view(batch_size, time, channels, height, width)
322
  x = x.permute(0, 2, 1, 3, 4)
323
-
324
  return x + identity
325
 
326
-
327
  class WanMidBlock(nn.Module):
328
  """
329
  Middle block for WanVAE encoder and decoder.
330
-
331
  Args:
332
  dim (int): Number of input/output channels.
333
  dropout (float): Dropout rate.
334
  non_linearity (str): Type of non-linearity to use.
335
  """
336
-
337
  def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
338
  super().__init__()
339
  self.dim = dim
340
-
341
- # Create the components
342
  resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
343
  attentions = []
344
  for _ in range(num_layers):
@@ -346,27 +268,19 @@ class WanMidBlock(nn.Module):
346
  resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
347
  self.attentions = nn.ModuleList(attentions)
348
  self.resnets = nn.ModuleList(resnets)
349
-
350
  self.gradient_checkpointing = False
351
 
352
  def forward(self, x, feat_cache=None, feat_idx=[0]):
353
- # First residual block
354
  x = self.resnets[0](x, feat_cache, feat_idx)
355
-
356
- # Process through attention and residual blocks
357
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
358
  if attn is not None:
359
  x = attn(x)
360
-
361
  x = resnet(x, feat_cache, feat_idx)
362
-
363
  return x
364
 
365
-
366
  class WanEncoder3d(nn.Module):
367
  r"""
368
  A 3D encoder module.
369
-
370
  Args:
371
  dim (int): The base number of channels in the first layer.
372
  z_dim (int): The dimensionality of the latent space.
@@ -377,7 +291,6 @@ class WanEncoder3d(nn.Module):
377
  dropout (float): Dropout rate for the dropout layers.
378
  non_linearity (str): Type of non-linearity to use.
379
  """
380
-
381
  def __init__(
382
  self,
383
  dim=128,
@@ -397,37 +310,23 @@ class WanEncoder3d(nn.Module):
397
  self.attn_scales = attn_scales
398
  self.temperal_downsample = temperal_downsample
399
  self.nonlinearity = get_activation(non_linearity)
400
-
401
- # dimensions
402
  dims = [dim * u for u in [1] + dim_mult]
403
  scale = 1.0
404
-
405
- # init block
406
  self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
407
-
408
- # downsample blocks
409
  self.down_blocks = nn.ModuleList([])
410
  for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
411
- # residual (+attention) blocks
412
  for _ in range(num_res_blocks):
413
  self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
414
  if scale in attn_scales:
415
  self.down_blocks.append(WanAttentionBlock(out_dim))
416
  in_dim = out_dim
417
-
418
- # downsample block
419
  if i != len(dim_mult) - 1:
420
  mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
421
  self.down_blocks.append(WanResample(out_dim, mode=mode))
422
  scale /= 2.0
423
-
424
- # middle blocks
425
  self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
426
-
427
- # output blocks
428
  self.norm_out = WanRMS_norm(out_dim, images=False)
429
  self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
430
-
431
  self.gradient_checkpointing = False
432
 
433
  def forward(self, x, feat_cache=None, feat_idx=[0]):
@@ -435,32 +334,24 @@ class WanEncoder3d(nn.Module):
435
  idx = feat_idx[0]
436
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
437
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
438
- # cache last frame of last two chunk
439
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
440
  x = self.conv_in(x, feat_cache[idx])
441
  feat_cache[idx] = cache_x
442
  feat_idx[0] += 1
443
  else:
444
  x = self.conv_in(x)
445
-
446
- ## downsamples
447
  for layer in self.down_blocks:
448
  if feat_cache is not None:
449
  x = layer(x, feat_cache, feat_idx)
450
  else:
451
  x = layer(x)
452
-
453
- ## middle
454
  x = self.mid_block(x, feat_cache, feat_idx)
455
-
456
- ## head
457
  x = self.norm_out(x)
458
  x = self.nonlinearity(x)
459
  if feat_cache is not None:
460
  idx = feat_idx[0]
461
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
462
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
463
- # cache last frame of last two chunk
464
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
465
  x = self.conv_out(x, feat_cache[idx])
466
  feat_cache[idx] = cache_x
@@ -469,11 +360,9 @@ class WanEncoder3d(nn.Module):
469
  x = self.conv_out(x)
470
  return x
471
 
472
-
473
  class WanUpBlock(nn.Module):
474
  """
475
  A block that handles upsampling for the WanVAE decoder.
476
-
477
  Args:
478
  in_dim (int): Input dimension
479
  out_dim (int): Output dimension
@@ -482,7 +371,6 @@ class WanUpBlock(nn.Module):
482
  upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
483
  non_linearity (str): Type of non-linearity to use
484
  """
485
-
486
  def __init__(
487
  self,
488
  in_dim: int,
@@ -495,42 +383,23 @@ class WanUpBlock(nn.Module):
495
  super().__init__()
496
  self.in_dim = in_dim
497
  self.out_dim = out_dim
498
-
499
- # Create layers list
500
  resnets = []
501
- # Add residual blocks and attention if needed
502
  current_dim = in_dim
503
  for _ in range(num_res_blocks + 1):
504
  resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
505
  current_dim = out_dim
506
-
507
  self.resnets = nn.ModuleList(resnets)
508
-
509
- # Add upsampling layer if needed
510
  self.upsamplers = None
511
  if upsample_mode is not None:
512
  self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
513
-
514
  self.gradient_checkpointing = False
515
 
516
  def forward(self, x, feat_cache=None, feat_idx=[0]):
517
- """
518
- Forward pass through the upsampling block.
519
-
520
- Args:
521
- x (torch.Tensor): Input tensor
522
- feat_cache (list, optional): Feature cache for causal convolutions
523
- feat_idx (list, optional): Feature index for cache management
524
-
525
- Returns:
526
- torch.Tensor: Output tensor
527
- """
528
  for resnet in self.resnets:
529
  if feat_cache is not None:
530
  x = resnet(x, feat_cache, feat_idx)
531
  else:
532
  x = resnet(x)
533
-
534
  if self.upsamplers is not None:
535
  if feat_cache is not None:
536
  x = self.upsamplers[0](x, feat_cache, feat_idx)
@@ -538,11 +407,9 @@ class WanUpBlock(nn.Module):
538
  x = self.upsamplers[0](x)
539
  return x
540
 
541
-
542
  class WanDecoder3d(nn.Module):
543
  r"""
544
  A 3D decoder module.
545
-
546
  Args:
547
  dim (int): The base number of channels in the first layer.
548
  z_dim (int): The dimensionality of the latent space.
@@ -553,7 +420,6 @@ class WanDecoder3d(nn.Module):
553
  dropout (float): Dropout rate for the dropout layers.
554
  non_linearity (str): Type of non-linearity to use.
555
  """
556
-
557
  def __init__(
558
  self,
559
  dim=128,
@@ -572,32 +438,18 @@ class WanDecoder3d(nn.Module):
572
  self.num_res_blocks = num_res_blocks
573
  self.attn_scales = attn_scales
574
  self.temperal_upsample = temperal_upsample
575
-
576
  self.nonlinearity = get_activation(non_linearity)
577
-
578
- # dimensions
579
  dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
580
  scale = 1.0 / 2 ** (len(dim_mult) - 2)
581
-
582
- # init block
583
  self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
584
-
585
- # middle blocks
586
  self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
587
-
588
- # upsample blocks
589
  self.up_blocks = nn.ModuleList([])
590
  for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
591
- # residual (+attention) blocks
592
  if i > 0:
593
  in_dim = in_dim // 2
594
-
595
- # Determine if we need upsampling
596
  upsample_mode = None
597
  if i != len(dim_mult) - 1:
598
  upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
599
-
600
- # Create and add the upsampling block
601
  up_block = WanUpBlock(
602
  in_dim=in_dim,
603
  out_dim=out_dim,
@@ -607,46 +459,32 @@ class WanDecoder3d(nn.Module):
607
  non_linearity=non_linearity,
608
  )
609
  self.up_blocks.append(up_block)
610
-
611
- # Update scale for next iteration
612
  if upsample_mode is not None:
613
  scale *= 2.0
614
-
615
- # output blocks
616
  self.norm_out = WanRMS_norm(out_dim, images=False)
617
  self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
618
-
619
  self.gradient_checkpointing = False
620
 
621
  def forward(self, x, feat_cache=None, feat_idx=[0]):
622
- ## conv1
623
  if feat_cache is not None:
624
  idx = feat_idx[0]
625
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
626
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
627
- # cache last frame of last two chunk
628
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
629
  x = self.conv_in(x, feat_cache[idx])
630
  feat_cache[idx] = cache_x
631
  feat_idx[0] += 1
632
  else:
633
  x = self.conv_in(x)
634
-
635
- ## middle
636
  x = self.mid_block(x, feat_cache, feat_idx)
637
-
638
- ## upsamples
639
  for up_block in self.up_blocks:
640
  x = up_block(x, feat_cache, feat_idx)
641
-
642
- ## head
643
  x = self.norm_out(x)
644
  x = self.nonlinearity(x)
645
  if feat_cache is not None:
646
  idx = feat_idx[0]
647
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
648
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
649
- # cache last frame of last two chunk
650
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
651
  x = self.conv_out(x, feat_cache[idx])
652
  feat_cache[idx] = cache_x
@@ -655,16 +493,13 @@ class WanDecoder3d(nn.Module):
655
  x = self.conv_out(x)
656
  return x
657
 
658
-
659
  class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
660
  r"""
661
  A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
662
  Introduced in [Wan 2.1].
663
-
664
  This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
665
  for all models (such as downloading or saving).
666
  """
667
-
668
  _supports_gradient_checkpointing = False
669
 
670
  @register_to_config
@@ -678,54 +513,23 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
678
  temperal_downsample: List[bool] = [False, True, True],
679
  dropout: float = 0.0,
680
  latents_mean: List[float] = [
681
- -0.7571,
682
- -0.7089,
683
- -0.9113,
684
- 0.1075,
685
- -0.1745,
686
- 0.9653,
687
- -0.1517,
688
- 1.5508,
689
- 0.4134,
690
- -0.0715,
691
- 0.5517,
692
- -0.3632,
693
- -0.1922,
694
- -0.9497,
695
- 0.2503,
696
- -0.2921,
697
  ],
698
  latents_std: List[float] = [
699
- 2.8184,
700
- 1.4541,
701
- 2.3275,
702
- 2.6558,
703
- 1.2196,
704
- 1.7708,
705
- 2.6052,
706
- 2.0743,
707
- 3.2687,
708
- 2.1526,
709
- 2.8652,
710
- 1.5579,
711
- 1.6382,
712
- 1.1253,
713
- 2.8251,
714
- 1.9160,
715
  ],
716
  ) -> None:
717
  super().__init__()
718
-
719
  self.z_dim = z_dim
720
  self.temperal_downsample = temperal_downsample
721
  self.temperal_upsample = temperal_downsample[::-1]
722
-
723
  self.encoder = WanEncoder3d(
724
  base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
725
  )
726
  self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
727
  self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
728
-
729
  self.decoder = WanDecoder3d(
730
  base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
731
  )
@@ -741,14 +545,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
741
  self._conv_num = _count_conv3d(self.decoder)
742
  self._conv_idx = [0]
743
  self._feat_map = [None] * self._conv_num
744
- # cache encode
745
  self._enc_conv_num = _count_conv3d(self.encoder)
746
  self._enc_conv_idx = [0]
747
  self._enc_feat_map = [None] * self._enc_conv_num
748
 
749
  def _encode(self, x: torch.Tensor) -> torch.Tensor:
750
  self.clear_cache()
751
- ## cache
752
  t = x.shape[2]
753
  iter_ = 1 + (t - 1) // 4
754
  for i in range(iter_):
@@ -762,7 +564,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
762
  feat_idx=self._enc_conv_idx,
763
  )
764
  out = torch.cat([out, out_], 2)
765
-
766
  enc = self.quant_conv(out)
767
  mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
768
  enc = torch.cat([mu, logvar], dim=1)
@@ -775,12 +576,10 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
775
  ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
776
  r"""
777
  Encode a batch of images into latents.
778
-
779
  Args:
780
  x (`torch.Tensor`): Input batch of images.
781
  return_dict (`bool`, *optional*, defaults to `True`):
782
  Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
783
-
784
  Returns:
785
  The latent representations of the encoded videos. If `return_dict` is True, a
786
  [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
@@ -793,7 +592,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
793
 
794
  def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
795
  self.clear_cache()
796
-
797
  iter_ = z.shape[2]
798
  x = self.post_quant_conv(z)
799
  for i in range(iter_):
@@ -803,24 +601,20 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
803
  else:
804
  out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
805
  out = torch.cat([out, out_], 2)
806
-
807
  out = torch.clamp(out, min=-1.0, max=1.0)
808
  self.clear_cache()
809
  if not return_dict:
810
  return (out,)
811
-
812
  return DecoderOutput(sample=out)
813
 
814
  @apply_forward_hook
815
  def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
816
  r"""
817
  Decode a batch of images.
818
-
819
  Args:
820
  z (`torch.Tensor`): Input batch of latent vectors.
821
  return_dict (`bool`, *optional*, defaults to `True`):
822
  Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
823
-
824
  Returns:
825
  [`~models.vae.DecoderOutput`] or `tuple`:
826
  If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
@@ -829,7 +623,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
829
  decoded = self._decode(z).sample
830
  if not return_dict:
831
  return (decoded,)
832
-
833
  return DecoderOutput(sample=decoded)
834
 
835
  def forward(
@@ -852,4 +645,4 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
852
  else:
853
  z = posterior.mode()
854
  dec = self.decode(z, return_dict=return_dict)
855
- return dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List, Optional, Tuple, Union
2
 
3
  import torch
 
5
  import torch.nn.functional as F
6
  import torch.utils.checkpoint
7
 
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config # 상대 → 절대 임포트
9
+ from diffusers.loaders import FromOriginalModelMixin # 상대 → 절대 임포트
10
+ from diffusers.utils import logging # 상대 → 절대 임포트
11
+ from diffusers.utils.accelerate_utils import apply_forward_hook # 상대 → 절대 임포트
12
+ from diffusers.models.activations import get_activation # 상대 → 절대 임포트
13
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput # 상대 → 절대 임포트
14
+ from diffusers.models.modeling_utils import ModelMixin # 상대 → 절대 임포트
15
+ from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution # 상대 → 절대 임포트
 
16
 
17
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
 
19
  CACHE_T = 2
20
 
 
21
  class WanCausalConv3d(nn.Conv3d):
22
  r"""
23
  A custom 3D causal convolution layer with feature caching support.
 
24
  This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
25
  caching for efficient inference.
 
26
  Args:
27
  in_channels (int): Number of channels in the input image
28
  out_channels (int): Number of channels produced by the convolution
 
30
  stride (int or tuple, optional): Stride of the convolution. Default: 1
31
  padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
32
  """
 
33
  def __init__(
34
  self,
35
  in_channels: int,
 
45
  stride=stride,
46
  padding=padding,
47
  )
 
 
48
  self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
49
  self.padding = (0, 0, 0)
50
 
 
57
  x = F.pad(x, padding)
58
  return super().forward(x)
59
 
 
60
  class WanRMS_norm(nn.Module):
61
  r"""
62
  A custom RMS normalization layer.
 
63
  Args:
64
  dim (int): The number of dimensions to normalize over.
65
  channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
 
67
  images (bool, optional): Whether the input represents image data. Default is True.
68
  bias (bool, optional): Whether to include a learnable bias term. Default is False.
69
  """
 
70
  def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
71
  super().__init__()
72
  broadcastable_dims = (1, 1, 1) if not images else (1, 1)
73
  shape = (dim, *broadcastable_dims) if channel_first else (dim,)
 
74
  self.channel_first = channel_first
75
  self.scale = dim**0.5
76
  self.gamma = nn.Parameter(torch.ones(shape))
 
79
  def forward(self, x):
80
  return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
81
 
 
82
  class WanUpsample(nn.Upsample):
83
  r"""
84
  Perform upsampling while ensuring the output tensor has the same data type as the input.
 
85
  Args:
86
  x (torch.Tensor): Input tensor to be upsampled.
 
87
  Returns:
88
  torch.Tensor: Upsampled tensor with the same data type as the input.
89
  """
 
90
  def forward(self, x):
91
  return super().forward(x.float()).type_as(x)
92
 
 
93
  class WanResample(nn.Module):
94
  r"""
95
  A custom resampling module for 2D and 3D data.
 
96
  Args:
97
  dim (int): The number of input/output channels.
98
  mode (str): The resampling mode. Must be one of:
 
102
  - 'downsample2d': 2D downsampling with zero-padding and convolution.
103
  - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
104
  """
 
105
  def __init__(self, dim: int, mode: str) -> None:
106
  super().__init__()
107
  self.dim = dim
108
  self.mode = mode
 
 
109
  if mode == "upsample2d":
110
  self.resample = nn.Sequential(
111
  WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
 
115
  WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
116
  )
117
  self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
 
118
  elif mode == "downsample2d":
119
  self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
120
  elif mode == "downsample3d":
121
  self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
122
  self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
 
123
  else:
124
  self.resample = nn.Identity()
125
 
 
134
  else:
135
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
136
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
 
137
  cache_x = torch.cat(
138
  [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
139
  )
 
145
  x = self.time_conv(x, feat_cache[idx])
146
  feat_cache[idx] = cache_x
147
  feat_idx[0] += 1
 
148
  x = x.reshape(b, 2, c, t, h, w)
149
  x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
150
  x = x.reshape(b, c, t * 2, h, w)
 
152
  x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
153
  x = self.resample(x)
154
  x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
 
155
  if self.mode == "downsample3d":
156
  if feat_cache is not None:
157
  idx = feat_idx[0]
 
165
  feat_idx[0] += 1
166
  return x
167
 
 
168
  class WanResidualBlock(nn.Module):
169
  r"""
170
  A custom residual block module.
 
171
  Args:
172
  in_dim (int): Number of input channels.
173
  out_dim (int): Number of output channels.
174
  dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
175
  non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
176
  """
 
177
  def __init__(
178
  self,
179
  in_dim: int,
 
185
  self.in_dim = in_dim
186
  self.out_dim = out_dim
187
  self.nonlinearity = get_activation(non_linearity)
 
 
188
  self.norm1 = WanRMS_norm(in_dim, images=False)
189
  self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
190
  self.norm2 = WanRMS_norm(out_dim, images=False)
 
193
  self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
194
 
195
  def forward(self, x, feat_cache=None, feat_idx=[0]):
 
196
  h = self.conv_shortcut(x)
 
 
197
  x = self.norm1(x)
198
  x = self.nonlinearity(x)
 
199
  if feat_cache is not None:
200
  idx = feat_idx[0]
201
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
202
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
203
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
 
204
  x = self.conv1(x, feat_cache[idx])
205
  feat_cache[idx] = cache_x
206
  feat_idx[0] += 1
207
  else:
208
  x = self.conv1(x)
 
 
209
  x = self.norm2(x)
210
  x = self.nonlinearity(x)
 
 
211
  x = self.dropout(x)
 
212
  if feat_cache is not None:
213
  idx = feat_idx[0]
214
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
215
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
216
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
 
217
  x = self.conv2(x, feat_cache[idx])
218
  feat_cache[idx] = cache_x
219
  feat_idx[0] += 1
220
  else:
221
  x = self.conv2(x)
 
 
222
  return x + h
223
 
 
224
  class WanAttentionBlock(nn.Module):
225
  r"""
226
  Causal self-attention with a single head.
 
227
  Args:
228
  dim (int): The number of channels in the input tensor.
229
  """
 
230
  def __init__(self, dim):
231
  super().__init__()
232
  self.dim = dim
 
 
233
  self.norm = WanRMS_norm(dim)
234
  self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
235
  self.proj = nn.Conv2d(dim, dim, 1)
 
237
  def forward(self, x):
238
  identity = x
239
  batch_size, channels, time, height, width = x.size()
 
240
  x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
241
  x = self.norm(x)
 
 
242
  qkv = self.to_qkv(x)
243
  qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
244
  qkv = qkv.permute(0, 1, 3, 2).contiguous()
245
  q, k, v = qkv.chunk(3, dim=-1)
 
 
246
  x = F.scaled_dot_product_attention(q, k, v)
 
247
  x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
 
 
248
  x = self.proj(x)
 
 
249
  x = x.view(batch_size, time, channels, height, width)
250
  x = x.permute(0, 2, 1, 3, 4)
 
251
  return x + identity
252
 
 
253
  class WanMidBlock(nn.Module):
254
  """
255
  Middle block for WanVAE encoder and decoder.
 
256
  Args:
257
  dim (int): Number of input/output channels.
258
  dropout (float): Dropout rate.
259
  non_linearity (str): Type of non-linearity to use.
260
  """
 
261
  def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
262
  super().__init__()
263
  self.dim = dim
 
 
264
  resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
265
  attentions = []
266
  for _ in range(num_layers):
 
268
  resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
269
  self.attentions = nn.ModuleList(attentions)
270
  self.resnets = nn.ModuleList(resnets)
 
271
  self.gradient_checkpointing = False
272
 
273
  def forward(self, x, feat_cache=None, feat_idx=[0]):
 
274
  x = self.resnets[0](x, feat_cache, feat_idx)
 
 
275
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
276
  if attn is not None:
277
  x = attn(x)
 
278
  x = resnet(x, feat_cache, feat_idx)
 
279
  return x
280
 
 
281
  class WanEncoder3d(nn.Module):
282
  r"""
283
  A 3D encoder module.
 
284
  Args:
285
  dim (int): The base number of channels in the first layer.
286
  z_dim (int): The dimensionality of the latent space.
 
291
  dropout (float): Dropout rate for the dropout layers.
292
  non_linearity (str): Type of non-linearity to use.
293
  """
 
294
  def __init__(
295
  self,
296
  dim=128,
 
310
  self.attn_scales = attn_scales
311
  self.temperal_downsample = temperal_downsample
312
  self.nonlinearity = get_activation(non_linearity)
 
 
313
  dims = [dim * u for u in [1] + dim_mult]
314
  scale = 1.0
 
 
315
  self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
 
 
316
  self.down_blocks = nn.ModuleList([])
317
  for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
 
318
  for _ in range(num_res_blocks):
319
  self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
320
  if scale in attn_scales:
321
  self.down_blocks.append(WanAttentionBlock(out_dim))
322
  in_dim = out_dim
 
 
323
  if i != len(dim_mult) - 1:
324
  mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
325
  self.down_blocks.append(WanResample(out_dim, mode=mode))
326
  scale /= 2.0
 
 
327
  self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
 
 
328
  self.norm_out = WanRMS_norm(out_dim, images=False)
329
  self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
 
330
  self.gradient_checkpointing = False
331
 
332
  def forward(self, x, feat_cache=None, feat_idx=[0]):
 
334
  idx = feat_idx[0]
335
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
336
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
 
337
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
338
  x = self.conv_in(x, feat_cache[idx])
339
  feat_cache[idx] = cache_x
340
  feat_idx[0] += 1
341
  else:
342
  x = self.conv_in(x)
 
 
343
  for layer in self.down_blocks:
344
  if feat_cache is not None:
345
  x = layer(x, feat_cache, feat_idx)
346
  else:
347
  x = layer(x)
 
 
348
  x = self.mid_block(x, feat_cache, feat_idx)
 
 
349
  x = self.norm_out(x)
350
  x = self.nonlinearity(x)
351
  if feat_cache is not None:
352
  idx = feat_idx[0]
353
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
354
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
 
355
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
356
  x = self.conv_out(x, feat_cache[idx])
357
  feat_cache[idx] = cache_x
 
360
  x = self.conv_out(x)
361
  return x
362
 
 
363
  class WanUpBlock(nn.Module):
364
  """
365
  A block that handles upsampling for the WanVAE decoder.
 
366
  Args:
367
  in_dim (int): Input dimension
368
  out_dim (int): Output dimension
 
371
  upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
372
  non_linearity (str): Type of non-linearity to use
373
  """
 
374
  def __init__(
375
  self,
376
  in_dim: int,
 
383
  super().__init__()
384
  self.in_dim = in_dim
385
  self.out_dim = out_dim
 
 
386
  resnets = []
 
387
  current_dim = in_dim
388
  for _ in range(num_res_blocks + 1):
389
  resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
390
  current_dim = out_dim
 
391
  self.resnets = nn.ModuleList(resnets)
 
 
392
  self.upsamplers = None
393
  if upsample_mode is not None:
394
  self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
 
395
  self.gradient_checkpointing = False
396
 
397
  def forward(self, x, feat_cache=None, feat_idx=[0]):
 
 
 
 
 
 
 
 
 
 
 
398
  for resnet in self.resnets:
399
  if feat_cache is not None:
400
  x = resnet(x, feat_cache, feat_idx)
401
  else:
402
  x = resnet(x)
 
403
  if self.upsamplers is not None:
404
  if feat_cache is not None:
405
  x = self.upsamplers[0](x, feat_cache, feat_idx)
 
407
  x = self.upsamplers[0](x)
408
  return x
409
 
 
410
  class WanDecoder3d(nn.Module):
411
  r"""
412
  A 3D decoder module.
 
413
  Args:
414
  dim (int): The base number of channels in the first layer.
415
  z_dim (int): The dimensionality of the latent space.
 
420
  dropout (float): Dropout rate for the dropout layers.
421
  non_linearity (str): Type of non-linearity to use.
422
  """
 
423
  def __init__(
424
  self,
425
  dim=128,
 
438
  self.num_res_blocks = num_res_blocks
439
  self.attn_scales = attn_scales
440
  self.temperal_upsample = temperal_upsample
 
441
  self.nonlinearity = get_activation(non_linearity)
 
 
442
  dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
443
  scale = 1.0 / 2 ** (len(dim_mult) - 2)
 
 
444
  self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
 
 
445
  self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
 
 
446
  self.up_blocks = nn.ModuleList([])
447
  for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
 
448
  if i > 0:
449
  in_dim = in_dim // 2
 
 
450
  upsample_mode = None
451
  if i != len(dim_mult) - 1:
452
  upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
 
 
453
  up_block = WanUpBlock(
454
  in_dim=in_dim,
455
  out_dim=out_dim,
 
459
  non_linearity=non_linearity,
460
  )
461
  self.up_blocks.append(up_block)
 
 
462
  if upsample_mode is not None:
463
  scale *= 2.0
 
 
464
  self.norm_out = WanRMS_norm(out_dim, images=False)
465
  self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
 
466
  self.gradient_checkpointing = False
467
 
468
  def forward(self, x, feat_cache=None, feat_idx=[0]):
 
469
  if feat_cache is not None:
470
  idx = feat_idx[0]
471
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
472
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
 
473
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
474
  x = self.conv_in(x, feat_cache[idx])
475
  feat_cache[idx] = cache_x
476
  feat_idx[0] += 1
477
  else:
478
  x = self.conv_in(x)
 
 
479
  x = self.mid_block(x, feat_cache, feat_idx)
 
 
480
  for up_block in self.up_blocks:
481
  x = up_block(x, feat_cache, feat_idx)
 
 
482
  x = self.norm_out(x)
483
  x = self.nonlinearity(x)
484
  if feat_cache is not None:
485
  idx = feat_idx[0]
486
  cache_x = x[:, :, -CACHE_T:, :, :].clone()
487
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
 
488
  cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
489
  x = self.conv_out(x, feat_cache[idx])
490
  feat_cache[idx] = cache_x
 
493
  x = self.conv_out(x)
494
  return x
495
 
 
496
  class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
497
  r"""
498
  A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
499
  Introduced in [Wan 2.1].
 
500
  This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
501
  for all models (such as downloading or saving).
502
  """
 
503
  _supports_gradient_checkpointing = False
504
 
505
  @register_to_config
 
513
  temperal_downsample: List[bool] = [False, True, True],
514
  dropout: float = 0.0,
515
  latents_mean: List[float] = [
516
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
517
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  ],
519
  latents_std: List[float] = [
520
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
521
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  ],
523
  ) -> None:
524
  super().__init__()
 
525
  self.z_dim = z_dim
526
  self.temperal_downsample = temperal_downsample
527
  self.temperal_upsample = temperal_downsample[::-1]
 
528
  self.encoder = WanEncoder3d(
529
  base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
530
  )
531
  self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
532
  self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
 
533
  self.decoder = WanDecoder3d(
534
  base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
535
  )
 
545
  self._conv_num = _count_conv3d(self.decoder)
546
  self._conv_idx = [0]
547
  self._feat_map = [None] * self._conv_num
 
548
  self._enc_conv_num = _count_conv3d(self.encoder)
549
  self._enc_conv_idx = [0]
550
  self._enc_feat_map = [None] * self._enc_conv_num
551
 
552
  def _encode(self, x: torch.Tensor) -> torch.Tensor:
553
  self.clear_cache()
 
554
  t = x.shape[2]
555
  iter_ = 1 + (t - 1) // 4
556
  for i in range(iter_):
 
564
  feat_idx=self._enc_conv_idx,
565
  )
566
  out = torch.cat([out, out_], 2)
 
567
  enc = self.quant_conv(out)
568
  mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
569
  enc = torch.cat([mu, logvar], dim=1)
 
576
  ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
577
  r"""
578
  Encode a batch of images into latents.
 
579
  Args:
580
  x (`torch.Tensor`): Input batch of images.
581
  return_dict (`bool`, *optional*, defaults to `True`):
582
  Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
 
583
  Returns:
584
  The latent representations of the encoded videos. If `return_dict` is True, a
585
  [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
 
592
 
593
  def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
594
  self.clear_cache()
 
595
  iter_ = z.shape[2]
596
  x = self.post_quant_conv(z)
597
  for i in range(iter_):
 
601
  else:
602
  out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
603
  out = torch.cat([out, out_], 2)
 
604
  out = torch.clamp(out, min=-1.0, max=1.0)
605
  self.clear_cache()
606
  if not return_dict:
607
  return (out,)
 
608
  return DecoderOutput(sample=out)
609
 
610
  @apply_forward_hook
611
  def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
612
  r"""
613
  Decode a batch of images.
 
614
  Args:
615
  z (`torch.Tensor`): Input batch of latent vectors.
616
  return_dict (`bool`, *optional*, defaults to `True`):
617
  Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
 
618
  Returns:
619
  [`~models.vae.DecoderOutput`] or `tuple`:
620
  If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
 
623
  decoded = self._decode(z).sample
624
  if not return_dict:
625
  return (decoded,)
 
626
  return DecoderOutput(sample=decoded)
627
 
628
  def forward(
 
645
  else:
646
  z = posterior.mode()
647
  dec = self.decode(z, return_dict=return_dict)
648
+ return dec