Spaces:
Runtime error
Runtime error
| from typing import Any, Optional, Union, List, Sequence | |
| import inspect | |
| import random | |
| from tqdm import tqdm | |
| import numpy as np | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.training_utils import compute_density_for_timestep_sampling | |
| from models.autoencoder.autoencoder_base import AutoEncoderBase | |
| from models.content_encoder.content_encoder import ContentEncoder | |
| from models.content_adapter import ContentAdapterBase | |
| from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase | |
| from utils.torch_utilities import ( | |
| create_alignment_path, create_mask_from_length, loss_with_mask, | |
| trim_or_pad_length | |
| ) | |
| from safetensors.torch import load_file | |
| class FlowMatchingMixin: | |
| def __init__( | |
| self, | |
| cfg_drop_ratio: float = 0.2, | |
| sample_strategy: str = 'normal', | |
| num_train_steps: int = 1000 | |
| ) -> None: | |
| r""" | |
| Args: | |
| cfg_drop_ratio (float): Dropout ratio for the autoencoder. | |
| sample_strategy (str): Sampling strategy for timesteps during training. | |
| num_train_steps (int): Number of training steps for the noise scheduler. | |
| """ | |
| self.sample_strategy = sample_strategy | |
| self.infer_noise_scheduler = FlowMatchEulerDiscreteScheduler( | |
| num_train_timesteps=num_train_steps | |
| ) | |
| self.train_noise_scheduler = copy.deepcopy(self.infer_noise_scheduler) | |
| self.classifier_free_guidance = cfg_drop_ratio > 0.0 | |
| self.cfg_drop_ratio = cfg_drop_ratio | |
| def get_input_target_and_timesteps( | |
| self, | |
| latent: torch.Tensor, | |
| training: bool = True | |
| ): | |
| bsz = latent.shape[0] | |
| noise = torch.randn_like(latent) | |
| if training: | |
| if self.sample_strategy == 'normal': | |
| u = compute_density_for_timestep_sampling( | |
| weighting_scheme="logit_normal", | |
| batch_size=bsz, | |
| logit_mean=0, | |
| logit_std=1, | |
| mode_scale=None, | |
| ) | |
| elif self.sample_strategy == 'uniform': | |
| u = torch.randn(bsz, ) | |
| else: | |
| raise NotImplementedError( | |
| f"{self.sample_strategy} samlping for timesteps is not supported now" | |
| ) | |
| else: | |
| u = torch.ones(bsz, ) / 2 | |
| indices = (u * self.train_noise_scheduler.config.num_train_timesteps | |
| ).long() | |
| # train_noise_scheduler.timesteps: a list from 1 ~ num_trainsteps with 1 as interval | |
| timesteps = self.train_noise_scheduler.timesteps[indices].to( | |
| device=latent.device | |
| ) | |
| sigmas = self.get_sigmas( | |
| timesteps, n_dim=latent.ndim, dtype=latent.dtype | |
| ) | |
| noisy_latent = (1.0 - sigmas) * latent + sigmas * noise | |
| target = noise - latent | |
| return noisy_latent, target, timesteps | |
| def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32): | |
| device = timesteps.device | |
| # a list from 1 declining to 1/num_train_steps | |
| sigmas = self.train_noise_scheduler.sigmas.to( | |
| device=device, dtype=dtype | |
| ) | |
| schedule_timesteps = self.train_noise_scheduler.timesteps.to(device) | |
| timesteps = timesteps.to(device) | |
| step_indices = [(schedule_timesteps == t).nonzero().item() | |
| for t in timesteps] | |
| sigma = sigmas[step_indices].flatten() | |
| while len(sigma.shape) < n_dim: | |
| sigma = sigma.unsqueeze(-1) | |
| return sigma | |
| def retrieve_timesteps( | |
| self, | |
| num_inference_steps: Optional[int] = None, | |
| device: Optional[Union[str, torch.device]] = None, | |
| timesteps: Optional[List[int]] = None, | |
| sigmas: Optional[List[float]] = None, | |
| **kwargs, | |
| ): | |
| # used in inference, retrieve new timesteps on given inference timesteps | |
| scheduler = self.infer_noise_scheduler | |
| if timesteps is not None and sigmas is not None: | |
| raise ValueError( | |
| "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" | |
| ) | |
| if timesteps is not None: | |
| accepts_timesteps = "timesteps" in set( | |
| inspect.signature(scheduler.set_timesteps).parameters.keys() | |
| ) | |
| if not accepts_timesteps: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" timestep schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps( | |
| timesteps=timesteps, device=device, **kwargs | |
| ) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| elif sigmas is not None: | |
| accept_sigmas = "sigmas" in set( | |
| inspect.signature(scheduler.set_timesteps).parameters.keys() | |
| ) | |
| if not accept_sigmas: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" sigmas schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| else: | |
| scheduler.set_timesteps( | |
| num_inference_steps, device=device, **kwargs | |
| ) | |
| timesteps = scheduler.timesteps | |
| return timesteps, num_inference_steps | |
| class ContentEncoderAdapterMixin: | |
| def __init__( | |
| self, | |
| content_encoder: ContentEncoder, | |
| content_adapter: ContentAdapterBase | None = None | |
| ): | |
| self.content_encoder = content_encoder | |
| self.content_adapter = content_adapter | |
| def encode_content( | |
| self, | |
| content: list[Any], | |
| task: list[str], | |
| device: str | torch.device, | |
| instruction: torch.Tensor | None = None, | |
| instruction_lengths: torch.Tensor | None = None | |
| ): | |
| content_output: dict[ | |
| str, torch.Tensor] = self.content_encoder.encode_content( | |
| content, task, device=device | |
| ) | |
| content, content_mask = content_output["content"], content_output[ | |
| "content_mask"] | |
| if instruction is not None: | |
| instruction_mask = create_mask_from_length(instruction_lengths) | |
| ( | |
| content, | |
| content_mask, | |
| global_duration_pred, | |
| local_duration_pred, | |
| ) = self.content_adapter( | |
| content, content_mask, instruction, instruction_mask | |
| ) | |
| return_dict = { | |
| "content": content, | |
| "content_mask": content_mask, | |
| "length_aligned_content": content_output["length_aligned_content"], | |
| } | |
| if instruction is not None: | |
| return_dict["global_duration_pred"] = global_duration_pred | |
| return_dict["local_duration_pred"] = local_duration_pred | |
| return return_dict | |
| class SingleTaskCrossAttentionAudioFlowMatching( | |
| LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase, | |
| FlowMatchingMixin, ContentEncoderAdapterMixin | |
| ): | |
| def __init__( | |
| self, | |
| autoencoder: nn.Module, | |
| content_encoder: ContentEncoder, | |
| backbone: nn.Module, | |
| cfg_drop_ratio: float = 0.2, | |
| sample_strategy: str = 'normal', | |
| num_train_steps: int = 1000, | |
| pretrained_ckpt: str | None = None, | |
| ): | |
| nn.Module.__init__(self) | |
| FlowMatchingMixin.__init__( | |
| self, cfg_drop_ratio, sample_strategy, num_train_steps | |
| ) | |
| ContentEncoderAdapterMixin.__init__( | |
| self, content_encoder=content_encoder | |
| ) | |
| self.autoencoder = autoencoder | |
| for param in self.autoencoder.parameters(): | |
| param.requires_grad = False | |
| if hasattr(self.content_encoder, "audio_encoder"): | |
| if self.content_encoder.audio_encoder is not None: | |
| self.content_encoder.audio_encoder.model = self.autoencoder | |
| self.backbone = backbone | |
| self.dummy_param = nn.Parameter(torch.empty(0)) | |
| if pretrained_ckpt is not None: | |
| print(f"Load pretrain FlowMatching model from {pretrained_ckpt}") | |
| pretrained_state_dict = load_file(pretrained_ckpt) | |
| self.load_pretrained(pretrained_state_dict) | |
| # missing, unexpected = self.load_state_dict(pretrained_state_dict, strict=False) | |
| # print("Missing keys:", missing) | |
| # print("Unexpected keys:", unexpected) | |
| # if content_encoder.embed_dim != 1024: | |
| # self.context_proj = nn.Sequential( | |
| # nn.Linear(content_encoder.embed_dim, 1024), | |
| # nn.SiLU(), | |
| # nn.Linear(1024, 1024), | |
| # ) | |
| # else: | |
| # self.context_proj = nn.Identity() | |
| def forward( | |
| self, content: list[Any], condition: list[Any], task: list[str], | |
| waveform: torch.Tensor, waveform_lengths: torch.Tensor, loss_reduce: bool = True, **kwargs | |
| ): | |
| loss_reduce = self.training or (loss_reduce and not self.training) | |
| device = self.dummy_param.device | |
| self.autoencoder.eval() | |
| with torch.no_grad(): | |
| latent, latent_mask = self.autoencoder.encode( | |
| waveform.unsqueeze(1), waveform_lengths | |
| ) | |
| content_dict = self.encode_content(content, task, device) | |
| content, content_mask = content_dict["content"], content_dict[ | |
| "content_mask"] | |
| # content = self.context_proj(content) | |
| if self.training and self.classifier_free_guidance: | |
| mask_indices = [ | |
| k for k in range(len(waveform)) | |
| if random.random() < self.cfg_drop_ratio | |
| ] | |
| if len(mask_indices) > 0: | |
| content[mask_indices] = 0 | |
| noisy_latent, target, timesteps = self.get_input_target_and_timesteps( | |
| latent, | |
| training = self.training | |
| ) | |
| pred: torch.Tensor = self.backbone( | |
| x=noisy_latent, | |
| timesteps=timesteps, | |
| context=content, | |
| x_mask=latent_mask, | |
| context_mask=content_mask | |
| ) | |
| diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none") | |
| diff_loss = loss_with_mask(diff_loss, latent_mask.unsqueeze(1), reduce=loss_reduce) | |
| #diff_loss = loss_with_mask(diff_loss, latent_mask.unsqueeze(1)) | |
| output = {"diff_loss": diff_loss} | |
| return output | |
| def iterative_denoise( | |
| self, latent: torch.Tensor, timesteps: list[int], num_steps: int, | |
| verbose: bool, cfg: bool, cfg_scale: float, backbone_input: dict | |
| ): | |
| progress_bar = tqdm(range(num_steps), disable=not verbose) | |
| for i, timestep in enumerate(timesteps): | |
| # expand the latent if we are doing classifier free guidance | |
| if cfg: | |
| latent_input = torch.cat([latent, latent]) | |
| else: | |
| latent_input = latent | |
| noise_pred: torch.Tensor = self.backbone( | |
| x=latent_input, timesteps=timestep, **backbone_input | |
| ) | |
| # perform guidance | |
| if cfg: | |
| noise_pred_uncond, noise_pred_content = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + cfg_scale * ( | |
| noise_pred_content - noise_pred_uncond | |
| ) | |
| latent = self.infer_noise_scheduler.step( | |
| noise_pred, timestep, latent | |
| ).prev_sample | |
| progress_bar.update(1) | |
| progress_bar.close() | |
| return latent | |
| def inference( | |
| self, | |
| content: list[Any], | |
| condition: list[Any], | |
| task: list[str], | |
| latent_shape: Sequence[int], | |
| num_steps: int = 50, | |
| sway_sampling_coef: float | None = -1.0, | |
| guidance_scale: float = 3.0, | |
| num_samples_per_content: int = 1, | |
| disable_progress: bool = True, | |
| **kwargs | |
| ): | |
| device = self.dummy_param.device | |
| classifier_free_guidance = guidance_scale > 1.0 | |
| batch_size = len(content) * num_samples_per_content | |
| if classifier_free_guidance: | |
| content, content_mask = self.encode_content_classifier_free( | |
| content, task, device, num_samples_per_content | |
| ) | |
| else: | |
| content_output: dict[ | |
| str, torch.Tensor] = self.content_encoder.encode_content( | |
| content, task | |
| ) | |
| content, content_mask = content_output["content"], content_output[ | |
| "content_mask"] | |
| content = content.repeat_interleave(num_samples_per_content, 0) | |
| content_mask = content_mask.repeat_interleave( | |
| num_samples_per_content, 0 | |
| ) | |
| latent = self.prepare_latent( | |
| batch_size, latent_shape, content.dtype, device | |
| ) | |
| if not sway_sampling_coef: | |
| sigmas = np.linspace(1.0, 1 / num_steps, num_steps) | |
| else: | |
| t = torch.linspace(0, 1, num_steps + 1) | |
| t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) | |
| sigmas = 1 - t | |
| timesteps, num_steps = self.retrieve_timesteps( | |
| num_steps, device, timesteps=None, sigmas=sigmas | |
| ) | |
| latent = self.iterative_denoise( | |
| latent=latent, | |
| timesteps=timesteps, | |
| num_steps=num_steps, | |
| verbose=not disable_progress, | |
| cfg=classifier_free_guidance, | |
| cfg_scale=guidance_scale, | |
| backbone_input={ | |
| "context": content, | |
| "context_mask": content_mask, | |
| }, | |
| ) | |
| waveform = self.autoencoder.decode(latent) | |
| return waveform | |
| def prepare_latent( | |
| self, batch_size: int, latent_shape: Sequence[int], dtype: torch.dtype, | |
| device: str | |
| ): | |
| shape = (batch_size, *latent_shape) | |
| latent = randn_tensor( | |
| shape, generator=None, device=device, dtype=dtype | |
| ) | |
| return latent | |
| def encode_content_classifier_free( | |
| self, | |
| content: list[Any], | |
| task: list[str], | |
| device, | |
| num_samples_per_content: int = 1 | |
| ): | |
| content_dict = self.content_encoder.encode_content( | |
| content, task, device | |
| ) | |
| content, content_mask = content_dict["content"], content_dict["content_mask"] | |
| # content, content_mask = self.content_encoder.encode_content( | |
| # content, task, device=device | |
| # ) | |
| content = content.repeat_interleave(num_samples_per_content, 0) | |
| content_mask = content_mask.repeat_interleave( | |
| num_samples_per_content, 0 | |
| ) | |
| # get unconditional embeddings for classifier free guidance | |
| uncond_content = torch.zeros_like(content) | |
| uncond_content_mask = content_mask.detach().clone() | |
| uncond_content = uncond_content.repeat_interleave( | |
| num_samples_per_content, 0 | |
| ) | |
| uncond_content_mask = uncond_content_mask.repeat_interleave( | |
| num_samples_per_content, 0 | |
| ) | |
| # For classifier free guidance, we need to do two forward passes. | |
| # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes | |
| content = torch.cat([uncond_content, content]) | |
| content_mask = torch.cat([uncond_content_mask, content_mask]) | |
| return content, content_mask | |
| class MultiContentAudioFlowMatching(SingleTaskCrossAttentionAudioFlowMatching): | |
| def __init__( | |
| self, | |
| autoencoder: AutoEncoderBase, | |
| content_encoder: ContentEncoder, | |
| backbone: nn.Module, | |
| cfg_drop_ratio: float = 0.2, | |
| sample_strategy: str = 'normal', | |
| num_train_steps: int = 1000, | |
| pretrained_ckpt: str | None = None, | |
| embed_dim: int = 1024, | |
| ): | |
| super().__init__( | |
| autoencoder=autoencoder, | |
| content_encoder=content_encoder, | |
| backbone=backbone, | |
| cfg_drop_ratio=cfg_drop_ratio, | |
| sample_strategy=sample_strategy, | |
| num_train_steps=num_train_steps, | |
| pretrained_ckpt=pretrained_ckpt, | |
| ) | |
| def forward( | |
| self, | |
| content: list[Any], | |
| duration: Sequence[float], | |
| task: list[str], | |
| waveform: torch.Tensor, | |
| waveform_lengths: torch.Tensor, | |
| loss_reduce: bool = True, | |
| **kwargs | |
| ): | |
| device = self.dummy_param.device | |
| loss_reduce = self.training or (loss_reduce and not self.training) | |
| self.autoencoder.eval() | |
| with torch.no_grad(): | |
| latent, latent_mask = self.autoencoder.encode( | |
| waveform.unsqueeze(1), waveform_lengths | |
| ) # latent [B, 128, 500/T=10s], latent_mask [B, 500/T=10s] | |
| content_dict = self.encode_content(content, task, device) | |
| context, context_mask, length_aligned_content = content_dict["content"], content_dict[ | |
| "content_mask"], content_dict["length_aligned_content"] | |
| # -------------------------------------------------------------------- | |
| # prepare latent and noise | |
| # -------------------------------------------------------------------- | |
| noisy_latent, target, timesteps = self.get_input_target_and_timesteps( | |
| latent, | |
| training = self.training | |
| ) | |
| # -------------------------------------------------------------------- | |
| # prepare input to the backbone | |
| # -------------------------------------------------------------------- | |
| # TODO compatility for 2D spectrogram VAE | |
| latent_length = noisy_latent.size(self.autoencoder.time_dim) | |
| time_aligned_content = trim_or_pad_length( | |
| length_aligned_content, latent_length, 1 | |
| ) | |
| # -------------------------------------------------------------------- | |
| # classifier free guidance | |
| # -------------------------------------------------------------------- | |
| if self.training and self.classifier_free_guidance: | |
| mask_indices = [ | |
| k for k in range(len(waveform)) | |
| if random.random() < self.cfg_drop_ratio | |
| ] | |
| if len(mask_indices) > 0: | |
| context[mask_indices] = 0 | |
| time_aligned_content[mask_indices] = 0 | |
| pred: torch.Tensor = self.backbone( | |
| x=noisy_latent, | |
| x_mask=latent_mask, | |
| timesteps=timesteps, | |
| context=context, | |
| context_mask=context_mask, | |
| time_aligned_context=time_aligned_content, | |
| ) | |
| pred = pred.transpose(1, self.autoencoder.time_dim) | |
| target = target.transpose(1, self.autoencoder.time_dim) | |
| diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none") | |
| diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce) | |
| return { | |
| "diff_loss": diff_loss, | |
| } | |
| def inference( | |
| self, | |
| content: list[Any], | |
| task: list[str], | |
| latent_shape: Sequence[int], | |
| num_steps: int = 50, | |
| sway_sampling_coef: float | None = -1.0, | |
| guidance_scale: float = 3.0, | |
| disable_progress: bool = True, | |
| **kwargs | |
| ): | |
| device = self.dummy_param.device | |
| classifier_free_guidance = guidance_scale > 1.0 | |
| batch_size = len(content) | |
| content_dict: dict[ | |
| str, torch.Tensor] = self.encode_content( | |
| content, task, device | |
| ) | |
| context, context_mask, length_aligned_content = \ | |
| content_dict["content"], content_dict[ | |
| "content_mask"], content_dict["length_aligned_content"] | |
| shape = (batch_size, *latent_shape) | |
| latent_length = shape[self.autoencoder.time_dim] | |
| time_aligned_content = trim_or_pad_length( | |
| length_aligned_content, latent_length, 1 | |
| ) | |
| # -------------------------------------------------------------------- | |
| # prepare unconditional input | |
| # -------------------------------------------------------------------- | |
| if classifier_free_guidance: | |
| uncond_time_aligned_content = torch.zeros_like( | |
| time_aligned_content | |
| ) | |
| uncond_context = torch.zeros_like(context) | |
| uncond_context_mask = context_mask.detach().clone() | |
| time_aligned_content = torch.cat([ | |
| uncond_time_aligned_content, time_aligned_content | |
| ]) | |
| context = torch.cat([uncond_context, context]) | |
| context_mask = torch.cat([uncond_context_mask, context_mask]) | |
| latent = randn_tensor( | |
| shape, generator=None, device=device, dtype=context.dtype | |
| ) | |
| if not sway_sampling_coef: | |
| sigmas = np.linspace(1.0, 1 / num_steps, num_steps) | |
| else: | |
| t = torch.linspace(0, 1, num_steps + 1) | |
| t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) | |
| sigmas = 1 - t | |
| timesteps, num_steps = self.retrieve_timesteps( | |
| num_steps, device, timesteps=None, sigmas=sigmas | |
| ) | |
| latent = self.iterative_denoise( | |
| latent=latent, | |
| timesteps=timesteps, | |
| num_steps=num_steps, | |
| verbose=not disable_progress, | |
| cfg=classifier_free_guidance, | |
| cfg_scale=guidance_scale, | |
| backbone_input={ | |
| "context": context, | |
| "context_mask": context_mask, | |
| "time_aligned_context": time_aligned_content, | |
| } | |
| ) | |
| waveform = self.autoencoder.decode(latent) | |
| return waveform | |
| class DurationAdapterMixin: | |
| def __init__( | |
| self, | |
| latent_token_rate: int, | |
| offset: float = 1.0, | |
| frame_resolution: float | None = None | |
| ): | |
| self.latent_token_rate = latent_token_rate | |
| self.offset = offset | |
| self.frame_resolution = frame_resolution | |
| def get_global_duration_loss( | |
| self, | |
| pred: torch.Tensor, | |
| latent_mask: torch.Tensor, | |
| reduce: bool = True, | |
| ): | |
| target = torch.log( | |
| latent_mask.sum(1) / self.latent_token_rate + self.offset | |
| ) | |
| loss = F.mse_loss(target, pred, reduction="mean" if reduce else "none") | |
| return loss | |
| def get_local_duration_loss( | |
| self, ground_truth: torch.Tensor, pred: torch.Tensor, | |
| mask: torch.Tensor, is_time_aligned: Sequence[bool], reduce: bool | |
| ): | |
| n_frames = torch.round(ground_truth / self.frame_resolution) | |
| target = torch.log(n_frames + self.offset) | |
| loss = loss_with_mask( | |
| (target - pred)**2, | |
| mask, | |
| reduce=False, | |
| ) | |
| loss *= is_time_aligned | |
| if reduce: | |
| if is_time_aligned.sum().item() == 0: | |
| loss *= 0.0 | |
| loss = loss.mean() | |
| else: | |
| loss = loss.sum() / is_time_aligned.sum() | |
| return loss | |
| def prepare_local_duration(self, pred: torch.Tensor, mask: torch.Tensor): | |
| pred = torch.exp(pred) * mask | |
| pred = torch.ceil(pred) - self.offset | |
| pred *= self.frame_resolution | |
| return pred | |
| def prepare_global_duration( | |
| self, | |
| global_pred: torch.Tensor, | |
| local_pred: torch.Tensor, | |
| is_time_aligned: Sequence[bool], | |
| use_local: bool = True, | |
| ): | |
| """ | |
| global_pred: predicted duration value, processed by logarithmic and offset | |
| local_pred: predicted latent length | |
| """ | |
| global_pred = torch.exp(global_pred) - self.offset | |
| result = global_pred | |
| # avoid error accumulation for each frame | |
| if use_local: | |
| pred_from_local = torch.round(local_pred * self.latent_token_rate) | |
| pred_from_local = pred_from_local.sum(1) / self.latent_token_rate | |
| result[is_time_aligned] = pred_from_local[is_time_aligned] | |
| return result | |
| def expand_by_duration( | |
| self, | |
| x: torch.Tensor, | |
| content_mask: torch.Tensor, | |
| local_duration: torch.Tensor, | |
| global_duration: torch.Tensor | None = None, | |
| ): | |
| n_latents = torch.round(local_duration * self.latent_token_rate) | |
| if global_duration is not None: | |
| latent_length = torch.round( | |
| global_duration * self.latent_token_rate | |
| ) | |
| else: | |
| latent_length = n_latents.sum(1) | |
| latent_mask = create_mask_from_length(latent_length).to( | |
| content_mask.device | |
| ) | |
| attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1) | |
| align_path = create_alignment_path(n_latents, attn_mask) | |
| expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x) | |
| return expanded_x, latent_mask | |
| class CrossAttentionAudioFlowMatching( | |
| SingleTaskCrossAttentionAudioFlowMatching, DurationAdapterMixin | |
| ): | |
| def __init__( | |
| self, | |
| autoencoder: AutoEncoderBase, | |
| content_encoder: ContentEncoder, | |
| content_adapter: ContentAdapterBase, | |
| backbone: nn.Module, | |
| content_dim: int, | |
| frame_resolution: float, | |
| duration_offset: float = 1.0, | |
| cfg_drop_ratio: float = 0.2, | |
| sample_strategy: str = 'normal', | |
| num_train_steps: int = 1000 | |
| ): | |
| super().__init__( | |
| autoencoder=autoencoder, | |
| content_encoder=content_encoder, | |
| backbone=backbone, | |
| cfg_drop_ratio=cfg_drop_ratio, | |
| sample_strategy=sample_strategy, | |
| num_train_steps=num_train_steps, | |
| ) | |
| ContentEncoderAdapterMixin.__init__( | |
| self, | |
| content_encoder=content_encoder, | |
| content_adapter=content_adapter | |
| ) | |
| DurationAdapterMixin.__init__( | |
| self, | |
| latent_token_rate=autoencoder.latent_token_rate, | |
| offset=duration_offset | |
| ) | |
| def encode_content_with_instruction( | |
| self, content: list[Any], task: list[str], device, | |
| instruction: torch.Tensor, instruction_lengths: torch.Tensor | |
| ): | |
| content_dict = self.encode_content( | |
| content, task, device, instruction, instruction_lengths | |
| ) | |
| return ( | |
| content_dict["content"], content_dict["content_mask"], | |
| content_dict["global_duration_pred"], | |
| content_dict["local_duration_pred"], | |
| content_dict["length_aligned_content"] | |
| ) | |
| def forward( | |
| self, | |
| content: list[Any], | |
| task: list[str], | |
| waveform: torch.Tensor, | |
| waveform_lengths: torch.Tensor, | |
| instruction: torch.Tensor, | |
| instruction_lengths: torch.Tensor, | |
| loss_reduce: bool = True, | |
| **kwargs | |
| ): | |
| device = self.dummy_param.device | |
| loss_reduce = self.training or (loss_reduce and not self.training) | |
| self.autoencoder.eval() | |
| with torch.no_grad(): | |
| latent, latent_mask = self.autoencoder.encode( | |
| waveform.unsqueeze(1), waveform_lengths | |
| ) | |
| content, content_mask, global_duration_pred, _, _ = \ | |
| self.encode_content_with_instruction( | |
| content, task, device, instruction, instruction_lengths | |
| ) | |
| global_duration_loss = self.get_global_duration_loss( | |
| global_duration_pred, latent_mask, reduce=loss_reduce | |
| ) | |
| if self.training and self.classifier_free_guidance: | |
| mask_indices = [ | |
| k for k in range(len(waveform)) | |
| if random.random() < self.cfg_drop_ratio | |
| ] | |
| if len(mask_indices) > 0: | |
| content[mask_indices] = 0 | |
| noisy_latent, target, timesteps = self.get_input_target_and_timesteps( | |
| latent, | |
| training = self.training | |
| ) | |
| pred: torch.Tensor = self.backbone( | |
| x=noisy_latent, | |
| timesteps=timesteps, | |
| context=content, | |
| x_mask=latent_mask, | |
| context_mask=content_mask, | |
| ) | |
| pred = pred.transpose(1, self.autoencoder.time_dim) | |
| target = target.transpose(1, self.autoencoder.time_dim) | |
| diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none") | |
| diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce) | |
| return { | |
| "diff_loss": diff_loss, | |
| "global_duration_loss": global_duration_loss, | |
| } | |
| def inference( | |
| self, | |
| content: list[Any], | |
| condition: list[Any], | |
| task: list[str], | |
| is_time_aligned: Sequence[bool], | |
| instruction: torch.Tensor, | |
| instruction_lengths: torch.Tensor, | |
| num_steps: int = 20, | |
| sway_sampling_coef: float | None = -1.0, | |
| guidance_scale: float = 3.0, | |
| disable_progress=True, | |
| use_gt_duration: bool = False, | |
| **kwargs | |
| ): | |
| device = self.dummy_param.device | |
| classifier_free_guidance = guidance_scale > 1.0 | |
| ( | |
| content, | |
| content_mask, | |
| global_duration_pred, | |
| local_duration_pred, | |
| _, | |
| ) = self.encode_content_with_instruction( | |
| content, task, device, instruction, instruction_lengths | |
| ) | |
| batch_size = content.size(0) | |
| if use_gt_duration: | |
| raise NotImplementedError( | |
| "Using ground truth global duration only is not implemented yet" | |
| ) | |
| # prepare global duration | |
| global_duration = self.prepare_global_duration( | |
| global_duration_pred, | |
| local_duration_pred, | |
| is_time_aligned, | |
| use_local=False | |
| ) | |
| latent_length = torch.round(global_duration * self.latent_token_rate) | |
| latent_mask = create_mask_from_length(latent_length).to(device) | |
| max_latent_length = latent_mask.sum(1).max().item() | |
| # prepare latent and noise | |
| if classifier_free_guidance: | |
| uncond_context = torch.zeros_like(content) | |
| uncond_content_mask = content_mask.detach().clone() | |
| context = torch.cat([uncond_context, content]) | |
| context_mask = torch.cat([uncond_content_mask, content_mask]) | |
| else: | |
| context = content | |
| context_mask = content_mask | |
| latent_shape = tuple( | |
| max_latent_length if dim is None else dim | |
| for dim in self.autoencoder.latent_shape | |
| ) | |
| shape = (batch_size, *latent_shape) | |
| latent = randn_tensor( | |
| shape, generator=None, device=device, dtype=content.dtype | |
| ) | |
| if not sway_sampling_coef: | |
| sigmas = np.linspace(1.0, 1 / num_steps, num_steps) | |
| else: | |
| t = torch.linspace(0, 1, num_steps + 1) | |
| t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) | |
| sigmas = 1 - t | |
| timesteps, num_steps = self.retrieve_timesteps( | |
| num_steps, device, timesteps=None, sigmas=sigmas | |
| ) | |
| latent = self.iterative_denoise( | |
| latent=latent, | |
| timesteps=timesteps, | |
| num_steps=num_steps, | |
| verbose=not disable_progress, | |
| cfg=classifier_free_guidance, | |
| cfg_scale=guidance_scale, | |
| backbone_input={ | |
| "x_mask": latent_mask, | |
| "context": context, | |
| "context_mask": context_mask, | |
| } | |
| ) | |
| waveform = self.autoencoder.decode(latent) | |
| return waveform | |
| class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching): | |
| def __init__( | |
| self, | |
| autoencoder: AutoEncoderBase, | |
| content_encoder: ContentEncoder, | |
| content_adapter: ContentAdapterBase, | |
| backbone: nn.Module, | |
| content_dim: int, | |
| frame_resolution: float, | |
| duration_offset: float = 1.0, | |
| cfg_drop_ratio: float = 0.2, | |
| sample_strategy: str = 'normal', | |
| num_train_steps: int = 1000 | |
| ): | |
| super().__init__( | |
| autoencoder=autoencoder, | |
| content_encoder=content_encoder, | |
| content_adapter=content_adapter, | |
| backbone=backbone, | |
| content_dim=content_dim, | |
| frame_resolution=frame_resolution, | |
| duration_offset=duration_offset, | |
| cfg_drop_ratio=cfg_drop_ratio, | |
| sample_strategy=sample_strategy, | |
| num_train_steps=num_train_steps | |
| ) | |
| DurationAdapterMixin.__init__( | |
| self, | |
| latent_token_rate=autoencoder.latent_token_rate, | |
| offset=duration_offset, | |
| frame_resolution=frame_resolution | |
| ) | |
| self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim)) | |
| self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim)) | |
| def get_backbone_input( | |
| self, target_length: int, content: torch.Tensor, | |
| content_mask: torch.Tensor, time_aligned_content: torch.Tensor, | |
| length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor | |
| ): | |
| # TODO compatility for 2D spectrogram VAE | |
| time_aligned_content = trim_or_pad_length( | |
| time_aligned_content, target_length, 1 | |
| ) | |
| length_aligned_content = trim_or_pad_length( | |
| length_aligned_content, target_length, 1 | |
| ) | |
| # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme) | |
| # length_aligned_content: from aligned input (f0/energy) | |
| time_aligned_content = time_aligned_content + length_aligned_content | |
| time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( | |
| time_aligned_content.dtype | |
| ) | |
| context = content | |
| context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype) | |
| # only use the first dummy non time aligned embedding | |
| context_mask = content_mask.detach().clone() | |
| context_mask[is_time_aligned, 1:] = False | |
| # truncate dummy non time aligned context | |
| if is_time_aligned.sum().item() < content.size(0): | |
| trunc_nta_length = content_mask[~is_time_aligned].sum(1).max() | |
| else: | |
| trunc_nta_length = content.size(1) | |
| context = context[:, :trunc_nta_length] | |
| context_mask = context_mask[:, :trunc_nta_length] | |
| return context, context_mask, time_aligned_content | |
| def forward( | |
| self, | |
| content: list[Any], | |
| duration: Sequence[float], | |
| task: list[str], | |
| is_time_aligned: Sequence[bool], | |
| waveform: torch.Tensor, | |
| waveform_lengths: torch.Tensor, | |
| instruction: torch.Tensor, | |
| instruction_lengths: torch.Tensor, | |
| loss_reduce: bool = True, | |
| **kwargs | |
| ): | |
| device = self.dummy_param.device | |
| loss_reduce = self.training or (loss_reduce and not self.training) | |
| self.autoencoder.eval() | |
| with torch.no_grad(): | |
| latent, latent_mask = self.autoencoder.encode( | |
| waveform.unsqueeze(1), waveform_lengths | |
| ) | |
| ( | |
| content, content_mask, global_duration_pred, local_duration_pred, | |
| length_aligned_content | |
| ) = self.encode_content_with_instruction( | |
| content, task, device, instruction, instruction_lengths | |
| ) | |
| # truncate unused non time aligned duration prediction | |
| if is_time_aligned.sum() > 0: | |
| trunc_ta_length = content_mask[is_time_aligned].sum(1).max() | |
| else: | |
| trunc_ta_length = content.size(1) | |
| # duration loss | |
| local_duration_pred = local_duration_pred[:, :trunc_ta_length] | |
| ta_content_mask = content_mask[:, :trunc_ta_length] | |
| local_duration_loss = self.get_local_duration_loss( | |
| duration, | |
| local_duration_pred, | |
| ta_content_mask, | |
| is_time_aligned, | |
| reduce=loss_reduce | |
| ) | |
| global_duration_loss = self.get_global_duration_loss( | |
| global_duration_pred, latent_mask, reduce=loss_reduce | |
| ) | |
| # -------------------------------------------------------------------- | |
| # prepare latent and noise | |
| # -------------------------------------------------------------------- | |
| noisy_latent, target, timesteps = self.get_input_target_and_timesteps( | |
| latent, | |
| training = self.training | |
| ) | |
| # -------------------------------------------------------------------- | |
| # duration adapter | |
| # -------------------------------------------------------------------- | |
| if is_time_aligned.sum() == 0 and \ | |
| duration.size(1) < content_mask.size(1): | |
| duration = F.pad( | |
| duration, (0, content_mask.size(1) - duration.size(1)) | |
| ) | |
| time_aligned_content, _ = self.expand_by_duration( | |
| x=content[:, :trunc_ta_length], | |
| content_mask=ta_content_mask, | |
| local_duration=duration, | |
| ) | |
| # -------------------------------------------------------------------- | |
| # prepare input to the backbone | |
| # -------------------------------------------------------------------- | |
| # TODO compatility for 2D spectrogram VAE | |
| latent_length = noisy_latent.size(self.autoencoder.time_dim) | |
| context, context_mask, time_aligned_content = self.get_backbone_input( | |
| latent_length, content, content_mask, time_aligned_content, | |
| length_aligned_content, is_time_aligned | |
| ) | |
| # -------------------------------------------------------------------- | |
| # classifier free guidance | |
| # -------------------------------------------------------------------- | |
| if self.training and self.classifier_free_guidance: | |
| mask_indices = [ | |
| k for k in range(len(waveform)) | |
| if random.random() < self.cfg_drop_ratio | |
| ] | |
| if len(mask_indices) > 0: | |
| context[mask_indices] = 0 | |
| time_aligned_content[mask_indices] = 0 | |
| pred: torch.Tensor = self.backbone( | |
| x=noisy_latent, | |
| x_mask=latent_mask, | |
| timesteps=timesteps, | |
| context=context, | |
| context_mask=context_mask, | |
| time_aligned_context=time_aligned_content, | |
| ) | |
| pred = pred.transpose(1, self.autoencoder.time_dim) | |
| target = target.transpose(1, self.autoencoder.time_dim) | |
| diff_loss = F.mse_loss(pred, target, reduction="none") | |
| diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce) | |
| return { | |
| "diff_loss": diff_loss, | |
| "local_duration_loss": local_duration_loss, | |
| "global_duration_loss": global_duration_loss, | |
| } | |
| def inference( | |
| self, | |
| content: list[Any], | |
| task: list[str], | |
| is_time_aligned: Sequence[bool], | |
| instruction: torch.Tensor, | |
| instruction_lengths: Sequence[int], | |
| num_steps: int = 20, | |
| sway_sampling_coef: float | None = -1.0, | |
| guidance_scale: float = 3.0, | |
| disable_progress: bool = True, | |
| use_gt_duration: bool = False, | |
| **kwargs | |
| ): | |
| device = self.dummy_param.device | |
| classifier_free_guidance = guidance_scale > 1.0 | |
| ( | |
| content, content_mask, global_duration_pred, local_duration_pred, | |
| length_aligned_content | |
| ) = self.encode_content_with_instruction( | |
| content, task, device, instruction, instruction_lengths | |
| ) | |
| # print("content std: ", content.std()) | |
| batch_size = content.size(0) | |
| # truncate dummy time aligned duration prediction | |
| is_time_aligned = torch.as_tensor(is_time_aligned) | |
| if is_time_aligned.sum() > 0: | |
| trunc_ta_length = content_mask[is_time_aligned].sum(1).max() | |
| else: | |
| trunc_ta_length = content.size(1) | |
| # prepare local duration | |
| local_duration = self.prepare_local_duration( | |
| local_duration_pred, content_mask | |
| ) | |
| local_duration = local_duration[:, :trunc_ta_length] | |
| # use ground truth duration | |
| if use_gt_duration and "duration" in kwargs: | |
| local_duration = torch.as_tensor(kwargs["duration"]).to(device) | |
| # prepare global duration | |
| global_duration = self.prepare_global_duration( | |
| global_duration_pred, local_duration, is_time_aligned | |
| ) | |
| # -------------------------------------------------------------------- | |
| # duration adapter | |
| # -------------------------------------------------------------------- | |
| time_aligned_content, latent_mask = self.expand_by_duration( | |
| x=content[:, :trunc_ta_length], | |
| content_mask=content_mask[:, :trunc_ta_length], | |
| local_duration=local_duration, | |
| global_duration=global_duration, | |
| ) | |
| context, context_mask, time_aligned_content = self.get_backbone_input( | |
| target_length=time_aligned_content.size(1), | |
| content=content, | |
| content_mask=content_mask, | |
| time_aligned_content=time_aligned_content, | |
| length_aligned_content=length_aligned_content, | |
| is_time_aligned=is_time_aligned | |
| ) | |
| # -------------------------------------------------------------------- | |
| # prepare unconditional input | |
| # -------------------------------------------------------------------- | |
| if classifier_free_guidance: | |
| uncond_time_aligned_content = torch.zeros_like( | |
| time_aligned_content | |
| ) | |
| uncond_context = torch.zeros_like(context) | |
| uncond_context_mask = context_mask.detach().clone() | |
| time_aligned_content = torch.cat([ | |
| uncond_time_aligned_content, time_aligned_content | |
| ]) | |
| context = torch.cat([uncond_context, context]) | |
| context_mask = torch.cat([uncond_context_mask, context_mask]) | |
| latent_mask = torch.cat([ | |
| latent_mask, latent_mask.detach().clone() | |
| ]) | |
| # -------------------------------------------------------------------- | |
| # prepare input to the backbone | |
| # -------------------------------------------------------------------- | |
| latent_length = latent_mask.sum(1).max().item() | |
| latent_shape = tuple( | |
| latent_length if dim is None else dim | |
| for dim in self.autoencoder.latent_shape | |
| ) | |
| shape = (batch_size, *latent_shape) | |
| latent = randn_tensor( | |
| shape, generator=None, device=device, dtype=content.dtype | |
| ) | |
| if not sway_sampling_coef: | |
| sigmas = np.linspace(1.0, 1 / num_steps, num_steps) | |
| else: | |
| t = torch.linspace(0, 1, num_steps + 1) | |
| t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) | |
| sigmas = 1 - t | |
| timesteps, num_steps = self.retrieve_timesteps( | |
| num_steps, device, timesteps=None, sigmas=sigmas | |
| ) | |
| latent = self.iterative_denoise( | |
| latent=latent, | |
| timesteps=timesteps, | |
| num_steps=num_steps, | |
| verbose=not disable_progress, | |
| cfg=classifier_free_guidance, | |
| cfg_scale=guidance_scale, | |
| backbone_input={ | |
| "x_mask": latent_mask, | |
| "context": context, | |
| "context_mask": context_mask, | |
| "time_aligned_context": time_aligned_content, | |
| } | |
| ) | |
| waveform = self.autoencoder.decode(latent) | |
| return waveform | |
| class DoubleContentAudioFlowMatching(DummyContentAudioFlowMatching): | |
| def get_backbone_input( | |
| self, target_length: int, content: torch.Tensor, | |
| content_mask: torch.Tensor, time_aligned_content: torch.Tensor, | |
| length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor | |
| ): | |
| # TODO compatility for 2D spectrogram VAE | |
| time_aligned_content = trim_or_pad_length( | |
| time_aligned_content, target_length, 1 | |
| ) | |
| length_aligned_content = trim_or_pad_length( | |
| length_aligned_content, target_length, 1 | |
| ) | |
| # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme) | |
| # length_aligned_content: from aligned input (f0/energy) | |
| time_aligned_content = time_aligned_content + length_aligned_content | |
| context = content | |
| context_mask = content_mask.detach().clone() | |
| return context, context_mask, time_aligned_content | |
| class HybridContentAudioFlowMatching(DummyContentAudioFlowMatching): | |
| def get_backbone_input( | |
| self, target_length: int, content: torch.Tensor, | |
| content_mask: torch.Tensor, time_aligned_content: torch.Tensor, | |
| length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor | |
| ): | |
| # TODO compatility for 2D spectrogram VAE | |
| time_aligned_content = trim_or_pad_length( | |
| time_aligned_content, target_length, 1 | |
| ) | |
| length_aligned_content = trim_or_pad_length( | |
| length_aligned_content, target_length, 1 | |
| ) | |
| # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme) | |
| # length_aligned_content: from aligned input (f0/energy) | |
| time_aligned_content = time_aligned_content + length_aligned_content | |
| time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( | |
| time_aligned_content.dtype | |
| ) | |
| context = content | |
| context_mask = content_mask.detach().clone() | |
| return context, context_mask, time_aligned_content | |