from typing import Any, Literal, Callable import math from pathlib import Path import torch import torch.nn as nn from torch.nn.utils import weight_norm import torchaudio from alias_free_torch import Activation1d from models.common import LoadPretrainedBase from models.autoencoder.autoencoder_base import AutoEncoderBase from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length # jit script make it 1.4x faster and save GPU memory @torch.jit.script def snake_beta(x, alpha, beta): return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) class SnakeBeta(nn.Module): def __init__( self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True ): super(SnakeBeta, self).__init__() self.in_features = in_features # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) self.beta = nn.Parameter(torch.zeros(in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = nn.Parameter(torch.ones(in_features) * alpha) self.beta = nn.Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.beta.requires_grad = alpha_trainable # self.no_div_by_zero = 0.000000001 def forward(self, x): alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] beta = self.beta.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) x = snake_beta(x, alpha, beta) return x def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) def get_activation( activation: Literal["elu", "snake", "none"], antialias=False, channels=None ) -> nn.Module: if activation == "elu": act = nn.ELU() elif activation == "snake": act = SnakeBeta(channels) elif activation == "none": act = nn.Identity() else: raise ValueError(f"Unknown activation {activation}") if antialias: act = Activation1d(act) return act class ResidualUnit(nn.Module): def __init__( self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False ): super().__init__() self.dilation = dilation padding = (dilation * (7 - 1)) // 2 self.layers = nn.Sequential( get_activation( "snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels ), WNConv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=7, dilation=dilation, padding=padding ), get_activation( "snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels ), WNConv1d( in_channels=out_channels, out_channels=out_channels, kernel_size=1 ) ) def forward(self, x): res = x #x = checkpoint(self.layers, x) x = self.layers(x) return x + res class EncoderBlock(nn.Module): def __init__( self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False ): super().__init__() self.layers = nn.Sequential( ResidualUnit( in_channels=in_channels, out_channels=in_channels, dilation=1, use_snake=use_snake ), ResidualUnit( in_channels=in_channels, out_channels=in_channels, dilation=3, use_snake=use_snake ), ResidualUnit( in_channels=in_channels, out_channels=in_channels, dilation=9, use_snake=use_snake ), get_activation( "snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels ), WNConv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2) ), ) def forward(self, x): return self.layers(x) class DecoderBlock(nn.Module): def __init__( self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False ): super().__init__() if use_nearest_upsample: upsample_layer = nn.Sequential( nn.Upsample(scale_factor=stride, mode="nearest"), WNConv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=2 * stride, stride=1, bias=False, padding='same' ) ) else: upsample_layer = WNConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2) ) self.layers = nn.Sequential( get_activation( "snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels ), upsample_layer, ResidualUnit( in_channels=out_channels, out_channels=out_channels, dilation=1, use_snake=use_snake ), ResidualUnit( in_channels=out_channels, out_channels=out_channels, dilation=3, use_snake=use_snake ), ResidualUnit( in_channels=out_channels, out_channels=out_channels, dilation=9, use_snake=use_snake ), ) def forward(self, x): return self.layers(x) class OobleckEncoder(nn.Module): def __init__( self, in_channels=2, channels=128, latent_dim=32, c_mults=[1, 2, 4, 8], strides=[2, 4, 8, 8], use_snake=False, antialias_activation=False ): super().__init__() c_mults = [1] + c_mults self.depth = len(c_mults) layers = [ WNConv1d( in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3 ) ] for i in range(self.depth - 1): layers += [ EncoderBlock( in_channels=c_mults[i] * channels, out_channels=c_mults[i + 1] * channels, stride=strides[i], use_snake=use_snake ) ] layers += [ get_activation( "snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels ), WNConv1d( in_channels=c_mults[-1] * channels, out_channels=latent_dim, kernel_size=3, padding=1 ) ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class OobleckDecoder(nn.Module): def __init__( self, out_channels=2, channels=128, latent_dim=32, c_mults=[1, 2, 4, 8], strides=[2, 4, 8, 8], use_snake=False, antialias_activation=False, use_nearest_upsample=False, final_tanh=True ): super().__init__() c_mults = [1] + c_mults self.depth = len(c_mults) layers = [ WNConv1d( in_channels=latent_dim, out_channels=c_mults[-1] * channels, kernel_size=7, padding=3 ), ] for i in range(self.depth - 1, 0, -1): layers += [ DecoderBlock( in_channels=c_mults[i] * channels, out_channels=c_mults[i - 1] * channels, stride=strides[i - 1], use_snake=use_snake, antialias_activation=antialias_activation, use_nearest_upsample=use_nearest_upsample ) ] layers += [ get_activation( "snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels ), WNConv1d( in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False ), nn.Tanh() if final_tanh else nn.Identity() ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class Bottleneck(nn.Module): def __init__(self, is_discrete: bool = False): super().__init__() self.is_discrete = is_discrete def encode(self, x, return_info=False, **kwargs): raise NotImplementedError def decode(self, x): raise NotImplementedError @torch.jit.script def vae_sample(mean, scale) -> dict[str, torch.Tensor]: stdev = nn.functional.softplus(scale) + 1e-4 var = stdev * stdev logvar = torch.log(var) latents = torch.randn_like(mean) * stdev + mean kl = (mean * mean + var - logvar - 1).sum(1).mean() return {"latents": latents, "kl": kl} class VAEBottleneck(Bottleneck): def __init__(self): super().__init__(is_discrete=False) def encode(self, x, return_info=False, **kwargs) -> dict[str, torch.Tensor] | torch.Tensor: mean, scale = x.chunk(2, dim=1) sampled = vae_sample(mean, scale) if return_info: return sampled["latents"], {"kl": sampled["kl"]} else: return sampled["latents"] def decode(self, x): return x def compute_mean_kernel(x, y): kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] return torch.exp(-kernel_input).mean() class Pretransform(nn.Module): def __init__(self, enable_grad, io_channels, is_discrete): super().__init__() self.is_discrete = is_discrete self.io_channels = io_channels self.encoded_channels = None self.downsampling_ratio = None self.enable_grad = enable_grad def encode(self, x): raise NotImplementedError def decode(self, z): raise NotImplementedError def tokenize(self, x): raise NotImplementedError def decode_tokens(self, tokens): raise NotImplementedError class StableVAE(LoadPretrainedBase, AutoEncoderBase): def __init__( self, encoder, decoder, latent_dim, downsampling_ratio, sample_rate, io_channels=2, bottleneck: Bottleneck = None, pretransform: Pretransform = None, in_channels=None, out_channels=None, soft_clip=False, pretrained_ckpt: str | Path = None ): LoadPretrainedBase.__init__(self) AutoEncoderBase.__init__( self, downsampling_ratio=downsampling_ratio, sample_rate=sample_rate, latent_shape=(latent_dim, None) ) self.latent_dim = latent_dim self.io_channels = io_channels self.in_channels = io_channels self.out_channels = io_channels self.min_length = self.downsampling_ratio if in_channels is not None: self.in_channels = in_channels if out_channels is not None: self.out_channels = out_channels self.bottleneck = bottleneck self.encoder = encoder self.decoder = decoder self.pretransform = pretransform self.soft_clip = soft_clip self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory( "autoencoder." ) if pretrained_ckpt is not None: self.load_pretrained(pretrained_ckpt) def process_state_dict(self, model_dict, state_dict): state_dict = state_dict["state_dict"] state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict) return state_dict def encode( self, waveform: torch.Tensor, waveform_lengths: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: z = self.encoder(waveform) z = self.bottleneck.encode(z) z_length = waveform_lengths // self.downsampling_ratio z_mask = create_mask_from_length(z_length) return z, z_mask def decode(self, latents: torch.Tensor) -> torch.Tensor: waveform = self.decoder(latents) return waveform class StableVAEProjectorWrapper(nn.Module): def __init__( self, vae_dim: int, embed_dim: int, model: StableVAE | None = None, ): super().__init__() self.model = model self.proj = nn.Linear(vae_dim, embed_dim) def forward( self, waveform: torch.Tensor, waveform_lengths: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: self.model.eval() with torch.no_grad(): z, z_mask = self.model.encode(waveform, waveform_lengths) z = self.proj(z.transpose(1, 2)) return {"output": z, "mask": z_mask} if __name__ == '__main__': import hydra from utils.config import generate_config_from_command_line_overrides model_config = generate_config_from_command_line_overrides( "configs/model/autoencoder/stable_vae.yaml" ) autoencoder: StableVAE = hydra.utils.instantiate(model_config) autoencoder.eval() waveform, sr = torchaudio.load( "/hpc_stor03/sjtu_home/xuenan.xu/data/m4singer/Tenor-1#童话/0006.wav" ) waveform = waveform.mean(0, keepdim=True) waveform = torchaudio.functional.resample( waveform, sr, model_config["sample_rate"] ) print("waveform: ", waveform.shape) with torch.no_grad(): latent, latent_length = autoencoder.encode( waveform, torch.as_tensor([waveform.shape[-1]]) ) print("latent: ", latent.shape) reconstructed = autoencoder.decode(latent) print("reconstructed: ", reconstructed.shape) import soundfile as sf sf.write( "./reconstructed.wav", reconstructed[0, 0].numpy(), samplerate=model_config["sample_rate"] )