Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright 2021 Tomoki Hayashi | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| """Test code for HiFi-GAN modules.""" | |
| import logging | |
| import numpy as np | |
| import pytest | |
| import torch | |
| from parallel_wavegan.losses import DiscriminatorAdversarialLoss | |
| from parallel_wavegan.losses import FeatureMatchLoss | |
| from parallel_wavegan.losses import GeneratorAdversarialLoss | |
| from parallel_wavegan.losses import MultiResolutionSTFTLoss | |
| from parallel_wavegan.models import HiFiGANGenerator | |
| from parallel_wavegan.models import HiFiGANMultiScaleMultiPeriodDiscriminator | |
| from test_parallel_wavegan import make_mutli_reso_stft_loss_args | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
| ) | |
| def make_hifigan_generator_args(**kwargs): | |
| defaults = dict( | |
| in_channels=80, | |
| out_channels=1, | |
| channels=512, | |
| kernel_size=7, | |
| upsample_scales=(8, 8, 2, 2), | |
| upsample_kernel_sizes=(16, 16, 4, 4), | |
| resblock_kernel_sizes=(3, 7, 11), | |
| resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)], | |
| use_additional_convs=True, | |
| bias=True, | |
| nonlinear_activation="LeakyReLU", | |
| nonlinear_activation_params={"negative_slope": 0.1}, | |
| use_weight_norm=True, | |
| ) | |
| defaults.update(kwargs) | |
| return defaults | |
| def make_hifigan_multi_scale_multi_period_discriminator_args(**kwargs): | |
| defaults = dict( | |
| scales=3, | |
| scale_downsample_pooling="AvgPool1d", | |
| scale_downsample_pooling_params={ | |
| "kernel_size": 4, | |
| "stride": 2, | |
| "padding": 2, | |
| }, | |
| scale_discriminator_params={ | |
| "in_channels": 1, | |
| "out_channels": 1, | |
| "kernel_sizes": [15, 41, 5, 3], | |
| "channels": 128, | |
| "max_downsample_channels": 128, | |
| "max_groups": 16, | |
| "bias": True, | |
| "downsample_scales": [2, 2, 4, 4, 1], | |
| "nonlinear_activation": "LeakyReLU", | |
| "nonlinear_activation_params": {"negative_slope": 0.1}, | |
| }, | |
| follow_official_norm=False, | |
| periods=[2, 3, 5, 7, 11], | |
| period_discriminator_params={ | |
| "in_channels": 1, | |
| "out_channels": 1, | |
| "kernel_sizes": [5, 3], | |
| "channels": 32, | |
| "downsample_scales": [3, 3, 3, 3, 1], | |
| "max_downsample_channels": 128, | |
| "bias": True, | |
| "nonlinear_activation": "LeakyReLU", | |
| "nonlinear_activation_params": {"negative_slope": 0.1}, | |
| "use_weight_norm": True, | |
| "use_spectral_norm": False, | |
| }, | |
| ) | |
| defaults.update(kwargs) | |
| return defaults | |
| def test_hifigan_trainable(dict_g, dict_d, dict_loss): | |
| # setup | |
| batch_size = 4 | |
| batch_length = 2 ** 13 | |
| args_g = make_hifigan_generator_args(**dict_g) | |
| args_d = make_hifigan_multi_scale_multi_period_discriminator_args(**dict_d) | |
| args_loss = make_mutli_reso_stft_loss_args(**dict_loss) | |
| y = torch.randn(batch_size, 1, batch_length) | |
| c = torch.randn( | |
| batch_size, | |
| args_g["in_channels"], | |
| batch_length // np.prod(args_g["upsample_scales"]), | |
| ) | |
| model_g = HiFiGANGenerator(**args_g) | |
| model_d = HiFiGANMultiScaleMultiPeriodDiscriminator(**args_d) | |
| aux_criterion = MultiResolutionSTFTLoss(**args_loss) | |
| feat_match_criterion = FeatureMatchLoss( | |
| average_by_layers=False, | |
| average_by_discriminators=False, | |
| include_final_outputs=True, | |
| ) | |
| gen_adv_criterion = GeneratorAdversarialLoss( | |
| average_by_discriminators=False, | |
| ) | |
| dis_adv_criterion = DiscriminatorAdversarialLoss( | |
| average_by_discriminators=False, | |
| ) | |
| optimizer_g = torch.optim.AdamW(model_g.parameters()) | |
| optimizer_d = torch.optim.AdamW(model_d.parameters()) | |
| # check generator trainable | |
| y_hat = model_g(c) | |
| p_hat = model_d(y_hat) | |
| sc_loss, mag_loss = aux_criterion(y_hat, y) | |
| aux_loss = sc_loss + mag_loss | |
| adv_loss = gen_adv_criterion(p_hat) | |
| with torch.no_grad(): | |
| p = model_d(y) | |
| fm_loss = feat_match_criterion(p_hat, p) | |
| loss_g = adv_loss + aux_loss + fm_loss | |
| optimizer_g.zero_grad() | |
| loss_g.backward() | |
| optimizer_g.step() | |
| # check discriminator trainable | |
| p = model_d(y) | |
| p_hat = model_d(y_hat.detach()) | |
| real_loss, fake_loss = dis_adv_criterion(p_hat, p) | |
| loss_d = real_loss + fake_loss | |
| optimizer_d.zero_grad() | |
| loss_d.backward() | |
| optimizer_d.step() | |
| print(model_d) | |
| print(model_g) | |