from typing import Sequence import random from typing import Any from pathlib import Path from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F import diffusers.schedulers as noise_schedulers from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils.torch_utils import randn_tensor from models.autoencoder.autoencoder_base import AutoEncoderBase from models.content_encoder.content_encoder import ContentEncoder from models.content_adapter import ContentAdapterBase from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase from utils.torch_utilities import ( create_alignment_path, create_mask_from_length, loss_with_mask, trim_or_pad_length ) from safetensors.torch import load_file class DiffusionMixin: def __init__( self, noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", snr_gamma: float = None, cfg_drop_ratio: float = 0.2 ) -> None: self.noise_scheduler_name = noise_scheduler_name self.snr_gamma = snr_gamma self.classifier_free_guidance = cfg_drop_ratio > 0.0 self.cfg_drop_ratio = cfg_drop_ratio self.noise_scheduler = noise_schedulers.DDPMScheduler.from_pretrained( self.noise_scheduler_name, subfolder="scheduler" ) def compute_snr(self, timesteps) -> torch.Tensor: """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 """ alphas_cumprod = self.noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod)**0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device )[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( device=timesteps.device )[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma)**2 return snr def get_timesteps( self, batch_size: int, device: torch.device, training: bool = True ) -> torch.Tensor: if training: timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, (batch_size, ), device=device ) else: # validation on half of the total timesteps timesteps = (self.noise_scheduler.config.num_train_timesteps // 2) * torch.ones((batch_size, ), dtype=torch.int64, device=device) timesteps = timesteps.long() return timesteps def get_target( self, latent: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor ) -> torch.Tensor: """ Get the target for loss depending on the prediction type """ if self.noise_scheduler.config.prediction_type == "epsilon": target = noise elif self.noise_scheduler.config.prediction_type == "v_prediction": target = self.noise_scheduler.get_velocity( latent, noise, timesteps ) else: raise ValueError( f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" ) return target def loss_with_snr( self, pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, mask: torch.Tensor, loss_reduce: bool = True, ) -> torch.Tensor: if self.snr_gamma is None: loss = F.mse_loss(pred.float(), target.float(), reduction="none") loss = loss_with_mask(loss, mask, reduce=loss_reduce) else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Adapted from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L1006 snr = self.compute_snr(timesteps) mse_loss_weights = torch.stack( [ snr, self.snr_gamma * torch.ones_like(timesteps), ], dim=1, ).min(dim=1)[0] # division by (snr + 1) does not work well, not clear about the reason mse_loss_weights = mse_loss_weights / snr loss = F.mse_loss(pred.float(), target.float(), reduction="none") loss = loss_with_mask(loss, mask, reduce=False) * mse_loss_weights if loss_reduce: loss = loss.mean() return loss def rescale_cfg( self, pred_cond: torch.Tensor, pred_cfg: torch.Tensor, guidance_rescale: float ): """ Rescale `pred_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_cond = pred_cond.std( dim=list(range(1, pred_cond.ndim)), keepdim=True ) std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True) pred_rescaled = pred_cfg * (std_cond / std_cfg) pred_cfg = guidance_rescale * pred_rescaled + ( 1 - guidance_rescale ) * pred_cfg return pred_cfg class CrossAttentionAudioDiffusion( LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase, DiffusionMixin ): def __init__( self, autoencoder: AutoEncoderBase, content_encoder: ContentEncoder, content_adapter: ContentAdapterBase, backbone: nn.Module, duration_offset: float = 1.0, noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", snr_gamma: float = None, cfg_drop_ratio: float = 0.2, ): nn.Module.__init__(self) DiffusionMixin.__init__( self, noise_scheduler_name, snr_gamma, cfg_drop_ratio ) self.autoencoder = autoencoder for param in self.autoencoder.parameters(): param.requires_grad = False self.content_encoder = content_encoder self.content_encoder.audio_encoder.model = self.autoencoder self.content_adapter = content_adapter self.backbone = backbone self.duration_offset = duration_offset self.dummy_param = nn.Parameter(torch.empty(0)) def forward( self, content: list[Any], task: list[str], waveform: torch.Tensor, waveform_lengths: torch.Tensor, instruction: torch.Tensor, instruction_lengths: Sequence[int], **kwargs ): device = self.dummy_param.device num_train_timesteps = self.noise_scheduler.config.num_train_timesteps self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) self.autoencoder.eval() with torch.no_grad(): latent, latent_mask = self.autoencoder.encode( waveform.unsqueeze(1), waveform_lengths ) content_output: dict[ str, torch.Tensor] = self.content_encoder.encode_content( content, task, device=device ) content, content_mask = content_output["content"], content_output[ "content_mask"] instruction_mask = create_mask_from_length(instruction_lengths) content, content_mask, global_duration_pred, _ = \ self.content_adapter(content, content_mask, instruction, instruction_mask) global_duration_target = torch.log( latent_mask.sum(1) / self.autoencoder.latent_token_rate + self.duration_offset ) global_duration_loss = F.mse_loss( global_duration_target, global_duration_pred ) if self.training and self.classifier_free_guidance: mask_indices = [ k for k in range(len(waveform)) if random.random() < self.cfg_drop_ratio ] if len(mask_indices) > 0: content[mask_indices] = 0 batch_size = latent.shape[0] timesteps = self.get_timesteps(batch_size, device, self.training) noise = torch.randn_like(latent) noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) target = self.get_target(latent, noise, timesteps) pred: torch.Tensor = self.backbone( x=noisy_latent, timesteps=timesteps, context=content, x_mask=latent_mask, context_mask=content_mask ) pred = pred.transpose(1, self.autoencoder.time_dim) target = target.transpose(1, self.autoencoder.time_dim) diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask) return { "diff_loss": diff_loss, "global_duration_loss": global_duration_loss, } @torch.no_grad() def inference( self, content: list[Any], condition: list[Any], task: list[str], instruction: torch.Tensor, instruction_lengths: Sequence[int], scheduler: SchedulerMixin, num_steps: int = 20, guidance_scale: float = 3.0, guidance_rescale: float = 0.0, disable_progress: bool = True, **kwargs ): device = self.dummy_param.device classifier_free_guidance = guidance_scale > 1.0 content_output: dict[ str, torch.Tensor] = self.content_encoder.encode_content( content, task, device=device ) content, content_mask = content_output["content"], content_output[ "content_mask"] instruction_mask = create_mask_from_length(instruction_lengths) content, content_mask, global_duration_pred, _ = \ self.content_adapter(content, content_mask, instruction, instruction_mask) batch_size = content.size(0) if classifier_free_guidance: uncond_content = torch.zeros_like(content) uncond_content_mask = content_mask.detach().clone() content = torch.cat([uncond_content, content]) content_mask = torch.cat([uncond_content_mask, content_mask]) scheduler.set_timesteps(num_steps, device=device) timesteps = scheduler.timesteps global_duration_pred = torch.exp( global_duration_pred ) - self.duration_offset global_duration_pred *= self.autoencoder.latent_token_rate global_duration_pred = torch.round(global_duration_pred) latent_shape = tuple( int(global_duration_pred.max().item()) if dim is None else dim for dim in self.autoencoder.latent_shape ) latent = self.prepare_latent( batch_size, scheduler, latent_shape, content.dtype, device ) latent_mask = create_mask_from_length(global_duration_pred).to( content_mask.device ) if classifier_free_guidance: latent_mask = torch.cat([latent_mask, latent_mask]) num_warmup_steps = len(timesteps) - num_steps * scheduler.order progress_bar = tqdm(range(num_steps), disable=disable_progress) for i, timestep in enumerate(timesteps): # expand the latent if we are doing classifier free guidance latent_input = torch.cat([latent, latent] ) if classifier_free_guidance else latent latent_input = scheduler.scale_model_input(latent_input, timestep) noise_pred = self.backbone( x=latent_input, x_mask=latent_mask, timesteps=timestep, context=content, context_mask=content_mask, ) # perform guidance if classifier_free_guidance: noise_pred_uncond, noise_pred_content = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_content - noise_pred_uncond ) if guidance_rescale != 0.0: noise_pred = self.rescale_cfg( noise_pred_content, noise_pred, guidance_rescale ) # compute the previous noisy sample x_t -> x_t-1 latent = scheduler.step(noise_pred, timestep, latent).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update(1) waveform = self.autoencoder.decode(latent) return waveform def prepare_latent( self, batch_size: int, scheduler: SchedulerMixin, latent_shape: Sequence[int], dtype: torch.dtype, device: str ): shape = (batch_size, *latent_shape) latent = randn_tensor( shape, generator=None, device=device, dtype=dtype ) # scale the initial noise by the standard deviation required by the scheduler latent = latent * scheduler.init_noise_sigma return latent class SingleTaskCrossAttentionAudioDiffusion(CrossAttentionAudioDiffusion ): def __init__( self, autoencoder: AutoEncoderBase, content_encoder: ContentEncoder, backbone: nn.Module, pretrained_ckpt: str | Path = None, noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", snr_gamma: float = None, cfg_drop_ratio: float = 0.2, ): nn.Module.__init__(self) DiffusionMixin.__init__( self, noise_scheduler_name, snr_gamma, cfg_drop_ratio ) self.autoencoder = autoencoder for param in self.autoencoder.parameters(): param.requires_grad = False self.backbone = backbone if pretrained_ckpt is not None: pretrained_state_dict = load_file(pretrained_ckpt) self.load_pretrained(pretrained_state_dict) self.content_encoder = content_encoder #self.content_encoder.audio_encoder.model = self.autoencoder self.dummy_param = nn.Parameter(torch.empty(0)) def forward( self, content: list[Any], condition: list[Any], task: list[str], waveform: torch.Tensor, waveform_lengths: torch.Tensor, loss_reduce: bool = True, **kwargs ): loss_reduce = self.training or (loss_reduce and not self.training) device = self.dummy_param.device num_train_timesteps = self.noise_scheduler.config.num_train_timesteps self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) self.autoencoder.eval() with torch.no_grad(): latent, latent_mask = self.autoencoder.encode( waveform.unsqueeze(1), waveform_lengths ) content_output: dict[ str, torch.Tensor] = self.content_encoder.encode_content( content, task, device=device ) content, content_mask = content_output["content"], content_output[ "content_mask"] if self.training and self.classifier_free_guidance: mask_indices = [ k for k in range(len(waveform)) if random.random() < self.cfg_drop_ratio ] if len(mask_indices) > 0: content[mask_indices] = 0 batch_size = latent.shape[0] timesteps = self.get_timesteps(batch_size, device, self.training) noise = torch.randn_like(latent) noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) target = self.get_target(latent, noise, timesteps) pred: torch.Tensor = self.backbone( x=noisy_latent, timesteps=timesteps, context=content, x_mask=latent_mask, context_mask=content_mask ) pred = pred.transpose(1, self.autoencoder.time_dim) target = target.transpose(1, self.autoencoder.time_dim) diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask, loss_reduce=loss_reduce) return { "diff_loss": diff_loss, } @torch.no_grad() def inference( self, content: list[Any], condition: list[Any], task: list[str], scheduler: SchedulerMixin, latent_shape: Sequence[int], num_steps: int = 20, guidance_scale: float = 3.0, guidance_rescale: float = 0.0, disable_progress: bool = True, **kwargs ): device = self.dummy_param.device classifier_free_guidance = guidance_scale > 1.0 content_output: dict[ str, torch.Tensor] = self.content_encoder.encode_content( content, task, device=device ) content, content_mask = content_output["content"], content_output[ "content_mask"] batch_size = content.size(0) if classifier_free_guidance: uncond_content = torch.zeros_like(content) uncond_content_mask = content_mask.detach().clone() content = torch.cat([uncond_content, content]) content_mask = torch.cat([uncond_content_mask, content_mask]) scheduler.set_timesteps(num_steps, device=device) timesteps = scheduler.timesteps latent = self.prepare_latent( batch_size, scheduler, latent_shape, content.dtype, device ) num_warmup_steps = len(timesteps) - num_steps * scheduler.order progress_bar = tqdm(range(num_steps), disable=disable_progress) for i, timestep in enumerate(timesteps): # expand the latent if we are doing classifier free guidance latent_input = torch.cat([latent, latent] ) if classifier_free_guidance else latent latent_input = scheduler.scale_model_input(latent_input, timestep) noise_pred = self.backbone( x=latent_input, timesteps=timestep, context=content, context_mask=content_mask, ) # perform guidance if classifier_free_guidance: noise_pred_uncond, noise_pred_content = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_content - noise_pred_uncond ) if guidance_rescale != 0.0: noise_pred = self.rescale_cfg( noise_pred_content, noise_pred, guidance_rescale ) # compute the previous noisy sample x_t -> x_t-1 latent = scheduler.step(noise_pred, timestep, latent).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update(1) waveform = self.autoencoder.decode(latent) return waveform class DummyContentAudioDiffusion(CrossAttentionAudioDiffusion): def __init__( self, autoencoder: AutoEncoderBase, content_encoder: ContentEncoder, content_adapter: ContentAdapterBase, backbone: nn.Module, content_dim: int, frame_resolution: float, duration_offset: float = 1.0, noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", snr_gamma: float = None, cfg_drop_ratio: float = 0.2, ): """ Args: autoencoder: Pretrained audio autoencoder that encodes raw waveforms into latent space and decodes latents back to waveforms. content_encoder: Module that produces content embeddings (e.g., from text, MIDI, or other modalities) used to guide the diffusion. content_adapter (ContentAdapterBase): Adapter module that fuses task instruction embeddings and content embeddings, and performs duration prediction for time-aligned tasks. backbone: U‑Net or Transformer backbone that performs the core denoising operations in latent space. content_dim: Dimension of the content embeddings produced by the `content_encoder` and `content_adapter`. frame_resolution: Time resolution, in seconds, of each content frame when predicting duration alignment. Used when calculating duration loss. duration_offset: A small positive offset (frame number) added to predicted durations to ensure numerical stability of log-scaled duration prediction. noise_scheduler_name: Identifier of the pretrained noise scheduler to use. snr_gamma: Clipping value in min-SNR diffusion loss weighting strategy. cfg_drop_ratio: Probability of dropping the content conditioning during training to support CFG. """ super().__init__( autoencoder=autoencoder, content_encoder=content_encoder, content_adapter=content_adapter, backbone=backbone, duration_offset=duration_offset, noise_scheduler_name=noise_scheduler_name, snr_gamma=snr_gamma, cfg_drop_ratio=cfg_drop_ratio, ) self.frame_resolution = frame_resolution self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim)) self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim)) def forward( self, content, duration, task, is_time_aligned, waveform, waveform_lengths, instruction, instruction_lengths, **kwargs ): device = self.dummy_param.device num_train_timesteps = self.noise_scheduler.config.num_train_timesteps self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) self.autoencoder.eval() with torch.no_grad(): latent, latent_mask = self.autoencoder.encode( waveform.unsqueeze(1), waveform_lengths ) # content: (B, L, E) content_output: dict[ str, torch.Tensor] = self.content_encoder.encode_content( content, task, device=device ) length_aligned_content = content_output["length_aligned_content"] content, content_mask = content_output["content"], content_output[ "content_mask"] instruction_mask = create_mask_from_length(instruction_lengths) content, content_mask, global_duration_pred, local_duration_pred = \ self.content_adapter(content, content_mask, instruction, instruction_mask) n_frames = torch.round(duration / self.frame_resolution) local_duration_target = torch.log(n_frames + self.duration_offset) global_duration_target = torch.log( latent_mask.sum(1) / self.autoencoder.latent_token_rate + self.duration_offset ) # truncate unused non time aligned duration prediction if is_time_aligned.sum() > 0: trunc_ta_length = content_mask[is_time_aligned].sum(1).max() else: trunc_ta_length = content.size(1) # local duration loss local_duration_pred = local_duration_pred[:, :trunc_ta_length] ta_content_mask = content_mask[:, :trunc_ta_length] local_duration_target = local_duration_target.to( dtype=local_duration_pred.dtype ) local_duration_loss = loss_with_mask( (local_duration_target - local_duration_pred)**2, ta_content_mask, reduce=False ) local_duration_loss *= is_time_aligned if is_time_aligned.sum().item() == 0: local_duration_loss *= 0.0 local_duration_loss = local_duration_loss.mean() else: local_duration_loss = local_duration_loss.sum( ) / is_time_aligned.sum() # global duration loss global_duration_loss = F.mse_loss( global_duration_target, global_duration_pred ) # -------------------------------------------------------------------- # prepare latent and diffusion-related noise # -------------------------------------------------------------------- batch_size = latent.shape[0] timesteps = self.get_timesteps(batch_size, device, self.training) noise = torch.randn_like(latent) noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) target = self.get_target(latent, noise, timesteps) # -------------------------------------------------------------------- # duration adapter # -------------------------------------------------------------------- if is_time_aligned.sum() == 0 and \ duration.size(1) < content_mask.size(1): # for non time-aligned tasks like TTA, `duration` is dummy one duration = F.pad( duration, (0, content_mask.size(1) - duration.size(1)) ) n_latents = torch.round(duration * self.autoencoder.latent_token_rate) # content_mask: [B, L], helper_latent_mask: [B, T] helper_latent_mask = create_mask_from_length(n_latents.sum(1)).to( content_mask.device ) attn_mask = ta_content_mask.unsqueeze( -1 ) * helper_latent_mask.unsqueeze(1) # attn_mask: [B, L, T] align_path = create_alignment_path(n_latents, attn_mask) time_aligned_content = content[:, :trunc_ta_length] time_aligned_content = torch.matmul( align_path.transpose(1, 2).to(content.dtype), time_aligned_content ) # (B, T, L) x (B, L, E) -> (B, T, E) # -------------------------------------------------------------------- # prepare input to the backbone # -------------------------------------------------------------------- # TODO compatility for 2D spectrogram VAE latent_length = noisy_latent.size(self.autoencoder.time_dim) time_aligned_content = trim_or_pad_length( time_aligned_content, latent_length, 1 ) length_aligned_content = trim_or_pad_length( length_aligned_content, latent_length, 1 ) # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme) # length_aligned_content: from aligned input (f0/energy) time_aligned_content = time_aligned_content + length_aligned_content time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( time_aligned_content.dtype ) context = content context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype) # only use the first dummy non time aligned embedding context_mask = content_mask.detach().clone() context_mask[is_time_aligned, 1:] = False # truncate dummy non time aligned context if is_time_aligned.sum().item() < batch_size: trunc_nta_length = content_mask[~is_time_aligned].sum(1).max() else: trunc_nta_length = content.size(1) context = context[:, :trunc_nta_length] context_mask = context_mask[:, :trunc_nta_length] # -------------------------------------------------------------------- # classifier free guidance # -------------------------------------------------------------------- if self.training and self.classifier_free_guidance: mask_indices = [ k for k in range(len(waveform)) if random.random() < self.cfg_drop_ratio ] if len(mask_indices) > 0: context[mask_indices] = 0 time_aligned_content[mask_indices] = 0 pred: torch.Tensor = self.backbone( x=noisy_latent, timesteps=timesteps, time_aligned_context=time_aligned_content, context=context, x_mask=latent_mask, context_mask=context_mask ) pred = pred.transpose(1, self.autoencoder.time_dim) target = target.transpose(1, self.autoencoder.time_dim) diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask) return { "diff_loss": diff_loss, "local_duration_loss": local_duration_loss, "global_duration_loss": global_duration_loss } @torch.no_grad() def inference( self, content: list[Any], condition: list[Any], task: list[str], is_time_aligned: list[bool], instruction: torch.Tensor, instruction_lengths: Sequence[int], scheduler: SchedulerMixin, num_steps: int = 20, guidance_scale: float = 3.0, guidance_rescale: float = 0.0, disable_progress: bool = True, use_gt_duration: bool = False, **kwargs ): device = self.dummy_param.device classifier_free_guidance = guidance_scale > 1.0 content_output: dict[ str, torch.Tensor] = self.content_encoder.encode_content( content, task, device=device ) length_aligned_content = content_output["length_aligned_content"] content, content_mask = content_output["content"], content_output[ "content_mask"] instruction_mask = create_mask_from_length(instruction_lengths) content, content_mask, global_duration_pred, local_duration_pred = \ self.content_adapter(content, content_mask, instruction, instruction_mask) scheduler.set_timesteps(num_steps, device=device) timesteps = scheduler.timesteps batch_size = content.size(0) # truncate dummy time aligned duration prediction is_time_aligned = torch.as_tensor(is_time_aligned) if is_time_aligned.sum() > 0: trunc_ta_length = content_mask[is_time_aligned].sum(1).max() else: trunc_ta_length = content.size(1) # prepare local duration local_duration_pred = torch.exp(local_duration_pred) * content_mask local_duration_pred = torch.ceil( local_duration_pred ) - self.duration_offset # frame number in `self.frame_resolution` local_duration_pred = torch.round(local_duration_pred * self.frame_resolution * \ self.autoencoder.latent_token_rate) local_duration_pred = local_duration_pred[:, :trunc_ta_length] # use ground truth duration if use_gt_duration and "duration" in kwargs: local_duration_pred = torch.round( torch.as_tensor(kwargs["duration"]) * self.autoencoder.latent_token_rate ).to(device) # prepare global duration global_duration = local_duration_pred.sum(1) global_duration_pred = torch.exp( global_duration_pred ) - self.duration_offset global_duration_pred *= self.autoencoder.latent_token_rate global_duration_pred = torch.round(global_duration_pred) global_duration[~is_time_aligned] = global_duration_pred[ ~is_time_aligned] # -------------------------------------------------------------------- # duration adapter # -------------------------------------------------------------------- time_aligned_content = content[:, :trunc_ta_length] ta_content_mask = content_mask[:, :trunc_ta_length] latent_mask = create_mask_from_length(global_duration).to( content_mask.device ) attn_mask = ta_content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1) # attn_mask: [B, L, T] align_path = create_alignment_path(local_duration_pred, attn_mask) time_aligned_content = torch.matmul( align_path.transpose(1, 2).to(content.dtype), time_aligned_content ) # (B, T, L) x (B, L, E) -> (B, T, E) time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( time_aligned_content.dtype ) length_aligned_content = trim_or_pad_length( length_aligned_content, time_aligned_content.size(1), 1 ) time_aligned_content = time_aligned_content + length_aligned_content # -------------------------------------------------------------------- # prepare unconditional input # -------------------------------------------------------------------- context = content context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype) context_mask = content_mask context_mask[ is_time_aligned, 1:] = False # only use the first dummy non time aligned embedding # truncate dummy non time aligned context if is_time_aligned.sum().item() < batch_size: trunc_nta_length = content_mask[~is_time_aligned].sum(1).max() else: trunc_nta_length = content.size(1) context = context[:, :trunc_nta_length] context_mask = context_mask[:, :trunc_nta_length] if classifier_free_guidance: uncond_time_aligned_content = torch.zeros_like( time_aligned_content ) uncond_context = torch.zeros_like(context) uncond_context_mask = context_mask.detach().clone() time_aligned_content = torch.cat([ uncond_time_aligned_content, time_aligned_content ]) context = torch.cat([uncond_context, context]) context_mask = torch.cat([uncond_context_mask, context_mask]) latent_mask = torch.cat([ latent_mask, latent_mask.detach().clone() ]) # -------------------------------------------------------------------- # prepare input to the backbone # -------------------------------------------------------------------- latent_shape = tuple( int(global_duration.max().item()) if dim is None else dim for dim in self.autoencoder.latent_shape ) shape = (batch_size, *latent_shape) latent = randn_tensor( shape, generator=None, device=device, dtype=content.dtype ) # scale the initial noise by the standard deviation required by the scheduler latent = latent * scheduler.init_noise_sigma num_warmup_steps = len(timesteps) - num_steps * scheduler.order progress_bar = tqdm(range(num_steps), disable=disable_progress) # -------------------------------------------------------------------- # iteratively denoising # -------------------------------------------------------------------- for i, timestep in enumerate(timesteps): # expand the latent if we are doing classifier free guidance if classifier_free_guidance: latent_input = torch.cat([latent, latent]) else: latent_input = latent latent_input = scheduler.scale_model_input(latent_input, timestep) noise_pred = self.backbone( x=latent_input, x_mask=latent_mask, timesteps=timestep, time_aligned_context=time_aligned_content, context=context, context_mask=context_mask ) if classifier_free_guidance: noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) if guidance_rescale != 0.0: noise_pred = self.rescale_cfg( noise_pred_cond, noise_pred, guidance_rescale ) # compute the previous noisy sample x_t -> x_t-1 latent = scheduler.step(noise_pred, timestep, latent).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update(1) progress_bar.close() # TODO variable length decoding, using `latent_mask` waveform = self.autoencoder.decode(latent) return waveform class DoubleContentAudioDiffusion(CrossAttentionAudioDiffusion): def __init__( self, autoencoder: AutoEncoderBase, content_encoder: ContentEncoder, content_adapter: nn.Module, backbone: nn.Module, content_dim: int, frame_resolution: float, duration_offset: float = 1.0, noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", snr_gamma: float = None, cfg_drop_ratio: float = 0.2, ): super().__init__( autoencoder=autoencoder, content_encoder=content_encoder, content_adapter=content_adapter, backbone=backbone, duration_offset=duration_offset, noise_scheduler_name=noise_scheduler_name, snr_gamma=snr_gamma, cfg_drop_ratio=cfg_drop_ratio ) self.frame_resolution = frame_resolution def forward( self, content, duration, task, is_time_aligned, waveform, waveform_lengths, instruction, instruction_lengths, **kwargs ): device = self.dummy_param.device num_train_timesteps = self.noise_scheduler.config.num_train_timesteps self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) self.autoencoder.eval() with torch.no_grad(): latent, latent_mask = self.autoencoder.encode( waveform.unsqueeze(1), waveform_lengths ) content_output: dict[ str, torch.Tensor] = self.content_encoder.encode_content( content, task, device=device ) length_aligned_content = content_output["length_aligned_content"] content, content_mask = content_output["content"], content_output[ "content_mask"] context_mask = content_mask.detach() instruction_mask = create_mask_from_length(instruction_lengths) content, content_mask, global_duration_pred, local_duration_pred = \ self.content_adapter(content, content_mask, instruction, instruction_mask) # TODO if all non time aligned, content length > duration length n_frames = torch.round(duration / self.frame_resolution) local_duration_target = torch.log(n_frames + self.duration_offset) global_duration_target = torch.log( latent_mask.sum(1) / self.autoencoder.latent_token_rate + self.duration_offset ) # truncate unused non time aligned duration prediction if is_time_aligned.sum() > 0: trunc_ta_length = content_mask[is_time_aligned].sum(1).max() else: trunc_ta_length = content.size(1) # local duration loss local_duration_pred = local_duration_pred[:, :trunc_ta_length] ta_content_mask = content_mask[:, :trunc_ta_length] local_duration_target = local_duration_target.to( dtype=local_duration_pred.dtype ) local_duration_loss = loss_with_mask( (local_duration_target - local_duration_pred)**2, ta_content_mask, reduce=False ) local_duration_loss *= is_time_aligned if is_time_aligned.sum().item() == 0: local_duration_loss *= 0.0 local_duration_loss = local_duration_loss.mean() else: local_duration_loss = local_duration_loss.sum( ) / is_time_aligned.sum() # global duration loss global_duration_loss = F.mse_loss( global_duration_target, global_duration_pred ) # -------------------------------------------------------------------- # prepare latent and diffusion-related noise # -------------------------------------------------------------------- batch_size = latent.shape[0] timesteps = self.get_timesteps(batch_size, device, self.training) noise = torch.randn_like(latent) noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) target = self.get_target(latent, noise, timesteps) # -------------------------------------------------------------------- # duration adapter # -------------------------------------------------------------------- # content_mask: [B, L], helper_latent_mask: [B, T] if is_time_aligned.sum() == 0 and \ duration.size(1) < content_mask.size(1): # for non time-aligned tasks like TTA, `duration` is dummy one duration = F.pad( duration, (0, content_mask.size(1) - duration.size(1)) ) n_latents = torch.round(duration * self.autoencoder.latent_token_rate) helper_latent_mask = create_mask_from_length(n_latents.sum(1)).to( content_mask.device ) attn_mask = ta_content_mask.unsqueeze( -1 ) * helper_latent_mask.unsqueeze(1) align_path = create_alignment_path(n_latents, attn_mask) time_aligned_content = content[:, :trunc_ta_length] time_aligned_content = torch.matmul( align_path.transpose(1, 2).to(content.dtype), time_aligned_content ) latent_length = noisy_latent.size(self.autoencoder.time_dim) time_aligned_content = trim_or_pad_length( time_aligned_content, latent_length, 1 ) length_aligned_content = trim_or_pad_length( length_aligned_content, latent_length, 1 ) time_aligned_content = time_aligned_content + length_aligned_content context = content # -------------------------------------------------------------------- # classifier free guidance # -------------------------------------------------------------------- if self.training and self.classifier_free_guidance: mask_indices = [ k for k in range(len(waveform)) if random.random() < self.cfg_drop_ratio ] if len(mask_indices) > 0: context[mask_indices] = 0 time_aligned_content[mask_indices] = 0 pred: torch.Tensor = self.backbone( x=noisy_latent, timesteps=timesteps, time_aligned_context=time_aligned_content, context=context, x_mask=latent_mask, context_mask=context_mask, ) pred = pred.transpose(1, self.autoencoder.time_dim) target = target.transpose(1, self.autoencoder.time_dim) diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask) return { "diff_loss": diff_loss, "local_duration_loss": local_duration_loss, "global_duration_loss": global_duration_loss, } @torch.no_grad() def inference( self, content: list[Any], condition: list[Any], task: list[str], is_time_aligned: list[bool], instruction: torch.Tensor, instruction_lengths: Sequence[int], scheduler: SchedulerMixin, num_steps: int = 20, guidance_scale: float = 3.0, guidance_rescale: float = 0.0, disable_progress: bool = True, use_gt_duration: bool = False, **kwargs ): device = self.dummy_param.device classifier_free_guidance = guidance_scale > 1.0 content_output: dict[ str, torch.Tensor] = self.content_encoder.encode_content( content, task, device=device ) length_aligned_content = content_output["length_aligned_content"] content, content_mask = content_output["content"], content_output[ "content_mask"] instruction_mask = create_mask_from_length(instruction_lengths) content, content_mask, global_duration_pred, local_duration_pred = \ self.content_adapter(content, content_mask, instruction, instruction_mask) scheduler.set_timesteps(num_steps, device=device) timesteps = scheduler.timesteps batch_size = content.size(0) # truncate dummy time aligned duration prediction is_time_aligned = torch.as_tensor(is_time_aligned) if is_time_aligned.sum() > 0: trunc_ta_length = content_mask[is_time_aligned].sum(1).max() else: trunc_ta_length = content.size(1) # prepare local duration local_duration_pred = torch.exp(local_duration_pred) * content_mask local_duration_pred = torch.ceil( local_duration_pred ) - self.duration_offset # frame number in `self.frame_resolution` local_duration_pred = torch.round(local_duration_pred * self.frame_resolution * \ self.autoencoder.latent_token_rate) local_duration_pred = local_duration_pred[:, :trunc_ta_length] # use ground truth duration if use_gt_duration and "duration" in kwargs: local_duration_pred = torch.round( torch.as_tensor(kwargs["duration"]) * self.autoencoder.latent_token_rate ).to(device) # prepare global duration global_duration = local_duration_pred.sum(1) global_duration_pred = torch.exp( global_duration_pred ) - self.duration_offset global_duration_pred *= self.autoencoder.latent_token_rate global_duration_pred = torch.round(global_duration_pred) global_duration[~is_time_aligned] = global_duration_pred[ ~is_time_aligned] # -------------------------------------------------------------------- # duration adapter # -------------------------------------------------------------------- time_aligned_content = content[:, :trunc_ta_length] ta_content_mask = content_mask[:, :trunc_ta_length] latent_mask = create_mask_from_length(global_duration).to( content_mask.device ) attn_mask = ta_content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1) # attn_mask: [B, L, T] align_path = create_alignment_path(local_duration_pred, attn_mask) time_aligned_content = torch.matmul( align_path.transpose(1, 2).to(content.dtype), time_aligned_content ) # (B, T, L) x (B, L, E) -> (B, T, E) # time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( # time_aligned_content.dtype # ) length_aligned_content = trim_or_pad_length( length_aligned_content, time_aligned_content.size(1), 1 ) time_aligned_content = time_aligned_content + length_aligned_content # -------------------------------------------------------------------- # prepare unconditional input # -------------------------------------------------------------------- context = content # context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype) context_mask = content_mask # context_mask[ # is_time_aligned, # 1:] = False # only use the first dummy non time aligned embedding # # truncate dummy non time aligned context # if is_time_aligned.sum().item() < batch_size: # trunc_nta_length = content_mask[~is_time_aligned].sum(1).max() # else: # trunc_nta_length = content.size(1) # context = context[:, :trunc_nta_length] # context_mask = context_mask[:, :trunc_nta_length] if classifier_free_guidance: uncond_time_aligned_content = torch.zeros_like( time_aligned_content ) uncond_context = torch.zeros_like(context) uncond_context_mask = context_mask.detach().clone() time_aligned_content = torch.cat([ uncond_time_aligned_content, time_aligned_content ]) context = torch.cat([uncond_context, context]) context_mask = torch.cat([uncond_context_mask, context_mask]) latent_mask = torch.cat([ latent_mask, latent_mask.detach().clone() ]) # -------------------------------------------------------------------- # prepare input to the backbone # -------------------------------------------------------------------- latent_shape = tuple( int(global_duration.max().item()) if dim is None else dim for dim in self.autoencoder.latent_shape ) shape = (batch_size, *latent_shape) latent = randn_tensor( shape, generator=None, device=device, dtype=content.dtype ) # scale the initial noise by the standard deviation required by the scheduler latent = latent * scheduler.init_noise_sigma num_warmup_steps = len(timesteps) - num_steps * scheduler.order progress_bar = tqdm(range(num_steps), disable=disable_progress) # -------------------------------------------------------------------- # iteratively denoising # -------------------------------------------------------------------- for i, timestep in enumerate(timesteps): # expand the latent if we are doing classifier free guidance if classifier_free_guidance: latent_input = torch.cat([latent, latent]) else: latent_input = latent latent_input = scheduler.scale_model_input(latent_input, timestep) noise_pred = self.backbone( x=latent_input, x_mask=latent_mask, timesteps=timestep, time_aligned_context=time_aligned_content, context=context, context_mask=context_mask ) if classifier_free_guidance: noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) if guidance_rescale != 0.0: noise_pred = self.rescale_cfg( noise_pred_cond, noise_pred, guidance_rescale ) # compute the previous noisy sample x_t -> x_t-1 latent = scheduler.step(noise_pred, timestep, latent).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update(1) progress_bar.close() # TODO variable length decoding, using `latent_mask` waveform = self.autoencoder.decode(latent) return waveform