Spaces:
Runtime error
Runtime error
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| """pitchshift.py""" | |
| # import math | |
| import numpy as np | |
| # from scipy import special | |
| from einops import rearrange | |
| from typing import Optional, Literal, Dict, List, Tuple, Callable | |
| import torch | |
| from torch import nn | |
| import torchaudio | |
| from torchaudio import transforms | |
| # from torchaudio import functional as F | |
| # from torchaudio.functional.functional import ( | |
| # _fix_waveform_shape, | |
| # _stretch_waveform, | |
| # ) | |
| # from model.ops import adjust_b_to_gcd, check_all_elements_equal | |
| class PitchShiftLayer(nn.Module): | |
| """Applying batch-wise pitch-shift to time-domain audio signals. | |
| Args: | |
| pshift_range (List[int]): Range of pitch shift in semitones. Default: ``[-2, 2]``. | |
| resample_source_fs (int): Default is 4000. | |
| stretch_n_fft (int): Default is 2048. | |
| window: (Optional[Literal['kaiser']]) Default is None. | |
| beta: (Optional[float]): Parameter for 'kaiser' filter. Default: None. | |
| """ | |
| def __init__( | |
| self, | |
| pshift_range: List[int] = [-2, 2], | |
| resample_source_fs: int = 4000, | |
| strecth_n_fft: int = 512, | |
| win_length: Optional[int] = None, | |
| hop_length: Optional[int] = None, | |
| window: Optional[Literal['kaiser']] = None, | |
| beta: Optional[float] = None, | |
| expected_input_shape: Optional[Tuple[int]] = None, | |
| device: Optional[torch.device] = None, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__() | |
| self.pshift_range = pshift_range | |
| self.resample_source_fs = resample_source_fs | |
| self.strecth_n_fft = strecth_n_fft | |
| self.win_length = win_length | |
| self.hop_length = hop_length | |
| if window is None: | |
| self.window_fn = torch.hann_window | |
| self.window_kwargs = None | |
| elif 'kaiser' in window: | |
| def custom_kaiser_window(window_length, beta, **kwargs): | |
| return torch.kaiser_window(window_length, periodic=True, beta=beta, **kwargs) | |
| self.window_fn = custom_kaiser_window | |
| self.window_kwargs = {'beta': beta} | |
| # Initialize pitch shifters for every semitone | |
| self.pshifters = None | |
| self.frame_gaps = None | |
| self._initialize_pshifters(expected_input_shape, device=device) | |
| self.requires_grad_(False) | |
| def _initialize_pshifters(self, | |
| expected_input_shape: Optional[Tuple[int]] = None, | |
| device: Optional[torch.device] = None) -> None: | |
| # DDP requires initializing parameters with a dummy input | |
| if expected_input_shape is not None: | |
| if device is not None: | |
| dummy_input = torch.randn(expected_input_shape, requires_grad=False).to(device) | |
| else: | |
| dummy_input = torch.randn(expected_input_shape, requires_grad=False) | |
| else: | |
| dummy_input = None | |
| pshifters = nn.ModuleDict() | |
| for semitone in range(self.pshift_range[0], self.pshift_range[1] + 1): | |
| if semitone == 0: | |
| # No need to shift and resample | |
| pshifters[str(semitone)] = None | |
| else: | |
| pshifter = transforms.PitchShift(self.resample_source_fs, | |
| n_steps=semitone, | |
| n_fft=self.strecth_n_fft, | |
| win_length=self.win_length, | |
| hop_length=self.hop_length, | |
| window_fn=self.window_fn, | |
| wkwargs=self.window_kwargs) | |
| pshifters[str(semitone)] = pshifter | |
| # Pass dummy input to initialize parameters | |
| with torch.no_grad(): | |
| if dummy_input is not None: | |
| _ = pshifter.initialize_parameters(dummy_input) | |
| self.pshifters = pshifters | |
| def calculate_frame_gaps(self) -> Dict[int, float]: | |
| """Calculate the expected gap between the original and the stretched audio.""" | |
| frame_gaps = {} # for debugging | |
| for semitone in range(self.pshift_range[0], self.pshift_range[1] + 1): | |
| if semitone == 0: | |
| # No need to shift and resample | |
| frame_gaps[semitone] = 0. | |
| else: | |
| pshifter = self.pshifters[str(semitone)] | |
| gap_in_ms = 1000. * (pshifter.kernel.shape[2] - | |
| pshifter.kernel.shape[0] / 2.0**(-float(semitone) / 12)) / self.resample_source_fs | |
| frame_gaps[semitone] = gap_in_ms | |
| return frame_gaps | |
| def forward(self, x: torch.Tensor, semitone: int) -> torch.Tensor: | |
| """ | |
| Args: | |
| x (torch.Tensor): (B, 1, T) or (B, T) | |
| Returns: | |
| torch.Tensor: (B, 1, T) or (B, T) | |
| """ | |
| if semitone == 0: | |
| return x | |
| elif semitone >= min(self.pshift_range) and semitone <= max(self.pshift_range): | |
| return self.pshifters[str(semitone)](x) | |
| else: | |
| raise ValueError(f"semitone must be in range {self.pshift_range}") | |
| def test_resampler_sinewave(): | |
| # x: {440Hz, 220Hz} sine wave at 16kHz | |
| t = torch.arange(0, 2, 1 / 16000) # 2 seconds at 16kHz | |
| x0 = torch.sin(2 * torch.pi * 440 * t) * 0.5 | |
| x1 = torch.sin(2 * torch.pi * 220 * t) * 0.5 | |
| x = torch.stack((x0, x1), dim=0) # (2, 32000) | |
| # Resample | |
| psl = PitchShiftLayer(pshift_range=[-2, 2], resample_source_fs=4000) | |
| y = psl(x, 2) # (2, 24000) | |
| # Export to wav | |
| torchaudio.save("x.wav", x, 16000, bits_per_sample=16) | |
| torchaudio.save("y.wav", y, 12000, bits_per_sample=16) | |
| # class Resampler(nn.Module): | |
| # """ | |
| # Resampling using conv1d operations, more memory-efficient than torchaudio's resampler. | |
| # Based on Dan Povey's resampler.py: | |
| # https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py | |
| # """ | |
| # def __init__(self, | |
| # input_sr: int, | |
| # output_sr: int, | |
| # dtype: torch.dtype = torch.float32, | |
| # filter_width: int = 16, | |
| # cutoff_ratio: float = 0.85, | |
| # filter: Literal['kaiser', 'kaiser_best', 'kaiser_fast', 'hann'] = 'kaiser_fast', | |
| # beta: float = 8.555504641634386) -> None: | |
| # super().__init__() # init the base class | |
| # """ | |
| # Initialize the Resampler. | |
| # Args: | |
| # - input_sr (int): Input sampling rate. | |
| # - output_sr (int): Output sampling rate. | |
| # - dtype (torch.dtype): Computation data type. Default: torch.float32. | |
| # - filter_width (int): Number of zeros per side in the sinc function. Default: 16. | |
| # - cutoff_ratio (float): Filter rolloff point as a fraction of Nyquist freq. Default: 0.95. | |
| # - filter (str): Filter type. One of ['kaiser', 'kaiser_best', 'kaiser_fast', 'hann']. Default: 'kaiser_fast'. | |
| # - beta (float): Parameter for 'kaiser' filter. Default: 8.555504641634386. | |
| # Note: Ratio between input_sr and output_sr should be reduced to simplest form. | |
| # """ | |
| # assert isinstance(input_sr, int) and isinstance(output_sr, int) | |
| # if input_sr == output_sr: | |
| # self.resample_type = 'trivial' | |
| # return | |
| # d = math.gcd(input_sr, output_sr) | |
| # input_sr, output_sr = input_sr // d, output_sr // d | |
| # assert dtype in [torch.float32, torch.float64] | |
| # assert filter_width > 3 # a reasonable bare minimum | |
| # np_dtype = np.float32 if dtype == torch.float32 else np.float64 | |
| # assert filter in ['hann', 'kaiser', 'kaiser_best', 'kaiser_fast'] | |
| # if filter == 'kaiser_best': | |
| # filter_width = 64 | |
| # beta = 14.769656459379492 | |
| # cutoff_ratio = 0.9475937167399596 | |
| # filter = 'kaiser' | |
| # elif filter == 'kaiser_fast': | |
| # filter_width = 16 | |
| # beta = 8.555504641634386 | |
| # cutoff_ratio = 0.85 | |
| # filter = 'kaiser' | |
| # """ | |
| # - Define a sample 'block' correlating `input_sr` input samples to `output_sr` output samples. | |
| # - Dividing samples into these blocks allows corresponding block alignment. | |
| # - On average, `zeros_per_block` zeros per block are present in the sinc function. | |
| # """ | |
| # zeros_per_block = min(input_sr, output_sr) * cutoff_ratio | |
| # """ | |
| # - Define conv kernel size n = (blocks_per_side*2 + 1), adding blocks to each side of the center. | |
| # - `blocks_per_side` blocks as window radius ensures each central block sample accesses its window. | |
| # - `blocks_per_side` is determined, rounding up if needed, as 1 + int(filter_width / zeros_per_block). | |
| # """ | |
| # blocks_per_side = int(np.ceil(filter_width / zeros_per_block)) | |
| # kernel_width = 2 * blocks_per_side + 1 | |
| # # Shape of conv1d weights: (out_channels, in_channels, kernel_width) | |
| # """ Time computations are in units of 1 block, aligning with the `canonical` time axis, | |
| # since each block has input_sr input samples, adhering to our time unit.""" | |
| # window_radius_in_blocks = blocks_per_side | |
| # """`times` will be sinc function arguments, expanding to shape (output_sr, input_sr, kernel_width) | |
| # via broadcasting. Ensuring t == 0 along the central block diagonal (when input_sr == output_sr)""" | |
| # times = ( | |
| # np.arange(output_sr, dtype=np_dtype).reshape( | |
| # (output_sr, 1, 1)) / output_sr - np.arange(input_sr, dtype=np_dtype).reshape( | |
| # (1, input_sr, 1)) / input_sr - (np.arange(kernel_width, dtype=np_dtype).reshape( | |
| # (1, 1, kernel_width)) - blocks_per_side)) | |
| # def hann_window(a): | |
| # """ | |
| # returning 0.5 + 0.5 cos(a*pi) on [-1,1] and 0 outside. | |
| # """ | |
| # return np.heaviside(1 - np.abs(a), 0.0) * (0.5 + 0.5 * np.cos(a * np.pi)) | |
| # def kaiser_window(a, beta): | |
| # w = special.i0(beta * np.sqrt(np.clip(1 - ( | |
| # (a - 0.0) / 1.0)**2.0, 0.0, 1.0))) / special.i0(beta) | |
| # return np.heaviside(1 - np.abs(a), 0.0) * w | |
| # """The weights are computed as a sinc function times a Hann-window function, normalized by | |
| # `zeros_per_block` (sinc) and `input_sr` (input function) to maintain integral and magnitude.""" | |
| # if filter == 'hann': | |
| # weights = ( | |
| # np.sinc(times * zeros_per_block) * hann_window(times / window_radius_in_blocks) * | |
| # zeros_per_block / input_sr) | |
| # else: | |
| # weights = ( | |
| # np.sinc(times * zeros_per_block) * | |
| # kaiser_window(times / window_radius_in_blocks, beta) * zeros_per_block / input_sr) | |
| # self.input_sr = input_sr | |
| # self.output_sr = output_sr | |
| # """If output_sr == 1, merge input_sr into kernel_width for weights (shape: output_sr, input_sr, | |
| # kernel_width) to optimize convolution speed and avoid extra reshaping.""" | |
| # assert weights.shape == (output_sr, input_sr, kernel_width) | |
| # if output_sr == 1: | |
| # self.resample_type = 'integer_downsample' | |
| # self.padding = input_sr * blocks_per_side | |
| # weights = torch.tensor(weights, dtype=dtype, requires_grad=False) | |
| # weights = weights.transpose(1, 2).contiguous().view(1, 1, input_sr * kernel_width) | |
| # elif input_sr == 1: | |
| # # For conv_transpose, use weights as if input_sr and output_sr were swapped, simulating downsampling. | |
| # self.resample_type = 'integer_upsample' | |
| # self.padding = output_sr * blocks_per_side | |
| # weights = torch.tensor(weights, dtype=dtype, requires_grad=False) | |
| # weights = weights.flip(2).transpose(0, | |
| # 2).contiguous().view(1, 1, output_sr * kernel_width) | |
| # else: | |
| # self.resample_type = 'general' | |
| # self.reshaped = False | |
| # self.padding = blocks_per_side | |
| # weights = torch.tensor(weights, dtype=dtype, requires_grad=False) | |
| # self.weights = torch.nn.Parameter(weights, requires_grad=False) | |
| # @torch.no_grad() | |
| # def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # """ | |
| # Parameters: | |
| # - x: torch.Tensor, with shape (minibatch_size, sequence_length), dtype should match the instance's dtype. | |
| # Returns: | |
| # - A torch.Tensor with shape (minibatch_size, (sequence_length//input_sr)*output_sr), dtype matching the input, | |
| # and content resampled. | |
| # """ | |
| # if self.resample_type == 'trivial': | |
| # return x | |
| # elif self.resample_type == 'integer_downsample': | |
| # (minibatch_size, seq_len) = x.shape # (B, in_C, L) with in_C == 1 | |
| # x = x.unsqueeze(1) | |
| # x = torch.nn.functional.conv1d( | |
| # x, self.weights, stride=self.input_sr, padding=self.padding) # (B, out_C, L) | |
| # return x.squeeze(1) # (B, L) | |
| # elif self.resample_type == 'integer_upsample': | |
| # x = x.unsqueeze(1) | |
| # x = torch.nn.functional.conv_transpose1d( | |
| # x, self.weights, stride=self.output_sr, padding=self.padding) | |
| # return x.squeeze(1) | |
| # else: | |
| # assert self.resample_type == 'general' | |
| # (minibatch_size, seq_len) = x.shape | |
| # num_blocks = seq_len // self.input_sr | |
| # if num_blocks == 0: | |
| # # TODO: pad with zeros. | |
| # raise RuntimeError("Signal is too short to resample") | |
| # # Truncate input | |
| # x = x[:, 0:(num_blocks * self.input_sr)].view(minibatch_size, num_blocks, self.input_sr) | |
| # x = x.transpose(1, 2) # (B, in_C, L) | |
| # x = torch.nn.functional.conv1d( | |
| # x, self.weights, padding=self.padding) # (B, out_C, num_blocks) | |
| # return x.transpose(1, 2).contiguous().view(minibatch_size, num_blocks * self.output_sr) | |
| # def test_resampler_sinewave(): | |
| # import torchaudio | |
| # # x: {440Hz, 220Hz} sine wave at 16kHz | |
| # t = torch.arange(0, 2, 1 / 16000) # 2 seconds at 16kHz | |
| # x0 = torch.sin(2 * torch.pi * 440 * t) * 0.5 | |
| # x1 = torch.sin(2 * torch.pi * 220 * t) * 0.5 | |
| # x = torch.stack((x0, x1), dim=0) # (2, 32000) | |
| # # Resample | |
| # resampler = Resampler(input_sr=16000, output_sr=12000) | |
| # y = resampler(x) # (2, 24000) | |
| # # Export to wav | |
| # torchaudio.save("x.wav", x, 16000, bits_per_sample=16) | |
| # torchaudio.save("y.wav", y, 12000, bits_per_sample=16) | |
| # def test_resampler_music(): | |
| # import torchaudio | |
| # # x: music at 16kHz | |
| # x, _ = torchaudio.load("music.wav") | |
| # slice_length = 32000 | |
| # n_slices = 80 | |
| # slices = [x[0, i * slice_length:(i + 1) * slice_length] for i in range(n_slices)] | |
| # x = torch.stack(slices) # (80, 32000) | |
| # # Resample | |
| # filter_width = 32 | |
| # resampler = Resampler(16000, 12000, filter_width=filter_width) | |
| # y = resampler(x) # (80, 24000) | |
| # y = y.reshape(1, -1) # (1, 1920000) | |
| # torchaudio.save(f"y_filter_width{filter_width}.wav", y, 12000, bits_per_sample=16) | |
| # class PitchShiftLayer(nn.Module): | |
| # """Applying batch-wise pitch-shift to time-domain audio signals. | |
| # Args: | |
| # expected_input_length (int): Expected input length. Default: ``32767``. | |
| # pshift_range (List[int]): Range of pitch shift in semitones. Default: ``[-2, 2]``. | |
| # min_gcd (int): Minimum GCD of input and output sampling rates for resampling. Setting high value can save GPU memory. Default: ``16``. | |
| # max_timing_error (float): Maximum allowed timing error in seconds. Default: ``0.002``. | |
| # fs (int): Sample rate of input waveform, x. Default: 16000. | |
| # bins_per_octave (int, optional): The number of steps per octave (Default : ``12``). | |
| # n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``). | |
| # win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``). | |
| # hop_length (int or None, optional): Length of hop between STFT windows. If None, then ``win_length // 4`` | |
| # is used (Default: ``None``). | |
| # window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window. | |
| # If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``). | |
| # """ | |
| # def __init__( | |
| # self, | |
| # expected_input_length: int = 32767, | |
| # pshift_range: List[int] = [-2, 2], | |
| # min_gcd: int = 16, | |
| # max_timing_error: float = 0.002, | |
| # fs: int = 16000, | |
| # bins_per_octave: int = 12, | |
| # n_fft: int = 2048, | |
| # win_length: Optional[int] = None, | |
| # hop_length: Optional[int] = None, | |
| # window: Optional[torch.Tensor] = None, | |
| # filter_width: int = 16, | |
| # filter: Literal['kaiser', 'kaiser_best', 'kaiser_fast', 'hann'] = 'kaiser_fast', | |
| # cutoff_ratio: float = 0.85, | |
| # beta: float = 8.555504641634386, | |
| # **kwargs, | |
| # ): | |
| # super().__init__() | |
| # self.expected_input_length = expected_input_length | |
| # self.pshift_range = pshift_range | |
| # self.min_gcd = min_gcd | |
| # self.max_timing_error = max_timing_error | |
| # self.fs = fs | |
| # self.bins_per_octave = bins_per_octave | |
| # self.n_fft = n_fft | |
| # self.win_length = win_length | |
| # self.hop_length = hop_length | |
| # self.window = window | |
| # self.resample_args = { | |
| # "filter_width": filter_width, | |
| # "filter": filter, | |
| # "cutoff_ratio": cutoff_ratio, | |
| # "beta": beta, | |
| # } | |
| # # Initialize Resamplers | |
| # self._initialize_resamplers() | |
| # def _initialize_resamplers(self): | |
| # resamplers = nn.ModuleDict() | |
| # self.frame_gaps = {} # for debugging | |
| # for i in range(self.pshift_range[0], self.pshift_range[1] + 1): | |
| # if i == 0: | |
| # # No need to shift and resample | |
| # resamplers[str(i)] = None | |
| # else: | |
| # # Find optimal reconversion frames meeting the min_gcd | |
| # stretched_frames, recon_frames, gap = self._find_optimal_reconversion_frames(i) | |
| # self.frame_gaps[i] = gap | |
| # resamplers[str(i)] = Resampler(stretched_frames, recon_frames, **self.resample_args) | |
| # self.resamplers = resamplers | |
| # def _find_optimal_reconversion_frames(self, semitone: int): | |
| # """ | |
| # Find the optimal reconversion frames for a given source sample rate, input length, and semitone for strech. | |
| # Parameters: | |
| # - sr (int): Input audio sample rate, which should be power of 2 | |
| # - n_step (int): The number of pitch-shift steps in semi-tone. | |
| # - min_gcd (int): The minimum desired GCD, power of 2. Defaults to 16. 16 or 32 are good choices. | |
| # - max_timing_error (float): The maximum allowed timing error, in seconds. Defaults to 5 ms | |
| # Returns: | |
| # - int: The optimal target sample rate | |
| # """ | |
| # stretch_rate = 1 / 2.0**(-float(semitone) / self.bins_per_octave) | |
| # stretched_frames = round(self.expected_input_length * stretch_rate) | |
| # gcd = math.gcd(self.expected_input_length, stretched_frames) | |
| # if gcd >= self.min_gcd: | |
| # return stretched_frames, self.expected_input_length, 0 | |
| # else: | |
| # reconversion_frames = adjust_b_to_gcd(stretched_frames, self.expected_input_length, | |
| # self.min_gcd) | |
| # gap = reconversion_frames - self.expected_input_length | |
| # gap_sec = gap / self.fs | |
| # if gap_sec > self.max_timing_error: | |
| # # TODO: modifying vocoder of stretch_waveform to adjust pitch-shift rate in cents | |
| # raise ValueError( | |
| # gap_sec < self.max_timing_error, | |
| # f"gap_sec={gap_sec} > max_timing_error={self.max_timing_error} with semitone={semitone}, stretched_frames={stretched_frames}, recon_frames={reconversion_frames}. Try adjusting input lenght or decreasing min_gcd." | |
| # ) | |
| # else: | |
| # return stretched_frames, reconversion_frames, gap_sec | |
| # @torch.no_grad() | |
| # def forward(self, | |
| # x: torch.Tensor, | |
| # semitone: int, | |
| # resample: bool = True, | |
| # fix_shape: bool = True) -> torch.Tensor: | |
| # """ | |
| # Args: | |
| # x (torch.Tensor): (B, 1, T) | |
| # Returns: | |
| # torch.Tensor: (B, 1, T) | |
| # """ | |
| # if semitone == 0: | |
| # return x | |
| # elif semitone >= min(self.pshift_range) and semitone <= max(self.pshift_range): | |
| # x = x.squeeze(1) # (B, T) | |
| # original_x_size = x.size() | |
| # x = _stretch_waveform( | |
| # x, | |
| # semitone, | |
| # self.bins_per_octave, | |
| # self.n_fft, | |
| # self.win_length, | |
| # self.hop_length, | |
| # self.window, | |
| # ) | |
| # if resample: | |
| # x = self.resamplers[str(semitone)].forward(x) | |
| # # Fix waveform shape | |
| # if fix_shape: | |
| # if x.size(1) != original_x_size[1]: | |
| # # print(f"Warning: {x.size(1)} != {original_x_length}") | |
| # x = _fix_waveform_shape(x, original_x_size) | |
| # return x.unsqueeze(1) # (B, 1, T) | |
| # else: | |
| # raise ValueError(f"semitone must be in range {self.pshift_range}") | |
| # def test_pitchshift_layer(): | |
| # import torchaudio | |
| # # music | |
| # # x, _ = torchaudio.load("music.wav") | |
| # # slice_length = 32767 | |
| # # n_slices = 80 | |
| # # slices = [x[0, i * slice_length:(i + 1) * slice_length] for i in range(n_slices)] | |
| # # x = torch.stack(slices).unsqueeze(1) # (80, 1, 32767) | |
| # # sine wave | |
| # t = torch.arange(0, 2.0479, 1 / 16000) # 2.05 seconds at 16kHz | |
| # x = torch.sin(2 * torch.pi * 440 * t) * 0.5 | |
| # x = x.reshape(1, 1, 32767).tile(80, 1, 1) | |
| # # Resample | |
| # pos = 0 | |
| # ps = PitchShiftLayer( | |
| # pshift_range=[-3, 4], | |
| # expected_input_length=32767, | |
| # fs=16000, | |
| # min_gcd=16, | |
| # max_timing_error=0.002, | |
| # # filter_width=64, | |
| # filter='kaiser_fast', | |
| # n_fft=2048) | |
| # y = [] | |
| # for i in range(-3, 4): | |
| # y.append(ps(x[[pos], :, :], i, resample=False, fix_shape=False)[0, 0, :]) | |
| # y = torch.cat(y).unsqueeze(0) # (1, 32767 * 7) | |
| # torchaudio.save("y_2048_kaiser_fast.wav", y, 16000, bits_per_sample=16) | |
| # # TorchAudio PitchShifter fopr comparision | |
| # y_ta = [] | |
| # for i in range(-3, 4): | |
| # ta_transform = torchaudio.transforms.PitchShift(16000, n_steps=i) | |
| # y_ta.append(ta_transform(x[[pos], :, :])[0, 0, :]) | |
| # y_ta = torch.cat(y_ta).unsqueeze(0) # (1, 32767 * 7) | |
| # torchaudio.save("y_ta.wav", y_ta, 16000, bits_per_sample=16) | |
| # def test_min_gcd_mem_usage(): | |
| # min_gcd = 16 | |
| # for i in range(-3, 4): | |
| # stretched_frames = _stretch_waveform(x, i).shape[1] | |
| # adjusted = adjust_b_to_gcd(stretched_frames, 32767, min_gcd) | |
| # gcd_val = math.gcd(adjusted, stretched_frames) | |
| # gap = adjusted - 32767 | |
| # gap_ms = (gap / 16000) * 1000 | |
| # mem_mb = (stretched_frames / gcd_val) * (adjusted / gcd_val) * 3 * 4 / 1000 / 1000 | |
| # print(f'\033[92mmin_gcd={min_gcd}\033[0m', f'ps={i}', f'frames={stretched_frames}', | |
| # f'adjusted_frames={adjusted}', f'gap={gap}', f'\033[91mgap_ms={gap_ms}\033[0m', | |
| # f'gcd={gcd_val}', f'mem_MB={mem_mb}') | |