# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List, Tuple import torch from diffusers.models import AutoModel from diffusers.schedulers import UniPCMultistepScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from diffusers.modular_pipelines import ( BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState, ModularPipeline, ) from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, InputParam, OutputParam, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name class WanRTLoopDenoiser(ModularPipelineBlocks): model_name = "wan" @property def expected_components(self) -> List[ComponentSpec]: return [ComponentSpec("transformer", AutoModel)] @property def description(self) -> str: return ( "Step within the denoising loop that denoise the latents with guidance. " "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `WanRTDenoiseLoopWrapper`)" ) @property def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam( "latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( "prompt_embeds", required=True, type_hint=torch.Tensor, description="Text embeddings to condition the denoising process", ), InputParam( "kv_cache", required=True, type_hint=torch.Tensor, description="KV Cache of the transformer model", ), InputParam( "crossattn_cache", required=True, type_hint=torch.Tensor, description="Cross Attention Cache of the transformer model", ), InputParam( "current_start_frame", required=True, type_hint=torch.Tensor, description="Starting frame index for the current block in the streaming generation", ), ] @torch.no_grad() def __call__( self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor, ) -> PipelineState: start_frame = min( block_state.current_start_frame, components.config.kv_cache_num_frames ) block_state.noise_pred = components.transformer( x=block_state.latents, t=t.expand( block_state.latents.shape[0], components.config.num_frames_per_block ), context=block_state.prompt_embeds, kv_cache=block_state.kv_cache, seq_len=components.config.seq_length, crossattn_cache=block_state.crossattn_cache, current_start=start_frame * components.config.frame_seq_length, cache_start=None, ) return components, block_state class WanRTLoopAfterDenoiser(ModularPipelineBlocks): model_name = "wan" @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", UniPCMultistepScheduler), ] @property def description(self) -> str: return ( "step within the denoising loop that update the latents. " "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `WanRTDenoiseLoopWrapper`)" ) @property def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam( "latents", description="Current latents being denoised", ), InputParam( "all_timesteps", description="All timesteps for the denoising process", ), InputParam( "sigmas", description="Noise schedule sigmas for each timestep", ), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "latents", type_hint=torch.Tensor, description="The denoised latents" ) ] @torch.no_grad() def __call__( self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor, ): # Perform scheduler step using the predicted output latents_dtype = block_state.latents.dtype timesteps = block_state.all_timesteps sigmas = block_state.sigmas timestep_id = torch.argmin((timesteps - t).abs()) sigma_t = sigmas[timestep_id] # Perform computation in double precision, then convert back once latents = ( block_state.latents.double() - sigma_t.double() * block_state.noise_pred.double() ).to(latents_dtype) block_state.latents = latents return components, block_state class WanRTDenoiseLoopWrapper(LoopSequentialPipelineBlocks): model_name = "wan" @property def description(self) -> str: return ( "Streaming denoising loop that processes a single block with persistent KV cache. " "Recomputes cache from context frames, denoises current block, and updates cache." ) def add_noise(self, block_state, sample, noise, timestep): timesteps = block_state.all_timesteps sigmas = block_state.sigmas.to(timesteps.device) if timestep.ndim == 2: timestep = timestep.flatten(0, 1) timestep_id = torch.argmin( (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 ) sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1) sample = ( 1 - sigma.double() ) * sample.double() + sigma.double() * noise.double() sample = sample.type_as(noise) return sample @property def loop_inputs(self) -> List[InputParam]: return [ InputParam( "timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "all_timesteps", required=True, type_hint=torch.Tensor, ), InputParam( "sigmas", required=True, type_hint=torch.Tensor, ), InputParam("current_denoised_latents", type_hint=torch.Tensor), InputParam( "num_inference_steps", required=True, type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "current_start_frame", required=True, type_hint=int, ), InputParam("generator", type_hint=torch.Generator), ] @torch.no_grad() def __call__( self, components: ModularPipeline, state: PipelineState ) -> PipelineState: block_state = self.get_block_state(state) for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) if i < (block_state.num_inference_steps - 1): t1 = block_state.timesteps[i + 1] block_state.latents = ( self.add_noise( block_state, block_state.latents.transpose(1, 2).squeeze(0), randn_tensor( block_state.latents.transpose(1, 2).squeeze(0).shape, device=block_state.latents.device, dtype=block_state.latents.dtype, generator=block_state.generator, ), t1.expand( block_state.latents.shape[0], components.config.num_frames_per_block, ), ) .unsqueeze(0) .transpose(1, 2) ) block_state.current_denoised_latents = block_state.latents self.set_block_state(state, block_state) return components, state class WanRTDenoiseStep(WanRTDenoiseLoopWrapper): block_classes = [ WanRTLoopDenoiser, WanRTLoopAfterDenoiser, ] block_names = ["denoiser", "after_denoiser"] @property def description(self) -> str: return ( "Denoise step that iteratively denoise the latents. \n" "Its loop logic is defined in `WanRTDenoiseLoopWrapper.__call__` method \n" "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `WanRTLoopDenoiser`\n" " - `WanRTLoopAfterDenoiser`\n" "This block supports both text2vid tasks." )