| import numpy as np | |
| import torch | |
| from TTS.vocoder.models.parallel_wavegan_generator import ParallelWaveganGenerator | |
| def test_pwgan_generator(): | |
| model = ParallelWaveganGenerator( | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_size=3, | |
| num_res_blocks=30, | |
| stacks=3, | |
| res_channels=64, | |
| gate_channels=128, | |
| skip_channels=64, | |
| aux_channels=80, | |
| dropout=0.0, | |
| bias=True, | |
| use_weight_norm=True, | |
| upsample_factors=[4, 4, 4, 4], | |
| ) | |
| dummy_c = torch.rand((2, 80, 5)) | |
| output = model(dummy_c) | |
| assert np.all(output.shape == (2, 1, 5 * 256)), output.shape | |
| model.remove_weight_norm() | |
| output = model.inference(dummy_c) | |
| assert np.all(output.shape == (2, 1, (5 + 4) * 256)) | |