# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List, Tuple import torch from diffusers.configuration_utils import FrozenDict from diffusers.guiders import ClassifierFreeGuidance from diffusers.models import AutoModel from diffusers.schedulers import UniPCMultistepScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from diffusers.modular_pipelines import ( BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState, ModularPipeline, ) from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, InputParam, OutputParam, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name class WanRTStreamingLoopDenoiser(ModularPipelineBlocks): model_name = "WanRTStreaming" @property def expected_components(self) -> List[ComponentSpec]: return [ComponentSpec("transformer", AutoModel)] @property def description(self) -> str: return ( "Step within the denoising loop that denoise the latents with guidance. " "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `WanRTStreamingDenoiseLoopWrapper`)" ) @property def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("attention_kwargs"), InputParam("block_idx"), InputParam( "latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( "prompt_embeds", required=True, type_hint=torch.Tensor, ), InputParam( "kv_cache", required=True, type_hint=torch.Tensor, ), InputParam( "crossattn_cache", required=True, type_hint=torch.Tensor, ), InputParam( "current_start_frame", required=True, type_hint=torch.Tensor, ), InputParam( "num_inference_steps", required=True, type_hint=int, default=4, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( kwargs_type="guider_input_fields", description=( "All conditional model inputs that need to be prepared with guider. " "It should contain prompt_embeds/negative_prompt_embeds. " "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" ), ), ] @torch.no_grad() def __call__( self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor, ) -> PipelineState: start_frame = min( block_state.current_start_frame, components.config.kv_cache_num_frames ) block_state.noise_pred = components.transformer( x=block_state.latents, t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block), context=block_state.prompt_embeds, kv_cache=block_state.kv_cache, seq_len=components.config.seq_length, crossattn_cache=block_state.crossattn_cache, current_start=start_frame * components.config.frame_seq_length, cache_start=start_frame * components.config.frame_seq_length, ) return components, block_state class WanRTStreamingLoopAfterDenoiser(ModularPipelineBlocks): model_name = "WanRTStreaming" @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", UniPCMultistepScheduler), ] @property def description(self) -> str: return ( "step within the denoising loop that update the latents. " "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `WanRTStreamingDenoiseLoopWrapper`)" ) @property def inputs(self) -> List[Tuple[str, Any]]: return [] @property def intermediate_inputs(self) -> List[str]: return [ InputParam("generator"), InputParam("block_id"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "latents", type_hint=torch.Tensor, description="The denoised latents" ) ] @torch.no_grad() def __call__( self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor, ): # Perform scheduler step using the predicted output latents_dtype = block_state.latents.dtype timesteps = block_state.all_timesteps sigmas = block_state.sigmas timestep_id = torch.argmin((timesteps - t).abs()) sigma_t = sigmas[timestep_id] # Perform computation in double precision, then convert back once latents = ( block_state.latents.double() - sigma_t.double() * block_state.noise_pred.double() ).to(latents_dtype) block_state.latents = latents return components, block_state class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks): model_name = "WanRTStreaming" @property def description(self) -> str: return ( "Streaming denoising loop that processes a single block with persistent KV cache. " "Recomputes cache from context frames, denoises current block, and updates cache." ) def add_noise(self, components, block_state, sample, noise, timestep, index): timesteps = block_state.all_timesteps sigmas = block_state.sigmas.to(timesteps.device) if timestep.ndim == 2: timestep = timestep.flatten(0, 1) timestep_id = torch.argmin( (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 ) sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1) sample = ( 1 - sigma.double() ) * sample.double() + sigma.double() * noise.double() sample = sample.type_as(noise) return sample @property def loop_inputs(self) -> List[InputParam]: return [ InputParam( "timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "all_timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "sigmas", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam("final_latents", type_hint=torch.Tensor), InputParam( "num_inference_steps", required=True, type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "num_frames_per_block", required=True, type_hint=int, default=3, ), InputParam( "current_start_frame", required=True, type_hint=int, ), InputParam( "block_idx", ), InputParam( "generator", ), ] @torch.no_grad() def __call__( self, components: ModularPipeline, state: PipelineState ) -> PipelineState: block_state = self.get_block_state(state) for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) if i < (block_state.num_inference_steps - 1): t1 = block_state.timesteps[i + 1] block_state.latents = ( self.add_noise( components, block_state, block_state.latents.transpose(1, 2).squeeze(0), randn_tensor( block_state.latents.transpose(1, 2).squeeze(0).shape, device=block_state.latents.device, dtype=block_state.latents.dtype, generator=block_state.generator, ), t1.expand( block_state.latents.shape[0], block_state.num_frames_per_block, ), i, ) .unsqueeze(0) .transpose(1, 2) ) # Update the state block_state.final_latents[ :, :, block_state.current_start_frame : block_state.current_start_frame + block_state.num_frames_per_block, ] = block_state.latents self.set_block_state(state, block_state) return components, state class WanRTStreamingDenoiseStep(WanRTStreamingDenoiseLoopWrapper): block_classes = [ WanRTStreamingLoopDenoiser, WanRTStreamingLoopAfterDenoiser, ] block_names = ["denoiser", "after_denoiser"] @property def description(self) -> str: return ( "Denoise step that iteratively denoise the latents. \n" "Its loop logic is defined in `WanRTStreamingDenoiseLoopWrapper.__call__` method \n" "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `WanRTStreamingLoopDenoiser`\n" " - `WanRTStreamingLoopAfterDenoiser`\n" "This block supports both text2vid tasks." )