krea-realtime-video / decoders.py
viccpoes's picture
Add diffusers (#1)
5b1c701 verified
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple, Union
import numpy as np
import PIL
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKLWan
from diffusers.utils import logging
from diffusers.video_processor import VideoProcessor
from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
InputParam,
OutputParam,
)
import types
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanRTDecodeStep(ModularPipelineBlocks):
model_name = "WanRT"
decoder_cache = []
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"vae",
AutoencoderKLWan,
repo="Wan-AI/Wan2.1-T2V-14B-Diffusers",
subfolder="vae",
),
ComponentSpec(
"video_processor",
VideoProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
),
InputParam(
"frame_cache_context",
description="The denoised latents from the denoising step",
),
InputParam(
"block_idx",
description="The denoised latents from the denoising step",
),
InputParam(
"decoder_cache",
description="The denoised latents from the denoising step",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"videos",
type_hint=Union[
List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]
],
description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae_dtype = components.vae.dtype
# Disable clearing cache
if block_state.block_idx == 0:
components.vae.clear_cache()
components.vae.clear_cache = lambda: None
components.vae._feat_map = [None] * 55
if block_state.block_idx != 0:
components.vae._feat_map = block_state.decoder_cache
if not block_state.output_type == "latent":
latents = block_state.latents.to(components.vae.device)
# Create tensors directly on target device and dtype to avoid redundant conversions
latents_mean = torch.tensor(
components.vae.config.latents_mean,
device=latents.device,
dtype=latents.dtype,
).view(1, components.vae.config.z_dim, 1, 1, 1)
latents_std = 1.0 / torch.tensor(
components.vae.config.latents_std,
device=latents.device,
dtype=latents.dtype,
).view(1, components.vae.config.z_dim, 1, 1, 1)
latents = latents / latents_std + latents_mean
latents = latents.to(vae_dtype)
videos = components.vae.decode(latents, return_dict=False)[0]
else:
block_state.videos = block_state.latents
block_state.decoder_cache = components.vae._feat_map
block_state.frame_cache_context.extend(videos.split(1, dim=2))
videos = components.video_processor.postprocess_video(
videos, output_type=block_state.output_type
)
block_state.videos = videos
self.set_block_state(state, block_state)
return components, state