# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List, Tuple, Union import numpy as np import PIL import torch from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKLWan from diffusers.utils import logging from diffusers.video_processor import VideoProcessor from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, InputParam, OutputParam, ) import types logger = logging.get_logger(__name__) # pylint: disable=invalid-name class WanRTDecodeStep(ModularPipelineBlocks): model_name = "WanRT" decoder_cache = [] @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( "vae", AutoencoderKLWan, repo="Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", ), ComponentSpec( "video_processor", VideoProcessor, config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config", ), ] @property def description(self) -> str: return "Step that decodes the denoised latents into images" @property def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("output_type", default="pil"), InputParam( "latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step", ), InputParam( "frame_cache_context", description="The denoised latents from the denoising step", ), InputParam( "block_idx", description="The denoised latents from the denoising step", ), InputParam( "decoder_cache", description="The denoised latents from the denoising step", ), ] @property def intermediate_outputs(self) -> List[str]: return [ OutputParam( "videos", type_hint=Union[ List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray] ], description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array", ) ] @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) vae_dtype = components.vae.dtype # Disable clearing cache if block_state.block_idx == 0: components.vae.clear_cache() components.vae.clear_cache = lambda: None components.vae._feat_map = [None] * 55 if block_state.block_idx != 0: components.vae._feat_map = block_state.decoder_cache if not block_state.output_type == "latent": latents = block_state.latents.to(components.vae.device) # Create tensors directly on target device and dtype to avoid redundant conversions latents_mean = torch.tensor( components.vae.config.latents_mean, device=latents.device, dtype=latents.dtype, ).view(1, components.vae.config.z_dim, 1, 1, 1) latents_std = 1.0 / torch.tensor( components.vae.config.latents_std, device=latents.device, dtype=latents.dtype, ).view(1, components.vae.config.z_dim, 1, 1, 1) latents = latents / latents_std + latents_mean latents = latents.to(vae_dtype) videos = components.vae.decode(latents, return_dict=False)[0] else: block_state.videos = block_state.latents block_state.decoder_cache = components.vae._feat_map block_state.frame_cache_context.extend(videos.split(1, dim=2)) videos = components.video_processor.postprocess_video( videos, output_type=block_state.output_type ) block_state.videos = videos self.set_block_state(state, block_state) return components, state