Update custom_pipeline/AutoencoderKLWan.py
Browse files- custom_pipeline/AutoencoderKLWan.py +13 -220
custom_pipeline/AutoencoderKLWan.py
CHANGED
|
@@ -1,17 +1,3 @@
|
|
| 1 |
-
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
from typing import List, Optional, Tuple, Union
|
| 16 |
|
| 17 |
import torch
|
|
@@ -19,28 +5,24 @@ import torch.nn as nn
|
|
| 19 |
import torch.nn.functional as F
|
| 20 |
import torch.utils.checkpoint
|
| 21 |
|
| 22 |
-
from
|
| 23 |
-
from
|
| 24 |
-
from
|
| 25 |
-
from
|
| 26 |
-
from
|
| 27 |
-
from
|
| 28 |
-
from
|
| 29 |
-
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
| 30 |
-
|
| 31 |
|
| 32 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 33 |
|
| 34 |
CACHE_T = 2
|
| 35 |
|
| 36 |
-
|
| 37 |
class WanCausalConv3d(nn.Conv3d):
|
| 38 |
r"""
|
| 39 |
A custom 3D causal convolution layer with feature caching support.
|
| 40 |
-
|
| 41 |
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
|
| 42 |
caching for efficient inference.
|
| 43 |
-
|
| 44 |
Args:
|
| 45 |
in_channels (int): Number of channels in the input image
|
| 46 |
out_channels (int): Number of channels produced by the convolution
|
|
@@ -48,7 +30,6 @@ class WanCausalConv3d(nn.Conv3d):
|
|
| 48 |
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 49 |
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
|
| 50 |
"""
|
| 51 |
-
|
| 52 |
def __init__(
|
| 53 |
self,
|
| 54 |
in_channels: int,
|
|
@@ -64,8 +45,6 @@ class WanCausalConv3d(nn.Conv3d):
|
|
| 64 |
stride=stride,
|
| 65 |
padding=padding,
|
| 66 |
)
|
| 67 |
-
|
| 68 |
-
# Set up causal padding
|
| 69 |
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
| 70 |
self.padding = (0, 0, 0)
|
| 71 |
|
|
@@ -78,11 +57,9 @@ class WanCausalConv3d(nn.Conv3d):
|
|
| 78 |
x = F.pad(x, padding)
|
| 79 |
return super().forward(x)
|
| 80 |
|
| 81 |
-
|
| 82 |
class WanRMS_norm(nn.Module):
|
| 83 |
r"""
|
| 84 |
A custom RMS normalization layer.
|
| 85 |
-
|
| 86 |
Args:
|
| 87 |
dim (int): The number of dimensions to normalize over.
|
| 88 |
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
|
@@ -90,12 +67,10 @@ class WanRMS_norm(nn.Module):
|
|
| 90 |
images (bool, optional): Whether the input represents image data. Default is True.
|
| 91 |
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
| 92 |
"""
|
| 93 |
-
|
| 94 |
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
| 95 |
super().__init__()
|
| 96 |
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 97 |
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 98 |
-
|
| 99 |
self.channel_first = channel_first
|
| 100 |
self.scale = dim**0.5
|
| 101 |
self.gamma = nn.Parameter(torch.ones(shape))
|
|
@@ -104,26 +79,20 @@ class WanRMS_norm(nn.Module):
|
|
| 104 |
def forward(self, x):
|
| 105 |
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
| 106 |
|
| 107 |
-
|
| 108 |
class WanUpsample(nn.Upsample):
|
| 109 |
r"""
|
| 110 |
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
| 111 |
-
|
| 112 |
Args:
|
| 113 |
x (torch.Tensor): Input tensor to be upsampled.
|
| 114 |
-
|
| 115 |
Returns:
|
| 116 |
torch.Tensor: Upsampled tensor with the same data type as the input.
|
| 117 |
"""
|
| 118 |
-
|
| 119 |
def forward(self, x):
|
| 120 |
return super().forward(x.float()).type_as(x)
|
| 121 |
|
| 122 |
-
|
| 123 |
class WanResample(nn.Module):
|
| 124 |
r"""
|
| 125 |
A custom resampling module for 2D and 3D data.
|
| 126 |
-
|
| 127 |
Args:
|
| 128 |
dim (int): The number of input/output channels.
|
| 129 |
mode (str): The resampling mode. Must be one of:
|
|
@@ -133,13 +102,10 @@ class WanResample(nn.Module):
|
|
| 133 |
- 'downsample2d': 2D downsampling with zero-padding and convolution.
|
| 134 |
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
| 135 |
"""
|
| 136 |
-
|
| 137 |
def __init__(self, dim: int, mode: str) -> None:
|
| 138 |
super().__init__()
|
| 139 |
self.dim = dim
|
| 140 |
self.mode = mode
|
| 141 |
-
|
| 142 |
-
# layers
|
| 143 |
if mode == "upsample2d":
|
| 144 |
self.resample = nn.Sequential(
|
| 145 |
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
|
@@ -149,13 +115,11 @@ class WanResample(nn.Module):
|
|
| 149 |
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
| 150 |
)
|
| 151 |
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 152 |
-
|
| 153 |
elif mode == "downsample2d":
|
| 154 |
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 155 |
elif mode == "downsample3d":
|
| 156 |
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 157 |
self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 158 |
-
|
| 159 |
else:
|
| 160 |
self.resample = nn.Identity()
|
| 161 |
|
|
@@ -170,7 +134,6 @@ class WanResample(nn.Module):
|
|
| 170 |
else:
|
| 171 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 172 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
| 173 |
-
# cache last frame of last two chunk
|
| 174 |
cache_x = torch.cat(
|
| 175 |
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
| 176 |
)
|
|
@@ -182,7 +145,6 @@ class WanResample(nn.Module):
|
|
| 182 |
x = self.time_conv(x, feat_cache[idx])
|
| 183 |
feat_cache[idx] = cache_x
|
| 184 |
feat_idx[0] += 1
|
| 185 |
-
|
| 186 |
x = x.reshape(b, 2, c, t, h, w)
|
| 187 |
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
| 188 |
x = x.reshape(b, c, t * 2, h, w)
|
|
@@ -190,7 +152,6 @@ class WanResample(nn.Module):
|
|
| 190 |
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 191 |
x = self.resample(x)
|
| 192 |
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
|
| 193 |
-
|
| 194 |
if self.mode == "downsample3d":
|
| 195 |
if feat_cache is not None:
|
| 196 |
idx = feat_idx[0]
|
|
@@ -204,18 +165,15 @@ class WanResample(nn.Module):
|
|
| 204 |
feat_idx[0] += 1
|
| 205 |
return x
|
| 206 |
|
| 207 |
-
|
| 208 |
class WanResidualBlock(nn.Module):
|
| 209 |
r"""
|
| 210 |
A custom residual block module.
|
| 211 |
-
|
| 212 |
Args:
|
| 213 |
in_dim (int): Number of input channels.
|
| 214 |
out_dim (int): Number of output channels.
|
| 215 |
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
|
| 216 |
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
| 217 |
"""
|
| 218 |
-
|
| 219 |
def __init__(
|
| 220 |
self,
|
| 221 |
in_dim: int,
|
|
@@ -227,8 +185,6 @@ class WanResidualBlock(nn.Module):
|
|
| 227 |
self.in_dim = in_dim
|
| 228 |
self.out_dim = out_dim
|
| 229 |
self.nonlinearity = get_activation(non_linearity)
|
| 230 |
-
|
| 231 |
-
# layers
|
| 232 |
self.norm1 = WanRMS_norm(in_dim, images=False)
|
| 233 |
self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
|
| 234 |
self.norm2 = WanRMS_norm(out_dim, images=False)
|
|
@@ -237,61 +193,43 @@ class WanResidualBlock(nn.Module):
|
|
| 237 |
self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
| 238 |
|
| 239 |
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 240 |
-
# Apply shortcut connection
|
| 241 |
h = self.conv_shortcut(x)
|
| 242 |
-
|
| 243 |
-
# First normalization and activation
|
| 244 |
x = self.norm1(x)
|
| 245 |
x = self.nonlinearity(x)
|
| 246 |
-
|
| 247 |
if feat_cache is not None:
|
| 248 |
idx = feat_idx[0]
|
| 249 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 250 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 251 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 252 |
-
|
| 253 |
x = self.conv1(x, feat_cache[idx])
|
| 254 |
feat_cache[idx] = cache_x
|
| 255 |
feat_idx[0] += 1
|
| 256 |
else:
|
| 257 |
x = self.conv1(x)
|
| 258 |
-
|
| 259 |
-
# Second normalization and activation
|
| 260 |
x = self.norm2(x)
|
| 261 |
x = self.nonlinearity(x)
|
| 262 |
-
|
| 263 |
-
# Dropout
|
| 264 |
x = self.dropout(x)
|
| 265 |
-
|
| 266 |
if feat_cache is not None:
|
| 267 |
idx = feat_idx[0]
|
| 268 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 269 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 270 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 271 |
-
|
| 272 |
x = self.conv2(x, feat_cache[idx])
|
| 273 |
feat_cache[idx] = cache_x
|
| 274 |
feat_idx[0] += 1
|
| 275 |
else:
|
| 276 |
x = self.conv2(x)
|
| 277 |
-
|
| 278 |
-
# Add residual connection
|
| 279 |
return x + h
|
| 280 |
|
| 281 |
-
|
| 282 |
class WanAttentionBlock(nn.Module):
|
| 283 |
r"""
|
| 284 |
Causal self-attention with a single head.
|
| 285 |
-
|
| 286 |
Args:
|
| 287 |
dim (int): The number of channels in the input tensor.
|
| 288 |
"""
|
| 289 |
-
|
| 290 |
def __init__(self, dim):
|
| 291 |
super().__init__()
|
| 292 |
self.dim = dim
|
| 293 |
-
|
| 294 |
-
# layers
|
| 295 |
self.norm = WanRMS_norm(dim)
|
| 296 |
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 297 |
self.proj = nn.Conv2d(dim, dim, 1)
|
|
@@ -299,46 +237,30 @@ class WanAttentionBlock(nn.Module):
|
|
| 299 |
def forward(self, x):
|
| 300 |
identity = x
|
| 301 |
batch_size, channels, time, height, width = x.size()
|
| 302 |
-
|
| 303 |
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
|
| 304 |
x = self.norm(x)
|
| 305 |
-
|
| 306 |
-
# compute query, key, value
|
| 307 |
qkv = self.to_qkv(x)
|
| 308 |
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
|
| 309 |
qkv = qkv.permute(0, 1, 3, 2).contiguous()
|
| 310 |
q, k, v = qkv.chunk(3, dim=-1)
|
| 311 |
-
|
| 312 |
-
# apply attention
|
| 313 |
x = F.scaled_dot_product_attention(q, k, v)
|
| 314 |
-
|
| 315 |
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
|
| 316 |
-
|
| 317 |
-
# output projection
|
| 318 |
x = self.proj(x)
|
| 319 |
-
|
| 320 |
-
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
|
| 321 |
x = x.view(batch_size, time, channels, height, width)
|
| 322 |
x = x.permute(0, 2, 1, 3, 4)
|
| 323 |
-
|
| 324 |
return x + identity
|
| 325 |
|
| 326 |
-
|
| 327 |
class WanMidBlock(nn.Module):
|
| 328 |
"""
|
| 329 |
Middle block for WanVAE encoder and decoder.
|
| 330 |
-
|
| 331 |
Args:
|
| 332 |
dim (int): Number of input/output channels.
|
| 333 |
dropout (float): Dropout rate.
|
| 334 |
non_linearity (str): Type of non-linearity to use.
|
| 335 |
"""
|
| 336 |
-
|
| 337 |
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
|
| 338 |
super().__init__()
|
| 339 |
self.dim = dim
|
| 340 |
-
|
| 341 |
-
# Create the components
|
| 342 |
resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
|
| 343 |
attentions = []
|
| 344 |
for _ in range(num_layers):
|
|
@@ -346,27 +268,19 @@ class WanMidBlock(nn.Module):
|
|
| 346 |
resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
|
| 347 |
self.attentions = nn.ModuleList(attentions)
|
| 348 |
self.resnets = nn.ModuleList(resnets)
|
| 349 |
-
|
| 350 |
self.gradient_checkpointing = False
|
| 351 |
|
| 352 |
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 353 |
-
# First residual block
|
| 354 |
x = self.resnets[0](x, feat_cache, feat_idx)
|
| 355 |
-
|
| 356 |
-
# Process through attention and residual blocks
|
| 357 |
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
| 358 |
if attn is not None:
|
| 359 |
x = attn(x)
|
| 360 |
-
|
| 361 |
x = resnet(x, feat_cache, feat_idx)
|
| 362 |
-
|
| 363 |
return x
|
| 364 |
|
| 365 |
-
|
| 366 |
class WanEncoder3d(nn.Module):
|
| 367 |
r"""
|
| 368 |
A 3D encoder module.
|
| 369 |
-
|
| 370 |
Args:
|
| 371 |
dim (int): The base number of channels in the first layer.
|
| 372 |
z_dim (int): The dimensionality of the latent space.
|
|
@@ -377,7 +291,6 @@ class WanEncoder3d(nn.Module):
|
|
| 377 |
dropout (float): Dropout rate for the dropout layers.
|
| 378 |
non_linearity (str): Type of non-linearity to use.
|
| 379 |
"""
|
| 380 |
-
|
| 381 |
def __init__(
|
| 382 |
self,
|
| 383 |
dim=128,
|
|
@@ -397,37 +310,23 @@ class WanEncoder3d(nn.Module):
|
|
| 397 |
self.attn_scales = attn_scales
|
| 398 |
self.temperal_downsample = temperal_downsample
|
| 399 |
self.nonlinearity = get_activation(non_linearity)
|
| 400 |
-
|
| 401 |
-
# dimensions
|
| 402 |
dims = [dim * u for u in [1] + dim_mult]
|
| 403 |
scale = 1.0
|
| 404 |
-
|
| 405 |
-
# init block
|
| 406 |
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
|
| 407 |
-
|
| 408 |
-
# downsample blocks
|
| 409 |
self.down_blocks = nn.ModuleList([])
|
| 410 |
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 411 |
-
# residual (+attention) blocks
|
| 412 |
for _ in range(num_res_blocks):
|
| 413 |
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
| 414 |
if scale in attn_scales:
|
| 415 |
self.down_blocks.append(WanAttentionBlock(out_dim))
|
| 416 |
in_dim = out_dim
|
| 417 |
-
|
| 418 |
-
# downsample block
|
| 419 |
if i != len(dim_mult) - 1:
|
| 420 |
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
| 421 |
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
| 422 |
scale /= 2.0
|
| 423 |
-
|
| 424 |
-
# middle blocks
|
| 425 |
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
| 426 |
-
|
| 427 |
-
# output blocks
|
| 428 |
self.norm_out = WanRMS_norm(out_dim, images=False)
|
| 429 |
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
|
| 430 |
-
|
| 431 |
self.gradient_checkpointing = False
|
| 432 |
|
| 433 |
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
@@ -435,32 +334,24 @@ class WanEncoder3d(nn.Module):
|
|
| 435 |
idx = feat_idx[0]
|
| 436 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 437 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 438 |
-
# cache last frame of last two chunk
|
| 439 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 440 |
x = self.conv_in(x, feat_cache[idx])
|
| 441 |
feat_cache[idx] = cache_x
|
| 442 |
feat_idx[0] += 1
|
| 443 |
else:
|
| 444 |
x = self.conv_in(x)
|
| 445 |
-
|
| 446 |
-
## downsamples
|
| 447 |
for layer in self.down_blocks:
|
| 448 |
if feat_cache is not None:
|
| 449 |
x = layer(x, feat_cache, feat_idx)
|
| 450 |
else:
|
| 451 |
x = layer(x)
|
| 452 |
-
|
| 453 |
-
## middle
|
| 454 |
x = self.mid_block(x, feat_cache, feat_idx)
|
| 455 |
-
|
| 456 |
-
## head
|
| 457 |
x = self.norm_out(x)
|
| 458 |
x = self.nonlinearity(x)
|
| 459 |
if feat_cache is not None:
|
| 460 |
idx = feat_idx[0]
|
| 461 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 462 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 463 |
-
# cache last frame of last two chunk
|
| 464 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 465 |
x = self.conv_out(x, feat_cache[idx])
|
| 466 |
feat_cache[idx] = cache_x
|
|
@@ -469,11 +360,9 @@ class WanEncoder3d(nn.Module):
|
|
| 469 |
x = self.conv_out(x)
|
| 470 |
return x
|
| 471 |
|
| 472 |
-
|
| 473 |
class WanUpBlock(nn.Module):
|
| 474 |
"""
|
| 475 |
A block that handles upsampling for the WanVAE decoder.
|
| 476 |
-
|
| 477 |
Args:
|
| 478 |
in_dim (int): Input dimension
|
| 479 |
out_dim (int): Output dimension
|
|
@@ -482,7 +371,6 @@ class WanUpBlock(nn.Module):
|
|
| 482 |
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
|
| 483 |
non_linearity (str): Type of non-linearity to use
|
| 484 |
"""
|
| 485 |
-
|
| 486 |
def __init__(
|
| 487 |
self,
|
| 488 |
in_dim: int,
|
|
@@ -495,42 +383,23 @@ class WanUpBlock(nn.Module):
|
|
| 495 |
super().__init__()
|
| 496 |
self.in_dim = in_dim
|
| 497 |
self.out_dim = out_dim
|
| 498 |
-
|
| 499 |
-
# Create layers list
|
| 500 |
resnets = []
|
| 501 |
-
# Add residual blocks and attention if needed
|
| 502 |
current_dim = in_dim
|
| 503 |
for _ in range(num_res_blocks + 1):
|
| 504 |
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
| 505 |
current_dim = out_dim
|
| 506 |
-
|
| 507 |
self.resnets = nn.ModuleList(resnets)
|
| 508 |
-
|
| 509 |
-
# Add upsampling layer if needed
|
| 510 |
self.upsamplers = None
|
| 511 |
if upsample_mode is not None:
|
| 512 |
self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
|
| 513 |
-
|
| 514 |
self.gradient_checkpointing = False
|
| 515 |
|
| 516 |
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 517 |
-
"""
|
| 518 |
-
Forward pass through the upsampling block.
|
| 519 |
-
|
| 520 |
-
Args:
|
| 521 |
-
x (torch.Tensor): Input tensor
|
| 522 |
-
feat_cache (list, optional): Feature cache for causal convolutions
|
| 523 |
-
feat_idx (list, optional): Feature index for cache management
|
| 524 |
-
|
| 525 |
-
Returns:
|
| 526 |
-
torch.Tensor: Output tensor
|
| 527 |
-
"""
|
| 528 |
for resnet in self.resnets:
|
| 529 |
if feat_cache is not None:
|
| 530 |
x = resnet(x, feat_cache, feat_idx)
|
| 531 |
else:
|
| 532 |
x = resnet(x)
|
| 533 |
-
|
| 534 |
if self.upsamplers is not None:
|
| 535 |
if feat_cache is not None:
|
| 536 |
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
|
@@ -538,11 +407,9 @@ class WanUpBlock(nn.Module):
|
|
| 538 |
x = self.upsamplers[0](x)
|
| 539 |
return x
|
| 540 |
|
| 541 |
-
|
| 542 |
class WanDecoder3d(nn.Module):
|
| 543 |
r"""
|
| 544 |
A 3D decoder module.
|
| 545 |
-
|
| 546 |
Args:
|
| 547 |
dim (int): The base number of channels in the first layer.
|
| 548 |
z_dim (int): The dimensionality of the latent space.
|
|
@@ -553,7 +420,6 @@ class WanDecoder3d(nn.Module):
|
|
| 553 |
dropout (float): Dropout rate for the dropout layers.
|
| 554 |
non_linearity (str): Type of non-linearity to use.
|
| 555 |
"""
|
| 556 |
-
|
| 557 |
def __init__(
|
| 558 |
self,
|
| 559 |
dim=128,
|
|
@@ -572,32 +438,18 @@ class WanDecoder3d(nn.Module):
|
|
| 572 |
self.num_res_blocks = num_res_blocks
|
| 573 |
self.attn_scales = attn_scales
|
| 574 |
self.temperal_upsample = temperal_upsample
|
| 575 |
-
|
| 576 |
self.nonlinearity = get_activation(non_linearity)
|
| 577 |
-
|
| 578 |
-
# dimensions
|
| 579 |
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 580 |
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
| 581 |
-
|
| 582 |
-
# init block
|
| 583 |
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 584 |
-
|
| 585 |
-
# middle blocks
|
| 586 |
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
|
| 587 |
-
|
| 588 |
-
# upsample blocks
|
| 589 |
self.up_blocks = nn.ModuleList([])
|
| 590 |
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 591 |
-
# residual (+attention) blocks
|
| 592 |
if i > 0:
|
| 593 |
in_dim = in_dim // 2
|
| 594 |
-
|
| 595 |
-
# Determine if we need upsampling
|
| 596 |
upsample_mode = None
|
| 597 |
if i != len(dim_mult) - 1:
|
| 598 |
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
| 599 |
-
|
| 600 |
-
# Create and add the upsampling block
|
| 601 |
up_block = WanUpBlock(
|
| 602 |
in_dim=in_dim,
|
| 603 |
out_dim=out_dim,
|
|
@@ -607,46 +459,32 @@ class WanDecoder3d(nn.Module):
|
|
| 607 |
non_linearity=non_linearity,
|
| 608 |
)
|
| 609 |
self.up_blocks.append(up_block)
|
| 610 |
-
|
| 611 |
-
# Update scale for next iteration
|
| 612 |
if upsample_mode is not None:
|
| 613 |
scale *= 2.0
|
| 614 |
-
|
| 615 |
-
# output blocks
|
| 616 |
self.norm_out = WanRMS_norm(out_dim, images=False)
|
| 617 |
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
|
| 618 |
-
|
| 619 |
self.gradient_checkpointing = False
|
| 620 |
|
| 621 |
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 622 |
-
## conv1
|
| 623 |
if feat_cache is not None:
|
| 624 |
idx = feat_idx[0]
|
| 625 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 626 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 627 |
-
# cache last frame of last two chunk
|
| 628 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 629 |
x = self.conv_in(x, feat_cache[idx])
|
| 630 |
feat_cache[idx] = cache_x
|
| 631 |
feat_idx[0] += 1
|
| 632 |
else:
|
| 633 |
x = self.conv_in(x)
|
| 634 |
-
|
| 635 |
-
## middle
|
| 636 |
x = self.mid_block(x, feat_cache, feat_idx)
|
| 637 |
-
|
| 638 |
-
## upsamples
|
| 639 |
for up_block in self.up_blocks:
|
| 640 |
x = up_block(x, feat_cache, feat_idx)
|
| 641 |
-
|
| 642 |
-
## head
|
| 643 |
x = self.norm_out(x)
|
| 644 |
x = self.nonlinearity(x)
|
| 645 |
if feat_cache is not None:
|
| 646 |
idx = feat_idx[0]
|
| 647 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 648 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 649 |
-
# cache last frame of last two chunk
|
| 650 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 651 |
x = self.conv_out(x, feat_cache[idx])
|
| 652 |
feat_cache[idx] = cache_x
|
|
@@ -655,16 +493,13 @@ class WanDecoder3d(nn.Module):
|
|
| 655 |
x = self.conv_out(x)
|
| 656 |
return x
|
| 657 |
|
| 658 |
-
|
| 659 |
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 660 |
r"""
|
| 661 |
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
| 662 |
Introduced in [Wan 2.1].
|
| 663 |
-
|
| 664 |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 665 |
for all models (such as downloading or saving).
|
| 666 |
"""
|
| 667 |
-
|
| 668 |
_supports_gradient_checkpointing = False
|
| 669 |
|
| 670 |
@register_to_config
|
|
@@ -678,54 +513,23 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 678 |
temperal_downsample: List[bool] = [False, True, True],
|
| 679 |
dropout: float = 0.0,
|
| 680 |
latents_mean: List[float] = [
|
| 681 |
-
-0.7571,
|
| 682 |
-
-0.
|
| 683 |
-
-0.9113,
|
| 684 |
-
0.1075,
|
| 685 |
-
-0.1745,
|
| 686 |
-
0.9653,
|
| 687 |
-
-0.1517,
|
| 688 |
-
1.5508,
|
| 689 |
-
0.4134,
|
| 690 |
-
-0.0715,
|
| 691 |
-
0.5517,
|
| 692 |
-
-0.3632,
|
| 693 |
-
-0.1922,
|
| 694 |
-
-0.9497,
|
| 695 |
-
0.2503,
|
| 696 |
-
-0.2921,
|
| 697 |
],
|
| 698 |
latents_std: List[float] = [
|
| 699 |
-
2.8184,
|
| 700 |
-
1.
|
| 701 |
-
2.3275,
|
| 702 |
-
2.6558,
|
| 703 |
-
1.2196,
|
| 704 |
-
1.7708,
|
| 705 |
-
2.6052,
|
| 706 |
-
2.0743,
|
| 707 |
-
3.2687,
|
| 708 |
-
2.1526,
|
| 709 |
-
2.8652,
|
| 710 |
-
1.5579,
|
| 711 |
-
1.6382,
|
| 712 |
-
1.1253,
|
| 713 |
-
2.8251,
|
| 714 |
-
1.9160,
|
| 715 |
],
|
| 716 |
) -> None:
|
| 717 |
super().__init__()
|
| 718 |
-
|
| 719 |
self.z_dim = z_dim
|
| 720 |
self.temperal_downsample = temperal_downsample
|
| 721 |
self.temperal_upsample = temperal_downsample[::-1]
|
| 722 |
-
|
| 723 |
self.encoder = WanEncoder3d(
|
| 724 |
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
| 725 |
)
|
| 726 |
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 727 |
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
| 728 |
-
|
| 729 |
self.decoder = WanDecoder3d(
|
| 730 |
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
| 731 |
)
|
|
@@ -741,14 +545,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 741 |
self._conv_num = _count_conv3d(self.decoder)
|
| 742 |
self._conv_idx = [0]
|
| 743 |
self._feat_map = [None] * self._conv_num
|
| 744 |
-
# cache encode
|
| 745 |
self._enc_conv_num = _count_conv3d(self.encoder)
|
| 746 |
self._enc_conv_idx = [0]
|
| 747 |
self._enc_feat_map = [None] * self._enc_conv_num
|
| 748 |
|
| 749 |
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 750 |
self.clear_cache()
|
| 751 |
-
## cache
|
| 752 |
t = x.shape[2]
|
| 753 |
iter_ = 1 + (t - 1) // 4
|
| 754 |
for i in range(iter_):
|
|
@@ -762,7 +564,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 762 |
feat_idx=self._enc_conv_idx,
|
| 763 |
)
|
| 764 |
out = torch.cat([out, out_], 2)
|
| 765 |
-
|
| 766 |
enc = self.quant_conv(out)
|
| 767 |
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
|
| 768 |
enc = torch.cat([mu, logvar], dim=1)
|
|
@@ -775,12 +576,10 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 775 |
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 776 |
r"""
|
| 777 |
Encode a batch of images into latents.
|
| 778 |
-
|
| 779 |
Args:
|
| 780 |
x (`torch.Tensor`): Input batch of images.
|
| 781 |
return_dict (`bool`, *optional*, defaults to `True`):
|
| 782 |
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 783 |
-
|
| 784 |
Returns:
|
| 785 |
The latent representations of the encoded videos. If `return_dict` is True, a
|
| 786 |
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
|
@@ -793,7 +592,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 793 |
|
| 794 |
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 795 |
self.clear_cache()
|
| 796 |
-
|
| 797 |
iter_ = z.shape[2]
|
| 798 |
x = self.post_quant_conv(z)
|
| 799 |
for i in range(iter_):
|
|
@@ -803,24 +601,20 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 803 |
else:
|
| 804 |
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 805 |
out = torch.cat([out, out_], 2)
|
| 806 |
-
|
| 807 |
out = torch.clamp(out, min=-1.0, max=1.0)
|
| 808 |
self.clear_cache()
|
| 809 |
if not return_dict:
|
| 810 |
return (out,)
|
| 811 |
-
|
| 812 |
return DecoderOutput(sample=out)
|
| 813 |
|
| 814 |
@apply_forward_hook
|
| 815 |
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 816 |
r"""
|
| 817 |
Decode a batch of images.
|
| 818 |
-
|
| 819 |
Args:
|
| 820 |
z (`torch.Tensor`): Input batch of latent vectors.
|
| 821 |
return_dict (`bool`, *optional*, defaults to `True`):
|
| 822 |
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 823 |
-
|
| 824 |
Returns:
|
| 825 |
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 826 |
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
|
@@ -829,7 +623,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 829 |
decoded = self._decode(z).sample
|
| 830 |
if not return_dict:
|
| 831 |
return (decoded,)
|
| 832 |
-
|
| 833 |
return DecoderOutput(sample=decoded)
|
| 834 |
|
| 835 |
def forward(
|
|
@@ -852,4 +645,4 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
| 852 |
else:
|
| 853 |
z = posterior.mode()
|
| 854 |
dec = self.decode(z, return_dict=return_dict)
|
| 855 |
-
return dec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import List, Optional, Tuple, Union
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
import torch.utils.checkpoint
|
| 7 |
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config # 상대 → 절대 임포트
|
| 9 |
+
from diffusers.loaders import FromOriginalModelMixin # 상대 → 절대 임포트
|
| 10 |
+
from diffusers.utils import logging # 상대 → 절대 임포트
|
| 11 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook # 상대 → 절대 임포트
|
| 12 |
+
from diffusers.models.activations import get_activation # 상대 → 절대 임포트
|
| 13 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput # 상대 → 절대 임포트
|
| 14 |
+
from diffusers.models.modeling_utils import ModelMixin # 상대 → 절대 임포트
|
| 15 |
+
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution # 상대 → 절대 임포트
|
|
|
|
| 16 |
|
| 17 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 18 |
|
| 19 |
CACHE_T = 2
|
| 20 |
|
|
|
|
| 21 |
class WanCausalConv3d(nn.Conv3d):
|
| 22 |
r"""
|
| 23 |
A custom 3D causal convolution layer with feature caching support.
|
|
|
|
| 24 |
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
|
| 25 |
caching for efficient inference.
|
|
|
|
| 26 |
Args:
|
| 27 |
in_channels (int): Number of channels in the input image
|
| 28 |
out_channels (int): Number of channels produced by the convolution
|
|
|
|
| 30 |
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 31 |
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
|
| 32 |
"""
|
|
|
|
| 33 |
def __init__(
|
| 34 |
self,
|
| 35 |
in_channels: int,
|
|
|
|
| 45 |
stride=stride,
|
| 46 |
padding=padding,
|
| 47 |
)
|
|
|
|
|
|
|
| 48 |
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
| 49 |
self.padding = (0, 0, 0)
|
| 50 |
|
|
|
|
| 57 |
x = F.pad(x, padding)
|
| 58 |
return super().forward(x)
|
| 59 |
|
|
|
|
| 60 |
class WanRMS_norm(nn.Module):
|
| 61 |
r"""
|
| 62 |
A custom RMS normalization layer.
|
|
|
|
| 63 |
Args:
|
| 64 |
dim (int): The number of dimensions to normalize over.
|
| 65 |
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
|
|
|
| 67 |
images (bool, optional): Whether the input represents image data. Default is True.
|
| 68 |
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
| 69 |
"""
|
|
|
|
| 70 |
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
| 71 |
super().__init__()
|
| 72 |
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 73 |
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
|
|
|
| 74 |
self.channel_first = channel_first
|
| 75 |
self.scale = dim**0.5
|
| 76 |
self.gamma = nn.Parameter(torch.ones(shape))
|
|
|
|
| 79 |
def forward(self, x):
|
| 80 |
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
| 81 |
|
|
|
|
| 82 |
class WanUpsample(nn.Upsample):
|
| 83 |
r"""
|
| 84 |
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
|
|
|
| 85 |
Args:
|
| 86 |
x (torch.Tensor): Input tensor to be upsampled.
|
|
|
|
| 87 |
Returns:
|
| 88 |
torch.Tensor: Upsampled tensor with the same data type as the input.
|
| 89 |
"""
|
|
|
|
| 90 |
def forward(self, x):
|
| 91 |
return super().forward(x.float()).type_as(x)
|
| 92 |
|
|
|
|
| 93 |
class WanResample(nn.Module):
|
| 94 |
r"""
|
| 95 |
A custom resampling module for 2D and 3D data.
|
|
|
|
| 96 |
Args:
|
| 97 |
dim (int): The number of input/output channels.
|
| 98 |
mode (str): The resampling mode. Must be one of:
|
|
|
|
| 102 |
- 'downsample2d': 2D downsampling with zero-padding and convolution.
|
| 103 |
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
| 104 |
"""
|
|
|
|
| 105 |
def __init__(self, dim: int, mode: str) -> None:
|
| 106 |
super().__init__()
|
| 107 |
self.dim = dim
|
| 108 |
self.mode = mode
|
|
|
|
|
|
|
| 109 |
if mode == "upsample2d":
|
| 110 |
self.resample = nn.Sequential(
|
| 111 |
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
|
|
|
| 115 |
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
| 116 |
)
|
| 117 |
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
|
|
|
| 118 |
elif mode == "downsample2d":
|
| 119 |
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 120 |
elif mode == "downsample3d":
|
| 121 |
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 122 |
self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
|
|
|
| 123 |
else:
|
| 124 |
self.resample = nn.Identity()
|
| 125 |
|
|
|
|
| 134 |
else:
|
| 135 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 136 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
|
|
|
| 137 |
cache_x = torch.cat(
|
| 138 |
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
| 139 |
)
|
|
|
|
| 145 |
x = self.time_conv(x, feat_cache[idx])
|
| 146 |
feat_cache[idx] = cache_x
|
| 147 |
feat_idx[0] += 1
|
|
|
|
| 148 |
x = x.reshape(b, 2, c, t, h, w)
|
| 149 |
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
| 150 |
x = x.reshape(b, c, t * 2, h, w)
|
|
|
|
| 152 |
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 153 |
x = self.resample(x)
|
| 154 |
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
|
|
|
|
| 155 |
if self.mode == "downsample3d":
|
| 156 |
if feat_cache is not None:
|
| 157 |
idx = feat_idx[0]
|
|
|
|
| 165 |
feat_idx[0] += 1
|
| 166 |
return x
|
| 167 |
|
|
|
|
| 168 |
class WanResidualBlock(nn.Module):
|
| 169 |
r"""
|
| 170 |
A custom residual block module.
|
|
|
|
| 171 |
Args:
|
| 172 |
in_dim (int): Number of input channels.
|
| 173 |
out_dim (int): Number of output channels.
|
| 174 |
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
|
| 175 |
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
| 176 |
"""
|
|
|
|
| 177 |
def __init__(
|
| 178 |
self,
|
| 179 |
in_dim: int,
|
|
|
|
| 185 |
self.in_dim = in_dim
|
| 186 |
self.out_dim = out_dim
|
| 187 |
self.nonlinearity = get_activation(non_linearity)
|
|
|
|
|
|
|
| 188 |
self.norm1 = WanRMS_norm(in_dim, images=False)
|
| 189 |
self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
|
| 190 |
self.norm2 = WanRMS_norm(out_dim, images=False)
|
|
|
|
| 193 |
self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
| 194 |
|
| 195 |
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
|
|
| 196 |
h = self.conv_shortcut(x)
|
|
|
|
|
|
|
| 197 |
x = self.norm1(x)
|
| 198 |
x = self.nonlinearity(x)
|
|
|
|
| 199 |
if feat_cache is not None:
|
| 200 |
idx = feat_idx[0]
|
| 201 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 202 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 203 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
|
|
|
| 204 |
x = self.conv1(x, feat_cache[idx])
|
| 205 |
feat_cache[idx] = cache_x
|
| 206 |
feat_idx[0] += 1
|
| 207 |
else:
|
| 208 |
x = self.conv1(x)
|
|
|
|
|
|
|
| 209 |
x = self.norm2(x)
|
| 210 |
x = self.nonlinearity(x)
|
|
|
|
|
|
|
| 211 |
x = self.dropout(x)
|
|
|
|
| 212 |
if feat_cache is not None:
|
| 213 |
idx = feat_idx[0]
|
| 214 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 215 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 216 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
|
|
|
| 217 |
x = self.conv2(x, feat_cache[idx])
|
| 218 |
feat_cache[idx] = cache_x
|
| 219 |
feat_idx[0] += 1
|
| 220 |
else:
|
| 221 |
x = self.conv2(x)
|
|
|
|
|
|
|
| 222 |
return x + h
|
| 223 |
|
|
|
|
| 224 |
class WanAttentionBlock(nn.Module):
|
| 225 |
r"""
|
| 226 |
Causal self-attention with a single head.
|
|
|
|
| 227 |
Args:
|
| 228 |
dim (int): The number of channels in the input tensor.
|
| 229 |
"""
|
|
|
|
| 230 |
def __init__(self, dim):
|
| 231 |
super().__init__()
|
| 232 |
self.dim = dim
|
|
|
|
|
|
|
| 233 |
self.norm = WanRMS_norm(dim)
|
| 234 |
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 235 |
self.proj = nn.Conv2d(dim, dim, 1)
|
|
|
|
| 237 |
def forward(self, x):
|
| 238 |
identity = x
|
| 239 |
batch_size, channels, time, height, width = x.size()
|
|
|
|
| 240 |
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
|
| 241 |
x = self.norm(x)
|
|
|
|
|
|
|
| 242 |
qkv = self.to_qkv(x)
|
| 243 |
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
|
| 244 |
qkv = qkv.permute(0, 1, 3, 2).contiguous()
|
| 245 |
q, k, v = qkv.chunk(3, dim=-1)
|
|
|
|
|
|
|
| 246 |
x = F.scaled_dot_product_attention(q, k, v)
|
|
|
|
| 247 |
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
|
|
|
|
|
|
|
| 248 |
x = self.proj(x)
|
|
|
|
|
|
|
| 249 |
x = x.view(batch_size, time, channels, height, width)
|
| 250 |
x = x.permute(0, 2, 1, 3, 4)
|
|
|
|
| 251 |
return x + identity
|
| 252 |
|
|
|
|
| 253 |
class WanMidBlock(nn.Module):
|
| 254 |
"""
|
| 255 |
Middle block for WanVAE encoder and decoder.
|
|
|
|
| 256 |
Args:
|
| 257 |
dim (int): Number of input/output channels.
|
| 258 |
dropout (float): Dropout rate.
|
| 259 |
non_linearity (str): Type of non-linearity to use.
|
| 260 |
"""
|
|
|
|
| 261 |
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
|
| 262 |
super().__init__()
|
| 263 |
self.dim = dim
|
|
|
|
|
|
|
| 264 |
resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
|
| 265 |
attentions = []
|
| 266 |
for _ in range(num_layers):
|
|
|
|
| 268 |
resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
|
| 269 |
self.attentions = nn.ModuleList(attentions)
|
| 270 |
self.resnets = nn.ModuleList(resnets)
|
|
|
|
| 271 |
self.gradient_checkpointing = False
|
| 272 |
|
| 273 |
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
|
|
| 274 |
x = self.resnets[0](x, feat_cache, feat_idx)
|
|
|
|
|
|
|
| 275 |
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
| 276 |
if attn is not None:
|
| 277 |
x = attn(x)
|
|
|
|
| 278 |
x = resnet(x, feat_cache, feat_idx)
|
|
|
|
| 279 |
return x
|
| 280 |
|
|
|
|
| 281 |
class WanEncoder3d(nn.Module):
|
| 282 |
r"""
|
| 283 |
A 3D encoder module.
|
|
|
|
| 284 |
Args:
|
| 285 |
dim (int): The base number of channels in the first layer.
|
| 286 |
z_dim (int): The dimensionality of the latent space.
|
|
|
|
| 291 |
dropout (float): Dropout rate for the dropout layers.
|
| 292 |
non_linearity (str): Type of non-linearity to use.
|
| 293 |
"""
|
|
|
|
| 294 |
def __init__(
|
| 295 |
self,
|
| 296 |
dim=128,
|
|
|
|
| 310 |
self.attn_scales = attn_scales
|
| 311 |
self.temperal_downsample = temperal_downsample
|
| 312 |
self.nonlinearity = get_activation(non_linearity)
|
|
|
|
|
|
|
| 313 |
dims = [dim * u for u in [1] + dim_mult]
|
| 314 |
scale = 1.0
|
|
|
|
|
|
|
| 315 |
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
|
|
|
|
|
|
|
| 316 |
self.down_blocks = nn.ModuleList([])
|
| 317 |
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
|
|
|
| 318 |
for _ in range(num_res_blocks):
|
| 319 |
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
| 320 |
if scale in attn_scales:
|
| 321 |
self.down_blocks.append(WanAttentionBlock(out_dim))
|
| 322 |
in_dim = out_dim
|
|
|
|
|
|
|
| 323 |
if i != len(dim_mult) - 1:
|
| 324 |
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
| 325 |
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
| 326 |
scale /= 2.0
|
|
|
|
|
|
|
| 327 |
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
|
|
|
|
|
|
| 328 |
self.norm_out = WanRMS_norm(out_dim, images=False)
|
| 329 |
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
|
|
|
|
| 330 |
self.gradient_checkpointing = False
|
| 331 |
|
| 332 |
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
|
|
| 334 |
idx = feat_idx[0]
|
| 335 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 336 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
|
|
| 337 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 338 |
x = self.conv_in(x, feat_cache[idx])
|
| 339 |
feat_cache[idx] = cache_x
|
| 340 |
feat_idx[0] += 1
|
| 341 |
else:
|
| 342 |
x = self.conv_in(x)
|
|
|
|
|
|
|
| 343 |
for layer in self.down_blocks:
|
| 344 |
if feat_cache is not None:
|
| 345 |
x = layer(x, feat_cache, feat_idx)
|
| 346 |
else:
|
| 347 |
x = layer(x)
|
|
|
|
|
|
|
| 348 |
x = self.mid_block(x, feat_cache, feat_idx)
|
|
|
|
|
|
|
| 349 |
x = self.norm_out(x)
|
| 350 |
x = self.nonlinearity(x)
|
| 351 |
if feat_cache is not None:
|
| 352 |
idx = feat_idx[0]
|
| 353 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 354 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
|
|
| 355 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 356 |
x = self.conv_out(x, feat_cache[idx])
|
| 357 |
feat_cache[idx] = cache_x
|
|
|
|
| 360 |
x = self.conv_out(x)
|
| 361 |
return x
|
| 362 |
|
|
|
|
| 363 |
class WanUpBlock(nn.Module):
|
| 364 |
"""
|
| 365 |
A block that handles upsampling for the WanVAE decoder.
|
|
|
|
| 366 |
Args:
|
| 367 |
in_dim (int): Input dimension
|
| 368 |
out_dim (int): Output dimension
|
|
|
|
| 371 |
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
|
| 372 |
non_linearity (str): Type of non-linearity to use
|
| 373 |
"""
|
|
|
|
| 374 |
def __init__(
|
| 375 |
self,
|
| 376 |
in_dim: int,
|
|
|
|
| 383 |
super().__init__()
|
| 384 |
self.in_dim = in_dim
|
| 385 |
self.out_dim = out_dim
|
|
|
|
|
|
|
| 386 |
resnets = []
|
|
|
|
| 387 |
current_dim = in_dim
|
| 388 |
for _ in range(num_res_blocks + 1):
|
| 389 |
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
| 390 |
current_dim = out_dim
|
|
|
|
| 391 |
self.resnets = nn.ModuleList(resnets)
|
|
|
|
|
|
|
| 392 |
self.upsamplers = None
|
| 393 |
if upsample_mode is not None:
|
| 394 |
self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
|
|
|
|
| 395 |
self.gradient_checkpointing = False
|
| 396 |
|
| 397 |
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
for resnet in self.resnets:
|
| 399 |
if feat_cache is not None:
|
| 400 |
x = resnet(x, feat_cache, feat_idx)
|
| 401 |
else:
|
| 402 |
x = resnet(x)
|
|
|
|
| 403 |
if self.upsamplers is not None:
|
| 404 |
if feat_cache is not None:
|
| 405 |
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
|
|
|
| 407 |
x = self.upsamplers[0](x)
|
| 408 |
return x
|
| 409 |
|
|
|
|
| 410 |
class WanDecoder3d(nn.Module):
|
| 411 |
r"""
|
| 412 |
A 3D decoder module.
|
|
|
|
| 413 |
Args:
|
| 414 |
dim (int): The base number of channels in the first layer.
|
| 415 |
z_dim (int): The dimensionality of the latent space.
|
|
|
|
| 420 |
dropout (float): Dropout rate for the dropout layers.
|
| 421 |
non_linearity (str): Type of non-linearity to use.
|
| 422 |
"""
|
|
|
|
| 423 |
def __init__(
|
| 424 |
self,
|
| 425 |
dim=128,
|
|
|
|
| 438 |
self.num_res_blocks = num_res_blocks
|
| 439 |
self.attn_scales = attn_scales
|
| 440 |
self.temperal_upsample = temperal_upsample
|
|
|
|
| 441 |
self.nonlinearity = get_activation(non_linearity)
|
|
|
|
|
|
|
| 442 |
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 443 |
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
|
|
|
|
|
|
| 444 |
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
|
|
|
|
|
|
| 445 |
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
|
|
|
|
|
|
|
| 446 |
self.up_blocks = nn.ModuleList([])
|
| 447 |
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
|
|
|
| 448 |
if i > 0:
|
| 449 |
in_dim = in_dim // 2
|
|
|
|
|
|
|
| 450 |
upsample_mode = None
|
| 451 |
if i != len(dim_mult) - 1:
|
| 452 |
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
|
|
|
|
|
|
| 453 |
up_block = WanUpBlock(
|
| 454 |
in_dim=in_dim,
|
| 455 |
out_dim=out_dim,
|
|
|
|
| 459 |
non_linearity=non_linearity,
|
| 460 |
)
|
| 461 |
self.up_blocks.append(up_block)
|
|
|
|
|
|
|
| 462 |
if upsample_mode is not None:
|
| 463 |
scale *= 2.0
|
|
|
|
|
|
|
| 464 |
self.norm_out = WanRMS_norm(out_dim, images=False)
|
| 465 |
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
|
|
|
|
| 466 |
self.gradient_checkpointing = False
|
| 467 |
|
| 468 |
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
|
|
| 469 |
if feat_cache is not None:
|
| 470 |
idx = feat_idx[0]
|
| 471 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 472 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
|
|
| 473 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 474 |
x = self.conv_in(x, feat_cache[idx])
|
| 475 |
feat_cache[idx] = cache_x
|
| 476 |
feat_idx[0] += 1
|
| 477 |
else:
|
| 478 |
x = self.conv_in(x)
|
|
|
|
|
|
|
| 479 |
x = self.mid_block(x, feat_cache, feat_idx)
|
|
|
|
|
|
|
| 480 |
for up_block in self.up_blocks:
|
| 481 |
x = up_block(x, feat_cache, feat_idx)
|
|
|
|
|
|
|
| 482 |
x = self.norm_out(x)
|
| 483 |
x = self.nonlinearity(x)
|
| 484 |
if feat_cache is not None:
|
| 485 |
idx = feat_idx[0]
|
| 486 |
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 487 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
|
|
| 488 |
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 489 |
x = self.conv_out(x, feat_cache[idx])
|
| 490 |
feat_cache[idx] = cache_x
|
|
|
|
| 493 |
x = self.conv_out(x)
|
| 494 |
return x
|
| 495 |
|
|
|
|
| 496 |
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 497 |
r"""
|
| 498 |
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
| 499 |
Introduced in [Wan 2.1].
|
|
|
|
| 500 |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 501 |
for all models (such as downloading or saving).
|
| 502 |
"""
|
|
|
|
| 503 |
_supports_gradient_checkpointing = False
|
| 504 |
|
| 505 |
@register_to_config
|
|
|
|
| 513 |
temperal_downsample: List[bool] = [False, True, True],
|
| 514 |
dropout: float = 0.0,
|
| 515 |
latents_mean: List[float] = [
|
| 516 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
| 517 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
],
|
| 519 |
latents_std: List[float] = [
|
| 520 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
| 521 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
],
|
| 523 |
) -> None:
|
| 524 |
super().__init__()
|
|
|
|
| 525 |
self.z_dim = z_dim
|
| 526 |
self.temperal_downsample = temperal_downsample
|
| 527 |
self.temperal_upsample = temperal_downsample[::-1]
|
|
|
|
| 528 |
self.encoder = WanEncoder3d(
|
| 529 |
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
| 530 |
)
|
| 531 |
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 532 |
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
|
|
|
| 533 |
self.decoder = WanDecoder3d(
|
| 534 |
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
| 535 |
)
|
|
|
|
| 545 |
self._conv_num = _count_conv3d(self.decoder)
|
| 546 |
self._conv_idx = [0]
|
| 547 |
self._feat_map = [None] * self._conv_num
|
|
|
|
| 548 |
self._enc_conv_num = _count_conv3d(self.encoder)
|
| 549 |
self._enc_conv_idx = [0]
|
| 550 |
self._enc_feat_map = [None] * self._enc_conv_num
|
| 551 |
|
| 552 |
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 553 |
self.clear_cache()
|
|
|
|
| 554 |
t = x.shape[2]
|
| 555 |
iter_ = 1 + (t - 1) // 4
|
| 556 |
for i in range(iter_):
|
|
|
|
| 564 |
feat_idx=self._enc_conv_idx,
|
| 565 |
)
|
| 566 |
out = torch.cat([out, out_], 2)
|
|
|
|
| 567 |
enc = self.quant_conv(out)
|
| 568 |
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
|
| 569 |
enc = torch.cat([mu, logvar], dim=1)
|
|
|
|
| 576 |
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 577 |
r"""
|
| 578 |
Encode a batch of images into latents.
|
|
|
|
| 579 |
Args:
|
| 580 |
x (`torch.Tensor`): Input batch of images.
|
| 581 |
return_dict (`bool`, *optional*, defaults to `True`):
|
| 582 |
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
|
|
|
| 583 |
Returns:
|
| 584 |
The latent representations of the encoded videos. If `return_dict` is True, a
|
| 585 |
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
|
|
|
| 592 |
|
| 593 |
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 594 |
self.clear_cache()
|
|
|
|
| 595 |
iter_ = z.shape[2]
|
| 596 |
x = self.post_quant_conv(z)
|
| 597 |
for i in range(iter_):
|
|
|
|
| 601 |
else:
|
| 602 |
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 603 |
out = torch.cat([out, out_], 2)
|
|
|
|
| 604 |
out = torch.clamp(out, min=-1.0, max=1.0)
|
| 605 |
self.clear_cache()
|
| 606 |
if not return_dict:
|
| 607 |
return (out,)
|
|
|
|
| 608 |
return DecoderOutput(sample=out)
|
| 609 |
|
| 610 |
@apply_forward_hook
|
| 611 |
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 612 |
r"""
|
| 613 |
Decode a batch of images.
|
|
|
|
| 614 |
Args:
|
| 615 |
z (`torch.Tensor`): Input batch of latent vectors.
|
| 616 |
return_dict (`bool`, *optional*, defaults to `True`):
|
| 617 |
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
|
|
|
| 618 |
Returns:
|
| 619 |
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 620 |
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
|
|
|
| 623 |
decoded = self._decode(z).sample
|
| 624 |
if not return_dict:
|
| 625 |
return (decoded,)
|
|
|
|
| 626 |
return DecoderOutput(sample=decoded)
|
| 627 |
|
| 628 |
def forward(
|
|
|
|
| 645 |
else:
|
| 646 |
z = posterior.mode()
|
| 647 |
dec = self.decode(z, return_dict=return_dict)
|
| 648 |
+
return dec
|