DKT / dkt /pipelines /pipeline.py
shaocong's picture
save gpu
03adfb4
import torch, warnings, glob, os, types
import numpy as np
from PIL import Image
from einops import repeat, reduce
from typing import Optional, Union
from dataclasses import dataclass
from modelscope import snapshot_download as ms_snap_download
from huggingface_hub import snapshot_download as hf_snap_download
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from typing_extensions import Literal
from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
from ..models import ModelManager, load_state_dict
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..schedulers.flow_match import FlowMatchScheduler
from ..prompters import WanPrompter
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
from ..lora import GeneralLoRALoader
from loguru import logger
import spaces
class BasePipeline(torch.nn.Module):
def __init__(
self,
device="cuda", torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):
super().__init__()
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.vram_management_enabled = False
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None):
# Shape check
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
if num_frames is None:
return height, width
else:
if num_frames % self.time_division_factor != self.time_division_remainder:
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
return height, width, num_frames
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
# Transform a PIL.Image to torch.Tensor
image = torch.Tensor(np.array(image, dtype=np.float32))
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
image = image * ((max_value - min_value) / 255) + min_value
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
return image
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a list of PIL.Image to torch.Tensor
if hasattr(video, 'length') and video.length is not None:
video = [self.preprocess_image(video[idx], torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for idx in range(video.length)]
else:
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
video = torch.stack(video, dim=pattern.index("T") // 2)
return video
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to PIL.Image
if pattern != "H W C":
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
image = image.to(device="cpu", dtype=torch.uint8)
image = Image.fromarray(image.numpy())
return image
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to list of PIL.Image
if pattern != "T H W C":
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
return video
def load_models_to_device(self, model_names=[]):
if self.vram_management_enabled:
# offload models
for name, model in self.named_children():
if name not in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
torch.cuda.empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
# Initialize Gaussian noise
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
return noise
def enable_cpu_offload(self):
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
self.vram_management_enabled = True
def get_vram(self):
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
def freeze_except(self, model_names):
for name, model in self.named_children():
if name in model_names:
model.train()
model.requires_grad_(True)
else:
model.eval()
model.requires_grad_(False)
@dataclass
class ModelConfig:
path: Union[str, list[str]] = None
model_id: str = None
origin_file_pattern: Union[str, list[str]] = None
download_resource: str = "ModelScope"
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
def download_if_necessary(self, local_model_path="./checkpoints", skip_download=False, use_usp=False):
if self.path is None:
# Check model_id and origin_file_pattern
if self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
# Skip if not in rank 0
if use_usp:
import torch.distributed as dist
skip_download = dist.get_rank() != 0
# Check whether the origin path is a folder
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.origin_file_pattern = ""
allow_file_pattern = None
is_folder = True
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
allow_file_pattern = self.origin_file_pattern + "*"
is_folder = True
else:
allow_file_pattern = self.origin_file_pattern
is_folder = False
# Download
if not skip_download:
# downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
#!========================================================================================================================
downloaded_files = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
#!========================================================================================================================
if downloaded_files is None or len(downloaded_files) == 0 or not os.path.exists(downloaded_files[0]) :
if 'Wan2' in self.model_id:
ms_snap_download(
self.model_id,
local_dir=os.path.join(local_model_path, self.model_id),
allow_file_pattern=allow_file_pattern,
ignore_file_pattern=downloaded_files,
)
else:
hf_snap_download(
repo_id=self.model_id,
local_dir=os.path.join(local_model_path, self.model_id),
allow_patterns=allow_file_pattern,
ignore_patterns=downloaded_files if downloaded_files else None
)
# Let rank 1, 2, ... wait for rank 0
if use_usp:
import torch.distributed as dist
dist.barrier(device_ids=[dist.get_rank()])
# Return downloaded files
if is_folder:
self.path = os.path.join(local_model_path, self.model_id, self.origin_file_pattern)
else:
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
self.text_encoder: WanTextEncoder = None
self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None
self.dit2: WanModel = None
self.vae: WanVideoVAE = None
self.motion_controller: WanMotionControllerModel = None
self.vace: VaceWanModel = None
self.in_iteration_models = ("dit", "motion_controller", "vace")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
self.unit_runner = PipelineUnitRunner()
self.units = [
WanVideoUnit_ShapeChecker(),
WanVideoUnit_NoiseInitializer(),
WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(),
# WanVideoUnit_ImageEmbedderVAE(),
# WanVideoUnit_ImageEmbedderCLIP(),
# WanVideoUnit_ImageEmbedderFused(),
# WanVideoUnit_FunControl(),
WanVideoUnit_FunControl_Mask(),
# WanVideoUnit_FunReference(),
# WanVideoUnit_FunCameraControl(),
# WanVideoUnit_SpeedControl(),
# WanVideoUnit_VACE(),
# WanVideoUnit_UnifiedSequenceParallel(),
# WanVideoUnit_TeaCache(),
# WanVideoUnit_CfgMerger(),
]
self.model_fn = model_fn_wan_video
def load_lora(self, module, path, alpha=1):
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def training_loss(self, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps)
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
#* 单步去噪的时候,每次返回的都是纯噪声
#? 指的是input_latents 吧?
#* 本来就有inputs["latents"], 只不过是完全等于inputs["noise"], 这里做了更新然后覆盖
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
noise_pred = self.model_fn(**inputs, timestep=timestep)#* timestep === 1
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.scheduler.training_weight(timestep)
return loss
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
dtype = next(iter(self.text_encoder.parameters())).dtype
enable_vram_management(
self.text_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
T5RelativeEmbedding: AutoWrappedModule,
T5LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.dit is not None:
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: WanAutoCastLayerNorm,
RMSNorm: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.dit2 is not None:
dtype = next(iter(self.dit2.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: WanAutoCastLayerNorm,
RMSNorm: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
self.vae,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
RMS_norm: AutoWrappedModule,
CausalConv3d: AutoWrappedModule,
Upsample: AutoWrappedModule,
torch.nn.SiLU: AutoWrappedModule,
torch.nn.Dropout: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
if self.image_encoder is not None:
dtype = next(iter(self.image_encoder.parameters())).dtype
enable_vram_management(
self.image_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=dtype,
computation_device=self.device,
),
)
if self.motion_controller is not None:
dtype = next(iter(self.motion_controller.parameters())).dtype
enable_vram_management(
self.motion_controller,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=dtype,
computation_device=self.device,
),
)
if self.vace is not None:
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.vace,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
def initialize_usp(self):
import torch.distributed as dist
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
dist.init_process_group(backend="nccl", init_method="env://")
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
torch.cuda.set_device(dist.get_rank())
def enable_usp(self):
from xfuser.core.distributed import get_sequence_parallel_world_size
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
for block in self.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
if self.dit2 is not None:
for block in self.dit2.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
local_model_path: str = "./checkpoints",
skip_download: bool = False,
redirect_common_files: bool = True,
use_usp=False,
training_strategy='origin',
):
# Redirect model path
if redirect_common_files:
redirect_dict = {
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
}
for model_config in model_configs:
if model_config.origin_file_pattern is None or model_config.model_id is None:
continue
if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
# Initialize pipeline
if training_strategy == 'origin':
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
logger.warning("Using origin generative model training")
else:
raise ValueError(f"Invalid training strategy: {training_strategy}")
if use_usp: pipe.initialize_usp()
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary(use_usp=use_usp)
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
torch_dtype=model_config.offload_dtype or torch_dtype
)
# Load models
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
dit = model_manager.fetch_model("wan_video_dit", index=2)
if isinstance(dit, list):
pipe.dit, pipe.dit2 = dit
else:
pipe.dit = dit
pipe.vae = model_manager.fetch_model("wan_video_vae")
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
pipe.vace = model_manager.fetch_model("wan_video_vace")
# Size division factor
if pipe.vae is not None:
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
pipe.width_division_factor = pipe.vae.upsampling_factor * 2
# Initialize tokenizer
tokenizer_config.download_if_necessary(use_usp=use_usp)
pipe.prompter.fetch_models(pipe.text_encoder)
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
# Unified Sequence Parallel
if use_usp: pipe.enable_usp()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
input_image: Optional[Image.Image] = None,
# First-last-frame-to-video
end_image: Optional[Image.Image] = None,
# Video-to-video
input_video: Optional[list[Image.Image]] = None,
denoising_strength: Optional[float] = 1.0,
# ControlNet
control_video: Optional[list[Image.Image]] = None,
reference_image: Optional[Image.Image] = None,
# Camera control
camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None,
camera_control_speed: Optional[float] = 1/54,
camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
# VACE
vace_video: Optional[list[Image.Image]] = None,
vace_video_mask: Optional[Image.Image] = None,
vace_reference_image: Optional[Image.Image] = None,
vace_scale: Optional[float] = 1.0,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height: Optional[int] = 480,
width: Optional[int] = 832,
num_frames=81,
# Classifier-free guidance
cfg_scale: Optional[float] = 5.0,
cfg_merge: Optional[bool] = False,
# Boundary
switch_DiT_boundary: Optional[float] = 0.875,
# Scheduler
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
# Speed control
motion_bucket_id: Optional[int] = None,
# VAE tiling
tiled: Optional[bool] = True,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# Sliding window
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
# Teacache
tea_cache_l1_thresh: Optional[float] = None,
tea_cache_model_id: Optional[str] = "",
# progress_bar
progress_bar_cmd=tqdm,
mask: Optional[Image.Image] = None,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Inputs
inputs_posi = {
"prompt": prompt,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
}
inputs_nega = {
"negative_prompt": negative_prompt,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
}
inputs_shared = {
"input_image": input_image,
"end_image": end_image,
"input_video": input_video, "denoising_strength": denoising_strength,
"control_video": control_video, "reference_image": reference_image,
"camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin,
"vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames,
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
"sigma_shift": sigma_shift,
"motion_bucket_id": motion_bucket_id,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
"mask":mask,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
# Switch DiT if necessary
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
self.load_models_to_device(self.in_iteration_models_2)
models["dit"] = self.dit2
# Timestep
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
if cfg_scale != 1.0:
if cfg_merge:
noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
else:
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
if "first_frame_latents" in inputs_shared:
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
# VACE (TODO: remove it)
if vace_reference_image is not None:
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
# Decode
self.load_models_to_device(['vae'])
vae_outs = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# from einops import reduce
# video = reduce(vae_outs, 'b c t h w -> b c t', 'mean')
video = self.vae_output_to_video(vae_outs)
self.load_models_to_device([])
return video,vae_outs
def extract_frames_from_video_file(video_path):
try:
cap = cv2.VideoCapture(video_path)
frames = []
fps = cap.get(cv2.CAP_PROP_FPS)
if fps <= 0:
fps = 15.0
while True:
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_rgb = Image.fromarray(frame_rgb)
frames.append(frame_rgb)
cap.release()
return frames, fps
except Exception as e:
logger.error(f"Error extracting frames from {video_path}: {str(e)}")
return [], 15.0
def resize_frame(frame, height, width):
frame = np.array(frame)
frame = torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0
frame = torch.nn.functional.interpolate(frame, (height, width), mode="bicubic", align_corners=False, antialias=True)
frame = (frame.squeeze(0).permute(1, 2, 0).clamp(0, 1) * 255).byte().numpy()
frame = Image.fromarray(frame)
return frame
from moge.model.v2 import MoGeModel
from tools.eval_utils import transfer_pred_disp2depth, transfer_pred_disp2depth_v2, colorize_depth_map
from tools.depth2pcd import depth2pcd
import cv2, copy
class DKTPipeline:
def __init__(self, ):
self.main_pipe = self.init_model()
self.moge_pipe = self.load_moge_model()
def init_model(self ):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(
model_id="PAI/Wan2.1-Fun-1.3B-Control",
origin_file_pattern="diffusion_pytorch_model*.safetensors",
offload_device="cpu",
),
ModelConfig(
model_id="PAI/Wan2.1-Fun-1.3B-Control",
origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth",
offload_device="cpu",
),
ModelConfig(
model_id="PAI/Wan2.1-Fun-1.3B-Control",
origin_file_pattern="Wan2.1_VAE.pth",
offload_device="cpu",
),
ModelConfig(
model_id="PAI/Wan2.1-Fun-1.3B-Control",
origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
offload_device="cpu",
),
],
training_strategy="origin",
)
lora_config = ModelConfig(
model_id="Daniellesry/DKT-Depth-1-3B",
origin_file_pattern="dkt-1-3B.safetensors",
offload_device="cpu",
)
lora_config.download_if_necessary(use_usp=False)
pipe.load_lora(pipe.dit, lora_config.path, alpha=1.0)#todo is it work?
pipe.enable_vram_management()
return pipe
def load_moge_model(self):
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
cached_model_path = 'checkpoints/moge_ckpt/moge-2-vitl-normal/model.pt'
if os.path.exists(cached_model_path):
logger.info(f"Found cached model at {cached_model_path}, loading from cache...")
moge_pipe = MoGeModel.from_pretrained(cached_model_path).to(device)
else:
logger.info(f"Cache not found at {cached_model_path}, downloading from HuggingFace...")
os.makedirs(os.path.dirname(cached_model_path), exist_ok=True)
moge_pipe = MoGeModel.from_pretrained('Ruicheng/moge-2-vitl-normal', cache_dir=os.path.dirname(cached_model_path)).to(device)
return moge_pipe
@spaces.GPU(duration=120)
@torch.inference_mode()
def __call__(self, video_file, prompt='depth', \
negative_prompt='', height=480, width=832, \
num_inference_steps=5, window_size=21, \
overlap=3, vis_pc = False, return_rgb = False):
origin_frames, input_fps = extract_frames_from_video_file(video_file)
frame_length = len(origin_frames)
original_width, original_height = origin_frames[0].size
ROTATE = False
if original_width < original_height:#* ensure the width is the longer side
ROTATE = True
origin_frames = [x.transpose(Image.ROTATE_90) for x in origin_frames]
tmp = original_width
original_width = original_height
original_height = tmp
frames = [resize_frame(frame, height, width) for frame in origin_frames]
if (frame_length - 1) % 4 != 0:
new_len = ((frame_length - 1) // 4 + 1) * 4 + 1
frames = frames + [copy.deepcopy(frames[-1]) for _ in range(new_len - frame_length)]
video, vae_outs = self.main_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
control_video=frames,
height=height,
width=width,
num_frames=len(frames),
seed=1,
tiled=False,
num_inference_steps=num_inference_steps,
sliding_window_size=window_size,
sliding_window_stride=window_size - overlap,
cfg_scale=1.0,
)
torch.cuda.empty_cache()
processed_video = video[:frame_length]
processed_video = [resize_frame(frame, original_height, original_width) for frame in processed_video]
if ROTATE:
processed_video = [x.transpose(Image.ROTATE_270) for x in processed_video]
origin_frames = [x.transpose(Image.ROTATE_270) for x in origin_frames]
color_predictions = []
if prompt == 'depth':
prediced_depth_map_np = [np.array(item).astype(np.float32).mean(-1) for item in processed_video]
prediced_depth_map_np = np.stack(prediced_depth_map_np)
prediced_depth_map_np = prediced_depth_map_np / 255.0
__min = prediced_depth_map_np.min()
__max = prediced_depth_map_np.max()
prediced_depth_map_np_normalized = (prediced_depth_map_np - __min) / (__max - __min)
color_predictions = [colorize_depth_map(item) for item in prediced_depth_map_np_normalized]
else:
color_predictions = processed_video
return_dict = {}
return_dict['depth_map'] = prediced_depth_map_np
return_dict['colored_depth_map'] = color_predictions
if vis_pc and prompt == 'depth':
vis_pc_num = 4
indices = np.linspace(0, frame_length-1, vis_pc_num)
indices = np.round(indices).astype(np.int32)
return_dict['point_clouds'] = self.prediction2pc(prediced_depth_map_np, origin_frames, indices)
if return_rgb:
return_dict['rgb_frames'] = origin_frames
return return_dict
def prediction2pc(self, prediction_depth_map, RGB_frames, indices, return_pcd = True,nb_neighbors = 20, std_ratio = 3.0):
resize_W,resize_H = RGB_frames[0].size
pcds = []
moge_device = self.moge_pipe.device if self.moge_pipe is not None else torch.device("cuda:0")
for idx in tqdm(indices):
orgin_rgb_frame = RGB_frames[idx]
predicted_depth = prediction_depth_map[idx]
# Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array
input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
output = self.moge_pipe.infer(input_image)
#* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])"
moge_intrinsics = output['intrinsics'].cpu().numpy()
moge_mask = output['mask'].cpu().numpy()
moge_depth = output['depth'].cpu().numpy()
metric_depth = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask)
moge_intrinsics[0, 0] *= resize_W
moge_intrinsics[1, 1] *= resize_H
moge_intrinsics[0, 2] *= resize_W
moge_intrinsics[1, 2] *= resize_H
pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=return_pcd)
if return_pcd:
#* [15,50], [2,3]
cl, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio)
pcd = pcd.select_by_index(ind)
#todo downsample
pcds.append(pcd)
return pcds
@spaces.GPU()
@torch.inference_mode()
def moge_infer(self, input_image):
return self.moge_pipe.infer(input_image)
def prediction2pc_v2(self, prediction_depth_map, RGB_frames, indices, return_pcd = True,nb_neighbors = 20, std_ratio = 3.0):
"""
call MoGe once
"""
resize_W,resize_H = RGB_frames[0].size
pcds = []
moge_device = self.moge_pipe.device if self.moge_pipe is not None else torch.device("cuda:0")
for iidx, idx in enumerate(tqdm(indices)):
orgin_rgb_frame = RGB_frames[idx]
predicted_depth = prediction_depth_map[idx]
input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array
if iidx == 0:
# Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
output = self.moge_infer(input_image)
#* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])"
moge_intrinsics = output['intrinsics'].cpu().numpy()
moge_mask = output['mask'].cpu().numpy()
moge_depth = output['depth'].cpu().numpy()
metric_depth, scale, shift = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask, return_scale_shift=True)
moge_intrinsics[0, 0] *= resize_W
moge_intrinsics[1, 1] *= resize_H
moge_intrinsics[0, 2] *= resize_W
moge_intrinsics[1, 2] *= resize_H
else:
metric_depth = transfer_pred_disp2depth_v2(predicted_depth, scale, shift)
pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=return_pcd)
if return_pcd:
#* [15,50], [2,3]
cl, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio)
pcd = pcd.select_by_index(ind)
#todo downsample
pcds.append(pcd)
return pcds
class WanVideoUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width", "num_frames"))
def process(self, pipe: WanVideoPipeline, height, width, num_frames):
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
return {"height": height, "width": width, "num_frames": num_frames}
class WanVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"))
def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image):
length = (num_frames - 1) // 4 + 1
if vace_reference_image is not None:
length += 1
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
if vace_reference_image is not None:
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
return {"noise": noise}
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(["vae"])#* input_video is the GT
input_video = pipe.preprocess_video(input_video) #* [B,3,F,W,H]
#* [B,3,(F/4) + 1 ,W/8,H/8]
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if vace_reference_image is not None:
vace_reference_image = pipe.preprocess_video([vace_reference_image])
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
#? during training, the input_latents have nothing to do with the noise,
#? but during inference, the input_latents is used to generate the noise
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents}
class WanVideoUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt", "positive": "positive"},
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
onload_model_names=("text_encoder",)
)
def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict:
pipe.load_models_to_device(self.onload_model_names)
prompt_emb = pipe.prompter.encode_prompt(prompt, positive=positive, device=pipe.device)
return {"context": prompt_emb}
class WanVideoUnit_ImageEmbedder(PipelineUnit):
"""
Deprecated
"""
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("image_encoder", "vae")
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or pipe.image_encoder is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
clip_context = pipe.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) #* indicate which image is reference image
msk[:, 1:] = 0
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
if pipe.dit.has_image_pos_emb:
clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context, "y": y}
class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "height", "width"),
onload_model_names=("image_encoder",)
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):
if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
clip_context = pipe.image_encoder.encode_image([image])
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
if pipe.dit.has_image_pos_emb:
clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context}
class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.dit.require_vae_embedding:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, 1:] = 0
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"y": y}
class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
"""
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
"""
def __init__(self):
super().__init__(
input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1)
z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents[:, :, 0: 1] = z
return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z}
class WanVideoUnit_FunControl(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
if control_video is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
#* transfer to torch.tensor from PIL.Image
#* result size: [1, 3, F, H, W]
control_video = pipe.preprocess_video(control_video)
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
#* size of control_latents: [1, 3, (F/4) + 1 , H/8, W/8]
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
if clip_feature is None or y is None:
#* this branch is used during training
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
# y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
#* [1, 16, (F/4) + 1 , H/8, W/8]
y = torch.zeros((1, 16, control_latents.shape[-3], height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
else:
y = y[:, -16:]
#* control_latents: [1, 16, 21, 60, 80]; y: [1, 16, 21, 60, 80])
#* [1, 32, (F/4) + 1 , H/8, W/8], 前16个通道是control_latents, 后16个通道是y(或者说0 vector)
y = torch.concat([control_latents, y], dim=1)
return {"clip_feature": clip_feature, "y": y}
class WanVideoUnit_FunControl_Mask(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("control_video", "mask","num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, control_video, mask, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
if control_video is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
#* transfer to torch.tensor from PIL.Image
#* result size: [1, 3, F, H, W]
control_video = pipe.preprocess_video(control_video)
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
#* size of control_latents: [1, 3, (F/4) + 1 , H/8, W/8]
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
if mask is not None:
mask = pipe.preprocess_video(mask)
mask_latents = pipe.vae.encode(mask, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
mask_latents = mask_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
if clip_feature is None or y is None:
#* this branch is used during training
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
# y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
#* [1, 16, (F/4) + 1 , H/8, W/8]
y = torch.zeros((1, 16, control_latents.shape[-3], height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
else:
y = y[:, -16:]
#* control_latents: [1, 16, 21, 60, 80]; y: [1, 16, 21, 60, 80])
#* [1, 32, (F/4) + 1 , H/8, W/8], 前16个通道是control_latents, 后16个通道是y(或者说0 vector)
if mask is not None:
y = torch.concat([control_latents, mask_latents], dim=1)
# logger.warning(f"mask is provided, using mask_latents instead of y")
else:
y = torch.concat([control_latents, y], dim=1)
# logger.warning(f"mask is not provided, using y")
return {"clip_feature": clip_feature, "y": y}
class WanVideoUnit_FunReference(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("reference_image", "height", "width", "reference_image"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
if reference_image is None:
return {}
pipe.load_models_to_device(["vae"])
reference_image = reference_image.resize((width, height))
reference_latents = pipe.preprocess_video([reference_image])
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
clip_feature = pipe.preprocess_image(reference_image)
clip_feature = pipe.image_encoder.encode_image([clip_feature])
return {"reference_latents": reference_latents, "clip_feature": clip_feature}
class WanVideoUnit_FunCameraControl(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image):
if camera_control_direction is None:
return {}
camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates(
camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin)
control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
control_camera_latents = torch.concat(
[
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
control_camera_video[:, :, 1:]
], dim=2
).transpose(1, 2)
b, f, c, h, w = control_camera_latents.shape
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
input_image = input_image.resize((width, height))
input_latents = pipe.preprocess_video([input_image])
pipe.load_models_to_device(self.onload_model_names)
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
y = torch.zeros_like(latents).to(pipe.device)
y[:, :, :1] = input_latents
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"control_camera_latents_input": control_camera_latents_input, "y": y}
class WanVideoUnit_SpeedControl(PipelineUnit):
def __init__(self):
super().__init__(input_params=("motion_bucket_id",))
def process(self, pipe: WanVideoPipeline, motion_bucket_id):
if motion_bucket_id is None:
return {}
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"motion_bucket_id": motion_bucket_id}
class WanVideoUnit_VACE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(
self,
pipe: WanVideoPipeline,
vace_video, vace_video_mask, vace_reference_image, vace_scale,
height, width, num_frames,
tiled, tile_size, tile_stride
):
if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None:
pipe.load_models_to_device(["vae"])
if vace_video is None:
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
else:
vace_video = pipe.preprocess_video(vace_video)
if vace_video_mask is None:
vace_video_mask = torch.ones_like(vace_video)
else:
vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1)
inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_video_latents = torch.concat((inactive, reactive), dim=1)
vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
if vace_reference_image is None:
pass
else:
vace_reference_image = pipe.preprocess_video([vace_reference_image])
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
return {"vace_context": vace_context, "vace_scale": vace_scale}
else:
return {"vace_context": None, "vace_scale": vace_scale}
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
def __init__(self):
super().__init__(input_params=())
def process(self, pipe: WanVideoPipeline):
if hasattr(pipe, "use_unified_sequence_parallel"):
if pipe.use_unified_sequence_parallel:
return {"use_unified_sequence_parallel": True}
return {}
class WanVideoUnit_TeaCache(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"},
input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"},
)
def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id):
if tea_cache_l1_thresh is None:
return {}
return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)}
class WanVideoUnit_CfgMerger(PipelineUnit):
def __init__(self):
super().__init__(take_over=True)
self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"]
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if not inputs_shared["cfg_merge"]:
return inputs_shared, inputs_posi, inputs_nega
for name in self.concat_tensor_names:
tensor_posi = inputs_posi.get(name)
tensor_nega = inputs_nega.get(name)
tensor_shared = inputs_shared.get(name)
if tensor_posi is not None and tensor_nega is not None:
inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0)
elif tensor_shared is not None:
inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0)
inputs_posi.clear()
inputs_nega.clear()
return inputs_shared, inputs_posi, inputs_nega
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step += 1
if self.step == self.num_inference_steps:
self.step = 0
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
class TemporalTiler_BCTHW:
def __init__(self):
pass
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if border_width == 0:
return x
shift = 0.5
if not left_bound:
x[:border_width] = (torch.arange(border_width) + shift) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, T, _, _ = data.shape
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
mask = repeat(t, "T -> 1 1 T 1 1")
return mask
def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None):
tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None]
tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names}
B, C, T, H, W = tensor_dict[tensor_names[0]].shape
if batch_size is not None:
B *= batch_size
data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype
value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)
weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
for t in range(0, T, sliding_window_stride):
if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: #* 如果上一个窗口已经走到最后一帧了, 那么就continue/break
continue
t_ = min(t + sliding_window_size, T)
model_kwargs.update({
tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \
for tensor_name in tensor_names
})
model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype)
mask = self.build_mask(
model_output,
is_bound=(t == 0, t_ == T),
border_width=(sliding_window_size - sliding_window_stride,)
).to(device=data_device, dtype=data_dtype)
# logger.info(f"t: {t}, t_: {t_}, sliding_window_size: {sliding_window_size}, sliding_window_stride: {sliding_window_stride}")
value[:, :, t: t_, :, :] += model_output * mask
weight[:, :, t: t_, :, :] += mask
value /= weight
model_kwargs.update(tensor_dict)
return value
def model_fn_wan_video(
dit: WanModel,
motion_controller: WanMotionControllerModel = None,
vace: VaceWanModel = None,
latents: torch.Tensor = None,
timestep: torch.Tensor = None,
context: torch.Tensor = None,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
reference_latents = None,
vace_context = None,
vace_scale = 1.0,
tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
motion_bucket_id: Optional[torch.Tensor] = None,
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
cfg_merge: bool = False,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
control_camera_latents_input = None,
fuse_vae_embedding_in_latents: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None: #* skip for training,
model_kwargs = dict(
dit=dit,
motion_controller=motion_controller,
vace=vace,
latents=latents,
timestep=timestep,
context=context,
clip_feature=clip_feature,
y=y,
reference_latents=reference_latents,
vace_context=vace_context,
vace_scale=vace_scale,
tea_cache=tea_cache,
use_unified_sequence_parallel=use_unified_sequence_parallel,
motion_bucket_id=motion_bucket_id,
)
return TemporalTiler_BCTHW().run(
model_fn_wan_video,
sliding_window_size, sliding_window_stride,
latents.device, latents.dtype,
model_kwargs=model_kwargs,
tensor_names=["latents", "y"],
batch_size=2 if cfg_merge else 1
)
if use_unified_sequence_parallel:#* skip
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
# Timestep
if dit.seperated_timestep and fuse_vae_embedding_in_latents:
timestep = torch.concat([
torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device),
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
]).flatten()
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1)
t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks]
t = t_chunks[get_sequence_parallel_rank()]
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
else:#* this branch
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) #* out: torch.Size([1, 1536])
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) #* out: torch.Size([1, 6, 1536]); dit.dim: 1536
if motion_bucket_id is not None and motion_controller is not None: #* skip
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)#* text prompt, 比如“depth”, : from torch.Size([1, 512, 4096]) to torch.Size([1, 512, 1536])
#todo double check 这个x
#* [1, 16, (F-1)/4, H/8, W/8], 纯高斯噪声 或者 加噪后的gt
x = latents
# Merged cfg
#* batch 这个维度必须一致, 跟
if x.shape[0] != context.shape[0]:
x = torch.concat([x] * context.shape[0], dim=0)
if timestep.shape[0] != context.shape[0]:
timestep = torch.concat([timestep] * context.shape[0], dim=0)
# Image Embedding
"""
new parameters:
#* require_vae_embedding
#* require_clip_embedding
"""
# todo: x 是target video(也就是depth/normal video) 通过噪声调整的结果 / 纯高斯噪声; y是输入的rgb video
#todo , double check 这个y, [1, 32, (F-1)/4, H/8, W/8]
if y is not None and dit.require_vae_embedding:
x = torch.cat([x, y], dim=1)# (b, c_x + c_y, f, h, w) #* [1, 48, (F-1)/4, H/8, W/8]
if clip_feature is not None and dit.require_clip_embedding:
#* clip_feature is initialized by zero, from torch.Size([1, 257, 1280]) to torch.Size([1, 257, 1536])
clip_embdding = dit.img_emb(clip_feature)
#* concat 257 and 512 to form torch.Size([1, 769, 1536])
context = torch.cat([clip_embdding, context], dim=1)
# Add camera control
#* from torch.Size([1, 48, (F-1)/4, H/8, W/8]),
#* to [1, 1536, (F-1)/4, H/16, W/16] (函数内的mlp)
#* to [1, 1536, ( (F-1)/4 * H/16 * W/16)]
#* x_out: [1, 1536, ( (F-1)/4 * H/16 * W/16)]
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
# Reference image
if reference_latents is not None: #* skip
if len(reference_latents.shape) == 5:
reference_latents = reference_latents[:, :, 0]
reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2)
x = torch.concat([reference_latents, x], dim=1)
f += 1
#* RoPE position embedding for 3D video, [ ( (F-1)/4 * H/16 * W/16), 1, 64]
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache
if tea_cache is not None:#*skip
tea_cache_update = tea_cache.check(dit, x, t_mod)
else:
tea_cache_update = False
if vace_context is not None:#*skip
vace_hints = vace(x, vace_context, context, t_mod, freqs)
# blocks
if use_unified_sequence_parallel:#* skip
if dist.is_initialized() and dist.get_world_size() > 1:
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
x = chunks[get_sequence_parallel_rank()]
if tea_cache_update:
x = tea_cache.update(x)
else:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
#* pass through dit blocks 30 times
for block_id, block in enumerate(dit.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)#* x_in: [1, ( (F-1)/4 * H/16 * W/16), 1536], context_in: [1, 769, 1536], t_mod_in: [1, 6, 1536], freqs_in: [ ( (F-1)/4 * H/16 * W/16), 1, 64], x_out: [1, ( (F-1)/4 * H/16 * W/16), 1536]
if vace_context is not None and block_id in vace.vace_layers_mapping:#* skip
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
x = x + current_vace_hint * vace_scale
if tea_cache is not None:#* skip
tea_cache.store(x)
#* x_in: [1, ( (F-1)/4 * H/16 * W/16), 1536], t_in: [1, 1536],
#* x_out: [1, ( (F-1)/4 * H/16 * W/16), 64]
x = dit.head(x, t)
if use_unified_sequence_parallel:#* skip
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = x[:, :-pad_shape] if pad_shape > 0 else x
# Remove reference latents
if reference_latents is not None:#* skip
x = x[:, reference_latents.shape[1]:]
f -= 1
#* unpatchify, from [1, ( (F-1)/4 * H/16 * W/16), 64] to [1, 16, (F-1)/4, H/8, W/8]
x = dit.unpatchify(x, (f, h, w))
return x