krea-realtime-video / encoders.py
viccpoes's picture
Add diffusers (#1)
5b1c701 verified
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
from typing import List, Optional, Union
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers.configuration_utils import FrozenDict
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.utils import is_ftfy_available, logging
from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
OutputParam,
)
from diffusers.modular_pipelines import WanModularPipeline
if is_ftfy_available():
import ftfy
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
model_name = "WanRTStreaming"
@property
def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", UMT5EncoderModel),
ComponentSpec("tokenizer", AutoTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("negative_prompt"),
InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
description="text embeddings used to guide the image generation",
),
InputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
description="negative text embeddings used to guide the image generation",
),
InputParam("attention_kwargs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="negative text embeddings used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (
not isinstance(block_state.prompt, str)
and not isinstance(block_state.prompt, list)
):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}"
)
@staticmethod
def _get_t5_prompt_embeds(
components,
prompt: Union[str, List[str]],
max_sequence_length: int,
device: torch.device,
):
dtype = components.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
text_inputs = components.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = components.text_encoder(
text_input_ids.to(device), mask.to(device)
).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[
torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))])
for u in prompt_embeds
],
dim=0,
)
return prompt_embeds
@staticmethod
def encode_prompt(
components,
prompt: str,
device: Optional[torch.device] = None,
num_videos_per_prompt: int = 1,
prepare_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_videos_per_prompt (`int`):
number of videos that should be generated per prompt
prepare_unconditional_embeds (`bool`):
whether to use prepare unconditional embeddings or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
max_sequence_length (`int`, defaults to `512`):
The maximum number of text tokens to be used for the generation process.
"""
device = device or components._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = WanRTStreamingTextEncoderStep._get_t5_prompt_embeds(
components, prompt, max_sequence_length, device
)
if prepare_unconditional_embeds and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = (
batch_size * [negative_prompt]
if isinstance(negative_prompt, str)
else negative_prompt
)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = (
WanRTStreamingTextEncoderStep._get_t5_prompt_embeds(
components, negative_prompt, max_sequence_length, device
)
)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_videos_per_prompt, seq_len, -1
)
if prepare_unconditional_embeds:
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_videos_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_videos_per_prompt, seq_len, -1
)
return prompt_embeds, negative_prompt_embeds
@torch.no_grad()
def __call__(
self, components: WanModularPipeline, state: PipelineState
) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.prepare_unconditional_embeds = False
block_state.device = components._execution_device
# Encode input prompt
(
block_state.prompt_embeds,
block_state.negative_prompt_embeds,
) = WanRTStreamingTextEncoderStep.encode_prompt(
components,
block_state.prompt,
block_state.device,
1,
block_state.prepare_unconditional_embeds,
block_state.negative_prompt,
prompt_embeds=block_state.prompt_embeds,
negative_prompt_embeds=block_state.negative_prompt_embeds,
)
# Add outputs
self.set_block_state(state, block_state)
return components, state