# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import html from typing import List, Optional, Union import regex as re import torch from transformers import AutoTokenizer, UMT5EncoderModel from diffusers.configuration_utils import FrozenDict from diffusers.guiders import ClassifierFreeGuidance from diffusers.utils import is_ftfy_available, logging from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, ConfigSpec, InputParam, OutputParam, ) from diffusers.modular_pipelines import WanModularPipeline if is_ftfy_available(): import ftfy logger = logging.get_logger(__name__) # pylint: disable=invalid-name def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r"\s+", " ", text) text = text.strip() return text def prompt_clean(text): text = whitespace_clean(basic_clean(text)) return text class WanRTStreamingTextEncoderStep(ModularPipelineBlocks): model_name = "WanRTStreaming" @property def description(self) -> str: return "Text Encoder step that generate text_embeddings to guide the video generation" @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("text_encoder", UMT5EncoderModel), ComponentSpec("tokenizer", AutoTokenizer), ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 5.0}), default_creation_method="from_config", ), ] @property def expected_configs(self) -> List[ConfigSpec]: return [] @property def inputs(self) -> List[InputParam]: return [ InputParam("prompt"), InputParam("negative_prompt"), InputParam( "prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation", ), InputParam( "negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation", ), InputParam("attention_kwargs"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields", description="text embeddings used to guide the image generation", ), OutputParam( "negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields", description="negative text embeddings used to guide the image generation", ), ] @staticmethod def check_inputs(block_state): if block_state.prompt is not None and ( not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) ): raise ValueError( f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}" ) @staticmethod def _get_t5_prompt_embeds( components, prompt: Union[str, List[str]], max_sequence_length: int, device: torch.device, ): dtype = components.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt_clean(u) for u in prompt] text_inputs = components.tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, return_tensors="pt", ) text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() prompt_embeds = components.text_encoder( text_input_ids.to(device), mask.to(device) ).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( [ torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds ], dim=0, ) return prompt_embeds @staticmethod def encode_prompt( components, prompt: str, device: Optional[torch.device] = None, num_videos_per_prompt: int = 1, prepare_unconditional_embeds: bool = True, negative_prompt: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 512, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_videos_per_prompt (`int`): number of videos that should be generated per prompt prepare_unconditional_embeds (`bool`): whether to use prepare unconditional embeddings or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. max_sequence_length (`int`, defaults to `512`): The maximum number of text tokens to be used for the generation process. """ device = device or components._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] if prompt_embeds is None: prompt_embeds = WanRTStreamingTextEncoderStep._get_t5_prompt_embeds( components, prompt, max_sequence_length, device ) if prepare_unconditional_embeds and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt = ( batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embeds = ( WanRTStreamingTextEncoderStep._get_t5_prompt_embeds( components, negative_prompt, max_sequence_length, device ) ) bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view( bs_embed * num_videos_per_prompt, seq_len, -1 ) if prepare_unconditional_embeds: negative_prompt_embeds = negative_prompt_embeds.repeat( 1, num_videos_per_prompt, 1 ) negative_prompt_embeds = negative_prompt_embeds.view( batch_size * num_videos_per_prompt, seq_len, -1 ) return prompt_embeds, negative_prompt_embeds @torch.no_grad() def __call__( self, components: WanModularPipeline, state: PipelineState ) -> PipelineState: # Get inputs and intermediates block_state = self.get_block_state(state) self.check_inputs(block_state) block_state.prepare_unconditional_embeds = False block_state.device = components._execution_device # Encode input prompt ( block_state.prompt_embeds, block_state.negative_prompt_embeds, ) = WanRTStreamingTextEncoderStep.encode_prompt( components, block_state.prompt, block_state.device, 1, block_state.prepare_unconditional_embeds, block_state.negative_prompt, prompt_embeds=block_state.prompt_embeds, negative_prompt_embeds=block_state.negative_prompt_embeds, ) # Add outputs self.set_block_state(state, block_state) return components, state