Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright 2021 Tomoki Hayashi | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| """Test code for StyleMelGAN modules.""" | |
| import logging | |
| import numpy as np | |
| import pytest | |
| import torch | |
| from parallel_wavegan.losses import DiscriminatorAdversarialLoss | |
| from parallel_wavegan.losses import GeneratorAdversarialLoss | |
| from parallel_wavegan.losses import MultiResolutionSTFTLoss | |
| from parallel_wavegan.models import StyleMelGANDiscriminator | |
| from parallel_wavegan.models import StyleMelGANGenerator | |
| from test_parallel_wavegan import make_mutli_reso_stft_loss_args | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
| ) | |
| def make_style_melgan_generator_args(**kwargs): | |
| defaults = dict( | |
| in_channels=128, | |
| aux_channels=80, | |
| channels=64, | |
| out_channels=1, | |
| kernel_size=9, | |
| dilation=2, | |
| bias=True, | |
| noise_upsample_scales=[11, 2, 2, 2], | |
| noise_upsample_activation="LeakyReLU", | |
| noise_upsample_activation_params={"negative_slope": 0.2}, | |
| upsample_scales=[2, 2, 2, 2, 2, 2, 2, 2, 1], | |
| upsample_mode="nearest", | |
| gated_function="softmax", | |
| use_weight_norm=True, | |
| ) | |
| defaults.update(kwargs) | |
| return defaults | |
| def make_style_melgan_discriminator_args(**kwargs): | |
| defaults = dict( | |
| repeats=2, | |
| window_sizes=[512, 1024, 2048, 4096], | |
| pqmf_params=[ | |
| [1, None, None, None], | |
| [2, 62, 0.26700, 9.0], | |
| [4, 62, 0.14200, 9.0], | |
| [8, 62, 0.07949, 9.0], | |
| ], | |
| discriminator_params={ | |
| "out_channels": 1, | |
| "kernel_sizes": [5, 3], | |
| "channels": 16, | |
| "max_downsample_channels": 32, | |
| "bias": True, | |
| "downsample_scales": [4, 4, 4, 1], | |
| "nonlinear_activation": "LeakyReLU", | |
| "nonlinear_activation_params": {"negative_slope": 0.2}, | |
| "pad": "ReflectionPad1d", | |
| "pad_params": {}, | |
| }, | |
| use_weight_norm=True, | |
| ) | |
| defaults.update(kwargs) | |
| return defaults | |
| def test_style_melgan_discriminator(dict_d): | |
| batch_size = 4 | |
| batch_length = 2 ** 14 | |
| args_d = make_style_melgan_discriminator_args(**dict_d) | |
| y = torch.randn(batch_size, 1, batch_length) | |
| model_d = StyleMelGANDiscriminator(**args_d) | |
| gen_adv_criterion = GeneratorAdversarialLoss() | |
| outs = model_d(y) | |
| gen_adv_criterion(outs) | |
| def test_style_melgan_generator(dict_g): | |
| args_g = make_style_melgan_generator_args(**dict_g) | |
| batch_size = 4 | |
| batch_length = np.prod(args_g["noise_upsample_scales"]) * np.prod( | |
| args_g["upsample_scales"] | |
| ) | |
| z = torch.randn(batch_size, args_g["in_channels"], 1) | |
| c = torch.randn( | |
| batch_size, | |
| args_g["aux_channels"], | |
| batch_length // np.prod(args_g["upsample_scales"]), | |
| ) | |
| model_g = StyleMelGANGenerator(**args_g) | |
| model_g(c, z) | |
| # inference | |
| c = torch.randn( | |
| 512, | |
| args_g["aux_channels"], | |
| ) | |
| y = model_g.inference(c) | |
| print(y.shape) | |
| def test_style_melgan_trainable(dict_g, dict_d, dict_loss, loss_type): | |
| # setup | |
| args_g = make_style_melgan_generator_args(**dict_g) | |
| args_d = make_style_melgan_discriminator_args(**dict_d) | |
| args_loss = make_mutli_reso_stft_loss_args(**dict_loss) | |
| batch_size = 4 | |
| batch_length = np.prod(args_g["noise_upsample_scales"]) * np.prod( | |
| args_g["upsample_scales"] | |
| ) | |
| y = torch.randn(batch_size, 1, batch_length) | |
| c = torch.randn( | |
| batch_size, | |
| args_g["aux_channels"], | |
| batch_length // np.prod(args_g["upsample_scales"]), | |
| ) | |
| model_g = StyleMelGANGenerator(**args_g) | |
| model_d = StyleMelGANDiscriminator(**args_d) | |
| aux_criterion = MultiResolutionSTFTLoss(**args_loss) | |
| gen_adv_criterion = GeneratorAdversarialLoss(loss_type=loss_type) | |
| dis_adv_criterion = DiscriminatorAdversarialLoss(loss_type=loss_type) | |
| optimizer_g = torch.optim.Adam(model_g.parameters()) | |
| optimizer_d = torch.optim.Adam(model_d.parameters()) | |
| # check generator trainable | |
| y_hat = model_g(c) | |
| p_hat = model_d(y_hat) | |
| adv_loss = gen_adv_criterion(p_hat) | |
| sc_loss, mag_loss = aux_criterion(y_hat, y) | |
| aux_loss = sc_loss + mag_loss | |
| loss_g = adv_loss + aux_loss | |
| optimizer_g.zero_grad() | |
| loss_g.backward() | |
| optimizer_g.step() | |
| # check discriminator trainable | |
| p = model_d(y) | |
| p_hat = model_d(y_hat.detach()) | |
| real_loss, fake_loss = dis_adv_criterion(p_hat, p) | |
| loss_d = real_loss + fake_loss | |
| optimizer_d.zero_grad() | |
| loss_d.backward() | |
| optimizer_d.step() | |