Spaces:
Runtime error
Runtime error
| import soundfile as sf | |
| import os | |
| from librosa.filters import mel as librosa_mel_fn | |
| import sys | |
| import tools.torch_tools as torch_tools | |
| import torch.nn as nn | |
| import torch | |
| import numpy as np | |
| from einops import rearrange | |
| from scipy.signal import get_window | |
| from librosa.util import pad_center, tiny | |
| import librosa.util as librosa_util | |
| class AttrDict(dict): | |
| def __init__(self, *args, **kwargs): | |
| super(AttrDict, self).__init__(*args, **kwargs) | |
| self.__dict__ = self | |
| def init_weights(m, mean=0.0, std=0.01): | |
| classname = m.__class__.__name__ | |
| if classname.find("Conv") != -1: | |
| m.weight.data.normal_(mean, std) | |
| def get_padding(kernel_size, dilation=1): | |
| return int((kernel_size * dilation - dilation) / 2) | |
| LRELU_SLOPE = 0.1 | |
| class ResBlock(torch.nn.Module): | |
| def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): | |
| super(ResBlock, self).__init__() | |
| self.h = h | |
| self.convs1 = nn.ModuleList( | |
| [ | |
| torch.nn.utils.weight_norm( | |
| nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=dilation[0], | |
| padding=get_padding(kernel_size, dilation[0]), | |
| ) | |
| ), | |
| torch.nn.utils.weight_norm( | |
| nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=dilation[1], | |
| padding=get_padding(kernel_size, dilation[1]), | |
| ) | |
| ), | |
| torch.nn.utils.weight_norm( | |
| nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=dilation[2], | |
| padding=get_padding(kernel_size, dilation[2]), | |
| ) | |
| ), | |
| ] | |
| ) | |
| self.convs1.apply(init_weights) | |
| self.convs2 = nn.ModuleList( | |
| [ | |
| torch.nn.utils.weight_norm( | |
| nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=1, | |
| padding=get_padding(kernel_size, 1), | |
| ) | |
| ), | |
| torch.nn.utils.weight_norm( | |
| nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=1, | |
| padding=get_padding(kernel_size, 1), | |
| ) | |
| ), | |
| torch.nn.utils.weight_norm( | |
| nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=1, | |
| padding=get_padding(kernel_size, 1), | |
| ) | |
| ), | |
| ] | |
| ) | |
| self.convs2.apply(init_weights) | |
| def forward(self, x): | |
| for c1, c2 in zip(self.convs1, self.convs2): | |
| xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) | |
| xt = c1(xt) | |
| xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) | |
| xt = c2(xt) | |
| x = xt + x | |
| return x | |
| def remove_weight_norm(self): | |
| for l in self.convs1: | |
| torch.nn.utils.remove_weight_norm(l) | |
| for l in self.convs2: | |
| torch.nn.utils.remove_weight_norm(l) | |
| class Generator_old(torch.nn.Module): | |
| def __init__(self, h): | |
| super(Generator_old, self).__init__() | |
| self.h = h | |
| self.num_kernels = len(h.resblock_kernel_sizes) | |
| self.num_upsamples = len(h.upsample_rates) | |
| self.conv_pre = torch.nn.utils.weight_norm( | |
| nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) | |
| ) | |
| resblock = ResBlock | |
| self.ups = nn.ModuleList() | |
| for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): | |
| self.ups.append( | |
| torch.nn.utils.weight_norm( | |
| nn.ConvTranspose1d( | |
| h.upsample_initial_channel // (2**i), | |
| h.upsample_initial_channel // (2 ** (i + 1)), | |
| k, | |
| u, | |
| padding=(k - u) // 2, | |
| ) | |
| ) | |
| ) | |
| self.resblocks = nn.ModuleList() | |
| for i in range(len(self.ups)): | |
| ch = h.upsample_initial_channel // (2 ** (i + 1)) | |
| for j, (k, d) in enumerate( | |
| zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) | |
| ): | |
| self.resblocks.append(resblock(h, ch, k, d)) | |
| self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3)) | |
| self.ups.apply(init_weights) | |
| self.conv_post.apply(init_weights) | |
| def forward(self, x): | |
| x = self.conv_pre(x) | |
| for i in range(self.num_upsamples): | |
| x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) | |
| x = self.ups[i](x) | |
| xs = None | |
| for j in range(self.num_kernels): | |
| if xs is None: | |
| xs = self.resblocks[i * self.num_kernels + j](x) | |
| else: | |
| xs += self.resblocks[i * self.num_kernels + j](x) | |
| x = xs / self.num_kernels | |
| x = torch.nn.functional.leaky_relu(x) | |
| x = self.conv_post(x) | |
| x = torch.tanh(x) | |
| return x | |
| def remove_weight_norm(self): | |
| # print("Removing weight norm...") | |
| for l in self.ups: | |
| torch.nn.utils.remove_weight_norm(l) | |
| for l in self.resblocks: | |
| l.remove_weight_norm() | |
| torch.nn.utils.remove_weight_norm(self.conv_pre) | |
| torch.nn.utils.remove_weight_norm(self.conv_post) | |
| def nonlinearity(x): | |
| # swish | |
| return x * torch.sigmoid(x) | |
| def Normalize(in_channels, num_groups=32): | |
| return torch.nn.GroupNorm( | |
| num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True | |
| ) | |
| class Downsample(nn.Module): | |
| def __init__(self, in_channels, with_conv): | |
| super().__init__() | |
| self.with_conv = with_conv | |
| if self.with_conv: | |
| # Do time downsampling here | |
| # no asymmetric padding in torch conv, must do it ourselves | |
| self.conv = torch.nn.Conv2d( | |
| in_channels, in_channels, kernel_size=3, stride=2, padding=0 | |
| ) | |
| def forward(self, x): | |
| if self.with_conv: | |
| pad = (0, 1, 0, 1) | |
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) | |
| x = self.conv(x) | |
| else: | |
| x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) | |
| return x | |
| class DownsampleTimeStride4(nn.Module): | |
| def __init__(self, in_channels, with_conv): | |
| super().__init__() | |
| self.with_conv = with_conv | |
| if self.with_conv: | |
| # Do time downsampling here | |
| # no asymmetric padding in torch conv, must do it ourselves | |
| self.conv = torch.nn.Conv2d( | |
| in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 | |
| ) | |
| def forward(self, x): | |
| if self.with_conv: | |
| pad = (0, 1, 0, 1) | |
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) | |
| x = self.conv(x) | |
| else: | |
| x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) | |
| return x | |
| class Upsample(nn.Module): | |
| def __init__(self, in_channels, with_conv): | |
| super().__init__() | |
| self.with_conv = with_conv | |
| if self.with_conv: | |
| self.conv = torch.nn.Conv2d( | |
| in_channels, in_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| def forward(self, x): | |
| x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") | |
| if self.with_conv: | |
| x = self.conv(x) | |
| return x | |
| class UpsampleTimeStride4(nn.Module): | |
| def __init__(self, in_channels, with_conv): | |
| super().__init__() | |
| self.with_conv = with_conv | |
| if self.with_conv: | |
| self.conv = torch.nn.Conv2d( | |
| in_channels, in_channels, kernel_size=5, stride=1, padding=2 | |
| ) | |
| def forward(self, x): | |
| x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") | |
| if self.with_conv: | |
| x = self.conv(x) | |
| return x | |
| class AttnBlock(nn.Module): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels) | |
| self.q = torch.nn.Conv2d( | |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.k = torch.nn.Conv2d( | |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.v = torch.nn.Conv2d( | |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.proj_out = torch.nn.Conv2d( | |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| def forward(self, x): | |
| h_ = x | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| # compute attention | |
| b, c, h, w = q.shape | |
| q = q.reshape(b, c, h * w).contiguous() | |
| q = q.permute(0, 2, 1).contiguous() # b,hw,c | |
| k = k.reshape(b, c, h * w).contiguous() # b,c,hw | |
| w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] | |
| w_ = w_ * (int(c) ** (-0.5)) | |
| w_ = torch.nn.functional.softmax(w_, dim=2) | |
| # attend to values | |
| v = v.reshape(b, c, h * w).contiguous() | |
| w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) | |
| h_ = torch.bmm( | |
| v, w_ | |
| ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] | |
| h_ = h_.reshape(b, c, h, w).contiguous() | |
| h_ = self.proj_out(h_) | |
| return x + h_ | |
| def make_attn(in_channels, attn_type="vanilla"): | |
| assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" | |
| # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") | |
| if attn_type == "vanilla": | |
| return AttnBlock(in_channels) | |
| elif attn_type == "none": | |
| return nn.Identity(in_channels) | |
| else: | |
| raise ValueError(attn_type) | |
| class ResnetBlock(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| in_channels, | |
| out_channels=None, | |
| conv_shortcut=False, | |
| dropout, | |
| temb_channels=512, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.norm1 = Normalize(in_channels) | |
| self.conv1 = torch.nn.Conv2d( | |
| in_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| if temb_channels > 0: | |
| self.temb_proj = torch.nn.Linear(temb_channels, out_channels) | |
| self.norm2 = Normalize(out_channels) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| self.conv2 = torch.nn.Conv2d( | |
| out_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| self.conv_shortcut = torch.nn.Conv2d( | |
| in_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| else: | |
| self.nin_shortcut = torch.nn.Conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| def forward(self, x, temb): | |
| h = x | |
| h = self.norm1(h) | |
| h = nonlinearity(h) | |
| h = self.conv1(h) | |
| if temb is not None: | |
| h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] | |
| h = self.norm2(h) | |
| h = nonlinearity(h) | |
| h = self.dropout(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| x = self.conv_shortcut(x) | |
| else: | |
| x = self.nin_shortcut(x) | |
| return x + h | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| ch, | |
| out_ch, | |
| ch_mult=(1, 2, 4, 8), | |
| num_res_blocks, | |
| attn_resolutions, | |
| dropout=0.0, | |
| resamp_with_conv=True, | |
| in_channels, | |
| resolution, | |
| z_channels, | |
| double_z=True, | |
| use_linear_attn=False, | |
| attn_type="vanilla", | |
| downsample_time_stride4_levels=[], | |
| **ignore_kwargs, | |
| ): | |
| super().__init__() | |
| if use_linear_attn: | |
| attn_type = "linear" | |
| self.ch = ch | |
| self.temb_ch = 0 | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.in_channels = in_channels | |
| self.downsample_time_stride4_levels = downsample_time_stride4_levels | |
| if len(self.downsample_time_stride4_levels) > 0: | |
| assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( | |
| "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" | |
| % str(self.num_resolutions) | |
| ) | |
| # downsampling | |
| self.conv_in = torch.nn.Conv2d( | |
| in_channels, self.ch, kernel_size=3, stride=1, padding=1 | |
| ) | |
| curr_res = resolution | |
| in_ch_mult = (1,) + tuple(ch_mult) | |
| self.in_ch_mult = in_ch_mult | |
| self.down = nn.ModuleList() | |
| for i_level in range(self.num_resolutions): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_in = ch * in_ch_mult[i_level] | |
| block_out = ch * ch_mult[i_level] | |
| for i_block in range(self.num_res_blocks): | |
| block.append( | |
| ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_out, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| ) | |
| block_in = block_out | |
| if curr_res in attn_resolutions: | |
| attn.append(make_attn(block_in, attn_type=attn_type)) | |
| down = nn.Module() | |
| down.block = block | |
| down.attn = attn | |
| if i_level != self.num_resolutions - 1: | |
| if i_level in self.downsample_time_stride4_levels: | |
| down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) | |
| else: | |
| down.downsample = Downsample(block_in, resamp_with_conv) | |
| curr_res = curr_res // 2 | |
| self.down.append(down) | |
| # middle | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) | |
| self.mid.block_2 = ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| # end | |
| self.norm_out = Normalize(block_in) | |
| self.conv_out = torch.nn.Conv2d( | |
| block_in, | |
| 2 * z_channels if double_z else z_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| ) | |
| def forward(self, x): | |
| # timestep embedding | |
| temb = None | |
| # downsampling | |
| hs = [self.conv_in(x)] | |
| for i_level in range(self.num_resolutions): | |
| for i_block in range(self.num_res_blocks): | |
| h = self.down[i_level].block[i_block](hs[-1], temb) | |
| if len(self.down[i_level].attn) > 0: | |
| h = self.down[i_level].attn[i_block](h) | |
| hs.append(h) | |
| if i_level != self.num_resolutions - 1: | |
| hs.append(self.down[i_level].downsample(hs[-1])) | |
| # middle | |
| h = hs[-1] | |
| h = self.mid.block_1(h, temb) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h, temb) | |
| # end | |
| h = self.norm_out(h) | |
| h = nonlinearity(h) | |
| h = self.conv_out(h) | |
| return h | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| ch, | |
| out_ch, | |
| ch_mult=(1, 2, 4, 8), | |
| num_res_blocks, | |
| attn_resolutions, | |
| dropout=0.0, | |
| resamp_with_conv=True, | |
| in_channels, | |
| resolution, | |
| z_channels, | |
| give_pre_end=False, | |
| tanh_out=False, | |
| use_linear_attn=False, | |
| downsample_time_stride4_levels=[], | |
| attn_type="vanilla", | |
| **ignorekwargs, | |
| ): | |
| super().__init__() | |
| if use_linear_attn: | |
| attn_type = "linear" | |
| self.ch = ch | |
| self.temb_ch = 0 | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.in_channels = in_channels | |
| self.give_pre_end = give_pre_end | |
| self.tanh_out = tanh_out | |
| self.downsample_time_stride4_levels = downsample_time_stride4_levels | |
| if len(self.downsample_time_stride4_levels) > 0: | |
| assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( | |
| "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" | |
| % str(self.num_resolutions) | |
| ) | |
| # compute in_ch_mult, block_in and curr_res at lowest res | |
| (1,) + tuple(ch_mult) | |
| block_in = ch * ch_mult[self.num_resolutions - 1] | |
| curr_res = resolution // 2 ** (self.num_resolutions - 1) | |
| self.z_shape = (1, z_channels, curr_res, curr_res) | |
| # print( | |
| # "Working with z of shape {} = {} dimensions.".format( | |
| # self.z_shape, np.prod(self.z_shape) | |
| # ) | |
| # ) | |
| # z to block_in | |
| self.conv_in = torch.nn.Conv2d( | |
| z_channels, block_in, kernel_size=3, stride=1, padding=1 | |
| ) | |
| # middle | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) | |
| self.mid.block_2 = ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| # upsampling | |
| self.up = nn.ModuleList() | |
| for i_level in reversed(range(self.num_resolutions)): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_out = ch * ch_mult[i_level] | |
| for i_block in range(self.num_res_blocks + 1): | |
| block.append( | |
| ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_out, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| ) | |
| block_in = block_out | |
| if curr_res in attn_resolutions: | |
| attn.append(make_attn(block_in, attn_type=attn_type)) | |
| up = nn.Module() | |
| up.block = block | |
| up.attn = attn | |
| if i_level != 0: | |
| if i_level - 1 in self.downsample_time_stride4_levels: | |
| up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) | |
| else: | |
| up.upsample = Upsample(block_in, resamp_with_conv) | |
| curr_res = curr_res * 2 | |
| self.up.insert(0, up) # prepend to get consistent order | |
| # end | |
| self.norm_out = Normalize(block_in) | |
| self.conv_out = torch.nn.Conv2d( | |
| block_in, out_ch, kernel_size=3, stride=1, padding=1 | |
| ) | |
| def forward(self, z): | |
| # assert z.shape[1:] == self.z_shape[1:] | |
| self.last_z_shape = z.shape | |
| # timestep embedding | |
| temb = None | |
| # z to block_in | |
| h = self.conv_in(z) | |
| # middle | |
| h = self.mid.block_1(h, temb) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h, temb) | |
| # upsampling | |
| for i_level in reversed(range(self.num_resolutions)): | |
| for i_block in range(self.num_res_blocks + 1): | |
| h = self.up[i_level].block[i_block](h, temb) | |
| if len(self.up[i_level].attn) > 0: | |
| h = self.up[i_level].attn[i_block](h) | |
| if i_level != 0: | |
| h = self.up[i_level].upsample(h) | |
| # end | |
| if self.give_pre_end: | |
| return h | |
| h = self.norm_out(h) | |
| h = nonlinearity(h) | |
| h = self.conv_out(h) | |
| if self.tanh_out: | |
| h = torch.tanh(h) | |
| return h | |
| class DiagonalGaussianDistribution(object): | |
| def __init__(self, parameters, deterministic=False): | |
| self.parameters = parameters | |
| self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) | |
| self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
| self.deterministic = deterministic | |
| self.std = torch.exp(0.5 * self.logvar) | |
| self.var = torch.exp(self.logvar) | |
| if self.deterministic: | |
| self.var = self.std = torch.zeros_like(self.mean).to( | |
| device=self.parameters.device | |
| ) | |
| def sample(self): | |
| x = self.mean + self.std * torch.randn(self.mean.shape).to( | |
| device=self.parameters.device | |
| ) | |
| return x | |
| def kl(self, other=None): | |
| if self.deterministic: | |
| return torch.Tensor([0.0]) | |
| else: | |
| if other is None: | |
| return 0.5 * torch.mean( | |
| torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, | |
| dim=[1, 2, 3], | |
| ) | |
| else: | |
| return 0.5 * torch.mean( | |
| torch.pow(self.mean - other.mean, 2) / other.var | |
| + self.var / other.var | |
| - 1.0 | |
| - self.logvar | |
| + other.logvar, | |
| dim=[1, 2, 3], | |
| ) | |
| def nll(self, sample, dims=[1, 2, 3]): | |
| if self.deterministic: | |
| return torch.Tensor([0.0]) | |
| logtwopi = np.log(2.0 * np.pi) | |
| return 0.5 * torch.sum( | |
| logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, | |
| dim=dims, | |
| ) | |
| def mode(self): | |
| return self.mean | |
| def get_vocoder_config_48k(): | |
| return { | |
| "resblock": "1", | |
| "num_gpus": 8, | |
| "batch_size": 128, | |
| "learning_rate": 0.0001, | |
| "adam_b1": 0.8, | |
| "adam_b2": 0.99, | |
| "lr_decay": 0.999, | |
| "seed": 1234, | |
| "upsample_rates": [6,5,4,2,2], | |
| "upsample_kernel_sizes": [12,10,8,4,4], | |
| "upsample_initial_channel": 1536, | |
| "resblock_kernel_sizes": [3,7,11,15], | |
| "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]], | |
| "segment_size": 15360, | |
| "num_mels": 256, | |
| "n_fft": 2048, | |
| "hop_size": 480, | |
| "win_size": 2048, | |
| "sampling_rate": 48000, | |
| "fmin": 20, | |
| "fmax": 24000, | |
| "fmax_for_loss": None, | |
| "num_workers": 8, | |
| "dist_config": { | |
| "dist_backend": "nccl", | |
| "dist_url": "tcp://localhost:18273", | |
| "world_size": 1 | |
| } | |
| } | |
| def get_vocoder(config, device, mel_bins): | |
| name = "HiFi-GAN" | |
| speaker = "" | |
| if name == "MelGAN": | |
| if speaker == "LJSpeech": | |
| vocoder = torch.hub.load( | |
| "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" | |
| ) | |
| elif speaker == "universal": | |
| vocoder = torch.hub.load( | |
| "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" | |
| ) | |
| vocoder.mel2wav.eval() | |
| vocoder.mel2wav.to(device) | |
| elif name == "HiFi-GAN": | |
| if(mel_bins == 256): | |
| config = get_vocoder_config_48k() | |
| config = AttrDict(config) | |
| vocoder = Generator_old(config) | |
| # print("Load hifigan/g_01080000") | |
| # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) | |
| # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) | |
| # ckpt = torch_version_orig_mod_remove(ckpt) | |
| # vocoder.load_state_dict(ckpt["generator"]) | |
| vocoder.eval() | |
| vocoder.remove_weight_norm() | |
| vocoder = vocoder.to(device) | |
| # vocoder = vocoder.half() | |
| else: | |
| raise ValueError(mel_bins) | |
| return vocoder | |
| def vocoder_infer(mels, vocoder, lengths=None): | |
| with torch.no_grad(): | |
| wavs = vocoder(mels).squeeze(1) | |
| #wavs = (wavs.cpu().numpy() * 32768).astype("int16") | |
| wavs = (wavs.cpu().numpy()) | |
| if lengths is not None: | |
| wavs = wavs[:, :lengths] | |
| # wavs = [wav for wav in wavs] | |
| # for i in range(len(mels)): | |
| # if lengths is not None: | |
| # wavs[i] = wavs[i][: lengths[i]] | |
| return wavs | |
| def vocoder_chunk_infer(mels, vocoder, lengths=None): | |
| chunk_size = 256*4 | |
| shift_size = 256*1 | |
| ov_size = chunk_size-shift_size | |
| # import pdb;pdb.set_trace() | |
| for cinx in range(0, mels.shape[2], shift_size): | |
| if(cinx==0): | |
| wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).float() | |
| num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size | |
| wavs = wavs[:,0:num_samples] | |
| ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size) | |
| ov_win = torch.linspace(0, 1, ov_sample, device="cuda").unsqueeze(0) | |
| ov_win = torch.cat([ov_win,1-ov_win],-1) | |
| if(cinx+chunk_size>=mels.shape[2]): | |
| break | |
| else: | |
| cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1)[:,0:num_samples].float() | |
| wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample] | |
| # wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * 1.0 + cur_wav[:,0:ov_sample] * 0.0 | |
| wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1) | |
| if(cinx+chunk_size>=mels.shape[2]): | |
| break | |
| # print(wavs.shape) | |
| wavs = (wavs.cpu().numpy()) | |
| if lengths is not None: | |
| wavs = wavs[:, :lengths] | |
| # print(wavs.shape) | |
| return wavs | |
| def synth_one_sample(mel_input, mel_prediction, labels, vocoder): | |
| if vocoder is not None: | |
| wav_reconstruction = vocoder_infer( | |
| mel_input.permute(0, 2, 1), | |
| vocoder, | |
| ) | |
| wav_prediction = vocoder_infer( | |
| mel_prediction.permute(0, 2, 1), | |
| vocoder, | |
| ) | |
| else: | |
| wav_reconstruction = wav_prediction = None | |
| return wav_reconstruction, wav_prediction | |
| class AutoencoderKL(nn.Module): | |
| def __init__( | |
| self, | |
| ddconfig=None, | |
| lossconfig=None, | |
| batchsize=None, | |
| embed_dim=None, | |
| time_shuffle=1, | |
| subband=1, | |
| sampling_rate=16000, | |
| ckpt_path=None, | |
| reload_from_ckpt=None, | |
| ignore_keys=[], | |
| image_key="fbank", | |
| colorize_nlabels=None, | |
| monitor=None, | |
| base_learning_rate=1e-5, | |
| scale_factor=1 | |
| ): | |
| super().__init__() | |
| self.automatic_optimization = False | |
| assert ( | |
| "mel_bins" in ddconfig.keys() | |
| ), "mel_bins is not specified in the Autoencoder config" | |
| num_mel = ddconfig["mel_bins"] | |
| self.image_key = image_key | |
| self.sampling_rate = sampling_rate | |
| self.encoder = Encoder(**ddconfig) | |
| self.decoder = Decoder(**ddconfig) | |
| self.loss = None | |
| self.subband = int(subband) | |
| if self.subband > 1: | |
| print("Use subband decomposition %s" % self.subband) | |
| assert ddconfig["double_z"] | |
| self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) | |
| self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) | |
| if self.image_key == "fbank": | |
| self.vocoder = get_vocoder(None, torch.device("cuda"), num_mel) | |
| self.embed_dim = embed_dim | |
| if colorize_nlabels is not None: | |
| assert type(colorize_nlabels) == int | |
| self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) | |
| if monitor is not None: | |
| self.monitor = monitor | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) | |
| self.learning_rate = float(base_learning_rate) | |
| # print("Initial learning rate %s" % self.learning_rate) | |
| self.time_shuffle = time_shuffle | |
| self.reload_from_ckpt = reload_from_ckpt | |
| self.reloaded = False | |
| self.mean, self.std = None, None | |
| self.feature_cache = None | |
| self.flag_first_run = True | |
| self.train_step = 0 | |
| self.logger_save_dir = None | |
| self.logger_exp_name = None | |
| self.scale_factor = scale_factor | |
| print("Num parameters:") | |
| print("Encoder : ", sum(p.numel() for p in self.encoder.parameters())) | |
| print("Decoder : ", sum(p.numel() for p in self.decoder.parameters())) | |
| print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters())) | |
| def get_log_dir(self): | |
| if self.logger_save_dir is None and self.logger_exp_name is None: | |
| return os.path.join(self.logger.save_dir, self.logger._project) | |
| else: | |
| return os.path.join(self.logger_save_dir, self.logger_exp_name) | |
| def set_log_dir(self, save_dir, exp_name): | |
| self.logger_save_dir = save_dir | |
| self.logger_exp_name = exp_name | |
| def init_from_ckpt(self, path, ignore_keys=list()): | |
| sd = torch.load(path, map_location="cpu")["state_dict"] | |
| keys = list(sd.keys()) | |
| for k in keys: | |
| for ik in ignore_keys: | |
| if k.startswith(ik): | |
| print("Deleting key {} from state_dict.".format(k)) | |
| del sd[k] | |
| self.load_state_dict(sd, strict=False) | |
| print(f"Restored from {path}") | |
| def encode(self, x): | |
| # x = self.time_shuffle_operation(x) | |
| # x = self.freq_split_subband(x) | |
| h = self.encoder(x) | |
| moments = self.quant_conv(h) | |
| posterior = DiagonalGaussianDistribution(moments) | |
| return posterior | |
| def decode(self, z): | |
| z = self.post_quant_conv(z) | |
| dec = self.decoder(z) | |
| # bs, ch, shuffled_timesteps, fbins = dec.size() | |
| # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) | |
| # dec = self.freq_merge_subband(dec) | |
| return dec | |
| def decode_to_waveform(self, dec): | |
| if self.image_key == "fbank": | |
| dec = dec.squeeze(1).permute(0, 2, 1) | |
| wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder) | |
| elif self.image_key == "stft": | |
| dec = dec.squeeze(1).permute(0, 2, 1) | |
| wav_reconstruction = self.wave_decoder(dec) | |
| return wav_reconstruction | |
| def mel_spectrogram_to_waveform( | |
| self, mel, savepath=".", bs=None, name="outwav", save=True | |
| ): | |
| # Mel: [bs, 1, t-steps, fbins] | |
| if len(mel.size()) == 4: | |
| mel = mel.squeeze(1) | |
| mel = mel.permute(0, 2, 1) | |
| waveform = self.vocoder(mel) | |
| waveform = waveform.cpu().detach().numpy() | |
| #if save: | |
| # self.save_waveform(waveform, savepath, name) | |
| return waveform | |
| def encode_first_stage(self, x): | |
| return self.encode(x) | |
| def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): | |
| if predict_cids: | |
| if z.dim() == 4: | |
| z = torch.argmax(z.exp(), dim=1).long() | |
| z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) | |
| z = rearrange(z, "b h w c -> b c h w").contiguous() | |
| z = 1.0 / self.scale_factor * z | |
| return self.decode(z) | |
| def decode_first_stage_withgrad(self, z): | |
| z = 1.0 / self.scale_factor * z | |
| return self.decode(z) | |
| def get_first_stage_encoding(self, encoder_posterior, use_mode=False): | |
| if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode: | |
| z = encoder_posterior.sample() | |
| elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode: | |
| z = encoder_posterior.mode() | |
| elif isinstance(encoder_posterior, torch.Tensor): | |
| z = encoder_posterior | |
| else: | |
| raise NotImplementedError( | |
| f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" | |
| ) | |
| return self.scale_factor * z | |
| def visualize_latent(self, input): | |
| import matplotlib.pyplot as plt | |
| # for i in range(10): | |
| # zero_input = torch.zeros_like(input) - 11.59 | |
| # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59 | |
| # posterior = self.encode(zero_input) | |
| # latent = posterior.sample() | |
| # avg_latent = torch.mean(latent, dim=1)[0] | |
| # plt.imshow(avg_latent.cpu().detach().numpy().T) | |
| # plt.savefig("%s.png" % i) | |
| # plt.close() | |
| np.save("input.npy", input.cpu().detach().numpy()) | |
| # zero_input = torch.zeros_like(input) - 11.59 | |
| time_input = input.clone() | |
| time_input[:, :, :, :32] *= 0 | |
| time_input[:, :, :, :32] -= 11.59 | |
| np.save("time_input.npy", time_input.cpu().detach().numpy()) | |
| posterior = self.encode(time_input) | |
| latent = posterior.sample() | |
| np.save("time_latent.npy", latent.cpu().detach().numpy()) | |
| avg_latent = torch.mean(latent, dim=1) | |
| for i in range(avg_latent.size(0)): | |
| plt.imshow(avg_latent[i].cpu().detach().numpy().T) | |
| plt.savefig("freq_%s.png" % i) | |
| plt.close() | |
| freq_input = input.clone() | |
| freq_input[:, :, :512, :] *= 0 | |
| freq_input[:, :, :512, :] -= 11.59 | |
| np.save("freq_input.npy", freq_input.cpu().detach().numpy()) | |
| posterior = self.encode(freq_input) | |
| latent = posterior.sample() | |
| np.save("freq_latent.npy", latent.cpu().detach().numpy()) | |
| avg_latent = torch.mean(latent, dim=1) | |
| for i in range(avg_latent.size(0)): | |
| plt.imshow(avg_latent[i].cpu().detach().numpy().T) | |
| plt.savefig("time_%s.png" % i) | |
| plt.close() | |
| def get_input(self, batch): | |
| fname, text, label_indices, waveform, stft, fbank = ( | |
| batch["fname"], | |
| batch["text"], | |
| batch["label_vector"], | |
| batch["waveform"], | |
| batch["stft"], | |
| batch["log_mel_spec"], | |
| ) | |
| # if(self.time_shuffle != 1): | |
| # if(fbank.size(1) % self.time_shuffle != 0): | |
| # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle) | |
| # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len)) | |
| ret = {} | |
| ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( | |
| fbank.unsqueeze(1), | |
| stft.unsqueeze(1), | |
| fname, | |
| waveform.unsqueeze(1), | |
| ) | |
| return ret | |
| def save_wave(self, batch_wav, fname, save_dir): | |
| os.makedirs(save_dir, exist_ok=True) | |
| for wav, name in zip(batch_wav, fname): | |
| name = os.path.basename(name) | |
| sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) | |
| def get_last_layer(self): | |
| return self.decoder.conv_out.weight | |
| def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): | |
| log = dict() | |
| x = batch.to(self.device) | |
| if not only_inputs: | |
| xrec, posterior = self(x) | |
| log["samples"] = self.decode(posterior.sample()) | |
| log["reconstructions"] = xrec | |
| log["inputs"] = x | |
| wavs = self._log_img(log, train=train, index=0, waveform=waveform) | |
| return wavs | |
| def _log_img(self, log, train=True, index=0, waveform=None): | |
| images_input = self.tensor2numpy(log["inputs"][index, 0]).T | |
| images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T | |
| images_samples = self.tensor2numpy(log["samples"][index, 0]).T | |
| if train: | |
| name = "train" | |
| else: | |
| name = "val" | |
| if self.logger is not None: | |
| self.logger.log_image( | |
| "img_%s" % name, | |
| [images_input, images_reconstruct, images_samples], | |
| caption=["input", "reconstruct", "samples"], | |
| ) | |
| inputs, reconstructions, samples = ( | |
| log["inputs"], | |
| log["reconstructions"], | |
| log["samples"], | |
| ) | |
| if self.image_key == "fbank": | |
| wav_original, wav_prediction = synth_one_sample( | |
| inputs[index], | |
| reconstructions[index], | |
| labels="validation", | |
| vocoder=self.vocoder, | |
| ) | |
| wav_original, wav_samples = synth_one_sample( | |
| inputs[index], samples[index], labels="validation", vocoder=self.vocoder | |
| ) | |
| wav_original, wav_samples, wav_prediction = ( | |
| wav_original[0], | |
| wav_samples[0], | |
| wav_prediction[0], | |
| ) | |
| elif self.image_key == "stft": | |
| wav_prediction = ( | |
| self.decode_to_waveform(reconstructions)[index, 0] | |
| .cpu() | |
| .detach() | |
| .numpy() | |
| ) | |
| wav_samples = ( | |
| self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() | |
| ) | |
| wav_original = waveform[index, 0].cpu().detach().numpy() | |
| if self.logger is not None: | |
| self.logger.experiment.log( | |
| { | |
| "original_%s" | |
| % name: wandb.Audio( | |
| wav_original, caption="original", sample_rate=self.sampling_rate | |
| ), | |
| "reconstruct_%s" | |
| % name: wandb.Audio( | |
| wav_prediction, | |
| caption="reconstruct", | |
| sample_rate=self.sampling_rate, | |
| ), | |
| "samples_%s" | |
| % name: wandb.Audio( | |
| wav_samples, caption="samples", sample_rate=self.sampling_rate | |
| ), | |
| } | |
| ) | |
| return wav_original, wav_prediction, wav_samples | |
| def tensor2numpy(self, tensor): | |
| return tensor.cpu().detach().numpy() | |
| def to_rgb(self, x): | |
| assert self.image_key == "segmentation" | |
| if not hasattr(self, "colorize"): | |
| self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) | |
| x = torch.nn.functional.conv2d(x, weight=self.colorize) | |
| x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 | |
| return x | |
| class IdentityFirstStage(torch.nn.Module): | |
| def __init__(self, *args, vq_interface=False, **kwargs): | |
| self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff | |
| super().__init__() | |
| def encode(self, x, *args, **kwargs): | |
| return x | |
| def decode(self, x, *args, **kwargs): | |
| return x | |
| def quantize(self, x, *args, **kwargs): | |
| if self.vq_interface: | |
| return x, None, [None, None, None] | |
| return x | |
| def forward(self, x, *args, **kwargs): | |
| return x | |
| def window_sumsquare( | |
| window, | |
| n_frames, | |
| hop_length, | |
| win_length, | |
| n_fft, | |
| dtype=np.float32, | |
| norm=None, | |
| ): | |
| """ | |
| # from librosa 0.6 | |
| Compute the sum-square envelope of a window function at a given hop length. | |
| This is used to estimate modulation effects induced by windowing | |
| observations in short-time fourier transforms. | |
| Parameters | |
| ---------- | |
| window : string, tuple, number, callable, or list-like | |
| Window specification, as in `get_window` | |
| n_frames : int > 0 | |
| The number of analysis frames | |
| hop_length : int > 0 | |
| The number of samples to advance between frames | |
| win_length : [optional] | |
| The length of the window function. By default, this matches `n_fft`. | |
| n_fft : int > 0 | |
| The length of each analysis frame. | |
| dtype : np.dtype | |
| The data type of the output | |
| Returns | |
| ------- | |
| wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` | |
| The sum-squared envelope of the window function | |
| """ | |
| if win_length is None: | |
| win_length = n_fft | |
| n = n_fft + hop_length * (n_frames - 1) | |
| x = np.zeros(n, dtype=dtype) | |
| # Compute the squared window at the desired length | |
| win_sq = get_window(window, win_length, fftbins=True) | |
| win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 | |
| win_sq = librosa_util.pad_center(win_sq, n_fft) | |
| # Fill the envelope | |
| for i in range(n_frames): | |
| sample = i * hop_length | |
| x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] | |
| return x | |
| def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): | |
| """ | |
| PARAMS | |
| ------ | |
| C: compression factor | |
| """ | |
| return normalize_fun(torch.clamp(x, min=clip_val) * C) | |
| def dynamic_range_decompression(x, C=1): | |
| """ | |
| PARAMS | |
| ------ | |
| C: compression factor used to compress | |
| """ | |
| return torch.exp(x) / C | |
| class STFT(torch.nn.Module): | |
| """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" | |
| def __init__(self, filter_length, hop_length, win_length, window="hann"): | |
| super(STFT, self).__init__() | |
| self.filter_length = filter_length | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| self.window = window | |
| self.forward_transform = None | |
| scale = self.filter_length / self.hop_length | |
| fourier_basis = np.fft.fft(np.eye(self.filter_length)) | |
| cutoff = int((self.filter_length / 2 + 1)) | |
| fourier_basis = np.vstack( | |
| [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] | |
| ) | |
| forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) | |
| inverse_basis = torch.FloatTensor( | |
| np.linalg.pinv(scale * fourier_basis).T[:, None, :] | |
| ) | |
| if window is not None: | |
| assert filter_length >= win_length | |
| # get window and zero center pad it to filter_length | |
| fft_window = get_window(window, win_length, fftbins=True) | |
| fft_window = pad_center(fft_window, size=filter_length) | |
| fft_window = torch.from_numpy(fft_window).float() | |
| # window the bases | |
| forward_basis *= fft_window | |
| inverse_basis *= fft_window | |
| self.register_buffer("forward_basis", forward_basis.float()) | |
| self.register_buffer("inverse_basis", inverse_basis.float()) | |
| def transform(self, input_data): | |
| device = self.forward_basis.device | |
| input_data = input_data.to(device) | |
| num_batches = input_data.size(0) | |
| num_samples = input_data.size(1) | |
| self.num_samples = num_samples | |
| # similar to librosa, reflect-pad the input | |
| input_data = input_data.view(num_batches, 1, num_samples) | |
| input_data = torch.nn.functional.pad( | |
| input_data.unsqueeze(1), | |
| (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), | |
| mode="reflect", | |
| ) | |
| input_data = input_data.squeeze(1) | |
| forward_transform = torch.nn.functional.conv1d( | |
| input_data, | |
| torch.autograd.Variable(self.forward_basis, requires_grad=False), | |
| stride=self.hop_length, | |
| padding=0, | |
| )#.cpu() | |
| cutoff = int((self.filter_length / 2) + 1) | |
| real_part = forward_transform[:, :cutoff, :] | |
| imag_part = forward_transform[:, cutoff:, :] | |
| magnitude = torch.sqrt(real_part**2 + imag_part**2) | |
| phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) | |
| return magnitude, phase | |
| def inverse(self, magnitude, phase): | |
| device = self.forward_basis.device | |
| magnitude, phase = magnitude.to(device), phase.to(device) | |
| recombine_magnitude_phase = torch.cat( | |
| [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 | |
| ) | |
| inverse_transform = torch.nn.functional.conv_transpose1d( | |
| recombine_magnitude_phase, | |
| torch.autograd.Variable(self.inverse_basis, requires_grad=False), | |
| stride=self.hop_length, | |
| padding=0, | |
| ) | |
| if self.window is not None: | |
| window_sum = window_sumsquare( | |
| self.window, | |
| magnitude.size(-1), | |
| hop_length=self.hop_length, | |
| win_length=self.win_length, | |
| n_fft=self.filter_length, | |
| dtype=np.float32, | |
| ) | |
| # remove modulation effects | |
| approx_nonzero_indices = torch.from_numpy( | |
| np.where(window_sum > tiny(window_sum))[0] | |
| ) | |
| window_sum = torch.autograd.Variable( | |
| torch.from_numpy(window_sum), requires_grad=False | |
| ) | |
| window_sum = window_sum | |
| inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ | |
| approx_nonzero_indices | |
| ] | |
| # scale by hop ratio | |
| inverse_transform *= float(self.filter_length) / self.hop_length | |
| inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] | |
| inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] | |
| return inverse_transform | |
| def forward(self, input_data): | |
| self.magnitude, self.phase = self.transform(input_data) | |
| reconstruction = self.inverse(self.magnitude, self.phase) | |
| return reconstruction | |
| class TacotronSTFT(torch.nn.Module): | |
| def __init__( | |
| self, | |
| filter_length, | |
| hop_length, | |
| win_length, | |
| n_mel_channels, | |
| sampling_rate, | |
| mel_fmin, | |
| mel_fmax, | |
| ): | |
| super(TacotronSTFT, self).__init__() | |
| self.n_mel_channels = n_mel_channels | |
| self.sampling_rate = sampling_rate | |
| self.stft_fn = STFT(filter_length, hop_length, win_length) | |
| mel_basis = librosa_mel_fn( | |
| sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax | |
| ) | |
| mel_basis = torch.from_numpy(mel_basis).float() | |
| self.register_buffer("mel_basis", mel_basis) | |
| def spectral_normalize(self, magnitudes, normalize_fun): | |
| output = dynamic_range_compression(magnitudes, normalize_fun) | |
| return output | |
| def spectral_de_normalize(self, magnitudes): | |
| output = dynamic_range_decompression(magnitudes) | |
| return output | |
| def mel_spectrogram(self, y, normalize_fun=torch.log): | |
| """Computes mel-spectrograms from a batch of waves | |
| PARAMS | |
| ------ | |
| y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] | |
| RETURNS | |
| ------- | |
| mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) | |
| """ | |
| assert torch.min(y.data) >= -1, torch.min(y.data) | |
| assert torch.max(y.data) <= 1, torch.max(y.data) | |
| magnitudes, phases = self.stft_fn.transform(y) | |
| magnitudes = magnitudes.data | |
| mel_output = torch.matmul(self.mel_basis, magnitudes) | |
| mel_output = self.spectral_normalize(mel_output, normalize_fun) | |
| energy = torch.norm(magnitudes, dim=1) | |
| log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) | |
| return mel_output, log_magnitudes, energy | |
| def build_pretrained_models(ckpt): | |
| checkpoint = torch.load(ckpt, map_location="cpu") | |
| scale_factor = checkpoint["state_dict"]["scale_factor"].item() | |
| print("scale_factor: ", scale_factor) | |
| vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k} | |
| config = { | |
| "preprocessing": { | |
| "audio": { | |
| "sampling_rate": 48000, | |
| "max_wav_value": 32768, | |
| "duration": 10.24 | |
| }, | |
| "stft": { | |
| "filter_length": 2048, | |
| "hop_length": 480, | |
| "win_length": 2048 | |
| }, | |
| "mel": { | |
| "n_mel_channels": 256, | |
| "mel_fmin": 20, | |
| "mel_fmax": 24000 | |
| } | |
| }, | |
| "model": { | |
| "params": { | |
| "first_stage_config": { | |
| "params": { | |
| "sampling_rate": 48000, | |
| "batchsize": 4, | |
| "monitor": "val/rec_loss", | |
| "image_key": "fbank", | |
| "subband": 1, | |
| "embed_dim": 16, | |
| "time_shuffle": 1, | |
| "lossconfig": { | |
| "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", | |
| "params": { | |
| "disc_start": 50001, | |
| "kl_weight": 1000, | |
| "disc_weight": 0.5, | |
| "disc_in_channels": 1 | |
| } | |
| }, | |
| "ddconfig": { | |
| "double_z": True, | |
| "mel_bins": 256, | |
| "z_channels": 16, | |
| "resolution": 256, | |
| "downsample_time": False, | |
| "in_channels": 1, | |
| "out_ch": 1, | |
| "ch": 128, | |
| "ch_mult": [ | |
| 1, | |
| 2, | |
| 4, | |
| 8 | |
| ], | |
| "num_res_blocks": 2, | |
| "attn_resolutions": [], | |
| "dropout": 0 | |
| } | |
| } | |
| }, | |
| } | |
| } | |
| } | |
| vae_config = config["model"]["params"]["first_stage_config"]["params"] | |
| vae_config["scale_factor"] = scale_factor | |
| vae = AutoencoderKL(**vae_config) | |
| vae.load_state_dict(vae_state_dict) | |
| fn_STFT = TacotronSTFT( | |
| config["preprocessing"]["stft"]["filter_length"], | |
| config["preprocessing"]["stft"]["hop_length"], | |
| config["preprocessing"]["stft"]["win_length"], | |
| config["preprocessing"]["mel"]["n_mel_channels"], | |
| config["preprocessing"]["audio"]["sampling_rate"], | |
| config["preprocessing"]["mel"]["mel_fmin"], | |
| config["preprocessing"]["mel"]["mel_fmax"], | |
| ) | |
| vae.eval() | |
| fn_STFT.eval() | |
| return vae, fn_STFT | |
| if __name__=="__main__": | |
| vae, stft = build_pretrained_models() | |
| vae, stft = vae.cuda(), stft.cuda() | |
| json_file="outputs/wav.scp" | |
| out_path="outputs/Music_inverse" | |
| wavform = torch.randn(2,int(48000*10.24)) | |
| mel, _, waveform = torch_tools.wav_to_fbank2(wavform, target_length=-1, fn_STFT=stft) | |
| mel = mel.unsqueeze(1).cuda() | |
| print(mel.shape) | |
| # true_latent = torch.cat([vae.get_first_stage_encoding(vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0) | |
| # print(true_latent.shape) | |
| true_latent = vae.get_first_stage_encoding(vae.encode_first_stage(mel)) | |
| print(true_latent.shape) | |
| true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach() | |
| true_latent = true_latent.reshape(true_latent.shape[0]*2,-1,true_latent.shape[2],true_latent.shape[3]) | |
| print("111", true_latent.size()) | |
| mel = vae.decode_first_stage(true_latent) | |
| print("222", mel.size()) | |
| audio = vae.decode_to_waveform(mel) | |
| print("333", audio.shape) | |
| # out_file = out_path + "/" + os.path.basename(fname.strip()) | |
| # sf.write(out_file, audio[0], samplerate=48000) | |