| import numpy as np | |
| import torch | |
| from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator | |
| from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator | |
| def test_melgan_discriminator(): | |
| model = MelganDiscriminator() | |
| print(model) | |
| dummy_input = torch.rand((4, 1, 256 * 10)) | |
| output, _ = model(dummy_input) | |
| assert np.all(output.shape == (4, 1, 10)) | |
| def test_melgan_multi_scale_discriminator(): | |
| model = MelganMultiscaleDiscriminator() | |
| print(model) | |
| dummy_input = torch.rand((4, 1, 256 * 16)) | |
| scores, feats = model(dummy_input) | |
| assert len(scores) == 3 | |
| assert len(scores) == len(feats) | |
| assert np.all(scores[0].shape == (4, 1, 64)) | |
| assert np.all(feats[0][0].shape == (4, 16, 4096)) | |
| assert np.all(feats[0][1].shape == (4, 64, 1024)) | |
| assert np.all(feats[0][2].shape == (4, 256, 256)) | |