|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, List, Tuple |
|
|
|
|
|
import torch |
|
|
|
|
|
from diffusers.configuration_utils import FrozenDict |
|
|
from diffusers.guiders import ClassifierFreeGuidance |
|
|
from diffusers.models import AutoModel |
|
|
from diffusers.schedulers import UniPCMultistepScheduler |
|
|
from diffusers.utils import logging |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
from diffusers.modular_pipelines import ( |
|
|
BlockState, |
|
|
LoopSequentialPipelineBlocks, |
|
|
ModularPipelineBlocks, |
|
|
PipelineState, |
|
|
ModularPipeline, |
|
|
) |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import ( |
|
|
ComponentSpec, |
|
|
InputParam, |
|
|
OutputParam, |
|
|
) |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class WanRTStreamingLoopDenoiser(ModularPipelineBlocks): |
|
|
model_name = "WanRTStreaming" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ComponentSpec("transformer", AutoModel)] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Step within the denoising loop that denoise the latents with guidance. " |
|
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " |
|
|
"object (e.g. `WanRTStreamingDenoiseLoopWrapper`)" |
|
|
) |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[Tuple[str, Any]]: |
|
|
return [ |
|
|
InputParam("attention_kwargs"), |
|
|
InputParam("block_idx"), |
|
|
InputParam( |
|
|
"latents", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", |
|
|
), |
|
|
InputParam( |
|
|
"prompt_embeds", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
), |
|
|
InputParam( |
|
|
"kv_cache", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
), |
|
|
InputParam( |
|
|
"crossattn_cache", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
), |
|
|
InputParam( |
|
|
"current_start_frame", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
), |
|
|
InputParam( |
|
|
"num_inference_steps", |
|
|
required=True, |
|
|
type_hint=int, |
|
|
default=4, |
|
|
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", |
|
|
), |
|
|
InputParam( |
|
|
kwargs_type="guider_input_fields", |
|
|
description=( |
|
|
"All conditional model inputs that need to be prepared with guider. " |
|
|
"It should contain prompt_embeds/negative_prompt_embeds. " |
|
|
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" |
|
|
), |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, |
|
|
components: ModularPipeline, |
|
|
block_state: BlockState, |
|
|
i: int, |
|
|
t: torch.Tensor, |
|
|
) -> PipelineState: |
|
|
start_frame = min( |
|
|
block_state.current_start_frame, components.config.kv_cache_num_frames |
|
|
) |
|
|
|
|
|
block_state.noise_pred = components.transformer( |
|
|
x=block_state.latents, |
|
|
t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block), |
|
|
context=block_state.prompt_embeds, |
|
|
kv_cache=block_state.kv_cache, |
|
|
seq_len=components.config.seq_length, |
|
|
crossattn_cache=block_state.crossattn_cache, |
|
|
current_start=start_frame * components.config.frame_seq_length, |
|
|
cache_start=start_frame * components.config.frame_seq_length, |
|
|
) |
|
|
|
|
|
return components, block_state |
|
|
|
|
|
|
|
|
class WanRTStreamingLoopAfterDenoiser(ModularPipelineBlocks): |
|
|
model_name = "WanRTStreaming" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("scheduler", UniPCMultistepScheduler), |
|
|
] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"step within the denoising loop that update the latents. " |
|
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " |
|
|
"object (e.g. `WanRTStreamingDenoiseLoopWrapper`)" |
|
|
) |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[Tuple[str, Any]]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def intermediate_inputs(self) -> List[str]: |
|
|
return [ |
|
|
InputParam("generator"), |
|
|
InputParam("block_id"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"latents", type_hint=torch.Tensor, description="The denoised latents" |
|
|
) |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, |
|
|
components: ModularPipeline, |
|
|
block_state: BlockState, |
|
|
i: int, |
|
|
t: torch.Tensor, |
|
|
): |
|
|
|
|
|
latents_dtype = block_state.latents.dtype |
|
|
timesteps = block_state.all_timesteps |
|
|
sigmas = block_state.sigmas |
|
|
|
|
|
timestep_id = torch.argmin((timesteps - t).abs()) |
|
|
sigma_t = sigmas[timestep_id] |
|
|
|
|
|
|
|
|
latents = ( |
|
|
block_state.latents.double() |
|
|
- sigma_t.double() * block_state.noise_pred.double() |
|
|
).to(latents_dtype) |
|
|
|
|
|
block_state.latents = latents |
|
|
|
|
|
return components, block_state |
|
|
|
|
|
|
|
|
class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks): |
|
|
model_name = "WanRTStreaming" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Streaming denoising loop that processes a single block with persistent KV cache. " |
|
|
"Recomputes cache from context frames, denoises current block, and updates cache." |
|
|
) |
|
|
|
|
|
def add_noise(self, components, block_state, sample, noise, timestep, index): |
|
|
timesteps = block_state.all_timesteps |
|
|
sigmas = block_state.sigmas.to(timesteps.device) |
|
|
|
|
|
if timestep.ndim == 2: |
|
|
timestep = timestep.flatten(0, 1) |
|
|
|
|
|
timestep_id = torch.argmin( |
|
|
(timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 |
|
|
) |
|
|
sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1) |
|
|
sample = ( |
|
|
1 - sigma.double() |
|
|
) * sample.double() + sigma.double() * noise.double() |
|
|
sample = sample.type_as(noise) |
|
|
|
|
|
return sample |
|
|
|
|
|
@property |
|
|
def loop_inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"timesteps", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", |
|
|
), |
|
|
InputParam( |
|
|
"all_timesteps", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", |
|
|
), |
|
|
InputParam( |
|
|
"sigmas", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", |
|
|
), |
|
|
InputParam("final_latents", type_hint=torch.Tensor), |
|
|
InputParam( |
|
|
"num_inference_steps", |
|
|
required=True, |
|
|
type_hint=int, |
|
|
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", |
|
|
), |
|
|
InputParam( |
|
|
"num_frames_per_block", |
|
|
required=True, |
|
|
type_hint=int, |
|
|
default=3, |
|
|
), |
|
|
InputParam( |
|
|
"current_start_frame", |
|
|
required=True, |
|
|
type_hint=int, |
|
|
), |
|
|
InputParam( |
|
|
"block_idx", |
|
|
), |
|
|
InputParam( |
|
|
"generator", |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
for i, t in enumerate(block_state.timesteps): |
|
|
components, block_state = self.loop_step(components, block_state, i=i, t=t) |
|
|
if i < (block_state.num_inference_steps - 1): |
|
|
t1 = block_state.timesteps[i + 1] |
|
|
|
|
|
block_state.latents = ( |
|
|
self.add_noise( |
|
|
components, |
|
|
block_state, |
|
|
block_state.latents.transpose(1, 2).squeeze(0), |
|
|
randn_tensor( |
|
|
block_state.latents.transpose(1, 2).squeeze(0).shape, |
|
|
device=block_state.latents.device, |
|
|
dtype=block_state.latents.dtype, |
|
|
generator=block_state.generator, |
|
|
), |
|
|
t1.expand( |
|
|
block_state.latents.shape[0], |
|
|
block_state.num_frames_per_block, |
|
|
), |
|
|
i, |
|
|
) |
|
|
.unsqueeze(0) |
|
|
.transpose(1, 2) |
|
|
) |
|
|
|
|
|
|
|
|
block_state.final_latents[ |
|
|
:, |
|
|
:, |
|
|
block_state.current_start_frame : block_state.current_start_frame |
|
|
+ block_state.num_frames_per_block, |
|
|
] = block_state.latents |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|
|
|
|
|
|
class WanRTStreamingDenoiseStep(WanRTStreamingDenoiseLoopWrapper): |
|
|
block_classes = [ |
|
|
WanRTStreamingLoopDenoiser, |
|
|
WanRTStreamingLoopAfterDenoiser, |
|
|
] |
|
|
block_names = ["denoiser", "after_denoiser"] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Denoise step that iteratively denoise the latents. \n" |
|
|
"Its loop logic is defined in `WanRTStreamingDenoiseLoopWrapper.__call__` method \n" |
|
|
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" |
|
|
" - `WanRTStreamingLoopDenoiser`\n" |
|
|
" - `WanRTStreamingLoopAfterDenoiser`\n" |
|
|
"This block supports both text2vid tasks." |
|
|
) |
|
|
|