|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import html |
|
|
from typing import List, Optional, Union |
|
|
|
|
|
import regex as re |
|
|
import torch |
|
|
from transformers import AutoTokenizer, UMT5EncoderModel |
|
|
|
|
|
from diffusers.configuration_utils import FrozenDict |
|
|
from diffusers.guiders import ClassifierFreeGuidance |
|
|
from diffusers.utils import is_ftfy_available, logging |
|
|
from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import ( |
|
|
ComponentSpec, |
|
|
ConfigSpec, |
|
|
InputParam, |
|
|
OutputParam, |
|
|
) |
|
|
from diffusers.modular_pipelines import WanModularPipeline |
|
|
|
|
|
|
|
|
if is_ftfy_available(): |
|
|
import ftfy |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def basic_clean(text): |
|
|
text = ftfy.fix_text(text) |
|
|
text = html.unescape(html.unescape(text)) |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
def whitespace_clean(text): |
|
|
text = re.sub(r"\s+", " ", text) |
|
|
text = text.strip() |
|
|
return text |
|
|
|
|
|
|
|
|
def prompt_clean(text): |
|
|
text = whitespace_clean(basic_clean(text)) |
|
|
return text |
|
|
|
|
|
|
|
|
class WanRTStreamingTextEncoderStep(ModularPipelineBlocks): |
|
|
model_name = "WanRTStreaming" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Text Encoder step that generate text_embeddings to guide the video generation" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("text_encoder", UMT5EncoderModel), |
|
|
ComponentSpec("tokenizer", AutoTokenizer), |
|
|
ComponentSpec( |
|
|
"guider", |
|
|
ClassifierFreeGuidance, |
|
|
config=FrozenDict({"guidance_scale": 5.0}), |
|
|
default_creation_method="from_config", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def expected_configs(self) -> List[ConfigSpec]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("prompt"), |
|
|
InputParam("negative_prompt"), |
|
|
InputParam( |
|
|
"prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="text embeddings used to guide the image generation", |
|
|
), |
|
|
InputParam( |
|
|
"negative_prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="negative text embeddings used to guide the image generation", |
|
|
), |
|
|
InputParam("attention_kwargs"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="text embeddings used to guide the image generation", |
|
|
), |
|
|
OutputParam( |
|
|
"negative_prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="negative text embeddings used to guide the image generation", |
|
|
), |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def check_inputs(block_state): |
|
|
if block_state.prompt is not None and ( |
|
|
not isinstance(block_state.prompt, str) |
|
|
and not isinstance(block_state.prompt, list) |
|
|
): |
|
|
raise ValueError( |
|
|
f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}" |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _get_t5_prompt_embeds( |
|
|
components, |
|
|
prompt: Union[str, List[str]], |
|
|
max_sequence_length: int, |
|
|
device: torch.device, |
|
|
): |
|
|
dtype = components.text_encoder.dtype |
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
prompt = [prompt_clean(u) for u in prompt] |
|
|
|
|
|
text_inputs = components.tokenizer( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=max_sequence_length, |
|
|
truncation=True, |
|
|
add_special_tokens=True, |
|
|
return_attention_mask=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask |
|
|
seq_lens = mask.gt(0).sum(dim=1).long() |
|
|
prompt_embeds = components.text_encoder( |
|
|
text_input_ids.to(device), mask.to(device) |
|
|
).last_hidden_state |
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype) |
|
|
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] |
|
|
prompt_embeds = torch.stack( |
|
|
[ |
|
|
torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) |
|
|
for u in prompt_embeds |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
return prompt_embeds |
|
|
|
|
|
@staticmethod |
|
|
def encode_prompt( |
|
|
components, |
|
|
prompt: str, |
|
|
device: Optional[torch.device] = None, |
|
|
num_videos_per_prompt: int = 1, |
|
|
prepare_unconditional_embeds: bool = True, |
|
|
negative_prompt: Optional[str] = None, |
|
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None, |
|
|
max_sequence_length: int = 512, |
|
|
): |
|
|
r""" |
|
|
Encodes the prompt into text encoder hidden states. |
|
|
|
|
|
Args: |
|
|
prompt (`str` or `List[str]`, *optional*): |
|
|
prompt to be encoded |
|
|
device: (`torch.device`): |
|
|
torch device |
|
|
num_videos_per_prompt (`int`): |
|
|
number of videos that should be generated per prompt |
|
|
prepare_unconditional_embeds (`bool`): |
|
|
whether to use prepare unconditional embeddings or not |
|
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
|
|
less than `1`). |
|
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
|
provided, text embeddings will be generated from `prompt` input argument. |
|
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
|
argument. |
|
|
max_sequence_length (`int`, defaults to `512`): |
|
|
The maximum number of text tokens to be used for the generation process. |
|
|
""" |
|
|
device = device or components._execution_device |
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] |
|
|
|
|
|
if prompt_embeds is None: |
|
|
prompt_embeds = WanRTStreamingTextEncoderStep._get_t5_prompt_embeds( |
|
|
components, prompt, max_sequence_length, device |
|
|
) |
|
|
|
|
|
if prepare_unconditional_embeds and negative_prompt_embeds is None: |
|
|
negative_prompt = negative_prompt or "" |
|
|
negative_prompt = ( |
|
|
batch_size * [negative_prompt] |
|
|
if isinstance(negative_prompt, str) |
|
|
else negative_prompt |
|
|
) |
|
|
|
|
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
|
raise TypeError( |
|
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
|
f" {type(prompt)}." |
|
|
) |
|
|
elif batch_size != len(negative_prompt): |
|
|
raise ValueError( |
|
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
|
" the batch size of `prompt`." |
|
|
) |
|
|
|
|
|
negative_prompt_embeds = ( |
|
|
WanRTStreamingTextEncoderStep._get_t5_prompt_embeds( |
|
|
components, negative_prompt, max_sequence_length, device |
|
|
) |
|
|
) |
|
|
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape |
|
|
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) |
|
|
prompt_embeds = prompt_embeds.view( |
|
|
bs_embed * num_videos_per_prompt, seq_len, -1 |
|
|
) |
|
|
|
|
|
if prepare_unconditional_embeds: |
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat( |
|
|
1, num_videos_per_prompt, 1 |
|
|
) |
|
|
negative_prompt_embeds = negative_prompt_embeds.view( |
|
|
batch_size * num_videos_per_prompt, seq_len, -1 |
|
|
) |
|
|
|
|
|
return prompt_embeds, negative_prompt_embeds |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: WanModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
|
|
|
block_state = self.get_block_state(state) |
|
|
self.check_inputs(block_state) |
|
|
|
|
|
block_state.prepare_unconditional_embeds = False |
|
|
block_state.device = components._execution_device |
|
|
|
|
|
|
|
|
( |
|
|
block_state.prompt_embeds, |
|
|
block_state.negative_prompt_embeds, |
|
|
) = WanRTStreamingTextEncoderStep.encode_prompt( |
|
|
components, |
|
|
block_state.prompt, |
|
|
block_state.device, |
|
|
1, |
|
|
block_state.prepare_unconditional_embeds, |
|
|
block_state.negative_prompt, |
|
|
prompt_embeds=block_state.prompt_embeds, |
|
|
negative_prompt_embeds=block_state.negative_prompt_embeds, |
|
|
) |
|
|
|
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|