|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
from typing import List, Optional, Union, Dict |
|
|
|
|
|
import torch |
|
|
|
|
|
from diffusers import AutoencoderKLWan |
|
|
from diffusers.schedulers import UniPCMultistepScheduler |
|
|
from diffusers.utils import logging |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
from diffusers.modular_pipelines import ( |
|
|
ModularPipeline, |
|
|
ModularPipelineBlocks, |
|
|
SequentialPipelineBlocks, |
|
|
PipelineState, |
|
|
) |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import ( |
|
|
ComponentSpec, |
|
|
ConfigSpec, |
|
|
InputParam, |
|
|
OutputParam, |
|
|
) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def retrieve_timesteps( |
|
|
scheduler, |
|
|
num_inference_steps: Optional[int] = None, |
|
|
device: Optional[Union[str, torch.device]] = None, |
|
|
timesteps: Optional[List[int]] = None, |
|
|
sigmas: Optional[List[float]] = None, |
|
|
**kwargs, |
|
|
): |
|
|
r""" |
|
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
|
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
|
|
|
|
|
Args: |
|
|
scheduler (`SchedulerMixin`): |
|
|
The scheduler to get timesteps from. |
|
|
num_inference_steps (`int`): |
|
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
|
|
must be `None`. |
|
|
device (`str` or `torch.device`, *optional*): |
|
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
|
timesteps (`List[int]`, *optional*): |
|
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
|
|
`num_inference_steps` and `sigmas` must be `None`. |
|
|
sigmas (`List[float]`, *optional*): |
|
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
|
|
`num_inference_steps` and `timesteps` must be `None`. |
|
|
|
|
|
Returns: |
|
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
|
|
second element is the number of inference steps. |
|
|
""" |
|
|
if timesteps is not None and sigmas is not None: |
|
|
raise ValueError( |
|
|
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" |
|
|
) |
|
|
if timesteps is not None: |
|
|
accepts_timesteps = "timesteps" in set( |
|
|
inspect.signature(scheduler.set_timesteps).parameters.keys() |
|
|
) |
|
|
if not accepts_timesteps: |
|
|
raise ValueError( |
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
|
f" timestep schedules. Please check whether you are using the correct scheduler." |
|
|
) |
|
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
|
|
timesteps = scheduler.timesteps |
|
|
num_inference_steps = len(timesteps) |
|
|
elif sigmas is not None: |
|
|
accept_sigmas = "sigmas" in set( |
|
|
inspect.signature(scheduler.set_timesteps).parameters.keys() |
|
|
) |
|
|
if not accept_sigmas: |
|
|
raise ValueError( |
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
|
f" sigmas schedules. Please check whether you are using the correct scheduler." |
|
|
) |
|
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
|
|
timesteps = scheduler.timesteps |
|
|
num_inference_steps = len(timesteps) |
|
|
else: |
|
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
|
timesteps = scheduler.timesteps |
|
|
return timesteps, num_inference_steps |
|
|
|
|
|
|
|
|
def retrieve_latents( |
|
|
encoder_output: torch.Tensor, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
sample_mode: str = "sample", |
|
|
): |
|
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
|
|
return encoder_output.latent_dist.sample(generator) |
|
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
|
|
return encoder_output.latent_dist.mode() |
|
|
elif hasattr(encoder_output, "latents"): |
|
|
return encoder_output.latents |
|
|
else: |
|
|
raise AttributeError("Could not access latents of provided encoder_output") |
|
|
|
|
|
|
|
|
def _initialize_kv_cache( |
|
|
components: ModularPipeline, |
|
|
kv_cache_existing: Optional[List[Dict]], |
|
|
batch_size: int, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
local_attn_size: int, |
|
|
frame_seq_length: int, |
|
|
): |
|
|
""" |
|
|
Initialize a Per-GPU KV cache for the Wan model. |
|
|
Mirrors causal_inference.py:279-313 |
|
|
""" |
|
|
kv_cache = [] |
|
|
|
|
|
|
|
|
if local_attn_size != -1: |
|
|
|
|
|
kv_cache_size = local_attn_size * frame_seq_length |
|
|
else: |
|
|
|
|
|
kv_cache_size = 32760 |
|
|
|
|
|
|
|
|
num_transformer_blocks = len(components.transformer.blocks) |
|
|
num_heads = components.transformer.config.num_heads |
|
|
dim = components.transformer.config.dim |
|
|
k_shape = [batch_size, kv_cache_size, num_heads, dim // num_heads] |
|
|
v_shape = [batch_size, kv_cache_size, num_heads, dim // num_heads] |
|
|
|
|
|
|
|
|
if ( |
|
|
kv_cache_existing |
|
|
and len(kv_cache_existing) > 0 |
|
|
and list(kv_cache_existing[0]["k"].shape) == k_shape |
|
|
and list(kv_cache_existing[0]["v"].shape) == v_shape |
|
|
): |
|
|
for i in range(num_transformer_blocks): |
|
|
kv_cache_existing[i]["k"].zero_() |
|
|
kv_cache_existing[i]["v"].zero_() |
|
|
kv_cache_existing[i]["global_end_index"] = 0 |
|
|
kv_cache_existing[i]["local_end_index"] = 0 |
|
|
return kv_cache_existing |
|
|
else: |
|
|
|
|
|
for _ in range(num_transformer_blocks): |
|
|
kv_cache.append( |
|
|
{ |
|
|
"k": torch.zeros(k_shape, dtype=dtype, device=device).contiguous(), |
|
|
"v": torch.zeros(v_shape, dtype=dtype, device=device).contiguous(), |
|
|
"global_end_index": 0, |
|
|
"local_end_index": 0, |
|
|
} |
|
|
) |
|
|
return kv_cache |
|
|
|
|
|
|
|
|
def _initialize_crossattn_cache( |
|
|
components: ModularPipeline, |
|
|
crossattn_cache_existing: Optional[List[Dict]], |
|
|
batch_size: int, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
): |
|
|
""" |
|
|
Initialize a Per-GPU cross-attention cache for the Wan model. |
|
|
Mirrors causal_inference.py:315-338 |
|
|
""" |
|
|
crossattn_cache = [] |
|
|
|
|
|
|
|
|
num_transformer_blocks = len(components.transformer.blocks) |
|
|
num_heads = components.transformer.config.num_heads |
|
|
dim = components.transformer.config.dim |
|
|
k_shape = [batch_size, 512, num_heads, dim // num_heads] |
|
|
v_shape = [batch_size, 512, num_heads, dim // num_heads] |
|
|
|
|
|
|
|
|
if ( |
|
|
crossattn_cache_existing |
|
|
and len(crossattn_cache_existing) > 0 |
|
|
and list(crossattn_cache_existing[0]["k"].shape) == k_shape |
|
|
and list(crossattn_cache_existing[0]["v"].shape) == v_shape |
|
|
): |
|
|
for i in range(num_transformer_blocks): |
|
|
crossattn_cache_existing[i]["k"].zero_() |
|
|
crossattn_cache_existing[i]["v"].zero_() |
|
|
crossattn_cache_existing[i]["is_init"] = False |
|
|
return crossattn_cache_existing |
|
|
else: |
|
|
|
|
|
for _ in range(num_transformer_blocks): |
|
|
crossattn_cache.append( |
|
|
{ |
|
|
"k": torch.zeros(k_shape, dtype=dtype, device=device).contiguous(), |
|
|
"v": torch.zeros(v_shape, dtype=dtype, device=device).contiguous(), |
|
|
"is_init": False, |
|
|
} |
|
|
) |
|
|
return crossattn_cache |
|
|
|
|
|
|
|
|
class WanInputStep(ModularPipelineBlocks): |
|
|
model_name = "WanRT" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Input processing step that:\n" |
|
|
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" |
|
|
" 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n" |
|
|
"All input tensors are expected to have either batch_size=1 or match the batch_size\n" |
|
|
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" |
|
|
"have a final batch_size of batch_size * num_videos_per_prompt." |
|
|
) |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("num_videos_per_prompt", default=1), |
|
|
InputParam( |
|
|
"prompt_embeds", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="Pre-generated text embeddings. Can be generated from text_encoder step.", |
|
|
), |
|
|
InputParam( |
|
|
"negative_prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[str]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"batch_size", |
|
|
type_hint=int, |
|
|
description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt", |
|
|
), |
|
|
OutputParam( |
|
|
"dtype", |
|
|
type_hint=torch.dtype, |
|
|
description="Data type of model tensor inputs (determined by `prompt_embeds`)", |
|
|
), |
|
|
OutputParam( |
|
|
"prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="text embeddings used to guide the image generation", |
|
|
), |
|
|
OutputParam( |
|
|
"negative_prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="negative text embeddings used to guide the image generation", |
|
|
), |
|
|
] |
|
|
|
|
|
def check_inputs(self, components, block_state): |
|
|
if ( |
|
|
block_state.prompt_embeds is not None |
|
|
and block_state.negative_prompt_embeds is not None |
|
|
): |
|
|
if ( |
|
|
block_state.prompt_embeds.shape |
|
|
!= block_state.negative_prompt_embeds.shape |
|
|
): |
|
|
raise ValueError( |
|
|
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" |
|
|
f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" |
|
|
f" {block_state.negative_prompt_embeds.shape}." |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
self.check_inputs(components, block_state) |
|
|
|
|
|
block_state.batch_size = block_state.prompt_embeds.shape[0] |
|
|
block_state.dtype = block_state.prompt_embeds.dtype |
|
|
|
|
|
_, seq_len, _ = block_state.prompt_embeds.shape |
|
|
block_state.prompt_embeds = block_state.prompt_embeds.repeat( |
|
|
1, block_state.num_videos_per_prompt, 1 |
|
|
) |
|
|
block_state.prompt_embeds = block_state.prompt_embeds.view( |
|
|
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 |
|
|
) |
|
|
|
|
|
if block_state.negative_prompt_embeds is not None: |
|
|
_, seq_len, _ = block_state.negative_prompt_embeds.shape |
|
|
block_state.negative_prompt_embeds = ( |
|
|
block_state.negative_prompt_embeds.repeat( |
|
|
1, block_state.num_videos_per_prompt, 1 |
|
|
) |
|
|
) |
|
|
block_state.negative_prompt_embeds = ( |
|
|
block_state.negative_prompt_embeds.view( |
|
|
block_state.batch_size * block_state.num_videos_per_prompt, |
|
|
seq_len, |
|
|
-1, |
|
|
) |
|
|
) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|
|
|
|
|
|
class WanRTStreamingSetTimestepsStep(ModularPipelineBlocks): |
|
|
model_name = "WanRT" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("scheduler", UniPCMultistepScheduler), |
|
|
] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Step that sets the scheduler's timesteps for inference" |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("num_inference_steps", default=4), |
|
|
InputParam("timesteps"), |
|
|
InputParam("sigmas"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"timesteps", |
|
|
type_hint=torch.Tensor, |
|
|
description="The timesteps to use for inference", |
|
|
), |
|
|
OutputParam( |
|
|
"all_timesteps", |
|
|
type_hint=torch.Tensor, |
|
|
description="The timesteps to use for inference", |
|
|
), |
|
|
OutputParam( |
|
|
"num_inference_steps", |
|
|
type_hint=int, |
|
|
description="The number of denoising steps to perform at inference time", |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
block_state.device = components._execution_device |
|
|
|
|
|
shift = 5.0 |
|
|
sigmas = torch.linspace(1.0, 0.0, 1001)[:-1] |
|
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) |
|
|
|
|
|
timesteps = sigmas.to(components.transformer.device) * 1000.0 |
|
|
zero_padded_timesteps = torch.cat( |
|
|
[ |
|
|
timesteps, |
|
|
torch.tensor([0], device=components.transformer.device), |
|
|
] |
|
|
) |
|
|
denoising_steps = torch.linspace( |
|
|
1000, 0, block_state.num_inference_steps, dtype=torch.float32 |
|
|
).to(torch.long) |
|
|
|
|
|
block_state.timesteps = zero_padded_timesteps[1000 - denoising_steps] |
|
|
block_state.all_timesteps = timesteps |
|
|
block_state.sigmas = sigmas |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|
|
|
|
|
|
class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks): |
|
|
model_name = "WanRT" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("vae", AutoencoderKLWan), |
|
|
] |
|
|
|
|
|
@property |
|
|
def expected_configs(self) -> List[ConfigSpec]: |
|
|
return [ConfigSpec("num_frames_per_block", 3)] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Prepare latents step that prepares the latents for the text-to-video generation process" |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("height", type_hint=int), |
|
|
InputParam("width", type_hint=int), |
|
|
InputParam("num_blocks", type_hint=int), |
|
|
InputParam("num_frames_per_block", type_hint=int), |
|
|
InputParam("latents", type_hint=Optional[torch.Tensor]), |
|
|
InputParam("init_latents", type_hint=Optional[torch.Tensor]), |
|
|
InputParam("final_latents", type_hint=Optional[torch.Tensor]), |
|
|
InputParam("num_videos_per_prompt", type_hint=int, default=1), |
|
|
InputParam("generator"), |
|
|
InputParam( |
|
|
"dtype", |
|
|
type_hint=torch.dtype, |
|
|
description="The dtype of the model inputs", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"latents", |
|
|
type_hint=torch.Tensor, |
|
|
description="The initial latents to use for the denoising process", |
|
|
), |
|
|
OutputParam( |
|
|
"init_latents", |
|
|
type_hint=torch.Tensor, |
|
|
description="The initial latents to use for the denoising process", |
|
|
), |
|
|
OutputParam( |
|
|
"final_latents", |
|
|
type_hint=torch.Tensor, |
|
|
), |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def check_inputs(components, block_state): |
|
|
if ( |
|
|
block_state.height is not None |
|
|
and block_state.height % components.vae_scale_factor_spatial != 0 |
|
|
) or ( |
|
|
block_state.width is not None |
|
|
and block_state.width % components.vae_scale_factor_spatial != 0 |
|
|
): |
|
|
raise ValueError( |
|
|
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def prepare_latents( |
|
|
components, |
|
|
batch_size: int, |
|
|
num_channels_latents: int = 16, |
|
|
height: int = 352, |
|
|
width: int = 640, |
|
|
num_blocks: int = 9, |
|
|
num_frames_per_block: int = 3, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
|
latents: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
if latents is not None: |
|
|
return latents.to(device=device, dtype=dtype) |
|
|
|
|
|
num_latent_frames = num_blocks * num_frames_per_block |
|
|
shape = ( |
|
|
batch_size, |
|
|
num_channels_latents, |
|
|
num_latent_frames, |
|
|
int(height) // components.vae_scale_factor_spatial, |
|
|
int(width) // components.vae_scale_factor_spatial, |
|
|
) |
|
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
|
raise ValueError( |
|
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
|
) |
|
|
|
|
|
latents = randn_tensor( |
|
|
shape, |
|
|
generator=generator, |
|
|
device=components.transformer.device, |
|
|
dtype=dtype, |
|
|
) |
|
|
return latents |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
block_state.height = block_state.height or components.default_height |
|
|
block_state.width = block_state.width or components.default_width |
|
|
block_state.device = components._execution_device |
|
|
block_state.num_channels_latents = components.num_channels_latents |
|
|
|
|
|
self.check_inputs(components, block_state) |
|
|
|
|
|
block_state.init_latents = self.prepare_latents( |
|
|
components, |
|
|
1, |
|
|
block_state.num_channels_latents, |
|
|
block_state.height, |
|
|
block_state.width, |
|
|
block_state.num_blocks, |
|
|
components.config.num_frames_per_block, |
|
|
components.transformer.dtype, |
|
|
block_state.device, |
|
|
block_state.generator, |
|
|
block_state.init_latents, |
|
|
) |
|
|
if block_state.final_latents is None: |
|
|
block_state.final_latents = torch.zeros_like( |
|
|
block_state.init_latents, device=components.transformer.device |
|
|
) |
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|
|
|
|
|
|
class WanRTStreamingExtractBlockLatentsStep(ModularPipelineBlocks): |
|
|
""" |
|
|
Extracts a single block of latents from the full video buffer for streaming generation. |
|
|
|
|
|
This block simply slices the final_latents buffer to get the current block's latents. |
|
|
The final_latents buffer should be created beforehand using WanRTStreamingPrepareAllLatents. |
|
|
""" |
|
|
|
|
|
model_name = "WanRT" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Extracts a single block from the full latent buffer for streaming generation. " |
|
|
"Slices final_latents based on block_idx to get current block's latents." |
|
|
) |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"final_latents", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="Full latent buffer [B, C, total_frames, H, W]", |
|
|
), |
|
|
InputParam( |
|
|
"init_latents", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="Full latent buffer [B, C, total_frames, H, W]", |
|
|
), |
|
|
InputParam( |
|
|
"latents", |
|
|
type_hint=torch.Tensor, |
|
|
description="Full latent buffer [B, C, total_frames, H, W]", |
|
|
), |
|
|
InputParam( |
|
|
"block_idx", |
|
|
required=True, |
|
|
type_hint=int, |
|
|
default=0, |
|
|
description="Current block index to process", |
|
|
), |
|
|
InputParam( |
|
|
"num_frames_per_block", |
|
|
required=True, |
|
|
type_hint=int, |
|
|
default=3, |
|
|
description="Number of frames per block", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"latents", |
|
|
type_hint=torch.Tensor, |
|
|
description="Latents for current block [B, C, num_frames_per_block, H, W]", |
|
|
), |
|
|
OutputParam( |
|
|
"current_start_frame", |
|
|
type_hint=int, |
|
|
description="Starting frame index for current block", |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
num_frames_per_block = block_state.num_frames_per_block |
|
|
block_idx = block_state.block_idx |
|
|
|
|
|
|
|
|
start_frame = block_idx * num_frames_per_block |
|
|
end_frame = start_frame + num_frames_per_block |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block_state.latents = block_state.init_latents[ |
|
|
:, :, start_frame:end_frame, :, : |
|
|
] |
|
|
block_state.current_start_frame = start_frame |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|
|
|
|
|
|
class WanRTStreamingSetupKVCache(ModularPipelineBlocks): |
|
|
""" |
|
|
Initializes KV cache and cross-attention cache for streaming generation. |
|
|
|
|
|
This block sets up the persistent caches used across all blocks in streaming |
|
|
generation. Mirrors the cache initialization logic from causal_inference.py. |
|
|
Should be called once at the start of streaming generation. |
|
|
""" |
|
|
|
|
|
model_name = "WanRT" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("transformer", torch.nn.Module), |
|
|
] |
|
|
|
|
|
@property |
|
|
def expected_configs(self) -> List[ConfigSpec]: |
|
|
return [ |
|
|
ConfigSpec("kv_cache_num_frames", 3), |
|
|
ConfigSpec("num_frames_per_block", 3), |
|
|
ConfigSpec("frame_seq_length", 1560), |
|
|
ConfigSpec("frame_cache_len", 9), |
|
|
] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Initializes KV cache and cross-attention cache for streaming generation. " |
|
|
"Creates persistent caches that will be reused across all blocks." |
|
|
) |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"kv_cache", |
|
|
required=False, |
|
|
type_hint=Optional[List[Dict]], |
|
|
description="Existing KV cache. If provided and shape matches, will be zeroed instead of recreated.", |
|
|
), |
|
|
InputParam( |
|
|
"crossattn_cache", |
|
|
required=False, |
|
|
type_hint=Optional[List[Dict]], |
|
|
description="Existing cross-attention cache. If provided and shape matches, will be zeroed.", |
|
|
), |
|
|
InputParam( |
|
|
"local_attn_size", |
|
|
required=False, |
|
|
type_hint=int, |
|
|
default=-1, |
|
|
description="Local attention size for computing KV cache size. -1 uses default (32760).", |
|
|
), |
|
|
InputParam( |
|
|
"dtype", |
|
|
required=False, |
|
|
type_hint=torch.dtype, |
|
|
description="Data type for caches (defaults to bfloat16)", |
|
|
), |
|
|
InputParam( |
|
|
"update_prompt_embeds", |
|
|
required=False, |
|
|
description="Flag to reinitialize prompt embeds if they are updated.", |
|
|
default=False, |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"kv_cache", |
|
|
type_hint=List[Dict], |
|
|
description="Initialized KV cache (list of dicts per transformer block)", |
|
|
), |
|
|
OutputParam( |
|
|
"crossattn_cache", |
|
|
type_hint=List[Dict], |
|
|
description="Initialized cross-attention cache", |
|
|
), |
|
|
OutputParam( |
|
|
"local_attn_size", |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
batch_size = 1 |
|
|
|
|
|
|
|
|
kv_cache = block_state.kv_cache |
|
|
crossattn_cache = block_state.crossattn_cache |
|
|
|
|
|
if block_state.crossattn_cache is None or block_state.update_prompt_embeds: |
|
|
block_state.crossattn_cache = _initialize_crossattn_cache( |
|
|
components, |
|
|
crossattn_cache, |
|
|
batch_size, |
|
|
components.transformer.dtype, |
|
|
components.transformer.device, |
|
|
) |
|
|
|
|
|
block_state.local_attn_size = ( |
|
|
components.config.kv_cache_num_frames |
|
|
+ components.config.num_frames_per_block |
|
|
) |
|
|
for block in components.transformer.blocks: |
|
|
block.self_attn.local_attn_size = -1 |
|
|
for block in components.transformer.blocks: |
|
|
block.self_attn.num_frame_per_block = components.config.num_frames_per_block |
|
|
|
|
|
block_state.kv_cache = _initialize_kv_cache( |
|
|
components, |
|
|
kv_cache, |
|
|
batch_size, |
|
|
components.transformer.dtype, |
|
|
components.transformer.device, |
|
|
block_state.local_attn_size, |
|
|
components.config.frame_seq_length, |
|
|
) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|
|
|
|
|
|
class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks): |
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"latents", |
|
|
type_hint=torch.Tensor, |
|
|
description="Current block latents [B, C, num_frames_per_block, H, W]", |
|
|
), |
|
|
InputParam( |
|
|
"num_frames_per_block", |
|
|
type_hint=int, |
|
|
description="Number of frames per block", |
|
|
), |
|
|
InputParam( |
|
|
"block_idx", |
|
|
type_hint=int, |
|
|
description="Current block index to process", |
|
|
), |
|
|
InputParam( |
|
|
"block_mask", |
|
|
description="Block-wise causal attention mask", |
|
|
), |
|
|
InputParam( |
|
|
"current_start_frame", |
|
|
type_hint=int, |
|
|
description="Starting frame index for current block", |
|
|
), |
|
|
InputParam( |
|
|
"videos", |
|
|
type_hint=torch.Tensor, |
|
|
description="Video frames for context encoding", |
|
|
), |
|
|
InputParam( |
|
|
"final_latents", |
|
|
type_hint=torch.Tensor, |
|
|
description="Full latent buffer [B, C, total_frames, H, W]", |
|
|
), |
|
|
InputParam( |
|
|
"prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="Text embeddings to guide generation", |
|
|
), |
|
|
InputParam( |
|
|
"kv_cache", |
|
|
type_hint=torch.Tensor, |
|
|
description="Key-value cache for attention", |
|
|
), |
|
|
InputParam( |
|
|
"crossattn_cache", |
|
|
type_hint=torch.Tensor, |
|
|
description="Cross-attention cache", |
|
|
), |
|
|
InputParam( |
|
|
"encoder_cache", |
|
|
description="Encoder feature cache", |
|
|
), |
|
|
InputParam( |
|
|
"frame_cache_context", |
|
|
description="Cached context frames for reencoding", |
|
|
), |
|
|
InputParam( |
|
|
"local_attn_size", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def expected_configs(self) -> List[ConfigSpec]: |
|
|
return [ConfigSpec("seq_length", 32760)] |
|
|
|
|
|
def prepare_latents(self, components, block_state): |
|
|
frames = block_state.frame_cache_context[0].half() |
|
|
|
|
|
components.vae._enc_feat_map = [None] * 55 |
|
|
latents = retrieve_latents(components.vae.encode(frames), sample_mode="argmax") |
|
|
latents_mean = ( |
|
|
torch.tensor(components.vae.config.latents_mean) |
|
|
.view(1, components.vae.config.z_dim, 1, 1, 1) |
|
|
.to(latents.device, latents.dtype) |
|
|
) |
|
|
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( |
|
|
1, components.vae.config.z_dim, 1, 1, 1 |
|
|
).to(latents.device, latents.dtype) |
|
|
latents = (latents - latents_mean) * latents_std |
|
|
|
|
|
return latents.to(components.transformer.dtype) |
|
|
|
|
|
def get_context_frames(self, components, block_state): |
|
|
current_kv_cache_num_frames = components.config.kv_cache_num_frames |
|
|
context_frames = block_state.final_latents[ |
|
|
:, :, : block_state.current_start_frame |
|
|
] |
|
|
|
|
|
if ( |
|
|
block_state.block_idx - 1 |
|
|
) * block_state.num_frames_per_block < current_kv_cache_num_frames: |
|
|
if current_kv_cache_num_frames == 1: |
|
|
context_frames = context_frames[:, :, :1] |
|
|
else: |
|
|
context_frames = torch.cat( |
|
|
( |
|
|
context_frames[:, :, :1], |
|
|
context_frames[:, :, 1:][ |
|
|
:, :, -current_kv_cache_num_frames + 1 : |
|
|
], |
|
|
), |
|
|
dim=2, |
|
|
) |
|
|
else: |
|
|
context_frames = context_frames[:, :, 1:][ |
|
|
:, :, -current_kv_cache_num_frames + 1 : |
|
|
] |
|
|
first_frame_latent = self.prepare_latents(components, block_state) |
|
|
first_frame_latent = first_frame_latent.to(block_state.latents) |
|
|
context_frames = torch.cat((first_frame_latent, context_frames), dim=2) |
|
|
|
|
|
return context_frames |
|
|
|
|
|
def __call__(self, components, state): |
|
|
block_state = self.get_block_state(state) |
|
|
if block_state.block_idx == 0: |
|
|
return components, state |
|
|
|
|
|
start_frame = min( |
|
|
block_state.current_start_frame, components.config.kv_cache_num_frames |
|
|
) |
|
|
context_frames = self.get_context_frames(components, block_state) |
|
|
block_state.block_mask = ( |
|
|
components.transformer._prepare_blockwise_causal_attn_mask( |
|
|
components.transformer.device, |
|
|
num_frames=context_frames.shape[2], |
|
|
frame_seqlen=components.config.frame_seq_length, |
|
|
num_frame_per_block=block_state.num_frames_per_block, |
|
|
local_attn_size=-1, |
|
|
) |
|
|
) |
|
|
components.transformer.block_mask = block_state.block_mask |
|
|
context_timestep = torch.zeros( |
|
|
(context_frames.shape[0], context_frames.shape[2]), |
|
|
device=components.transformer.device, |
|
|
dtype=torch.int64, |
|
|
) |
|
|
components.transformer( |
|
|
x=context_frames.to(components.transformer.dtype), |
|
|
t=context_timestep, |
|
|
context=block_state.prompt_embeds.to(components.transformer.dtype), |
|
|
kv_cache=block_state.kv_cache, |
|
|
seq_len=components.config.seq_length, |
|
|
crossattn_cache=block_state.crossattn_cache, |
|
|
current_start=start_frame * components.config.frame_seq_length, |
|
|
cache_start=None, |
|
|
) |
|
|
components.transformer.block_mask = None |
|
|
|
|
|
return components, state |
|
|
|
|
|
|
|
|
class WanRTStreamingBeforeDenoiseStep(SequentialPipelineBlocks): |
|
|
block_classes = [ |
|
|
WanRTStreamingSetTimestepsStep, |
|
|
WanRTStreamingPrepareLatentsStep, |
|
|
WanRTStreamingExtractBlockLatentsStep, |
|
|
WanRTStreamingSetupKVCache, |
|
|
WanRTStreamingRecomputeKVCache, |
|
|
] |
|
|
block_names = [ |
|
|
"set_timesteps", |
|
|
"prepare_latents", |
|
|
"extract_block_init_latents", |
|
|
"setup_kv_cache", |
|
|
"recompute_kv_cache", |
|
|
] |
|
|
|
|
|
@property |
|
|
def description(self): |
|
|
return ( |
|
|
"Before denoise step that prepare the inputs for the denoise step.\n" |
|
|
+ "This is a sequential pipeline blocks:\n" |
|
|
+ " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n" |
|
|
+ " - `WanRTSetTimestepsStep` is used to set the timesteps\n" |
|
|
+ " - `WanRTPrepareLatentsStep` is used to prepare the latents\n" |
|
|
) |
|
|
|