krea-realtime-video / before_denoise.py
viccpoes's picture
Add diffusers (#1)
5b1c701 verified
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Union, Dict
import torch
from diffusers import AutoencoderKLWan
from diffusers.schedulers import UniPCMultistepScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.modular_pipelines import (
ModularPipeline,
ModularPipelineBlocks,
SequentialPipelineBlocks,
PipelineState,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
OutputParam,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError(
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
)
if timesteps is not None:
accepts_timesteps = "timesteps" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
def retrieve_latents(
encoder_output: torch.Tensor,
generator: Optional[torch.Generator] = None,
sample_mode: str = "sample",
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def _initialize_kv_cache(
components: ModularPipeline,
kv_cache_existing: Optional[List[Dict]],
batch_size: int,
dtype: torch.dtype,
device: torch.device,
local_attn_size: int,
frame_seq_length: int,
):
"""
Initialize a Per-GPU KV cache for the Wan model.
Mirrors causal_inference.py:279-313
"""
kv_cache = []
# Calculate KV cache size
if local_attn_size != -1:
# Use the local attention size to compute the KV cache size
kv_cache_size = local_attn_size * frame_seq_length
else:
# Use the default KV cache size
kv_cache_size = 32760
# Get transformer config
num_transformer_blocks = len(components.transformer.blocks)
num_heads = components.transformer.config.num_heads
dim = components.transformer.config.dim
k_shape = [batch_size, kv_cache_size, num_heads, dim // num_heads]
v_shape = [batch_size, kv_cache_size, num_heads, dim // num_heads]
# Check if we can reuse existing cache
if (
kv_cache_existing
and len(kv_cache_existing) > 0
and list(kv_cache_existing[0]["k"].shape) == k_shape
and list(kv_cache_existing[0]["v"].shape) == v_shape
):
for i in range(num_transformer_blocks):
kv_cache_existing[i]["k"].zero_()
kv_cache_existing[i]["v"].zero_()
kv_cache_existing[i]["global_end_index"] = 0
kv_cache_existing[i]["local_end_index"] = 0
return kv_cache_existing
else:
# Create new cache
for _ in range(num_transformer_blocks):
kv_cache.append(
{
"k": torch.zeros(k_shape, dtype=dtype, device=device).contiguous(),
"v": torch.zeros(v_shape, dtype=dtype, device=device).contiguous(),
"global_end_index": 0,
"local_end_index": 0,
}
)
return kv_cache
def _initialize_crossattn_cache(
components: ModularPipeline,
crossattn_cache_existing: Optional[List[Dict]],
batch_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""
Initialize a Per-GPU cross-attention cache for the Wan model.
Mirrors causal_inference.py:315-338
"""
crossattn_cache = []
# Get transformer config
num_transformer_blocks = len(components.transformer.blocks)
num_heads = components.transformer.config.num_heads
dim = components.transformer.config.dim
k_shape = [batch_size, 512, num_heads, dim // num_heads]
v_shape = [batch_size, 512, num_heads, dim // num_heads]
# Check if we can reuse existing cache
if (
crossattn_cache_existing
and len(crossattn_cache_existing) > 0
and list(crossattn_cache_existing[0]["k"].shape) == k_shape
and list(crossattn_cache_existing[0]["v"].shape) == v_shape
):
for i in range(num_transformer_blocks):
crossattn_cache_existing[i]["k"].zero_()
crossattn_cache_existing[i]["v"].zero_()
crossattn_cache_existing[i]["is_init"] = False
return crossattn_cache_existing
else:
# Create new cache
for _ in range(num_transformer_blocks):
crossattn_cache.append(
{
"k": torch.zeros(k_shape, dtype=dtype, device=device).contiguous(),
"v": torch.zeros(v_shape, dtype=dtype, device=device).contiguous(),
"is_init": False,
}
)
return crossattn_cache
class WanInputStep(ModularPipelineBlocks):
model_name = "WanRT"
@property
def description(self) -> str:
return (
"Input processing step that:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n"
"All input tensors are expected to have either batch_size=1 or match the batch_size\n"
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
"have a final batch_size of batch_size * num_videos_per_prompt."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_videos_per_prompt", default=1),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="negative text embeddings used to guide the image generation",
),
]
def check_inputs(self, components, block_state):
if (
block_state.prompt_embeds is not None
and block_state.negative_prompt_embeds is not None
):
if (
block_state.prompt_embeds.shape
!= block_state.negative_prompt_embeds.shape
):
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`"
f" {block_state.negative_prompt_embeds.shape}."
)
@torch.no_grad()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(
1, block_state.num_videos_per_prompt, 1
)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
)
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = (
block_state.negative_prompt_embeds.repeat(
1, block_state.num_videos_per_prompt, 1
)
)
block_state.negative_prompt_embeds = (
block_state.negative_prompt_embeds.view(
block_state.batch_size * block_state.num_videos_per_prompt,
seq_len,
-1,
)
)
self.set_block_state(state, block_state)
return components, state
class WanRTStreamingSetTimestepsStep(ModularPipelineBlocks):
model_name = "WanRT"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", UniPCMultistepScheduler),
]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for inference"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=4),
InputParam("timesteps"),
InputParam("sigmas"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"timesteps",
type_hint=torch.Tensor,
description="The timesteps to use for inference",
),
OutputParam(
"all_timesteps",
type_hint=torch.Tensor,
description="The timesteps to use for inference",
),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
]
@torch.no_grad()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
shift = 5.0
sigmas = torch.linspace(1.0, 0.0, 1001)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas.to(components.transformer.device) * 1000.0
zero_padded_timesteps = torch.cat(
[
timesteps,
torch.tensor([0], device=components.transformer.device),
]
)
denoising_steps = torch.linspace(
1000, 0, block_state.num_inference_steps, dtype=torch.float32
).to(torch.long)
block_state.timesteps = zero_padded_timesteps[1000 - denoising_steps]
block_state.all_timesteps = timesteps
block_state.sigmas = sigmas
self.set_block_state(state, block_state)
return components, state
class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks):
model_name = "WanRT"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLWan),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [ConfigSpec("num_frames_per_block", 3)]
@property
def description(self) -> str:
return "Prepare latents step that prepares the latents for the text-to-video generation process"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("num_blocks", type_hint=int),
InputParam("num_frames_per_block", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("init_latents", type_hint=Optional[torch.Tensor]),
InputParam("final_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_videos_per_prompt", type_hint=int, default=1),
InputParam("generator"),
InputParam(
"dtype",
type_hint=torch.dtype,
description="The dtype of the model inputs",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents",
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process",
),
OutputParam(
"init_latents",
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process",
),
OutputParam(
"final_latents",
type_hint=torch.Tensor,
),
]
@staticmethod
def check_inputs(components, block_state):
if (
block_state.height is not None
and block_state.height % components.vae_scale_factor_spatial != 0
) or (
block_state.width is not None
and block_state.width % components.vae_scale_factor_spatial != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
)
@staticmethod
def prepare_latents(
components,
batch_size: int,
num_channels_latents: int = 16,
height: int = 352,
width: int = 640,
num_blocks: int = 9,
num_frames_per_block: int = 3,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_latent_frames = num_blocks * num_frames_per_block
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
int(height) // components.vae_scale_factor_spatial,
int(width) // components.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(
shape,
generator=generator,
device=components.transformer.device,
dtype=dtype,
)
return latents
@torch.no_grad()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
block_state.init_latents = self.prepare_latents(
components,
1,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.num_blocks,
components.config.num_frames_per_block,
components.transformer.dtype,
block_state.device,
block_state.generator,
block_state.init_latents,
)
if block_state.final_latents is None:
block_state.final_latents = torch.zeros_like(
block_state.init_latents, device=components.transformer.device
)
self.set_block_state(state, block_state)
return components, state
class WanRTStreamingExtractBlockLatentsStep(ModularPipelineBlocks):
"""
Extracts a single block of latents from the full video buffer for streaming generation.
This block simply slices the final_latents buffer to get the current block's latents.
The final_latents buffer should be created beforehand using WanRTStreamingPrepareAllLatents.
"""
model_name = "WanRT"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def description(self) -> str:
return (
"Extracts a single block from the full latent buffer for streaming generation. "
"Slices final_latents based on block_idx to get current block's latents."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"final_latents",
required=True,
type_hint=torch.Tensor,
description="Full latent buffer [B, C, total_frames, H, W]",
),
InputParam(
"init_latents",
required=True,
type_hint=torch.Tensor,
description="Full latent buffer [B, C, total_frames, H, W]",
),
InputParam(
"latents",
type_hint=torch.Tensor,
description="Full latent buffer [B, C, total_frames, H, W]",
),
InputParam(
"block_idx",
required=True,
type_hint=int,
default=0,
description="Current block index to process",
),
InputParam(
"num_frames_per_block",
required=True,
type_hint=int,
default=3,
description="Number of frames per block",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents",
type_hint=torch.Tensor,
description="Latents for current block [B, C, num_frames_per_block, H, W]",
),
OutputParam(
"current_start_frame",
type_hint=int,
description="Starting frame index for current block",
),
]
@torch.no_grad()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
num_frames_per_block = block_state.num_frames_per_block
block_idx = block_state.block_idx
# Calculate frame range for current block
start_frame = block_idx * num_frames_per_block
end_frame = start_frame + num_frames_per_block
# Extract single block from full latent buffer
# final_latents shape: [B, C, total_frames, H, W]
# Extract frames along the time dimension (dim=2)
block_state.latents = block_state.init_latents[
:, :, start_frame:end_frame, :, :
]
block_state.current_start_frame = start_frame
self.set_block_state(state, block_state)
return components, state
class WanRTStreamingSetupKVCache(ModularPipelineBlocks):
"""
Initializes KV cache and cross-attention cache for streaming generation.
This block sets up the persistent caches used across all blocks in streaming
generation. Mirrors the cache initialization logic from causal_inference.py.
Should be called once at the start of streaming generation.
"""
model_name = "WanRT"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("transformer", torch.nn.Module),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec("kv_cache_num_frames", 3),
ConfigSpec("num_frames_per_block", 3),
ConfigSpec("frame_seq_length", 1560),
ConfigSpec("frame_cache_len", 9),
]
@property
def description(self) -> str:
return (
"Initializes KV cache and cross-attention cache for streaming generation. "
"Creates persistent caches that will be reused across all blocks."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"kv_cache",
required=False,
type_hint=Optional[List[Dict]],
description="Existing KV cache. If provided and shape matches, will be zeroed instead of recreated.",
),
InputParam(
"crossattn_cache",
required=False,
type_hint=Optional[List[Dict]],
description="Existing cross-attention cache. If provided and shape matches, will be zeroed.",
),
InputParam(
"local_attn_size",
required=False,
type_hint=int,
default=-1,
description="Local attention size for computing KV cache size. -1 uses default (32760).",
),
InputParam(
"dtype",
required=False,
type_hint=torch.dtype,
description="Data type for caches (defaults to bfloat16)",
),
InputParam(
"update_prompt_embeds",
required=False,
description="Flag to reinitialize prompt embeds if they are updated.",
default=False,
),
]
@property
def outputs(self) -> List[OutputParam]:
return [
OutputParam(
"kv_cache",
type_hint=List[Dict],
description="Initialized KV cache (list of dicts per transformer block)",
),
OutputParam(
"crossattn_cache",
type_hint=List[Dict],
description="Initialized cross-attention cache",
),
OutputParam(
"local_attn_size",
),
]
@torch.no_grad()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
batch_size = 1 # Streaming always uses batch_size=1
# Get existing caches if they exist
kv_cache = block_state.kv_cache
crossattn_cache = block_state.crossattn_cache
if block_state.crossattn_cache is None or block_state.update_prompt_embeds:
block_state.crossattn_cache = _initialize_crossattn_cache(
components,
crossattn_cache,
batch_size,
components.transformer.dtype,
components.transformer.device,
)
block_state.local_attn_size = (
components.config.kv_cache_num_frames
+ components.config.num_frames_per_block
)
for block in components.transformer.blocks:
block.self_attn.local_attn_size = -1
for block in components.transformer.blocks:
block.self_attn.num_frame_per_block = components.config.num_frames_per_block
block_state.kv_cache = _initialize_kv_cache(
components,
kv_cache,
batch_size,
components.transformer.dtype,
components.transformer.device,
block_state.local_attn_size,
components.config.frame_seq_length,
)
self.set_block_state(state, block_state)
return components, state
class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
type_hint=torch.Tensor,
description="Current block latents [B, C, num_frames_per_block, H, W]",
),
InputParam(
"num_frames_per_block",
type_hint=int,
description="Number of frames per block",
),
InputParam(
"block_idx",
type_hint=int,
description="Current block index to process",
),
InputParam(
"block_mask",
description="Block-wise causal attention mask",
),
InputParam(
"current_start_frame",
type_hint=int,
description="Starting frame index for current block",
),
InputParam(
"videos",
type_hint=torch.Tensor,
description="Video frames for context encoding",
),
InputParam(
"final_latents",
type_hint=torch.Tensor,
description="Full latent buffer [B, C, total_frames, H, W]",
),
InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
description="Text embeddings to guide generation",
),
InputParam(
"kv_cache",
type_hint=torch.Tensor,
description="Key-value cache for attention",
),
InputParam(
"crossattn_cache",
type_hint=torch.Tensor,
description="Cross-attention cache",
),
InputParam(
"encoder_cache",
description="Encoder feature cache",
),
InputParam(
"frame_cache_context",
description="Cached context frames for reencoding",
),
InputParam(
"local_attn_size",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [ConfigSpec("seq_length", 32760)]
def prepare_latents(self, components, block_state):
frames = block_state.frame_cache_context[0].half()
components.vae._enc_feat_map = [None] * 55
latents = retrieve_latents(components.vae.encode(frames), sample_mode="argmax")
latents_mean = (
torch.tensor(components.vae.config.latents_mean)
.view(1, components.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
1, components.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * latents_std
return latents.to(components.transformer.dtype)
def get_context_frames(self, components, block_state):
current_kv_cache_num_frames = components.config.kv_cache_num_frames
context_frames = block_state.final_latents[
:, :, : block_state.current_start_frame
]
if (
block_state.block_idx - 1
) * block_state.num_frames_per_block < current_kv_cache_num_frames:
if current_kv_cache_num_frames == 1:
context_frames = context_frames[:, :, :1]
else:
context_frames = torch.cat(
(
context_frames[:, :, :1],
context_frames[:, :, 1:][
:, :, -current_kv_cache_num_frames + 1 :
],
),
dim=2,
)
else:
context_frames = context_frames[:, :, 1:][
:, :, -current_kv_cache_num_frames + 1 :
]
first_frame_latent = self.prepare_latents(components, block_state)
first_frame_latent = first_frame_latent.to(block_state.latents)
context_frames = torch.cat((first_frame_latent, context_frames), dim=2)
return context_frames
def __call__(self, components, state):
block_state = self.get_block_state(state)
if block_state.block_idx == 0:
return components, state
start_frame = min(
block_state.current_start_frame, components.config.kv_cache_num_frames
)
context_frames = self.get_context_frames(components, block_state)
block_state.block_mask = (
components.transformer._prepare_blockwise_causal_attn_mask(
components.transformer.device,
num_frames=context_frames.shape[2],
frame_seqlen=components.config.frame_seq_length,
num_frame_per_block=block_state.num_frames_per_block,
local_attn_size=-1,
)
)
components.transformer.block_mask = block_state.block_mask
context_timestep = torch.zeros(
(context_frames.shape[0], context_frames.shape[2]),
device=components.transformer.device,
dtype=torch.int64,
)
components.transformer(
x=context_frames.to(components.transformer.dtype),
t=context_timestep,
context=block_state.prompt_embeds.to(components.transformer.dtype),
kv_cache=block_state.kv_cache,
seq_len=components.config.seq_length,
crossattn_cache=block_state.crossattn_cache,
current_start=start_frame * components.config.frame_seq_length,
cache_start=None,
)
components.transformer.block_mask = None
return components, state
class WanRTStreamingBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanRTStreamingSetTimestepsStep,
WanRTStreamingPrepareLatentsStep,
WanRTStreamingExtractBlockLatentsStep,
WanRTStreamingSetupKVCache,
WanRTStreamingRecomputeKVCache,
]
block_names = [
"set_timesteps",
"prepare_latents",
"extract_block_init_latents",
"setup_kv_cache",
"recompute_kv_cache",
]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanRTSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanRTPrepareLatentsStep` is used to prepare the latents\n"
)