Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2019 Tomoki Hayashi | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| import logging | |
| import numpy as np | |
| import pytest | |
| import torch | |
| from parallel_wavegan.losses import DiscriminatorAdversarialLoss | |
| from parallel_wavegan.losses import GeneratorAdversarialLoss | |
| from parallel_wavegan.losses import MultiResolutionSTFTLoss | |
| from parallel_wavegan.models import ParallelWaveGANDiscriminator | |
| from parallel_wavegan.models import ParallelWaveGANGenerator | |
| from parallel_wavegan.models import ResidualParallelWaveGANDiscriminator | |
| from parallel_wavegan.optimizers import RAdam | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
| ) | |
| def make_generator_args(**kwargs): | |
| defaults = dict( | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_size=3, | |
| layers=6, | |
| stacks=3, | |
| residual_channels=8, | |
| gate_channels=16, | |
| skip_channels=8, | |
| aux_channels=10, | |
| aux_context_window=0, | |
| dropout=1 - 0.95, | |
| use_weight_norm=True, | |
| use_causal_conv=False, | |
| upsample_conditional_features=True, | |
| upsample_net="ConvInUpsampleNetwork", | |
| upsample_params={"upsample_scales": [4, 4]}, | |
| ) | |
| defaults.update(kwargs) | |
| return defaults | |
| def make_discriminator_args(**kwargs): | |
| defaults = dict( | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_size=3, | |
| layers=5, | |
| conv_channels=16, | |
| nonlinear_activation="LeakyReLU", | |
| nonlinear_activation_params={"negative_slope": 0.2}, | |
| bias=True, | |
| use_weight_norm=True, | |
| ) | |
| defaults.update(kwargs) | |
| return defaults | |
| def make_residual_discriminator_args(**kwargs): | |
| defaults = dict( | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_size=3, | |
| layers=10, | |
| stacks=1, | |
| residual_channels=8, | |
| gate_channels=16, | |
| skip_channels=8, | |
| dropout=0.0, | |
| use_weight_norm=True, | |
| use_causal_conv=False, | |
| nonlinear_activation_params={"negative_slope": 0.2}, | |
| ) | |
| defaults.update(kwargs) | |
| return defaults | |
| def make_mutli_reso_stft_loss_args(**kwargs): | |
| defaults = dict( | |
| fft_sizes=[64, 128, 256], | |
| hop_sizes=[32, 64, 128], | |
| win_lengths=[48, 96, 192], | |
| window="hann_window", | |
| ) | |
| defaults.update(kwargs) | |
| return defaults | |
| def test_parallel_wavegan_trainable(dict_g, dict_d, dict_loss): | |
| # setup | |
| batch_size = 4 | |
| batch_length = 4096 | |
| args_g = make_generator_args(**dict_g) | |
| args_d = make_discriminator_args(**dict_d) | |
| args_loss = make_mutli_reso_stft_loss_args(**dict_loss) | |
| z = torch.randn(batch_size, 1, batch_length) | |
| y = torch.randn(batch_size, 1, batch_length) | |
| c = torch.randn( | |
| batch_size, | |
| args_g["aux_channels"], | |
| batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) | |
| + 2 * args_g["aux_context_window"], | |
| ) | |
| model_g = ParallelWaveGANGenerator(**args_g) | |
| model_d = ParallelWaveGANDiscriminator(**args_d) | |
| aux_criterion = MultiResolutionSTFTLoss(**args_loss) | |
| gen_adv_criterion = GeneratorAdversarialLoss() | |
| dis_adv_criterion = DiscriminatorAdversarialLoss() | |
| optimizer_g = RAdam(model_g.parameters()) | |
| optimizer_d = RAdam(model_d.parameters()) | |
| # check generator trainable | |
| y_hat = model_g(z, c) | |
| p_hat = model_d(y_hat) | |
| adv_loss = gen_adv_criterion(p_hat) | |
| sc_loss, mag_loss = aux_criterion(y_hat, y) | |
| aux_loss = sc_loss + mag_loss | |
| loss_g = adv_loss + aux_loss | |
| optimizer_g.zero_grad() | |
| loss_g.backward() | |
| optimizer_g.step() | |
| # check discriminator trainable | |
| p = model_d(y) | |
| p_hat = model_d(y_hat.detach()) | |
| real_loss, fake_loss = dis_adv_criterion(p_hat, p) | |
| loss_d = real_loss + fake_loss | |
| optimizer_d.zero_grad() | |
| loss_d.backward() | |
| optimizer_d.step() | |
| def test_parallel_wavegan_with_residual_discriminator_trainable( | |
| dict_g, dict_d, dict_loss | |
| ): | |
| # setup | |
| batch_size = 4 | |
| batch_length = 4096 | |
| args_g = make_generator_args(**dict_g) | |
| args_d = make_residual_discriminator_args(**dict_d) | |
| args_loss = make_mutli_reso_stft_loss_args(**dict_loss) | |
| z = torch.randn(batch_size, 1, batch_length) | |
| y = torch.randn(batch_size, 1, batch_length) | |
| c = torch.randn( | |
| batch_size, | |
| args_g["aux_channels"], | |
| batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) | |
| + 2 * args_g["aux_context_window"], | |
| ) | |
| model_g = ParallelWaveGANGenerator(**args_g) | |
| model_d = ResidualParallelWaveGANDiscriminator(**args_d) | |
| aux_criterion = MultiResolutionSTFTLoss(**args_loss) | |
| gen_adv_criterion = GeneratorAdversarialLoss() | |
| dis_adv_criterion = DiscriminatorAdversarialLoss() | |
| optimizer_g = RAdam(model_g.parameters()) | |
| optimizer_d = RAdam(model_d.parameters()) | |
| # check generator trainable | |
| y_hat = model_g(z, c) | |
| p_hat = model_d(y_hat) | |
| adv_loss = gen_adv_criterion(p_hat) | |
| sc_loss, mag_loss = aux_criterion(y_hat, y) | |
| aux_loss = sc_loss + mag_loss | |
| loss_g = adv_loss + aux_loss | |
| optimizer_g.zero_grad() | |
| loss_g.backward() | |
| optimizer_g.step() | |
| # check discriminator trainable | |
| p = model_d(y) | |
| p_hat = model_d(y_hat.detach()) | |
| real_loss, fake_loss = dis_adv_criterion(p_hat, p) | |
| loss_d = real_loss + fake_loss | |
| optimizer_d.zero_grad() | |
| loss_d.backward() | |
| optimizer_d.step() | |
| def test_causal_parallel_wavegan(upsample_net, aux_context_window): | |
| batch_size = 1 | |
| batch_length = 4096 | |
| args_g = make_generator_args( | |
| use_causal_conv=True, | |
| upsample_net=upsample_net, | |
| aux_context_window=aux_context_window, | |
| dropout=0.0, | |
| ) | |
| model_g = ParallelWaveGANGenerator(**args_g) | |
| z = torch.randn(batch_size, 1, batch_length) | |
| c = torch.randn( | |
| batch_size, | |
| args_g["aux_channels"], | |
| batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]), | |
| ) | |
| z_ = z.clone() | |
| c_ = c.clone() | |
| z_[..., z.size(-1) // 2 :] = torch.randn(z[..., z.size(-1) // 2 :].shape) | |
| c_[..., c.size(-1) // 2 :] = torch.randn(c[..., c.size(-1) // 2 :].shape) | |
| c = torch.nn.ConstantPad1d(args_g["aux_context_window"], 0.0)(c) | |
| c_ = torch.nn.ConstantPad1d(args_g["aux_context_window"], 0.0)(c_) | |
| try: | |
| # check not equal | |
| np.testing.assert_array_equal(c.numpy(), c_.numpy()) | |
| except AssertionError: | |
| pass | |
| else: | |
| raise AssertionError("Must be different.") | |
| try: | |
| # check not equal | |
| np.testing.assert_array_equal(z.numpy(), z_.numpy()) | |
| except AssertionError: | |
| pass | |
| else: | |
| raise AssertionError("Must be different.") | |
| # check causality | |
| y = model_g(z, c) | |
| y_ = model_g(z_, c_) | |
| np.testing.assert_array_equal( | |
| y[..., : y.size(-1) // 2].detach().cpu().numpy(), | |
| y_[..., : y_.size(-1) // 2].detach().cpu().numpy(), | |
| ) | |