Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2019 Tomoki Hayashi | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| import logging | |
| import numpy as np | |
| import pytest | |
| import torch | |
| from parallel_wavegan.losses import DiscriminatorAdversarialLoss | |
| from parallel_wavegan.losses import FeatureMatchLoss | |
| from parallel_wavegan.losses import GeneratorAdversarialLoss | |
| from parallel_wavegan.losses import MultiResolutionSTFTLoss | |
| from parallel_wavegan.models import MelGANGenerator | |
| from parallel_wavegan.models import MelGANMultiScaleDiscriminator | |
| from parallel_wavegan.models import ParallelWaveGANDiscriminator | |
| from parallel_wavegan.models import ResidualParallelWaveGANDiscriminator | |
| from parallel_wavegan.optimizers import RAdam | |
| from test_parallel_wavegan import make_discriminator_args | |
| from test_parallel_wavegan import make_mutli_reso_stft_loss_args | |
| from test_parallel_wavegan import make_residual_discriminator_args | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
| ) | |
| def make_melgan_generator_args(**kwargs): | |
| defaults = dict( | |
| in_channels=80, | |
| out_channels=1, | |
| kernel_size=7, | |
| channels=512, | |
| bias=True, | |
| upsample_scales=[8, 8, 2, 2], | |
| stack_kernel_size=3, | |
| stacks=3, | |
| nonlinear_activation="LeakyReLU", | |
| nonlinear_activation_params={"negative_slope": 0.2}, | |
| pad="ReflectionPad1d", | |
| pad_params={}, | |
| use_final_nonlinear_activation=True, | |
| use_weight_norm=True, | |
| use_causal_conv=False, | |
| ) | |
| defaults.update(kwargs) | |
| return defaults | |
| def make_melgan_discriminator_args(**kwargs): | |
| defaults = dict( | |
| in_channels=1, | |
| out_channels=1, | |
| scales=3, | |
| downsample_pooling="AvgPool1d", | |
| # follow the official implementation setting | |
| downsample_pooling_params={ | |
| "kernel_size": 4, | |
| "stride": 2, | |
| "padding": 1, | |
| "count_include_pad": False, | |
| }, | |
| kernel_sizes=[5, 3], | |
| channels=16, | |
| max_downsample_channels=1024, | |
| bias=True, | |
| downsample_scales=[4, 4, 4, 4], | |
| nonlinear_activation="LeakyReLU", | |
| nonlinear_activation_params={"negative_slope": 0.2}, | |
| pad="ReflectionPad1d", | |
| pad_params={}, | |
| use_weight_norm=True, | |
| ) | |
| defaults.update(kwargs) | |
| return defaults | |
| def test_melgan_trainable(dict_g, dict_d, dict_loss): | |
| # setup | |
| batch_size = 4 | |
| batch_length = 4096 | |
| args_g = make_melgan_generator_args(**dict_g) | |
| args_d = make_discriminator_args(**dict_d) | |
| args_loss = make_mutli_reso_stft_loss_args(**dict_loss) | |
| y = torch.randn(batch_size, 1, batch_length) | |
| c = torch.randn( | |
| batch_size, | |
| args_g["in_channels"], | |
| batch_length // np.prod(args_g["upsample_scales"]), | |
| ) | |
| model_g = MelGANGenerator(**args_g) | |
| model_d = ParallelWaveGANDiscriminator(**args_d) | |
| aux_criterion = MultiResolutionSTFTLoss(**args_loss) | |
| gen_adv_criterion = GeneratorAdversarialLoss() | |
| dis_adv_criterion = DiscriminatorAdversarialLoss() | |
| optimizer_g = RAdam(model_g.parameters()) | |
| optimizer_d = RAdam(model_d.parameters()) | |
| # check generator trainable | |
| y_hat = model_g(c) | |
| p_hat = model_d(y_hat) | |
| adv_loss = gen_adv_criterion(p_hat) | |
| sc_loss, mag_loss = aux_criterion(y_hat, y) | |
| aux_loss = sc_loss + mag_loss | |
| loss_g = adv_loss + aux_loss | |
| optimizer_g.zero_grad() | |
| loss_g.backward() | |
| optimizer_g.step() | |
| # check discriminator trainable | |
| p = model_d(y) | |
| p_hat = model_d(y_hat.detach()) | |
| real_loss, fake_loss = dis_adv_criterion(p_hat, p) | |
| loss_d = real_loss + fake_loss | |
| optimizer_d.zero_grad() | |
| loss_d.backward() | |
| optimizer_d.step() | |
| def test_melgan_trainable_with_residual_discriminator(dict_g, dict_d, dict_loss): | |
| # setup | |
| batch_size = 4 | |
| batch_length = 4096 | |
| args_g = make_melgan_generator_args(**dict_g) | |
| args_d = make_residual_discriminator_args(**dict_d) | |
| args_loss = make_mutli_reso_stft_loss_args(**dict_loss) | |
| y = torch.randn(batch_size, 1, batch_length) | |
| c = torch.randn( | |
| batch_size, | |
| args_g["in_channels"], | |
| batch_length // np.prod(args_g["upsample_scales"]), | |
| ) | |
| model_g = MelGANGenerator(**args_g) | |
| model_d = ResidualParallelWaveGANDiscriminator(**args_d) | |
| aux_criterion = MultiResolutionSTFTLoss(**args_loss) | |
| gen_adv_criterion = GeneratorAdversarialLoss() | |
| dis_adv_criterion = DiscriminatorAdversarialLoss() | |
| optimizer_g = RAdam(model_g.parameters()) | |
| optimizer_d = RAdam(model_d.parameters()) | |
| # check generator trainable | |
| y_hat = model_g(c) | |
| p_hat = model_d(y_hat) | |
| adv_loss = gen_adv_criterion(p_hat) | |
| sc_loss, mag_loss = aux_criterion(y_hat, y) | |
| aux_loss = sc_loss + mag_loss | |
| loss_g = adv_loss + aux_loss | |
| optimizer_g.zero_grad() | |
| loss_g.backward() | |
| optimizer_g.step() | |
| # check discriminator trainable | |
| p = model_d(y) | |
| p_hat = model_d(y_hat.detach()) | |
| real_loss, fake_loss = dis_adv_criterion(p_hat, p) | |
| loss_d = real_loss + fake_loss | |
| optimizer_d.zero_grad() | |
| loss_d.backward() | |
| optimizer_d.step() | |
| def test_melgan_trainable_with_melgan_discriminator(dict_g, dict_d, dict_loss): | |
| # setup | |
| batch_size = 4 | |
| batch_length = 4096 | |
| args_g = make_melgan_generator_args(**dict_g) | |
| args_d = make_melgan_discriminator_args(**dict_d) | |
| args_loss = make_mutli_reso_stft_loss_args(**dict_loss) | |
| y = torch.randn(batch_size, 1, batch_length) | |
| c = torch.randn( | |
| batch_size, | |
| args_g["in_channels"], | |
| batch_length // np.prod(args_g["upsample_scales"]), | |
| ) | |
| model_g = MelGANGenerator(**args_g) | |
| model_d = MelGANMultiScaleDiscriminator(**args_d) | |
| aux_criterion = MultiResolutionSTFTLoss(**args_loss) | |
| feat_match_criterion = FeatureMatchLoss() | |
| gen_adv_criterion = GeneratorAdversarialLoss() | |
| dis_adv_criterion = DiscriminatorAdversarialLoss() | |
| optimizer_g = RAdam(model_g.parameters()) | |
| optimizer_d = RAdam(model_d.parameters()) | |
| # check generator trainable | |
| y_hat = model_g(c) | |
| p_hat = model_d(y_hat) | |
| sc_loss, mag_loss = aux_criterion(y_hat, y) | |
| aux_loss = sc_loss + mag_loss | |
| adv_loss = gen_adv_criterion(p_hat) | |
| with torch.no_grad(): | |
| p = model_d(y) | |
| fm_loss = feat_match_criterion(p_hat, p) | |
| loss_g = adv_loss + aux_loss + fm_loss | |
| optimizer_g.zero_grad() | |
| loss_g.backward() | |
| optimizer_g.step() | |
| # check discriminator trainable | |
| p = model_d(y) | |
| p_hat = model_d(y_hat.detach()) | |
| real_loss, fake_loss = dis_adv_criterion(p_hat, p) | |
| loss_d = real_loss + fake_loss | |
| optimizer_d.zero_grad() | |
| loss_d.backward() | |
| optimizer_d.step() | |
| def test_causal_melgan(dict_g): | |
| batch_size = 4 | |
| batch_length = 4096 | |
| args_g = make_melgan_generator_args(**dict_g) | |
| upsampling_factor = np.prod(args_g["upsample_scales"]) | |
| c = torch.randn( | |
| batch_size, args_g["in_channels"], batch_length // upsampling_factor | |
| ) | |
| model_g = MelGANGenerator(**args_g) | |
| c_ = c.clone() | |
| c_[..., c.size(-1) // 2 :] = torch.randn(c[..., c.size(-1) // 2 :].shape) | |
| try: | |
| # check not equal | |
| np.testing.assert_array_equal(c.numpy(), c_.numpy()) | |
| except AssertionError: | |
| pass | |
| else: | |
| raise AssertionError("Must be different.") | |
| # check causality | |
| y = model_g(c) | |
| y_ = model_g(c_) | |
| assert y.size(2) == c.size(2) * upsampling_factor | |
| np.testing.assert_array_equal( | |
| y[..., : c.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(), | |
| y_[..., : c_.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(), | |
| ) | |