krea-realtime-video / denoise.py
viccpoes's picture
Add diffusers (#1)
5b1c701 verified
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.models import AutoModel
from diffusers.schedulers import UniPCMultistepScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.modular_pipelines import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
ModularPipeline,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
InputParam,
OutputParam,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanRTStreamingLoopDenoiser(ModularPipelineBlocks):
model_name = "WanRTStreaming"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", AutoModel)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanRTStreamingDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("attention_kwargs"),
InputParam("block_idx"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
),
InputParam(
"kv_cache",
required=True,
type_hint=torch.Tensor,
),
InputParam(
"crossattn_cache",
required=True,
type_hint=torch.Tensor,
),
InputParam(
"current_start_frame",
required=True,
type_hint=torch.Tensor,
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
default=4,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds. "
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@torch.no_grad()
def __call__(
self,
components: ModularPipeline,
block_state: BlockState,
i: int,
t: torch.Tensor,
) -> PipelineState:
start_frame = min(
block_state.current_start_frame, components.config.kv_cache_num_frames
)
block_state.noise_pred = components.transformer(
x=block_state.latents,
t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block),
context=block_state.prompt_embeds,
kv_cache=block_state.kv_cache,
seq_len=components.config.seq_length,
crossattn_cache=block_state.crossattn_cache,
current_start=start_frame * components.config.frame_seq_length,
cache_start=start_frame * components.config.frame_seq_length,
)
return components, block_state
class WanRTStreamingLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "WanRTStreaming"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", UniPCMultistepScheduler),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that update the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanRTStreamingDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return []
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
InputParam("block_id"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The denoised latents"
)
]
@torch.no_grad()
def __call__(
self,
components: ModularPipeline,
block_state: BlockState,
i: int,
t: torch.Tensor,
):
# Perform scheduler step using the predicted output
latents_dtype = block_state.latents.dtype
timesteps = block_state.all_timesteps
sigmas = block_state.sigmas
timestep_id = torch.argmin((timesteps - t).abs())
sigma_t = sigmas[timestep_id]
# Perform computation in double precision, then convert back once
latents = (
block_state.latents.double()
- sigma_t.double() * block_state.noise_pred.double()
).to(latents_dtype)
block_state.latents = latents
return components, block_state
class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "WanRTStreaming"
@property
def description(self) -> str:
return (
"Streaming denoising loop that processes a single block with persistent KV cache. "
"Recomputes cache from context frames, denoises current block, and updates cache."
)
def add_noise(self, components, block_state, sample, noise, timestep, index):
timesteps = block_state.all_timesteps
sigmas = block_state.sigmas.to(timesteps.device)
if timestep.ndim == 2:
timestep = timestep.flatten(0, 1)
timestep_id = torch.argmin(
(timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1
)
sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1)
sample = (
1 - sigma.double()
) * sample.double() + sigma.double() * noise.double()
sample = sample.type_as(noise)
return sample
@property
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"all_timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"sigmas",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam("final_latents", type_hint=torch.Tensor),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_frames_per_block",
required=True,
type_hint=int,
default=3,
),
InputParam(
"current_start_frame",
required=True,
type_hint=int,
),
InputParam(
"block_idx",
),
InputParam(
"generator",
),
]
@torch.no_grad()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i < (block_state.num_inference_steps - 1):
t1 = block_state.timesteps[i + 1]
block_state.latents = (
self.add_noise(
components,
block_state,
block_state.latents.transpose(1, 2).squeeze(0),
randn_tensor(
block_state.latents.transpose(1, 2).squeeze(0).shape,
device=block_state.latents.device,
dtype=block_state.latents.dtype,
generator=block_state.generator,
),
t1.expand(
block_state.latents.shape[0],
block_state.num_frames_per_block,
),
i,
)
.unsqueeze(0)
.transpose(1, 2)
)
# Update the state
block_state.final_latents[
:,
:,
block_state.current_start_frame : block_state.current_start_frame
+ block_state.num_frames_per_block,
] = block_state.latents
self.set_block_state(state, block_state)
return components, state
class WanRTStreamingDenoiseStep(WanRTStreamingDenoiseLoopWrapper):
block_classes = [
WanRTStreamingLoopDenoiser,
WanRTStreamingLoopAfterDenoiser,
]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanRTStreamingDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `WanRTStreamingLoopDenoiser`\n"
" - `WanRTStreamingLoopAfterDenoiser`\n"
"This block supports both text2vid tasks."
)