Spaces:
Running
Running
| import math | |
| import os | |
| import random | |
| import torch | |
| import torch.utils.data | |
| import numpy as np | |
| from librosa.util import normalize | |
| from scipy.io.wavfile import read | |
| import scipy | |
| import librosa | |
| import wave | |
| from pydub import AudioSegment | |
| MAX_WAV_VALUE = 32768.0 | |
| def load_wav(full_path): | |
| try: | |
| sampling_rate, data = read(full_path) | |
| if max(data.shape) / sampling_rate < 0.5: | |
| return None, None | |
| except FileNotFoundError: | |
| print(f"File not found: {file_path}") | |
| return None, None | |
| except Exception as e: | |
| print(f"An unexpected error occurred: {e}") | |
| return None, None | |
| if len(data.shape) > 1: | |
| if data.shape[1] <= 2: | |
| data = data[...,0] | |
| else: | |
| data = data[0,...] | |
| return data / MAX_WAV_VALUE, sampling_rate | |
| def get_wave_duration(file_path): | |
| """ | |
| Gets the duration of a WAV file in seconds. | |
| :param file_path: Path to the WAV file. | |
| :return: Duration of the WAV file in seconds. | |
| """ | |
| try: | |
| with wave.open(file_path, 'rb') as wf: | |
| # Get the number of frames | |
| num_frames = wf.getnframes() | |
| # Get the frame rate | |
| frame_rate = wf.getframerate() | |
| # Calculate duration | |
| duration = num_frames / float(frame_rate) | |
| return duration, frame_rate, num_frames | |
| except wave.Error as e: | |
| print(f"Error reading {file_path}: {e}") | |
| return None, None, None | |
| except FileNotFoundError: | |
| print(f"File not found: {file_path}") | |
| return None, None, None | |
| except Exception as e: | |
| print(f"An unexpected error occurred: {e}") | |
| return None, None, None | |
| def read_audio_segment(file_path, start_ms, end_ms): | |
| """ | |
| Reads a segment from a WAV file and returns the raw data and its properties. | |
| :param file_path: Path to the WAV file. | |
| :param start_ms: Start time of the segment in milliseconds. | |
| :param end_ms: End time of the segment in milliseconds. | |
| :return: A tuple containing the raw audio data, frame rate, sample width, and number of channels. | |
| """ | |
| #start_time = time.time() | |
| try: | |
| # Load the audio file | |
| audio = AudioSegment.from_wav(file_path) | |
| # Extract the segment | |
| segment = audio[start_ms:end_ms] | |
| # Get raw audio data | |
| raw_data = segment.raw_data | |
| # Get audio properties | |
| frame_rate = segment.frame_rate | |
| sample_width = segment.sample_width | |
| channels = segment.channels | |
| # Create NumPy array from the raw audio data | |
| audio_array = np.frombuffer(raw_data, dtype=np.int16) | |
| # If stereo, reshape the array to have a second dimension | |
| if channels > 1: | |
| audio_array = audio_array.reshape((-1, channels)) | |
| audio_array = audio_array[...,0] | |
| ''' | |
| if frame_rate !=48000: | |
| audio_array = audio_array/MAX_WAV_VALUE | |
| audio_array = librosa.resample(audio_array, frame_rate, 48000) | |
| audio_array = audio_array * MAX_WAV_VALUE | |
| frame_rate = 48000 | |
| ''' | |
| #end_time = time.time() | |
| #time_taken = end_time - start_time | |
| #print(f"Successfully read segment from {start_ms}ms to {end_ms}ms in {time_taken:.4f} seconds") | |
| return audio_array / MAX_WAV_VALUE#, frame_rate #, sample_width, channels | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| return None#, None #, None, None | |
| def resample(audio, sr_in, sr_out, target_len=None): | |
| #audio = audio / MAX_WAV_VALUE | |
| #audio = normalize(audio) * 0.95 | |
| if target_len is not None: | |
| audio = scipy.signal.resample(audio, target_len) | |
| return audio | |
| resample_factor = sr_out / sr_in | |
| new_samples = int(len(audio) * resample_factor) | |
| audio = scipy.signal.resample(audio, new_samples) | |
| return audio | |
| def load_segment(full_path, target_sampling_rate=None, segment_size=None): | |
| if segment_size is not None: | |
| dur,sampling_rate,len_data = get_wave_duration(full_path) | |
| if sampling_rate is None: return None, None | |
| if sampling_rate < 44100: return None, None | |
| target_dur = segment_size / target_sampling_rate | |
| if dur < target_dur: | |
| data, sampling_rate = load_wav(full_path) | |
| #print(f'data_read: {data.shape}, sampling_rate: {sampling_rate}') | |
| if data is None: return None, None | |
| if target_sampling_rate is not None and sampling_rate != target_sampling_rate: | |
| data = resample(data, sampling_rate, target_sampling_rate) | |
| sampling_rate = target_sampling_rate | |
| data = torch.FloatTensor(data) | |
| data = data.unsqueeze(0) | |
| data = torch.nn.functional.pad(data, (0, segment_size - data.size(1)), 'constant') | |
| data = data.squeeze(0) | |
| return data.numpy(), sampling_rate | |
| else: | |
| dur,sampling_rate,len_data = get_wave_duration(full_path) | |
| if sampling_rate < 44100: return None, None | |
| target_dur = segment_size / target_sampling_rate | |
| target_len = int(target_dur * sampling_rate) | |
| start_idx = random.randint(0, (len_data - target_len)) | |
| start_ms = start_idx / sampling_rate * 1000 | |
| end_ms = start_ms + target_dur * 1000 | |
| data = read_audio_segment(full_path, start_ms, end_ms) | |
| #print(f'data_read: {data.shape}, sampling_rate: {sampling_rate}') | |
| if data is None: return None, None | |
| if target_sampling_rate is not None and sampling_rate != target_sampling_rate: | |
| data = resample(data, sampling_rate, target_sampling_rate) | |
| sampling_rate = target_sampling_rate | |
| if len(data) < segment_size: | |
| data = torch.FloatTensor(data) | |
| data = data.unsqueeze(0) | |
| data = torch.nn.functional.pad(data, (0, segment_size - data.size(1)), 'constant') | |
| data = data.squeeze(0) | |
| data = data.numpy() | |
| else: | |
| start_idx = random.randint(0, (len(data) - segment_size)) | |
| data = data[start_idx:start_idx+segment_size] | |
| #print(f'data_cut: {data.shape}') | |
| return data, sampling_rate | |
| else: | |
| dur,sampling_rate,len_data = get_wave_duration(full_path) | |
| if sampling_rate is None: return None, None | |
| if sampling_rate < 44100: return None, None | |
| data, sampling_rate = load_wav(full_path) | |
| if data is None: return None, None | |
| if target_sampling_rate is not None and sampling_rate != target_sampling_rate: | |
| data = resample(data, sampling_rate, target_sampling_rate) | |
| sampling_rate = target_sampling_rate | |
| return data, sampling_rate | |
| def dynamic_range_compression(x, C=1, clip_val=1e-5): | |
| return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) | |
| def dynamic_range_decompression(x, C=1): | |
| return np.exp(x) / C | |
| def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): | |
| return torch.log(torch.clamp(x, min=clip_val) * C) | |
| def dynamic_range_decompression_torch(x, C=1): | |
| return torch.exp(x) / C | |
| def spectral_normalize_torch(magnitudes): | |
| output = dynamic_range_compression_torch(magnitudes) | |
| return output | |
| def spectral_de_normalize_torch(magnitudes): | |
| output = dynamic_range_decompression_torch(magnitudes) | |
| return output | |
| mel_basis = {} | |
| hann_window = {} | |
| def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): | |
| ''' | |
| if torch.min(y) < -1.: | |
| print('min value is ', torch.min(y)) | |
| if torch.max(y) > 1.: | |
| print('max value is ', torch.max(y)) | |
| ''' | |
| global mel_basis, hann_window | |
| if fmax not in mel_basis: | |
| #mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) | |
| # sr, n_fft, n_mels=128, fmin=0.0, fmax | |
| mel = librosa.filters.mel(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) | |
| mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) | |
| hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) | |
| y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') | |
| y = y.squeeze(1) | |
| spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], | |
| center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) | |
| spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) | |
| spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) | |
| spec = spectral_normalize_torch(spec) | |
| return spec | |
| def get_dataset_filelist_org(a): | |
| with open(a.input_training_file, 'r', encoding='utf-8') as fi: | |
| training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') | |
| for x in fi.read().split('\n') if len(x) > 0] | |
| with open(a.input_validation_file, 'r', encoding='utf-8') as fi: | |
| validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') | |
| for x in fi.read().split('\n') if len(x) > 0] | |
| return training_files, validation_files | |
| def get_dataset_filelist(a): | |
| with open(a.input_training_file, 'r', encoding='utf-8') as fi: | |
| training_files = [x for x in fi.read().split('\n') if len(x) > 0] | |
| with open(a.input_validation_file, 'r', encoding='utf-8') as fi: | |
| validation_files = [x for x in fi.read().split('\n') if len(x) > 0] | |
| return training_files, validation_files | |
| class MelDataset(torch.utils.data.Dataset): | |
| def __init__(self, training_files, segment_size, n_fft, num_mels, | |
| hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1, | |
| device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None): | |
| self.audio_files = training_files | |
| random.seed(1234) | |
| if shuffle: | |
| random.shuffle(self.audio_files) | |
| self.segment_size = segment_size | |
| self.sampling_rate = sampling_rate | |
| self.split = split | |
| self.n_fft = n_fft | |
| self.num_mels = num_mels | |
| self.hop_size = hop_size | |
| self.win_size = win_size | |
| self.fmin = fmin | |
| self.fmax = fmax | |
| self.fmax_loss = fmax_loss | |
| self.cached_wav = None | |
| self.n_cache_reuse = n_cache_reuse | |
| self._cache_ref_count = 0 | |
| self.device = device | |
| self.fine_tuning = fine_tuning | |
| self.base_mels_path = base_mels_path | |
| self.supported_samples = [16000, 22050, 24000] #[4000, 8000, 16000, 22050, 24000, 32000] | |
| #self.supported_samples = [4000, 8000] #, 16000, 22050, 24000, 32000] | |
| def __getitem__(self, index): | |
| filename = self.audio_files[index] | |
| while 1: | |
| #audio, sampling_rate = load_wav(filename) | |
| audio, sampling_rate = load_segment(filename, self.sampling_rate, self.segment_size) | |
| if audio is not None: break | |
| else: | |
| filename = self.audio_files[random.randint(0,index)] | |
| #audio, sampling_rate = load_wav(filename) | |
| #audio, sampling_rate = load_segment(filename, self.sampling_rate, self.segment_size) | |
| #audio = audio / MAX_WAV_VALUE | |
| if not self.fine_tuning: | |
| audio = normalize(audio) * 0.95 | |
| sr_out = random.choice(self.supported_samples) | |
| audio_down = resample(audio, self.sampling_rate, sr_out) | |
| target_len = len(audio) #/ downsample_factor | |
| audio_up = resample(audio_down, None, None, target_len) | |
| audio = torch.FloatTensor(audio) | |
| audio = audio.unsqueeze(0) | |
| audio_up = torch.FloatTensor(audio_up) | |
| audio_up = audio_up.unsqueeze(0) | |
| mel = mel_spectrogram(audio_up, self.n_fft, self.num_mels, | |
| self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, | |
| center=False) | |
| mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels, | |
| self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, | |
| center=False) | |
| return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) | |
| def __getitem__org(self, index): | |
| filename = self.audio_files[index] | |
| if self._cache_ref_count == 0: | |
| while 1: | |
| audio, sampling_rate = load_wav(filename) | |
| if audio is not None: break | |
| else: | |
| filename = self.audio_files[random.randint(0,index)] | |
| audio, sampling_rate = load_wav(filename) | |
| audio = audio / MAX_WAV_VALUE | |
| if not self.fine_tuning: | |
| audio = normalize(audio) * 0.95 | |
| #self.cached_wav = audio | |
| if sampling_rate != self.sampling_rate: | |
| resample_factor = self.sampling_rate / sampling_rate | |
| new_samples = int(len(audio) * resample_factor) | |
| audio = scipy.signal.resample(audio, new_samples)#.astype(np.int16) | |
| #raise ValueError("{} SR doesn't match target {} SR".format( | |
| # sampling_rate, self.sampling_rate)) | |
| downsample_factor = 16000 / self.sampling_rate | |
| new_samples = int(len(audio) * downsample_factor) | |
| audio_down = scipy.signal.resample(audio, new_samples) | |
| new_samples = len(audio) #/ downsample_factor | |
| audio_up = scipy.signal.resample(audio_down, new_samples) | |
| #print(f'audio: {audio.shape}, audio_up: {audio_up.shape}') | |
| #min_idx = min(len(audio), len(audio_up)) | |
| #audio = audio[:min_idx] | |
| #audio_up = audio_up[:min_idx] | |
| self.cached_wav = audio | |
| self.cached_wav_up = audio_up | |
| self._cache_ref_count = self.n_cache_reuse | |
| else: | |
| audio = self.cached_wav | |
| audio_up = self.cached_wav_up | |
| self._cache_ref_count -= 1 | |
| audio = torch.FloatTensor(audio) | |
| audio = audio.unsqueeze(0) | |
| audio_up = torch.FloatTensor(audio_up) | |
| audio_up = audio_up.unsqueeze(0) | |
| if True: | |
| if self.split: | |
| if audio.size(1) >= self.segment_size: | |
| max_audio_start = audio.size(1) - self.segment_size | |
| audio_start = random.randint(0, max_audio_start) | |
| audio = audio[:, audio_start:audio_start+self.segment_size] | |
| audio_up = audio_up[:, audio_start:audio_start+self.segment_size] | |
| else: | |
| audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') | |
| audio_up = torch.nn.functional.pad(audio_up, (0, self.segment_size - audio_up.size(1)), 'constant') | |
| mel = mel_spectrogram(audio_up, self.n_fft, self.num_mels, | |
| self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, | |
| center=False) | |
| mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels, | |
| self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, | |
| center=False) | |
| return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) | |
| def __len__(self): | |
| return len(self.audio_files) | |