|
|
from typing import Any, Literal, Callable |
|
|
import math |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn.utils import weight_norm |
|
|
import torchaudio |
|
|
from alias_free_torch import Activation1d |
|
|
|
|
|
from models.common import LoadPretrainedBase |
|
|
from models.autoencoder.autoencoder_base import AutoEncoderBase |
|
|
from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length |
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
|
def snake_beta(x, alpha, beta): |
|
|
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) |
|
|
|
|
|
|
|
|
class SnakeBeta(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_features, |
|
|
alpha=1.0, |
|
|
alpha_trainable=True, |
|
|
alpha_logscale=True |
|
|
): |
|
|
super(SnakeBeta, self).__init__() |
|
|
self.in_features = in_features |
|
|
|
|
|
|
|
|
self.alpha_logscale = alpha_logscale |
|
|
if self.alpha_logscale: |
|
|
|
|
|
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) |
|
|
self.beta = nn.Parameter(torch.zeros(in_features) * alpha) |
|
|
else: |
|
|
|
|
|
self.alpha = nn.Parameter(torch.ones(in_features) * alpha) |
|
|
self.beta = nn.Parameter(torch.ones(in_features) * alpha) |
|
|
|
|
|
self.alpha.requires_grad = alpha_trainable |
|
|
self.beta.requires_grad = alpha_trainable |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) |
|
|
|
|
|
beta = self.beta.unsqueeze(0).unsqueeze(-1) |
|
|
if self.alpha_logscale: |
|
|
alpha = torch.exp(alpha) |
|
|
beta = torch.exp(beta) |
|
|
x = snake_beta(x, alpha, beta) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
def WNConv1d(*args, **kwargs): |
|
|
return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
|
|
|
|
|
def WNConvTranspose1d(*args, **kwargs): |
|
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
|
|
|
|
|
def get_activation( |
|
|
activation: Literal["elu", "snake", "none"], |
|
|
antialias=False, |
|
|
channels=None |
|
|
) -> nn.Module: |
|
|
if activation == "elu": |
|
|
act = nn.ELU() |
|
|
elif activation == "snake": |
|
|
act = SnakeBeta(channels) |
|
|
elif activation == "none": |
|
|
act = nn.Identity() |
|
|
else: |
|
|
raise ValueError(f"Unknown activation {activation}") |
|
|
|
|
|
if antialias: |
|
|
act = Activation1d(act) |
|
|
|
|
|
return act |
|
|
|
|
|
|
|
|
class ResidualUnit(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
out_channels, |
|
|
dilation, |
|
|
use_snake=False, |
|
|
antialias_activation=False |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.dilation = dilation |
|
|
|
|
|
padding = (dilation * (7 - 1)) // 2 |
|
|
|
|
|
self.layers = nn.Sequential( |
|
|
get_activation( |
|
|
"snake" if use_snake else "elu", |
|
|
antialias=antialias_activation, |
|
|
channels=out_channels |
|
|
), |
|
|
WNConv1d( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
kernel_size=7, |
|
|
dilation=dilation, |
|
|
padding=padding |
|
|
), |
|
|
get_activation( |
|
|
"snake" if use_snake else "elu", |
|
|
antialias=antialias_activation, |
|
|
channels=out_channels |
|
|
), |
|
|
WNConv1d( |
|
|
in_channels=out_channels, |
|
|
out_channels=out_channels, |
|
|
kernel_size=1 |
|
|
) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
res = x |
|
|
|
|
|
|
|
|
x = self.layers(x) |
|
|
|
|
|
return x + res |
|
|
|
|
|
|
|
|
class EncoderBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
out_channels, |
|
|
stride, |
|
|
use_snake=False, |
|
|
antialias_activation=False |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.layers = nn.Sequential( |
|
|
ResidualUnit( |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
dilation=1, |
|
|
use_snake=use_snake |
|
|
), |
|
|
ResidualUnit( |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
dilation=3, |
|
|
use_snake=use_snake |
|
|
), |
|
|
ResidualUnit( |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
dilation=9, |
|
|
use_snake=use_snake |
|
|
), |
|
|
get_activation( |
|
|
"snake" if use_snake else "elu", |
|
|
antialias=antialias_activation, |
|
|
channels=in_channels |
|
|
), |
|
|
WNConv1d( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
kernel_size=2 * stride, |
|
|
stride=stride, |
|
|
padding=math.ceil(stride / 2) |
|
|
), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.layers(x) |
|
|
|
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
out_channels, |
|
|
stride, |
|
|
use_snake=False, |
|
|
antialias_activation=False, |
|
|
use_nearest_upsample=False |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if use_nearest_upsample: |
|
|
upsample_layer = nn.Sequential( |
|
|
nn.Upsample(scale_factor=stride, mode="nearest"), |
|
|
WNConv1d( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
kernel_size=2 * stride, |
|
|
stride=1, |
|
|
bias=False, |
|
|
padding='same' |
|
|
) |
|
|
) |
|
|
else: |
|
|
upsample_layer = WNConvTranspose1d( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
kernel_size=2 * stride, |
|
|
stride=stride, |
|
|
padding=math.ceil(stride / 2) |
|
|
) |
|
|
|
|
|
self.layers = nn.Sequential( |
|
|
get_activation( |
|
|
"snake" if use_snake else "elu", |
|
|
antialias=antialias_activation, |
|
|
channels=in_channels |
|
|
), |
|
|
upsample_layer, |
|
|
ResidualUnit( |
|
|
in_channels=out_channels, |
|
|
out_channels=out_channels, |
|
|
dilation=1, |
|
|
use_snake=use_snake |
|
|
), |
|
|
ResidualUnit( |
|
|
in_channels=out_channels, |
|
|
out_channels=out_channels, |
|
|
dilation=3, |
|
|
use_snake=use_snake |
|
|
), |
|
|
ResidualUnit( |
|
|
in_channels=out_channels, |
|
|
out_channels=out_channels, |
|
|
dilation=9, |
|
|
use_snake=use_snake |
|
|
), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.layers(x) |
|
|
|
|
|
|
|
|
class OobleckEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels=2, |
|
|
channels=128, |
|
|
latent_dim=32, |
|
|
c_mults=[1, 2, 4, 8], |
|
|
strides=[2, 4, 8, 8], |
|
|
use_snake=False, |
|
|
antialias_activation=False |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
c_mults = [1] + c_mults |
|
|
|
|
|
self.depth = len(c_mults) |
|
|
|
|
|
layers = [ |
|
|
WNConv1d( |
|
|
in_channels=in_channels, |
|
|
out_channels=c_mults[0] * channels, |
|
|
kernel_size=7, |
|
|
padding=3 |
|
|
) |
|
|
] |
|
|
|
|
|
for i in range(self.depth - 1): |
|
|
layers += [ |
|
|
EncoderBlock( |
|
|
in_channels=c_mults[i] * channels, |
|
|
out_channels=c_mults[i + 1] * channels, |
|
|
stride=strides[i], |
|
|
use_snake=use_snake |
|
|
) |
|
|
] |
|
|
|
|
|
layers += [ |
|
|
get_activation( |
|
|
"snake" if use_snake else "elu", |
|
|
antialias=antialias_activation, |
|
|
channels=c_mults[-1] * channels |
|
|
), |
|
|
WNConv1d( |
|
|
in_channels=c_mults[-1] * channels, |
|
|
out_channels=latent_dim, |
|
|
kernel_size=3, |
|
|
padding=1 |
|
|
) |
|
|
] |
|
|
|
|
|
self.layers = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.layers(x) |
|
|
|
|
|
|
|
|
class OobleckDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
out_channels=2, |
|
|
channels=128, |
|
|
latent_dim=32, |
|
|
c_mults=[1, 2, 4, 8], |
|
|
strides=[2, 4, 8, 8], |
|
|
use_snake=False, |
|
|
antialias_activation=False, |
|
|
use_nearest_upsample=False, |
|
|
final_tanh=True |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
c_mults = [1] + c_mults |
|
|
|
|
|
self.depth = len(c_mults) |
|
|
|
|
|
layers = [ |
|
|
WNConv1d( |
|
|
in_channels=latent_dim, |
|
|
out_channels=c_mults[-1] * channels, |
|
|
kernel_size=7, |
|
|
padding=3 |
|
|
), |
|
|
] |
|
|
|
|
|
for i in range(self.depth - 1, 0, -1): |
|
|
layers += [ |
|
|
DecoderBlock( |
|
|
in_channels=c_mults[i] * channels, |
|
|
out_channels=c_mults[i - 1] * channels, |
|
|
stride=strides[i - 1], |
|
|
use_snake=use_snake, |
|
|
antialias_activation=antialias_activation, |
|
|
use_nearest_upsample=use_nearest_upsample |
|
|
) |
|
|
] |
|
|
|
|
|
layers += [ |
|
|
get_activation( |
|
|
"snake" if use_snake else "elu", |
|
|
antialias=antialias_activation, |
|
|
channels=c_mults[0] * channels |
|
|
), |
|
|
WNConv1d( |
|
|
in_channels=c_mults[0] * channels, |
|
|
out_channels=out_channels, |
|
|
kernel_size=7, |
|
|
padding=3, |
|
|
bias=False |
|
|
), |
|
|
nn.Tanh() if final_tanh else nn.Identity() |
|
|
] |
|
|
|
|
|
self.layers = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.layers(x) |
|
|
|
|
|
|
|
|
class Bottleneck(nn.Module): |
|
|
def __init__(self, is_discrete: bool = False): |
|
|
super().__init__() |
|
|
|
|
|
self.is_discrete = is_discrete |
|
|
|
|
|
def encode(self, x, return_info=False, **kwargs): |
|
|
raise NotImplementedError |
|
|
|
|
|
def decode(self, x): |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
@torch.jit.script |
|
|
def vae_sample(mean, scale) -> dict[str, torch.Tensor]: |
|
|
stdev = nn.functional.softplus(scale) + 1e-4 |
|
|
var = stdev * stdev |
|
|
logvar = torch.log(var) |
|
|
latents = torch.randn_like(mean) * stdev + mean |
|
|
|
|
|
kl = (mean * mean + var - logvar - 1).sum(1).mean() |
|
|
return {"latents": latents, "kl": kl} |
|
|
|
|
|
|
|
|
class VAEBottleneck(Bottleneck): |
|
|
def __init__(self): |
|
|
super().__init__(is_discrete=False) |
|
|
|
|
|
def encode(self, |
|
|
x, |
|
|
return_info=False, |
|
|
**kwargs) -> dict[str, torch.Tensor] | torch.Tensor: |
|
|
mean, scale = x.chunk(2, dim=1) |
|
|
sampled = vae_sample(mean, scale) |
|
|
|
|
|
if return_info: |
|
|
return sampled["latents"], {"kl": sampled["kl"]} |
|
|
else: |
|
|
return sampled["latents"] |
|
|
|
|
|
def decode(self, x): |
|
|
return x |
|
|
|
|
|
|
|
|
def compute_mean_kernel(x, y): |
|
|
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] |
|
|
return torch.exp(-kernel_input).mean() |
|
|
|
|
|
|
|
|
class Pretransform(nn.Module): |
|
|
def __init__(self, enable_grad, io_channels, is_discrete): |
|
|
super().__init__() |
|
|
|
|
|
self.is_discrete = is_discrete |
|
|
self.io_channels = io_channels |
|
|
self.encoded_channels = None |
|
|
self.downsampling_ratio = None |
|
|
|
|
|
self.enable_grad = enable_grad |
|
|
|
|
|
def encode(self, x): |
|
|
raise NotImplementedError |
|
|
|
|
|
def decode(self, z): |
|
|
raise NotImplementedError |
|
|
|
|
|
def tokenize(self, x): |
|
|
raise NotImplementedError |
|
|
|
|
|
def decode_tokens(self, tokens): |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class StableVAE(LoadPretrainedBase, AutoEncoderBase): |
|
|
def __init__( |
|
|
self, |
|
|
encoder, |
|
|
decoder, |
|
|
latent_dim, |
|
|
downsampling_ratio, |
|
|
sample_rate, |
|
|
io_channels=2, |
|
|
bottleneck: Bottleneck = None, |
|
|
pretransform: Pretransform = None, |
|
|
in_channels=None, |
|
|
out_channels=None, |
|
|
soft_clip=False, |
|
|
pretrained_ckpt: str | Path = None |
|
|
): |
|
|
LoadPretrainedBase.__init__(self) |
|
|
AutoEncoderBase.__init__( |
|
|
self, |
|
|
downsampling_ratio=downsampling_ratio, |
|
|
sample_rate=sample_rate, |
|
|
latent_shape=(latent_dim, None) |
|
|
) |
|
|
|
|
|
self.latent_dim = latent_dim |
|
|
self.io_channels = io_channels |
|
|
self.in_channels = io_channels |
|
|
self.out_channels = io_channels |
|
|
self.min_length = self.downsampling_ratio |
|
|
|
|
|
if in_channels is not None: |
|
|
self.in_channels = in_channels |
|
|
|
|
|
if out_channels is not None: |
|
|
self.out_channels = out_channels |
|
|
|
|
|
self.bottleneck = bottleneck |
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
self.pretransform = pretransform |
|
|
self.soft_clip = soft_clip |
|
|
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete |
|
|
|
|
|
self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory( |
|
|
"autoencoder." |
|
|
) |
|
|
if pretrained_ckpt is not None: |
|
|
self.load_pretrained(pretrained_ckpt) |
|
|
|
|
|
def process_state_dict(self, model_dict, state_dict): |
|
|
state_dict = state_dict["state_dict"] |
|
|
state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict) |
|
|
return state_dict |
|
|
|
|
|
def encode( |
|
|
self, waveform: torch.Tensor, waveform_lengths: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
z = self.encoder(waveform) |
|
|
z = self.bottleneck.encode(z) |
|
|
z_length = waveform_lengths // self.downsampling_ratio |
|
|
z_mask = create_mask_from_length(z_length) |
|
|
return z, z_mask |
|
|
|
|
|
def decode(self, latents: torch.Tensor) -> torch.Tensor: |
|
|
waveform = self.decoder(latents) |
|
|
return waveform |
|
|
|
|
|
|
|
|
class StableVAEProjectorWrapper(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
vae_dim: int, |
|
|
embed_dim: int, |
|
|
model: StableVAE | None = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.proj = nn.Linear(vae_dim, embed_dim) |
|
|
|
|
|
def forward( |
|
|
self, waveform: torch.Tensor, waveform_lengths: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
self.model.eval() |
|
|
with torch.no_grad(): |
|
|
z, z_mask = self.model.encode(waveform, waveform_lengths) |
|
|
z = self.proj(z.transpose(1, 2)) |
|
|
return {"output": z, "mask": z_mask} |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
import hydra |
|
|
from utils.config import generate_config_from_command_line_overrides |
|
|
model_config = generate_config_from_command_line_overrides( |
|
|
"configs/model/autoencoder/stable_vae.yaml" |
|
|
) |
|
|
autoencoder: StableVAE = hydra.utils.instantiate(model_config) |
|
|
autoencoder.eval() |
|
|
|
|
|
waveform, sr = torchaudio.load( |
|
|
"/hpc_stor03/sjtu_home/xuenan.xu/data/m4singer/Tenor-1#童话/0006.wav" |
|
|
) |
|
|
waveform = waveform.mean(0, keepdim=True) |
|
|
waveform = torchaudio.functional.resample( |
|
|
waveform, sr, model_config["sample_rate"] |
|
|
) |
|
|
print("waveform: ", waveform.shape) |
|
|
with torch.no_grad(): |
|
|
latent, latent_length = autoencoder.encode( |
|
|
waveform, torch.as_tensor([waveform.shape[-1]]) |
|
|
) |
|
|
print("latent: ", latent.shape) |
|
|
reconstructed = autoencoder.decode(latent) |
|
|
print("reconstructed: ", reconstructed.shape) |
|
|
import soundfile as sf |
|
|
sf.write( |
|
|
"./reconstructed.wav", |
|
|
reconstructed[0, 0].numpy(), |
|
|
samplerate=model_config["sample_rate"] |
|
|
) |
|
|
|