Spaces:
Running
on
Zero
Running
on
Zero
| import torch, warnings, glob, os, types | |
| import numpy as np | |
| from PIL import Image | |
| from einops import repeat, reduce | |
| from typing import Optional, Union | |
| from dataclasses import dataclass | |
| from modelscope import snapshot_download as ms_snap_download | |
| from huggingface_hub import snapshot_download as hf_snap_download | |
| from einops import rearrange | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from typing import Optional | |
| from typing_extensions import Literal | |
| from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner | |
| from ..models import ModelManager, load_state_dict | |
| from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d | |
| from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm | |
| from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample | |
| from ..models.wan_video_image_encoder import WanImageEncoder | |
| from ..models.wan_video_vace import VaceWanModel | |
| from ..models.wan_video_motion_controller import WanMotionControllerModel | |
| from ..schedulers.flow_match import FlowMatchScheduler | |
| from ..prompters import WanPrompter | |
| from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm | |
| from ..lora import GeneralLoRALoader | |
| from loguru import logger | |
| import spaces | |
| class BasePipeline(torch.nn.Module): | |
| def __init__( | |
| self, | |
| device="cuda", torch_dtype=torch.float16, | |
| height_division_factor=64, width_division_factor=64, | |
| time_division_factor=None, time_division_remainder=None, | |
| ): | |
| super().__init__() | |
| # The device and torch_dtype is used for the storage of intermediate variables, not models. | |
| self.device = device | |
| self.torch_dtype = torch_dtype | |
| # The following parameters are used for shape check. | |
| self.height_division_factor = height_division_factor | |
| self.width_division_factor = width_division_factor | |
| self.time_division_factor = time_division_factor | |
| self.time_division_remainder = time_division_remainder | |
| self.vram_management_enabled = False | |
| def to(self, *args, **kwargs): | |
| device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) | |
| if device is not None: | |
| self.device = device | |
| if dtype is not None: | |
| self.torch_dtype = dtype | |
| super().to(*args, **kwargs) | |
| return self | |
| def check_resize_height_width(self, height, width, num_frames=None): | |
| # Shape check | |
| if height % self.height_division_factor != 0: | |
| height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor | |
| print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") | |
| if width % self.width_division_factor != 0: | |
| width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor | |
| print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") | |
| if num_frames is None: | |
| return height, width | |
| else: | |
| if num_frames % self.time_division_factor != self.time_division_remainder: | |
| num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder | |
| print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") | |
| return height, width, num_frames | |
| def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): | |
| # Transform a PIL.Image to torch.Tensor | |
| image = torch.Tensor(np.array(image, dtype=np.float32)) | |
| image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) | |
| image = image * ((max_value - min_value) / 255) + min_value | |
| image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})) | |
| return image | |
| def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): | |
| # Transform a list of PIL.Image to torch.Tensor | |
| if hasattr(video, 'length') and video.length is not None: | |
| video = [self.preprocess_image(video[idx], torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for idx in range(video.length)] | |
| else: | |
| video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] | |
| video = torch.stack(video, dim=pattern.index("T") // 2) | |
| return video | |
| def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1): | |
| # Transform a torch.Tensor to PIL.Image | |
| if pattern != "H W C": | |
| vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") | |
| image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255) | |
| image = image.to(device="cpu", dtype=torch.uint8) | |
| image = Image.fromarray(image.numpy()) | |
| return image | |
| def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1): | |
| # Transform a torch.Tensor to list of PIL.Image | |
| if pattern != "T H W C": | |
| vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") | |
| video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] | |
| return video | |
| def load_models_to_device(self, model_names=[]): | |
| if self.vram_management_enabled: | |
| # offload models | |
| for name, model in self.named_children(): | |
| if name not in model_names: | |
| if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: | |
| for module in model.modules(): | |
| if hasattr(module, "offload"): | |
| module.offload() | |
| else: | |
| model.cpu() | |
| torch.cuda.empty_cache() | |
| # onload models | |
| for name, model in self.named_children(): | |
| if name in model_names: | |
| if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: | |
| for module in model.modules(): | |
| if hasattr(module, "onload"): | |
| module.onload() | |
| else: | |
| model.to(self.device) | |
| def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None): | |
| # Initialize Gaussian noise | |
| generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed) | |
| noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype) | |
| noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) | |
| return noise | |
| def enable_cpu_offload(self): | |
| warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.") | |
| self.vram_management_enabled = True | |
| def get_vram(self): | |
| return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3) | |
| def freeze_except(self, model_names): | |
| for name, model in self.named_children(): | |
| if name in model_names: | |
| model.train() | |
| model.requires_grad_(True) | |
| else: | |
| model.eval() | |
| model.requires_grad_(False) | |
| class ModelConfig: | |
| path: Union[str, list[str]] = None | |
| model_id: str = None | |
| origin_file_pattern: Union[str, list[str]] = None | |
| download_resource: str = "ModelScope" | |
| offload_device: Optional[Union[str, torch.device]] = None | |
| offload_dtype: Optional[torch.dtype] = None | |
| def download_if_necessary(self, local_model_path="./checkpoints", skip_download=False, use_usp=False): | |
| if self.path is None: | |
| # Check model_id and origin_file_pattern | |
| if self.model_id is None: | |
| raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") | |
| # Skip if not in rank 0 | |
| if use_usp: | |
| import torch.distributed as dist | |
| skip_download = dist.get_rank() != 0 | |
| # Check whether the origin path is a folder | |
| if self.origin_file_pattern is None or self.origin_file_pattern == "": | |
| self.origin_file_pattern = "" | |
| allow_file_pattern = None | |
| is_folder = True | |
| elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"): | |
| allow_file_pattern = self.origin_file_pattern + "*" | |
| is_folder = True | |
| else: | |
| allow_file_pattern = self.origin_file_pattern | |
| is_folder = False | |
| # Download | |
| if not skip_download: | |
| # downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id)) | |
| #!======================================================================================================================== | |
| downloaded_files = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) | |
| #!======================================================================================================================== | |
| if downloaded_files is None or len(downloaded_files) == 0 or not os.path.exists(downloaded_files[0]) : | |
| if 'Wan2' in self.model_id: | |
| ms_snap_download( | |
| self.model_id, | |
| local_dir=os.path.join(local_model_path, self.model_id), | |
| allow_file_pattern=allow_file_pattern, | |
| ignore_file_pattern=downloaded_files, | |
| ) | |
| else: | |
| hf_snap_download( | |
| repo_id=self.model_id, | |
| local_dir=os.path.join(local_model_path, self.model_id), | |
| allow_patterns=allow_file_pattern, | |
| ignore_patterns=downloaded_files if downloaded_files else None | |
| ) | |
| # Let rank 1, 2, ... wait for rank 0 | |
| if use_usp: | |
| import torch.distributed as dist | |
| dist.barrier(device_ids=[dist.get_rank()]) | |
| # Return downloaded files | |
| if is_folder: | |
| self.path = os.path.join(local_model_path, self.model_id, self.origin_file_pattern) | |
| else: | |
| self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) | |
| if isinstance(self.path, list) and len(self.path) == 1: | |
| self.path = self.path[0] | |
| class WanVideoPipeline(BasePipeline): | |
| def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): | |
| super().__init__( | |
| device=device, torch_dtype=torch_dtype, | |
| height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 | |
| ) | |
| self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) | |
| self.prompter = WanPrompter(tokenizer_path=tokenizer_path) | |
| self.text_encoder: WanTextEncoder = None | |
| self.image_encoder: WanImageEncoder = None | |
| self.dit: WanModel = None | |
| self.dit2: WanModel = None | |
| self.vae: WanVideoVAE = None | |
| self.motion_controller: WanMotionControllerModel = None | |
| self.vace: VaceWanModel = None | |
| self.in_iteration_models = ("dit", "motion_controller", "vace") | |
| self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") | |
| self.unit_runner = PipelineUnitRunner() | |
| self.units = [ | |
| WanVideoUnit_ShapeChecker(), | |
| WanVideoUnit_NoiseInitializer(), | |
| WanVideoUnit_InputVideoEmbedder(), | |
| WanVideoUnit_PromptEmbedder(), | |
| # WanVideoUnit_ImageEmbedderVAE(), | |
| # WanVideoUnit_ImageEmbedderCLIP(), | |
| # WanVideoUnit_ImageEmbedderFused(), | |
| # WanVideoUnit_FunControl(), | |
| WanVideoUnit_FunControl_Mask(), | |
| # WanVideoUnit_FunReference(), | |
| # WanVideoUnit_FunCameraControl(), | |
| # WanVideoUnit_SpeedControl(), | |
| # WanVideoUnit_VACE(), | |
| # WanVideoUnit_UnifiedSequenceParallel(), | |
| # WanVideoUnit_TeaCache(), | |
| # WanVideoUnit_CfgMerger(), | |
| ] | |
| self.model_fn = model_fn_wan_video | |
| def load_lora(self, module, path, alpha=1): | |
| loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) | |
| lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) | |
| loader.load(module, lora, alpha=alpha) | |
| def training_loss(self, **inputs): | |
| max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps) | |
| min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps) | |
| timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) | |
| timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) | |
| #* 单步去噪的时候,每次返回的都是纯噪声 | |
| #? 指的是input_latents 吧? | |
| #* 本来就有inputs["latents"], 只不过是完全等于inputs["noise"], 这里做了更新然后覆盖 | |
| inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) | |
| training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep) | |
| noise_pred = self.model_fn(**inputs, timestep=timestep)#* timestep === 1 | |
| loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) | |
| loss = loss * self.scheduler.training_weight(timestep) | |
| return loss | |
| def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): | |
| self.vram_management_enabled = True | |
| if num_persistent_param_in_dit is not None: | |
| vram_limit = None | |
| else: | |
| if vram_limit is None: | |
| vram_limit = self.get_vram() | |
| vram_limit = vram_limit - vram_buffer | |
| if self.text_encoder is not None: | |
| dtype = next(iter(self.text_encoder.parameters())).dtype | |
| enable_vram_management( | |
| self.text_encoder, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Embedding: AutoWrappedModule, | |
| T5RelativeEmbedding: AutoWrappedModule, | |
| T5LayerNorm: AutoWrappedModule, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device="cpu", | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| vram_limit=vram_limit, | |
| ) | |
| if self.dit is not None: | |
| dtype = next(iter(self.dit.parameters())).dtype | |
| device = "cpu" if vram_limit is not None else self.device | |
| enable_vram_management( | |
| self.dit, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Conv3d: AutoWrappedModule, | |
| torch.nn.LayerNorm: WanAutoCastLayerNorm, | |
| RMSNorm: AutoWrappedModule, | |
| torch.nn.Conv2d: AutoWrappedModule, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device=device, | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| max_num_param=num_persistent_param_in_dit, | |
| overflow_module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device="cpu", | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| vram_limit=vram_limit, | |
| ) | |
| if self.dit2 is not None: | |
| dtype = next(iter(self.dit2.parameters())).dtype | |
| device = "cpu" if vram_limit is not None else self.device | |
| enable_vram_management( | |
| self.dit2, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Conv3d: AutoWrappedModule, | |
| torch.nn.LayerNorm: WanAutoCastLayerNorm, | |
| RMSNorm: AutoWrappedModule, | |
| torch.nn.Conv2d: AutoWrappedModule, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device=device, | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| max_num_param=num_persistent_param_in_dit, | |
| overflow_module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device="cpu", | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| vram_limit=vram_limit, | |
| ) | |
| if self.vae is not None: | |
| dtype = next(iter(self.vae.parameters())).dtype | |
| enable_vram_management( | |
| self.vae, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Conv2d: AutoWrappedModule, | |
| RMS_norm: AutoWrappedModule, | |
| CausalConv3d: AutoWrappedModule, | |
| Upsample: AutoWrappedModule, | |
| torch.nn.SiLU: AutoWrappedModule, | |
| torch.nn.Dropout: AutoWrappedModule, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device=self.device, | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| ) | |
| if self.image_encoder is not None: | |
| dtype = next(iter(self.image_encoder.parameters())).dtype | |
| enable_vram_management( | |
| self.image_encoder, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Conv2d: AutoWrappedModule, | |
| torch.nn.LayerNorm: AutoWrappedModule, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device="cpu", | |
| computation_dtype=dtype, | |
| computation_device=self.device, | |
| ), | |
| ) | |
| if self.motion_controller is not None: | |
| dtype = next(iter(self.motion_controller.parameters())).dtype | |
| enable_vram_management( | |
| self.motion_controller, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device="cpu", | |
| computation_dtype=dtype, | |
| computation_device=self.device, | |
| ), | |
| ) | |
| if self.vace is not None: | |
| device = "cpu" if vram_limit is not None else self.device | |
| enable_vram_management( | |
| self.vace, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Conv3d: AutoWrappedModule, | |
| torch.nn.LayerNorm: AutoWrappedModule, | |
| RMSNorm: AutoWrappedModule, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device=device, | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| vram_limit=vram_limit, | |
| ) | |
| def initialize_usp(self): | |
| import torch.distributed as dist | |
| from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment | |
| dist.init_process_group(backend="nccl", init_method="env://") | |
| init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) | |
| initialize_model_parallel( | |
| sequence_parallel_degree=dist.get_world_size(), | |
| ring_degree=1, | |
| ulysses_degree=dist.get_world_size(), | |
| ) | |
| torch.cuda.set_device(dist.get_rank()) | |
| def enable_usp(self): | |
| from xfuser.core.distributed import get_sequence_parallel_world_size | |
| from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward | |
| for block in self.dit.blocks: | |
| block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) | |
| self.dit.forward = types.MethodType(usp_dit_forward, self.dit) | |
| if self.dit2 is not None: | |
| for block in self.dit2.blocks: | |
| block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) | |
| self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) | |
| self.sp_size = get_sequence_parallel_world_size() | |
| self.use_unified_sequence_parallel = True | |
| def from_pretrained( | |
| torch_dtype: torch.dtype = torch.bfloat16, | |
| device: Union[str, torch.device] = "cuda", | |
| model_configs: list[ModelConfig] = [], | |
| tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), | |
| local_model_path: str = "./checkpoints", | |
| skip_download: bool = False, | |
| redirect_common_files: bool = True, | |
| use_usp=False, | |
| training_strategy='origin', | |
| ): | |
| # Redirect model path | |
| if redirect_common_files: | |
| redirect_dict = { | |
| "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B", | |
| "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B", | |
| "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P", | |
| } | |
| for model_config in model_configs: | |
| if model_config.origin_file_pattern is None or model_config.model_id is None: | |
| continue | |
| if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]: | |
| print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.") | |
| model_config.model_id = redirect_dict[model_config.origin_file_pattern] | |
| # Initialize pipeline | |
| if training_strategy == 'origin': | |
| pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) | |
| logger.warning("Using origin generative model training") | |
| else: | |
| raise ValueError(f"Invalid training strategy: {training_strategy}") | |
| if use_usp: pipe.initialize_usp() | |
| # Download and load models | |
| model_manager = ModelManager() | |
| for model_config in model_configs: | |
| model_config.download_if_necessary(use_usp=use_usp) | |
| model_manager.load_model( | |
| model_config.path, | |
| device=model_config.offload_device or device, | |
| torch_dtype=model_config.offload_dtype or torch_dtype | |
| ) | |
| # Load models | |
| pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") | |
| dit = model_manager.fetch_model("wan_video_dit", index=2) | |
| if isinstance(dit, list): | |
| pipe.dit, pipe.dit2 = dit | |
| else: | |
| pipe.dit = dit | |
| pipe.vae = model_manager.fetch_model("wan_video_vae") | |
| pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") | |
| pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") | |
| pipe.vace = model_manager.fetch_model("wan_video_vace") | |
| # Size division factor | |
| if pipe.vae is not None: | |
| pipe.height_division_factor = pipe.vae.upsampling_factor * 2 | |
| pipe.width_division_factor = pipe.vae.upsampling_factor * 2 | |
| # Initialize tokenizer | |
| tokenizer_config.download_if_necessary(use_usp=use_usp) | |
| pipe.prompter.fetch_models(pipe.text_encoder) | |
| pipe.prompter.fetch_tokenizer(tokenizer_config.path) | |
| # Unified Sequence Parallel | |
| if use_usp: pipe.enable_usp() | |
| return pipe | |
| def __call__( | |
| self, | |
| # Prompt | |
| prompt: str, | |
| negative_prompt: Optional[str] = "", | |
| # Image-to-video | |
| input_image: Optional[Image.Image] = None, | |
| # First-last-frame-to-video | |
| end_image: Optional[Image.Image] = None, | |
| # Video-to-video | |
| input_video: Optional[list[Image.Image]] = None, | |
| denoising_strength: Optional[float] = 1.0, | |
| # ControlNet | |
| control_video: Optional[list[Image.Image]] = None, | |
| reference_image: Optional[Image.Image] = None, | |
| # Camera control | |
| camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, | |
| camera_control_speed: Optional[float] = 1/54, | |
| camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), | |
| # VACE | |
| vace_video: Optional[list[Image.Image]] = None, | |
| vace_video_mask: Optional[Image.Image] = None, | |
| vace_reference_image: Optional[Image.Image] = None, | |
| vace_scale: Optional[float] = 1.0, | |
| # Randomness | |
| seed: Optional[int] = None, | |
| rand_device: Optional[str] = "cpu", | |
| # Shape | |
| height: Optional[int] = 480, | |
| width: Optional[int] = 832, | |
| num_frames=81, | |
| # Classifier-free guidance | |
| cfg_scale: Optional[float] = 5.0, | |
| cfg_merge: Optional[bool] = False, | |
| # Boundary | |
| switch_DiT_boundary: Optional[float] = 0.875, | |
| # Scheduler | |
| num_inference_steps: Optional[int] = 50, | |
| sigma_shift: Optional[float] = 5.0, | |
| # Speed control | |
| motion_bucket_id: Optional[int] = None, | |
| # VAE tiling | |
| tiled: Optional[bool] = True, | |
| tile_size: Optional[tuple[int, int]] = (30, 52), | |
| tile_stride: Optional[tuple[int, int]] = (15, 26), | |
| # Sliding window | |
| sliding_window_size: Optional[int] = None, | |
| sliding_window_stride: Optional[int] = None, | |
| # Teacache | |
| tea_cache_l1_thresh: Optional[float] = None, | |
| tea_cache_model_id: Optional[str] = "", | |
| # progress_bar | |
| progress_bar_cmd=tqdm, | |
| mask: Optional[Image.Image] = None, | |
| ): | |
| # Scheduler | |
| self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) | |
| # Inputs | |
| inputs_posi = { | |
| "prompt": prompt, | |
| "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, | |
| } | |
| inputs_nega = { | |
| "negative_prompt": negative_prompt, | |
| "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, | |
| } | |
| inputs_shared = { | |
| "input_image": input_image, | |
| "end_image": end_image, | |
| "input_video": input_video, "denoising_strength": denoising_strength, | |
| "control_video": control_video, "reference_image": reference_image, | |
| "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, | |
| "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, | |
| "seed": seed, "rand_device": rand_device, | |
| "height": height, "width": width, "num_frames": num_frames, | |
| "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, | |
| "sigma_shift": sigma_shift, | |
| "motion_bucket_id": motion_bucket_id, | |
| "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, | |
| "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, | |
| "mask":mask, | |
| } | |
| for unit in self.units: | |
| inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) | |
| # Denoise | |
| self.load_models_to_device(self.in_iteration_models) | |
| models = {name: getattr(self, name) for name in self.in_iteration_models} | |
| for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): | |
| # Switch DiT if necessary | |
| if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: | |
| self.load_models_to_device(self.in_iteration_models_2) | |
| models["dit"] = self.dit2 | |
| # Timestep | |
| timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) | |
| # Inference | |
| noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) | |
| if cfg_scale != 1.0: | |
| if cfg_merge: | |
| noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) | |
| else: | |
| noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) | |
| noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) | |
| else: | |
| noise_pred = noise_pred_posi | |
| # Scheduler | |
| inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) | |
| if "first_frame_latents" in inputs_shared: | |
| inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] | |
| # VACE (TODO: remove it) | |
| if vace_reference_image is not None: | |
| inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:] | |
| # Decode | |
| self.load_models_to_device(['vae']) | |
| vae_outs = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| # from einops import reduce | |
| # video = reduce(vae_outs, 'b c t h w -> b c t', 'mean') | |
| video = self.vae_output_to_video(vae_outs) | |
| self.load_models_to_device([]) | |
| return video,vae_outs | |
| def extract_frames_from_video_file(video_path): | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if fps <= 0: | |
| fps = 15.0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame_rgb = Image.fromarray(frame_rgb) | |
| frames.append(frame_rgb) | |
| cap.release() | |
| return frames, fps | |
| except Exception as e: | |
| logger.error(f"Error extracting frames from {video_path}: {str(e)}") | |
| return [], 15.0 | |
| def resize_frame(frame, height, width): | |
| frame = np.array(frame) | |
| frame = torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0 | |
| frame = torch.nn.functional.interpolate(frame, (height, width), mode="bicubic", align_corners=False, antialias=True) | |
| frame = (frame.squeeze(0).permute(1, 2, 0).clamp(0, 1) * 255).byte().numpy() | |
| frame = Image.fromarray(frame) | |
| return frame | |
| from moge.model.v2 import MoGeModel | |
| from tools.eval_utils import transfer_pred_disp2depth, transfer_pred_disp2depth_v2, colorize_depth_map | |
| from tools.depth2pcd import depth2pcd | |
| import cv2, copy | |
| class DKTPipeline: | |
| def __init__(self, ): | |
| self.main_pipe = self.init_model() | |
| self.moge_pipe = self.load_moge_model() | |
| def init_model(self ): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| pipe = WanVideoPipeline.from_pretrained( | |
| torch_dtype=torch.bfloat16, | |
| device=device, | |
| model_configs=[ | |
| ModelConfig( | |
| model_id="PAI/Wan2.1-Fun-1.3B-Control", | |
| origin_file_pattern="diffusion_pytorch_model*.safetensors", | |
| offload_device="cpu", | |
| ), | |
| ModelConfig( | |
| model_id="PAI/Wan2.1-Fun-1.3B-Control", | |
| origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", | |
| offload_device="cpu", | |
| ), | |
| ModelConfig( | |
| model_id="PAI/Wan2.1-Fun-1.3B-Control", | |
| origin_file_pattern="Wan2.1_VAE.pth", | |
| offload_device="cpu", | |
| ), | |
| ModelConfig( | |
| model_id="PAI/Wan2.1-Fun-1.3B-Control", | |
| origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", | |
| offload_device="cpu", | |
| ), | |
| ], | |
| training_strategy="origin", | |
| ) | |
| lora_config = ModelConfig( | |
| model_id="Daniellesry/DKT-Depth-1-3B", | |
| origin_file_pattern="dkt-1-3B.safetensors", | |
| offload_device="cpu", | |
| ) | |
| lora_config.download_if_necessary(use_usp=False) | |
| pipe.load_lora(pipe.dit, lora_config.path, alpha=1.0)#todo is it work? | |
| pipe.enable_vram_management() | |
| return pipe | |
| def load_moge_model(self): | |
| device= torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| cached_model_path = 'checkpoints/moge_ckpt/moge-2-vitl-normal/model.pt' | |
| if os.path.exists(cached_model_path): | |
| logger.info(f"Found cached model at {cached_model_path}, loading from cache...") | |
| moge_pipe = MoGeModel.from_pretrained(cached_model_path).to(device) | |
| else: | |
| logger.info(f"Cache not found at {cached_model_path}, downloading from HuggingFace...") | |
| os.makedirs(os.path.dirname(cached_model_path), exist_ok=True) | |
| moge_pipe = MoGeModel.from_pretrained('Ruicheng/moge-2-vitl-normal', cache_dir=os.path.dirname(cached_model_path)).to(device) | |
| return moge_pipe | |
| def __call__(self, video_file, prompt='depth', \ | |
| negative_prompt='', height=480, width=832, \ | |
| num_inference_steps=5, window_size=21, \ | |
| overlap=3, vis_pc = False, return_rgb = False): | |
| origin_frames, input_fps = extract_frames_from_video_file(video_file) | |
| frame_length = len(origin_frames) | |
| original_width, original_height = origin_frames[0].size | |
| ROTATE = False | |
| if original_width < original_height:#* ensure the width is the longer side | |
| ROTATE = True | |
| origin_frames = [x.transpose(Image.ROTATE_90) for x in origin_frames] | |
| tmp = original_width | |
| original_width = original_height | |
| original_height = tmp | |
| frames = [resize_frame(frame, height, width) for frame in origin_frames] | |
| if (frame_length - 1) % 4 != 0: | |
| new_len = ((frame_length - 1) // 4 + 1) * 4 + 1 | |
| frames = frames + [copy.deepcopy(frames[-1]) for _ in range(new_len - frame_length)] | |
| video, vae_outs = self.main_pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| control_video=frames, | |
| height=height, | |
| width=width, | |
| num_frames=len(frames), | |
| seed=1, | |
| tiled=False, | |
| num_inference_steps=num_inference_steps, | |
| sliding_window_size=window_size, | |
| sliding_window_stride=window_size - overlap, | |
| cfg_scale=1.0, | |
| ) | |
| torch.cuda.empty_cache() | |
| processed_video = video[:frame_length] | |
| processed_video = [resize_frame(frame, original_height, original_width) for frame in processed_video] | |
| if ROTATE: | |
| processed_video = [x.transpose(Image.ROTATE_270) for x in processed_video] | |
| origin_frames = [x.transpose(Image.ROTATE_270) for x in origin_frames] | |
| color_predictions = [] | |
| if prompt == 'depth': | |
| prediced_depth_map_np = [np.array(item).astype(np.float32).mean(-1) for item in processed_video] | |
| prediced_depth_map_np = np.stack(prediced_depth_map_np) | |
| prediced_depth_map_np = prediced_depth_map_np / 255.0 | |
| __min = prediced_depth_map_np.min() | |
| __max = prediced_depth_map_np.max() | |
| prediced_depth_map_np_normalized = (prediced_depth_map_np - __min) / (__max - __min) | |
| color_predictions = [colorize_depth_map(item) for item in prediced_depth_map_np_normalized] | |
| else: | |
| color_predictions = processed_video | |
| return_dict = {} | |
| return_dict['depth_map'] = prediced_depth_map_np | |
| return_dict['colored_depth_map'] = color_predictions | |
| if vis_pc and prompt == 'depth': | |
| vis_pc_num = 4 | |
| indices = np.linspace(0, frame_length-1, vis_pc_num) | |
| indices = np.round(indices).astype(np.int32) | |
| return_dict['point_clouds'] = self.prediction2pc(prediced_depth_map_np, origin_frames, indices) | |
| if return_rgb: | |
| return_dict['rgb_frames'] = origin_frames | |
| return return_dict | |
| def prediction2pc(self, prediction_depth_map, RGB_frames, indices, return_pcd = True,nb_neighbors = 20, std_ratio = 3.0): | |
| resize_W,resize_H = RGB_frames[0].size | |
| pcds = [] | |
| moge_device = self.moge_pipe.device if self.moge_pipe is not None else torch.device("cuda:0") | |
| for idx in tqdm(indices): | |
| orgin_rgb_frame = RGB_frames[idx] | |
| predicted_depth = prediction_depth_map[idx] | |
| # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1] | |
| input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array | |
| input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1) | |
| output = self.moge_pipe.infer(input_image) | |
| #* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])" | |
| moge_intrinsics = output['intrinsics'].cpu().numpy() | |
| moge_mask = output['mask'].cpu().numpy() | |
| moge_depth = output['depth'].cpu().numpy() | |
| metric_depth = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask) | |
| moge_intrinsics[0, 0] *= resize_W | |
| moge_intrinsics[1, 1] *= resize_H | |
| moge_intrinsics[0, 2] *= resize_W | |
| moge_intrinsics[1, 2] *= resize_H | |
| pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=return_pcd) | |
| if return_pcd: | |
| #* [15,50], [2,3] | |
| cl, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio) | |
| pcd = pcd.select_by_index(ind) | |
| #todo downsample | |
| pcds.append(pcd) | |
| return pcds | |
| def moge_infer(self, input_image): | |
| return self.moge_pipe.infer(input_image) | |
| def prediction2pc_v2(self, prediction_depth_map, RGB_frames, indices, return_pcd = True,nb_neighbors = 20, std_ratio = 3.0): | |
| """ | |
| call MoGe once | |
| """ | |
| resize_W,resize_H = RGB_frames[0].size | |
| pcds = [] | |
| moge_device = self.moge_pipe.device if self.moge_pipe is not None else torch.device("cuda:0") | |
| for iidx, idx in enumerate(tqdm(indices)): | |
| orgin_rgb_frame = RGB_frames[idx] | |
| predicted_depth = prediction_depth_map[idx] | |
| input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array | |
| if iidx == 0: | |
| # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1] | |
| input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1) | |
| output = self.moge_infer(input_image) | |
| #* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])" | |
| moge_intrinsics = output['intrinsics'].cpu().numpy() | |
| moge_mask = output['mask'].cpu().numpy() | |
| moge_depth = output['depth'].cpu().numpy() | |
| metric_depth, scale, shift = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask, return_scale_shift=True) | |
| moge_intrinsics[0, 0] *= resize_W | |
| moge_intrinsics[1, 1] *= resize_H | |
| moge_intrinsics[0, 2] *= resize_W | |
| moge_intrinsics[1, 2] *= resize_H | |
| else: | |
| metric_depth = transfer_pred_disp2depth_v2(predicted_depth, scale, shift) | |
| pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=return_pcd) | |
| if return_pcd: | |
| #* [15,50], [2,3] | |
| cl, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio) | |
| pcd = pcd.select_by_index(ind) | |
| #todo downsample | |
| pcds.append(pcd) | |
| return pcds | |
| class WanVideoUnit_ShapeChecker(PipelineUnit): | |
| def __init__(self): | |
| super().__init__(input_params=("height", "width", "num_frames")) | |
| def process(self, pipe: WanVideoPipeline, height, width, num_frames): | |
| height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) | |
| return {"height": height, "width": width, "num_frames": num_frames} | |
| class WanVideoUnit_NoiseInitializer(PipelineUnit): | |
| def __init__(self): | |
| super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image")) | |
| def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): | |
| length = (num_frames - 1) // 4 + 1 | |
| if vace_reference_image is not None: | |
| length += 1 | |
| shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) | |
| noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) | |
| if vace_reference_image is not None: | |
| noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) | |
| return {"noise": noise} | |
| class WanVideoUnit_InputVideoEmbedder(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): | |
| if input_video is None: | |
| return {"latents": noise} | |
| pipe.load_models_to_device(["vae"])#* input_video is the GT | |
| input_video = pipe.preprocess_video(input_video) #* [B,3,F,W,H] | |
| #* [B,3,(F/4) + 1 ,W/8,H/8] | |
| input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| if vace_reference_image is not None: | |
| vace_reference_image = pipe.preprocess_video([vace_reference_image]) | |
| vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) | |
| #? during training, the input_latents have nothing to do with the noise, | |
| #? but during inference, the input_latents is used to generate the noise | |
| if pipe.scheduler.training: | |
| return {"latents": noise, "input_latents": input_latents} | |
| else: | |
| latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) | |
| return {"latents": latents} | |
| class WanVideoUnit_PromptEmbedder(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| seperate_cfg=True, | |
| input_params_posi={"prompt": "prompt", "positive": "positive"}, | |
| input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, | |
| onload_model_names=("text_encoder",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: | |
| pipe.load_models_to_device(self.onload_model_names) | |
| prompt_emb = pipe.prompter.encode_prompt(prompt, positive=positive, device=pipe.device) | |
| return {"context": prompt_emb} | |
| class WanVideoUnit_ImageEmbedder(PipelineUnit): | |
| """ | |
| Deprecated | |
| """ | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), | |
| onload_model_names=("image_encoder", "vae") | |
| ) | |
| def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): | |
| if input_image is None or pipe.image_encoder is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) | |
| clip_context = pipe.image_encoder.encode_image([image]) | |
| msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) #* indicate which image is reference image | |
| msk[:, 1:] = 0 | |
| if end_image is not None: | |
| end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) | |
| vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) | |
| if pipe.dit.has_image_pos_emb: | |
| clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) | |
| msk[:, -1:] = 1 | |
| else: | |
| vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) | |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
| msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) | |
| msk = msk.transpose(1, 2)[0] | |
| y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| y = torch.concat([msk, y]) | |
| y = y.unsqueeze(0) | |
| clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"clip_feature": clip_context, "y": y} | |
| class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("input_image", "end_image", "height", "width"), | |
| onload_model_names=("image_encoder",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): | |
| if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) | |
| clip_context = pipe.image_encoder.encode_image([image]) | |
| if end_image is not None: | |
| end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) | |
| if pipe.dit.has_image_pos_emb: | |
| clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) | |
| clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"clip_feature": clip_context} | |
| class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): | |
| if input_image is None or not pipe.dit.require_vae_embedding: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) | |
| msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) | |
| msk[:, 1:] = 0 | |
| if end_image is not None: | |
| end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) | |
| vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) | |
| msk[:, -1:] = 1 | |
| else: | |
| vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) | |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
| msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) | |
| msk = msk.transpose(1, 2)[0] | |
| y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| y = torch.concat([msk, y]) | |
| y = y.unsqueeze(0) | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"y": y} | |
| class WanVideoUnit_ImageEmbedderFused(PipelineUnit): | |
| """ | |
| Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. | |
| """ | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): | |
| if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) | |
| z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| latents[:, :, 0: 1] = z | |
| return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} | |
| class WanVideoUnit_FunControl(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): | |
| if control_video is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| #* transfer to torch.tensor from PIL.Image | |
| #* result size: [1, 3, F, H, W] | |
| control_video = pipe.preprocess_video(control_video) | |
| control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| #* size of control_latents: [1, 3, (F/4) + 1 , H/8, W/8] | |
| control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| if clip_feature is None or y is None: | |
| #* this branch is used during training | |
| clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) | |
| # y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) | |
| #* [1, 16, (F/4) + 1 , H/8, W/8] | |
| y = torch.zeros((1, 16, control_latents.shape[-3], height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) | |
| else: | |
| y = y[:, -16:] | |
| #* control_latents: [1, 16, 21, 60, 80]; y: [1, 16, 21, 60, 80]) | |
| #* [1, 32, (F/4) + 1 , H/8, W/8], 前16个通道是control_latents, 后16个通道是y(或者说0 vector) | |
| y = torch.concat([control_latents, y], dim=1) | |
| return {"clip_feature": clip_feature, "y": y} | |
| class WanVideoUnit_FunControl_Mask(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("control_video", "mask","num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, control_video, mask, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): | |
| if control_video is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| #* transfer to torch.tensor from PIL.Image | |
| #* result size: [1, 3, F, H, W] | |
| control_video = pipe.preprocess_video(control_video) | |
| control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| #* size of control_latents: [1, 3, (F/4) + 1 , H/8, W/8] | |
| control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| if mask is not None: | |
| mask = pipe.preprocess_video(mask) | |
| mask_latents = pipe.vae.encode(mask, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| mask_latents = mask_latents.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| if clip_feature is None or y is None: | |
| #* this branch is used during training | |
| clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) | |
| # y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) | |
| #* [1, 16, (F/4) + 1 , H/8, W/8] | |
| y = torch.zeros((1, 16, control_latents.shape[-3], height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) | |
| else: | |
| y = y[:, -16:] | |
| #* control_latents: [1, 16, 21, 60, 80]; y: [1, 16, 21, 60, 80]) | |
| #* [1, 32, (F/4) + 1 , H/8, W/8], 前16个通道是control_latents, 后16个通道是y(或者说0 vector) | |
| if mask is not None: | |
| y = torch.concat([control_latents, mask_latents], dim=1) | |
| # logger.warning(f"mask is provided, using mask_latents instead of y") | |
| else: | |
| y = torch.concat([control_latents, y], dim=1) | |
| # logger.warning(f"mask is not provided, using y") | |
| return {"clip_feature": clip_feature, "y": y} | |
| class WanVideoUnit_FunReference(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("reference_image", "height", "width", "reference_image"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, reference_image, height, width): | |
| if reference_image is None: | |
| return {} | |
| pipe.load_models_to_device(["vae"]) | |
| reference_image = reference_image.resize((width, height)) | |
| reference_latents = pipe.preprocess_video([reference_image]) | |
| reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) | |
| clip_feature = pipe.preprocess_image(reference_image) | |
| clip_feature = pipe.image_encoder.encode_image([clip_feature]) | |
| return {"reference_latents": reference_latents, "clip_feature": clip_feature} | |
| class WanVideoUnit_FunCameraControl(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image): | |
| if camera_control_direction is None: | |
| return {} | |
| camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( | |
| camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) | |
| control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) | |
| control_camera_latents = torch.concat( | |
| [ | |
| torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), | |
| control_camera_video[:, :, 1:] | |
| ], dim=2 | |
| ).transpose(1, 2) | |
| b, f, c, h, w = control_camera_latents.shape | |
| control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) | |
| control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) | |
| control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) | |
| input_image = input_image.resize((width, height)) | |
| input_latents = pipe.preprocess_video([input_image]) | |
| pipe.load_models_to_device(self.onload_model_names) | |
| input_latents = pipe.vae.encode(input_latents, device=pipe.device) | |
| y = torch.zeros_like(latents).to(pipe.device) | |
| y[:, :, :1] = input_latents | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"control_camera_latents_input": control_camera_latents_input, "y": y} | |
| class WanVideoUnit_SpeedControl(PipelineUnit): | |
| def __init__(self): | |
| super().__init__(input_params=("motion_bucket_id",)) | |
| def process(self, pipe: WanVideoPipeline, motion_bucket_id): | |
| if motion_bucket_id is None: | |
| return {} | |
| motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"motion_bucket_id": motion_bucket_id} | |
| class WanVideoUnit_VACE(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process( | |
| self, | |
| pipe: WanVideoPipeline, | |
| vace_video, vace_video_mask, vace_reference_image, vace_scale, | |
| height, width, num_frames, | |
| tiled, tile_size, tile_stride | |
| ): | |
| if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None: | |
| pipe.load_models_to_device(["vae"]) | |
| if vace_video is None: | |
| vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) | |
| else: | |
| vace_video = pipe.preprocess_video(vace_video) | |
| if vace_video_mask is None: | |
| vace_video_mask = torch.ones_like(vace_video) | |
| else: | |
| vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) | |
| inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask | |
| reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) | |
| inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| vace_video_latents = torch.concat((inactive, reactive), dim=1) | |
| vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) | |
| vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') | |
| if vace_reference_image is None: | |
| pass | |
| else: | |
| vace_reference_image = pipe.preprocess_video([vace_reference_image]) | |
| vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) | |
| vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) | |
| vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) | |
| vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) | |
| return {"vace_context": vace_context, "vace_scale": vace_scale} | |
| else: | |
| return {"vace_context": None, "vace_scale": vace_scale} | |
| class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): | |
| def __init__(self): | |
| super().__init__(input_params=()) | |
| def process(self, pipe: WanVideoPipeline): | |
| if hasattr(pipe, "use_unified_sequence_parallel"): | |
| if pipe.use_unified_sequence_parallel: | |
| return {"use_unified_sequence_parallel": True} | |
| return {} | |
| class WanVideoUnit_TeaCache(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| seperate_cfg=True, | |
| input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, | |
| input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, | |
| ) | |
| def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): | |
| if tea_cache_l1_thresh is None: | |
| return {} | |
| return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} | |
| class WanVideoUnit_CfgMerger(PipelineUnit): | |
| def __init__(self): | |
| super().__init__(take_over=True) | |
| self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] | |
| def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): | |
| if not inputs_shared["cfg_merge"]: | |
| return inputs_shared, inputs_posi, inputs_nega | |
| for name in self.concat_tensor_names: | |
| tensor_posi = inputs_posi.get(name) | |
| tensor_nega = inputs_nega.get(name) | |
| tensor_shared = inputs_shared.get(name) | |
| if tensor_posi is not None and tensor_nega is not None: | |
| inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) | |
| elif tensor_shared is not None: | |
| inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) | |
| inputs_posi.clear() | |
| inputs_nega.clear() | |
| return inputs_shared, inputs_posi, inputs_nega | |
| class TeaCache: | |
| def __init__(self, num_inference_steps, rel_l1_thresh, model_id): | |
| self.num_inference_steps = num_inference_steps | |
| self.step = 0 | |
| self.accumulated_rel_l1_distance = 0 | |
| self.previous_modulated_input = None | |
| self.rel_l1_thresh = rel_l1_thresh | |
| self.previous_residual = None | |
| self.previous_hidden_states = None | |
| self.coefficients_dict = { | |
| "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], | |
| "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], | |
| "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], | |
| "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], | |
| } | |
| if model_id not in self.coefficients_dict: | |
| supported_model_ids = ", ".join([i for i in self.coefficients_dict]) | |
| raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") | |
| self.coefficients = self.coefficients_dict[model_id] | |
| def check(self, dit: WanModel, x, t_mod): | |
| modulated_inp = t_mod.clone() | |
| if self.step == 0 or self.step == self.num_inference_steps - 1: | |
| should_calc = True | |
| self.accumulated_rel_l1_distance = 0 | |
| else: | |
| coefficients = self.coefficients | |
| rescale_func = np.poly1d(coefficients) | |
| self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) | |
| if self.accumulated_rel_l1_distance < self.rel_l1_thresh: | |
| should_calc = False | |
| else: | |
| should_calc = True | |
| self.accumulated_rel_l1_distance = 0 | |
| self.previous_modulated_input = modulated_inp | |
| self.step += 1 | |
| if self.step == self.num_inference_steps: | |
| self.step = 0 | |
| if should_calc: | |
| self.previous_hidden_states = x.clone() | |
| return not should_calc | |
| def store(self, hidden_states): | |
| self.previous_residual = hidden_states - self.previous_hidden_states | |
| self.previous_hidden_states = None | |
| def update(self, hidden_states): | |
| hidden_states = hidden_states + self.previous_residual | |
| return hidden_states | |
| class TemporalTiler_BCTHW: | |
| def __init__(self): | |
| pass | |
| def build_1d_mask(self, length, left_bound, right_bound, border_width): | |
| x = torch.ones((length,)) | |
| if border_width == 0: | |
| return x | |
| shift = 0.5 | |
| if not left_bound: | |
| x[:border_width] = (torch.arange(border_width) + shift) / border_width | |
| if not right_bound: | |
| x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,)) | |
| return x | |
| def build_mask(self, data, is_bound, border_width): | |
| _, _, T, _, _ = data.shape | |
| t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) | |
| mask = repeat(t, "T -> 1 1 T 1 1") | |
| return mask | |
| def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): | |
| tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] | |
| tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} | |
| B, C, T, H, W = tensor_dict[tensor_names[0]].shape | |
| if batch_size is not None: | |
| B *= batch_size | |
| data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype | |
| value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) | |
| weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) | |
| for t in range(0, T, sliding_window_stride): | |
| if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: #* 如果上一个窗口已经走到最后一帧了, 那么就continue/break | |
| continue | |
| t_ = min(t + sliding_window_size, T) | |
| model_kwargs.update({ | |
| tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ | |
| for tensor_name in tensor_names | |
| }) | |
| model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) | |
| mask = self.build_mask( | |
| model_output, | |
| is_bound=(t == 0, t_ == T), | |
| border_width=(sliding_window_size - sliding_window_stride,) | |
| ).to(device=data_device, dtype=data_dtype) | |
| # logger.info(f"t: {t}, t_: {t_}, sliding_window_size: {sliding_window_size}, sliding_window_stride: {sliding_window_stride}") | |
| value[:, :, t: t_, :, :] += model_output * mask | |
| weight[:, :, t: t_, :, :] += mask | |
| value /= weight | |
| model_kwargs.update(tensor_dict) | |
| return value | |
| def model_fn_wan_video( | |
| dit: WanModel, | |
| motion_controller: WanMotionControllerModel = None, | |
| vace: VaceWanModel = None, | |
| latents: torch.Tensor = None, | |
| timestep: torch.Tensor = None, | |
| context: torch.Tensor = None, | |
| clip_feature: Optional[torch.Tensor] = None, | |
| y: Optional[torch.Tensor] = None, | |
| reference_latents = None, | |
| vace_context = None, | |
| vace_scale = 1.0, | |
| tea_cache: TeaCache = None, | |
| use_unified_sequence_parallel: bool = False, | |
| motion_bucket_id: Optional[torch.Tensor] = None, | |
| sliding_window_size: Optional[int] = None, | |
| sliding_window_stride: Optional[int] = None, | |
| cfg_merge: bool = False, | |
| use_gradient_checkpointing: bool = False, | |
| use_gradient_checkpointing_offload: bool = False, | |
| control_camera_latents_input = None, | |
| fuse_vae_embedding_in_latents: bool = False, | |
| **kwargs, | |
| ): | |
| if sliding_window_size is not None and sliding_window_stride is not None: #* skip for training, | |
| model_kwargs = dict( | |
| dit=dit, | |
| motion_controller=motion_controller, | |
| vace=vace, | |
| latents=latents, | |
| timestep=timestep, | |
| context=context, | |
| clip_feature=clip_feature, | |
| y=y, | |
| reference_latents=reference_latents, | |
| vace_context=vace_context, | |
| vace_scale=vace_scale, | |
| tea_cache=tea_cache, | |
| use_unified_sequence_parallel=use_unified_sequence_parallel, | |
| motion_bucket_id=motion_bucket_id, | |
| ) | |
| return TemporalTiler_BCTHW().run( | |
| model_fn_wan_video, | |
| sliding_window_size, sliding_window_stride, | |
| latents.device, latents.dtype, | |
| model_kwargs=model_kwargs, | |
| tensor_names=["latents", "y"], | |
| batch_size=2 if cfg_merge else 1 | |
| ) | |
| if use_unified_sequence_parallel:#* skip | |
| import torch.distributed as dist | |
| from xfuser.core.distributed import (get_sequence_parallel_rank, | |
| get_sequence_parallel_world_size, | |
| get_sp_group) | |
| # Timestep | |
| if dit.seperated_timestep and fuse_vae_embedding_in_latents: | |
| timestep = torch.concat([ | |
| torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), | |
| torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep | |
| ]).flatten() | |
| t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) | |
| if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: | |
| t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) | |
| t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] | |
| t = t_chunks[get_sequence_parallel_rank()] | |
| t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) | |
| else:#* this branch | |
| t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) #* out: torch.Size([1, 1536]) | |
| t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) #* out: torch.Size([1, 6, 1536]); dit.dim: 1536 | |
| if motion_bucket_id is not None and motion_controller is not None: #* skip | |
| t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) | |
| context = dit.text_embedding(context)#* text prompt, 比如“depth”, : from torch.Size([1, 512, 4096]) to torch.Size([1, 512, 1536]) | |
| #todo double check 这个x | |
| #* [1, 16, (F-1)/4, H/8, W/8], 纯高斯噪声 或者 加噪后的gt | |
| x = latents | |
| # Merged cfg | |
| #* batch 这个维度必须一致, 跟 | |
| if x.shape[0] != context.shape[0]: | |
| x = torch.concat([x] * context.shape[0], dim=0) | |
| if timestep.shape[0] != context.shape[0]: | |
| timestep = torch.concat([timestep] * context.shape[0], dim=0) | |
| # Image Embedding | |
| """ | |
| new parameters: | |
| #* require_vae_embedding | |
| #* require_clip_embedding | |
| """ | |
| # todo: x 是target video(也就是depth/normal video) 通过噪声调整的结果 / 纯高斯噪声; y是输入的rgb video | |
| #todo , double check 这个y, [1, 32, (F-1)/4, H/8, W/8] | |
| if y is not None and dit.require_vae_embedding: | |
| x = torch.cat([x, y], dim=1)# (b, c_x + c_y, f, h, w) #* [1, 48, (F-1)/4, H/8, W/8] | |
| if clip_feature is not None and dit.require_clip_embedding: | |
| #* clip_feature is initialized by zero, from torch.Size([1, 257, 1280]) to torch.Size([1, 257, 1536]) | |
| clip_embdding = dit.img_emb(clip_feature) | |
| #* concat 257 and 512 to form torch.Size([1, 769, 1536]) | |
| context = torch.cat([clip_embdding, context], dim=1) | |
| # Add camera control | |
| #* from torch.Size([1, 48, (F-1)/4, H/8, W/8]), | |
| #* to [1, 1536, (F-1)/4, H/16, W/16] (函数内的mlp) | |
| #* to [1, 1536, ( (F-1)/4 * H/16 * W/16)] | |
| #* x_out: [1, 1536, ( (F-1)/4 * H/16 * W/16)] | |
| x, (f, h, w) = dit.patchify(x, control_camera_latents_input) | |
| # Reference image | |
| if reference_latents is not None: #* skip | |
| if len(reference_latents.shape) == 5: | |
| reference_latents = reference_latents[:, :, 0] | |
| reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) | |
| x = torch.concat([reference_latents, x], dim=1) | |
| f += 1 | |
| #* RoPE position embedding for 3D video, [ ( (F-1)/4 * H/16 * W/16), 1, 64] | |
| freqs = torch.cat([ | |
| dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), | |
| dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), | |
| dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) | |
| ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) | |
| # TeaCache | |
| if tea_cache is not None:#*skip | |
| tea_cache_update = tea_cache.check(dit, x, t_mod) | |
| else: | |
| tea_cache_update = False | |
| if vace_context is not None:#*skip | |
| vace_hints = vace(x, vace_context, context, t_mod, freqs) | |
| # blocks | |
| if use_unified_sequence_parallel:#* skip | |
| if dist.is_initialized() and dist.get_world_size() > 1: | |
| chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) | |
| pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] | |
| chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] | |
| x = chunks[get_sequence_parallel_rank()] | |
| if tea_cache_update: | |
| x = tea_cache.update(x) | |
| else: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| #* pass through dit blocks 30 times | |
| for block_id, block in enumerate(dit.blocks): | |
| if use_gradient_checkpointing_offload: | |
| with torch.autograd.graph.save_on_cpu(): | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, context, t_mod, freqs, | |
| use_reentrant=False, | |
| ) | |
| elif use_gradient_checkpointing: | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, context, t_mod, freqs, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| x = block(x, context, t_mod, freqs)#* x_in: [1, ( (F-1)/4 * H/16 * W/16), 1536], context_in: [1, 769, 1536], t_mod_in: [1, 6, 1536], freqs_in: [ ( (F-1)/4 * H/16 * W/16), 1, 64], x_out: [1, ( (F-1)/4 * H/16 * W/16), 1536] | |
| if vace_context is not None and block_id in vace.vace_layers_mapping:#* skip | |
| current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] | |
| if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: | |
| current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] | |
| current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) | |
| x = x + current_vace_hint * vace_scale | |
| if tea_cache is not None:#* skip | |
| tea_cache.store(x) | |
| #* x_in: [1, ( (F-1)/4 * H/16 * W/16), 1536], t_in: [1, 1536], | |
| #* x_out: [1, ( (F-1)/4 * H/16 * W/16), 64] | |
| x = dit.head(x, t) | |
| if use_unified_sequence_parallel:#* skip | |
| if dist.is_initialized() and dist.get_world_size() > 1: | |
| x = get_sp_group().all_gather(x, dim=1) | |
| x = x[:, :-pad_shape] if pad_shape > 0 else x | |
| # Remove reference latents | |
| if reference_latents is not None:#* skip | |
| x = x[:, reference_latents.shape[1]:] | |
| f -= 1 | |
| #* unpatchify, from [1, ( (F-1)/4 * H/16 * W/16), 64] to [1, 16, (F-1)/4, H/8, W/8] | |
| x = dit.unpatchify(x, (f, h, w)) | |
| return x |