Spaces:
Running
on
T4
Running
on
T4
| import os | |
| import random | |
| from multiprocessing import Manager | |
| from multiprocessing import Process | |
| import librosa | |
| import numpy | |
| import soundfile as sf | |
| import torch | |
| import torchaudio | |
| from torch.utils.data import Dataset | |
| from tqdm import tqdm | |
| from Preprocessing.AudioPreprocessor import AudioPreprocessor | |
| def random_pitch_shifter(x): | |
| n_steps = random.choice([-12, -9, -6, 3, 12]) # when using 12 steps per octave, these are the only ones that are pretty fast. I benchmarked it and the variance is many orders of magnitude. | |
| return torchaudio.transforms.PitchShift(sample_rate=24000, n_steps=n_steps)(x) | |
| def polarity_inverter(x): | |
| return x * -1 | |
| class HiFiGANDataset(Dataset): | |
| def __init__(self, | |
| list_of_paths, | |
| desired_samplingrate=24000, | |
| samples_per_segment=12288, # = (8192 * 3) 2 , as I used 8192 for 16kHz previously | |
| loading_processes=max(os.cpu_count() - 2, 1), | |
| use_random_corruption=False): | |
| self.use_random_corruption = use_random_corruption | |
| self.samples_per_segment = samples_per_segment | |
| self.desired_samplingrate = desired_samplingrate | |
| self.melspec_ap = AudioPreprocessor(input_sr=self.desired_samplingrate, | |
| output_sr=16000, | |
| cut_silence=False) | |
| # hop length of spec loss should be same as the product of the upscale factors | |
| # samples per segment must be a multiple of hop length of spec loss | |
| if loading_processes == 1: | |
| self.waves = list() | |
| self.cache_builder_process(list_of_paths) | |
| else: | |
| resource_manager = Manager() | |
| self.waves = resource_manager.list() | |
| # make processes | |
| path_splits = list() | |
| process_list = list() | |
| for i in range(loading_processes): | |
| path_splits.append(list_of_paths[i * len(list_of_paths) // loading_processes:(i + 1) * len( | |
| list_of_paths) // loading_processes]) | |
| for path_split in path_splits: | |
| process_list.append(Process(target=self.cache_builder_process, args=(path_split,), daemon=True)) | |
| process_list[-1].start() | |
| for process in process_list: | |
| process.join() | |
| # self.masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True) # up to 16 consecutive bands can be masked, each element in the batch gets a different mask. Taken out because it seems too extreme. | |
| self.wave_augs = [random_pitch_shifter, polarity_inverter, lambda x: x, lambda x: x, lambda x: x, lambda x: x] # just some data augmentation | |
| self.wave_distortions = [CodecSimulator(), lambda x: x, lambda x: x, lambda x: x, lambda x: x] # simulating the fact, that we train the TTS on codec-compressed waves | |
| print("{} eligible audios found".format(len(self.waves))) | |
| def cache_builder_process(self, path_split): | |
| for path in tqdm(path_split): | |
| try: | |
| wave, sr = sf.read(path) | |
| if len(wave.shape) == 2: | |
| wave = librosa.to_mono(numpy.transpose(wave)) | |
| if sr != self.desired_samplingrate: | |
| wave = librosa.resample(y=wave, orig_sr=sr, target_sr=self.desired_samplingrate) | |
| self.waves.append(wave) | |
| except RuntimeError: | |
| print(f"Problem with the following path: {path}") | |
| def __getitem__(self, index): | |
| """ | |
| load the audio from the path and clean it. | |
| All audio segments have to be cut to the same length, | |
| according to the NeurIPS reference implementation. | |
| return a pair of high-res audio and corresponding low-res spectrogram as if it was predicted by the TTS | |
| """ | |
| try: | |
| wave = self.waves[index] | |
| while len(wave) < self.samples_per_segment + 50: # + 50 is just to be extra sure | |
| # catch files that are too short to apply meaningful signal processing and make them longer | |
| wave = numpy.concatenate([wave, numpy.zeros(shape=1000), wave]) | |
| # add some true silence in the mix, so the vocoder is exposed to that as well during training | |
| wave = torch.Tensor(wave) | |
| if self.use_random_corruption: | |
| # augmentations for the wave | |
| wave = random.choice(self.wave_augs)(wave.unsqueeze(0)).squeeze(0) # it is intentional that this affects the target as well. This is not a distortion, but an augmentation. | |
| max_audio_start = len(wave) - self.samples_per_segment | |
| audio_start = random.randint(0, max_audio_start) | |
| segment = wave[audio_start: audio_start + self.samples_per_segment] | |
| resampled_segment = self.melspec_ap.resample(segment).float() # 16kHz spectrogram as input, 24kHz wave as output, see Blizzard 2021 DelightfulTTS | |
| if self.use_random_corruption: | |
| # augmentations for the wave | |
| resampled_segment = random.choice(self.wave_distortions)(resampled_segment.unsqueeze(0)).squeeze(0) | |
| melspec = self.melspec_ap.audio_to_mel_spec_tensor(resampled_segment, | |
| explicit_sampling_rate=16000, | |
| normalize=False).transpose(0, 1)[:-1].transpose(0, 1) | |
| return segment.detach(), melspec.detach() | |
| except RuntimeError: | |
| print("encountered a runtime error, using fallback strategy") | |
| if index == 0: | |
| index = len(self.waves) - 1 | |
| return self.__getitem__(index - 1) | |
| def __len__(self): | |
| return len(self.waves) | |
| class CodecSimulator(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.encoder = torchaudio.transforms.MuLawEncoding(quantization_channels=64) | |
| self.decoder = torchaudio.transforms.MuLawDecoding(quantization_channels=64) | |
| def forward(self, x): | |
| return self.decoder(self.encoder(x)) | |
| if __name__ == '__main__': | |
| import matplotlib.pyplot as plt | |
| wav, sr = sf.read("../../audios/speaker_references/female_high_voice.wav") | |
| resampled_wave = torch.Tensor(librosa.resample(y=wav, orig_sr=sr, target_sr=24000)) | |
| audio = torch.tensor(resampled_wave) | |
| melspec_ap = AudioPreprocessor(input_sr=24000, | |
| output_sr=16000, | |
| cut_silence=False) | |
| spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(resampled_wave).float(), | |
| explicit_sampling_rate=16000, | |
| normalize=False).transpose(0, 1)[:-1].transpose(0, 1) | |
| cs = CodecSimulator() | |
| masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True) # up to 8 consecutive bands can be masked | |
| # testing codec simulator | |
| out = cs(resampled_wave.unsqueeze(0)).squeeze(0) | |
| plt.plot(resampled_wave, alpha=0.5) | |
| plt.plot(out, alpha=0.5) | |
| plt.title("Codec Simulator") | |
| plt.show() | |
| # testing spectrogram masking | |
| for _ in range(5): | |
| masked_spec = masker(spec.unsqueeze(0)).squeeze(0) | |
| print(masked_spec) | |
| plt.imshow(masked_spec.cpu().numpy(), origin="lower", cmap='GnBu') | |
| plt.title("Masked Spec") | |
| plt.show() | |
| # testing pitch shift | |
| for _ in range(5): | |
| shifted_wave = random_pitch_shifter(resampled_wave.unsqueeze(0)).squeeze(0) | |
| shifted_spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(shifted_wave).float(), | |
| explicit_sampling_rate=16000, | |
| normalize=False).transpose(0, 1)[:-1].transpose(0, 1) | |
| plt.imshow(shifted_spec.detach().cpu().numpy(), origin="lower", cmap='GnBu') | |
| plt.title("Pitch Shifted Spec") | |
| plt.show() | |