Spaces:
Runtime error
Runtime error
| # Copyright 2021 Tomoki Hayashi | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| """StyleMelGAN Modules.""" | |
| import copy | |
| import logging | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from parallel_wavegan.layers import PQMF | |
| from parallel_wavegan.layers import TADEResBlock | |
| from parallel_wavegan.models import MelGANDiscriminator as BaseDiscriminator | |
| from parallel_wavegan.utils import read_hdf5 | |
| class StyleMelGANGenerator(torch.nn.Module): | |
| """Style MelGAN generator module.""" | |
| def __init__( | |
| self, | |
| in_channels=128, | |
| aux_channels=80, | |
| channels=64, | |
| out_channels=1, | |
| kernel_size=9, | |
| dilation=2, | |
| bias=True, | |
| noise_upsample_scales=[11, 2, 2, 2], | |
| noise_upsample_activation="LeakyReLU", | |
| noise_upsample_activation_params={"negative_slope": 0.2}, | |
| upsample_scales=[2, 2, 2, 2, 2, 2, 2, 2, 1], | |
| upsample_mode="nearest", | |
| gated_function="softmax", | |
| use_weight_norm=True, | |
| ): | |
| """Initilize Style MelGAN generator. | |
| Args: | |
| in_channels (int): Number of input noise channels. | |
| aux_channels (int): Number of auxiliary input channels. | |
| channels (int): Number of channels for conv layer. | |
| out_channels (int): Number of output channels. | |
| kernel_size (int): Kernel size of conv layers. | |
| dilation (int): Dilation factor for conv layers. | |
| bias (bool): Whether to add bias parameter in convolution layers. | |
| noise_upsample_scales (list): List of noise upsampling scales. | |
| noise_upsample_activation (str): Activation function module name for noise upsampling. | |
| noise_upsample_activation_params (dict): Hyperparameters for the above activation function. | |
| upsample_scales (list): List of upsampling scales. | |
| upsample_mode (str): Upsampling mode in TADE layer. | |
| gated_function (str): Gated function in TADEResBlock ("softmax" or "sigmoid"). | |
| use_weight_norm (bool): Whether to use weight norm. | |
| If set to true, it will be applied to all of the conv layers. | |
| """ | |
| super().__init__() | |
| self.in_channels = in_channels | |
| noise_upsample = [] | |
| in_chs = in_channels | |
| for noise_upsample_scale in noise_upsample_scales: | |
| # NOTE(kan-bayashi): How should we design noise upsampling part? | |
| noise_upsample += [ | |
| torch.nn.ConvTranspose1d( | |
| in_chs, | |
| channels, | |
| noise_upsample_scale * 2, | |
| stride=noise_upsample_scale, | |
| padding=noise_upsample_scale // 2 + noise_upsample_scale % 2, | |
| output_padding=noise_upsample_scale % 2, | |
| bias=bias, | |
| ) | |
| ] | |
| noise_upsample += [ | |
| getattr(torch.nn, noise_upsample_activation)( | |
| **noise_upsample_activation_params | |
| ) | |
| ] | |
| in_chs = channels | |
| self.noise_upsample = torch.nn.Sequential(*noise_upsample) | |
| self.noise_upsample_factor = np.prod(noise_upsample_scales) | |
| self.blocks = torch.nn.ModuleList() | |
| aux_chs = aux_channels | |
| for upsample_scale in upsample_scales: | |
| self.blocks += [ | |
| TADEResBlock( | |
| in_channels=channels, | |
| aux_channels=aux_chs, | |
| kernel_size=kernel_size, | |
| dilation=dilation, | |
| bias=bias, | |
| upsample_factor=upsample_scale, | |
| upsample_mode=upsample_mode, | |
| gated_function=gated_function, | |
| ), | |
| ] | |
| aux_chs = channels | |
| self.upsample_factor = np.prod(upsample_scales) | |
| self.output_conv = torch.nn.Sequential( | |
| torch.nn.Conv1d( | |
| channels, | |
| out_channels, | |
| kernel_size, | |
| 1, | |
| bias=bias, | |
| padding=(kernel_size - 1) // 2, | |
| ), | |
| torch.nn.Tanh(), | |
| ) | |
| # apply weight norm | |
| if use_weight_norm: | |
| self.apply_weight_norm() | |
| # reset parameters | |
| self.reset_parameters() | |
| def forward(self, c, z=None): | |
| """Calculate forward propagation. | |
| Args: | |
| c (Tensor): Auxiliary input tensor (B, channels, T). | |
| z (Tensor): Input noise tensor (B, in_channels, 1). | |
| Returns: | |
| Tensor: Output tensor (B, out_channels, T ** prod(upsample_scales)). | |
| """ | |
| if z is None: | |
| z = torch.randn(c.size(0), self.in_channels, 1).to( | |
| device=c.device, | |
| dtype=c.dtype, | |
| ) | |
| x = self.noise_upsample(z) | |
| for block in self.blocks: | |
| x, c = block(x, c) | |
| x = self.output_conv(x) | |
| return x | |
| def remove_weight_norm(self): | |
| """Remove weight normalization module from all of the layers.""" | |
| def _remove_weight_norm(m): | |
| try: | |
| logging.debug(f"Weight norm is removed from {m}.") | |
| torch.nn.utils.remove_weight_norm(m) | |
| except ValueError: # this module didn't have weight norm | |
| return | |
| self.apply(_remove_weight_norm) | |
| def apply_weight_norm(self): | |
| """Apply weight normalization module from all of the layers.""" | |
| def _apply_weight_norm(m): | |
| if isinstance(m, torch.nn.Conv1d) or isinstance( | |
| m, torch.nn.ConvTranspose1d | |
| ): | |
| torch.nn.utils.weight_norm(m) | |
| logging.debug(f"Weight norm is applied to {m}.") | |
| self.apply(_apply_weight_norm) | |
| def reset_parameters(self): | |
| """Reset parameters.""" | |
| def _reset_parameters(m): | |
| if isinstance(m, torch.nn.Conv1d) or isinstance( | |
| m, torch.nn.ConvTranspose1d | |
| ): | |
| m.weight.data.normal_(0.0, 0.02) | |
| logging.debug(f"Reset parameters in {m}.") | |
| self.apply(_reset_parameters) | |
| def register_stats(self, stats): | |
| """Register stats for de-normalization as buffer. | |
| Args: | |
| stats (str): Path of statistics file (".npy" or ".h5"). | |
| """ | |
| assert stats.endswith(".h5") or stats.endswith(".npy") | |
| if stats.endswith(".h5"): | |
| mean = read_hdf5(stats, "mean").reshape(-1) | |
| scale = read_hdf5(stats, "scale").reshape(-1) | |
| else: | |
| mean = np.load(stats)[0].reshape(-1) | |
| scale = np.load(stats)[1].reshape(-1) | |
| self.register_buffer("mean", torch.from_numpy(mean).float()) | |
| self.register_buffer("scale", torch.from_numpy(scale).float()) | |
| logging.info("Successfully registered stats as buffer.") | |
| def inference(self, c, normalize_before=False): | |
| """Perform inference. | |
| Args: | |
| c (Union[Tensor, ndarray]): Input tensor (T, in_channels). | |
| normalize_before (bool): Whether to perform normalization. | |
| Returns: | |
| Tensor: Output tensor (T ** prod(upsample_scales), out_channels). | |
| """ | |
| if not isinstance(c, torch.Tensor): | |
| c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device) | |
| if normalize_before: | |
| c = (c - self.mean) / self.scale | |
| c = c.transpose(1, 0).unsqueeze(0) | |
| # prepare noise input | |
| noise_size = ( | |
| 1, | |
| self.in_channels, | |
| math.ceil(c.size(2) / self.noise_upsample_factor), | |
| ) | |
| noise = torch.randn(*noise_size, dtype=torch.float).to( | |
| next(self.parameters()).device | |
| ) | |
| x = self.noise_upsample(noise) | |
| # NOTE(kan-bayashi): To remove pop noise at the end of audio, perform padding | |
| # for feature sequence and after generation cut the generated audio. This | |
| # requires additional computation but it can prevent pop noise. | |
| total_length = c.size(2) * self.upsample_factor | |
| c = F.pad(c, (0, x.size(2) - c.size(2)), "replicate") | |
| # This version causes pop noise. | |
| # x = x[:, :, :c.size(2)] | |
| for block in self.blocks: | |
| x, c = block(x, c) | |
| x = self.output_conv(x)[..., :total_length] | |
| return x.squeeze(0).transpose(1, 0) | |
| class StyleMelGANDiscriminator(torch.nn.Module): | |
| """Style MelGAN disciminator module.""" | |
| def __init__( | |
| self, | |
| repeats=2, | |
| window_sizes=[512, 1024, 2048, 4096], | |
| pqmf_params=[ | |
| [1, None, None, None], | |
| [2, 62, 0.26700, 9.0], | |
| [4, 62, 0.14200, 9.0], | |
| [8, 62, 0.07949, 9.0], | |
| ], | |
| discriminator_params={ | |
| "out_channels": 1, | |
| "kernel_sizes": [5, 3], | |
| "channels": 16, | |
| "max_downsample_channels": 512, | |
| "bias": True, | |
| "downsample_scales": [4, 4, 4, 1], | |
| "nonlinear_activation": "LeakyReLU", | |
| "nonlinear_activation_params": {"negative_slope": 0.2}, | |
| "pad": "ReflectionPad1d", | |
| "pad_params": {}, | |
| }, | |
| use_weight_norm=True, | |
| ): | |
| """Initilize Style MelGAN discriminator. | |
| Args: | |
| repeats (int): Number of repititons to apply RWD. | |
| window_sizes (list): List of random window sizes. | |
| pqmf_params (list): List of list of Parameters for PQMF modules | |
| discriminator_params (dict): Parameters for base discriminator module. | |
| use_weight_nom (bool): Whether to apply weight normalization. | |
| """ | |
| super().__init__() | |
| # window size check | |
| assert len(window_sizes) == len(pqmf_params) | |
| sizes = [ws // p[0] for ws, p in zip(window_sizes, pqmf_params)] | |
| assert len(window_sizes) == sum([sizes[0] == size for size in sizes]) | |
| self.repeats = repeats | |
| self.window_sizes = window_sizes | |
| self.pqmfs = torch.nn.ModuleList() | |
| self.discriminators = torch.nn.ModuleList() | |
| for pqmf_param in pqmf_params: | |
| d_params = copy.deepcopy(discriminator_params) | |
| d_params["in_channels"] = pqmf_param[0] | |
| if pqmf_param[0] == 1: | |
| self.pqmfs += [torch.nn.Identity()] | |
| else: | |
| self.pqmfs += [PQMF(*pqmf_param)] | |
| self.discriminators += [BaseDiscriminator(**d_params)] | |
| # apply weight norm | |
| if use_weight_norm: | |
| self.apply_weight_norm() | |
| # reset parameters | |
| self.reset_parameters() | |
| def forward(self, x): | |
| """Calculate forward propagation. | |
| Args: | |
| x (Tensor): Input tensor (B, 1, T). | |
| Returns: | |
| List: List of discriminator outputs, #items in the list will be | |
| equal to repeats * #discriminators. | |
| """ | |
| outs = [] | |
| for _ in range(self.repeats): | |
| outs += self._forward(x) | |
| return outs | |
| def _forward(self, x): | |
| outs = [] | |
| for idx, (ws, pqmf, disc) in enumerate( | |
| zip(self.window_sizes, self.pqmfs, self.discriminators) | |
| ): | |
| # NOTE(kan-bayashi): Is it ok to apply different window for real and fake samples? | |
| start_idx = np.random.randint(x.size(-1) - ws) | |
| x_ = x[:, :, start_idx : start_idx + ws] | |
| if idx == 0: | |
| x_ = pqmf(x_) | |
| else: | |
| x_ = pqmf.analysis(x_) | |
| outs += [disc(x_)] | |
| return outs | |
| def apply_weight_norm(self): | |
| """Apply weight normalization module from all of the layers.""" | |
| def _apply_weight_norm(m): | |
| if isinstance(m, torch.nn.Conv1d) or isinstance( | |
| m, torch.nn.ConvTranspose1d | |
| ): | |
| torch.nn.utils.weight_norm(m) | |
| logging.debug(f"Weight norm is applied to {m}.") | |
| self.apply(_apply_weight_norm) | |
| def reset_parameters(self): | |
| """Reset parameters.""" | |
| def _reset_parameters(m): | |
| if isinstance(m, torch.nn.Conv1d) or isinstance( | |
| m, torch.nn.ConvTranspose1d | |
| ): | |
| m.weight.data.normal_(0.0, 0.02) | |
| logging.debug(f"Reset parameters in {m}.") | |
| self.apply(_reset_parameters) | |