Spaces:
Runtime error
Runtime error
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| """spectrogram.py""" | |
| import importlib | |
| from typing import Optional, Literal, Dict, Tuple | |
| from packaging.version import parse as VersionParse | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from model.ops import minmax_normalize | |
| from config.config import audio_cfg as default_audio_cfg | |
| """ | |
| Example usage: | |
| # MT3 setup | |
| hop = 8 ms or 128 samples | |
| melspec = Melspectrogram(sample_rate=16000, n_fft=2048, hop_length=128, | |
| f_min=50, f_max=8000, n_mels=512) | |
| x = torch.randn(2, 1, 32767) # (B, C=1, T): 2.048 s | |
| y = melspec(x) # (2, 256, 512) (B, T, F) | |
| # PerceiverTF-like setup | |
| hop = 18.75 ms or 300 samples | |
| spec = Spectrogram(n_fft=2048, hop_length=300) | |
| ) | |
| x = torch.randn(2, 1, 95999) # (B, C=1, T): 6.000 s | |
| y = spec(x) # (2, 320, 1024) (B, T, F) | |
| # Hybrid setup (2.048 seconds segment and spectrogram with hop=300) | |
| hop = 18.75 ms or 300 samples | |
| spec = Spectrogram(n_fft=2048, hop_length=300) | |
| x = torch.randn(2, 1, 32767) # (B, C=1, T): 2.048 s | |
| y = spec(x) # (2, 110, 1024) (B, T, F) | |
| # PerceiverTF-like setup, hop=256 | |
| hop = 16 ms or 256 samples | |
| spec256 = Spectrogram(sample_rate=16000, n_fft=2048, hop_length=256, | |
| f_min=20, f_max=8000, n_mels=256) | |
| x = torch.randn(2, 1, 32767) # (B, C=1, T): 2.048 s | |
| y = spec256(x) # (2, 128, 1024) (B, T, F) | |
| """ | |
| def optional_compiler_disable(func): | |
| if VersionParse(torch.__version__) >= VersionParse("2.1"): | |
| # If the version is 2.1 or higher, apply the torch.compiler.disable decorator. | |
| return torch.compiler.disable(func) | |
| else: | |
| # If the version is below 2.1, return the original function. | |
| return func | |
| # ------------------------------------------------------------------------------------- | |
| # Log-Mel spectrogram | |
| # ------------------------------------------------------------------------------------- | |
| class Melspectrogram(nn.Module): | |
| def __init__( | |
| self, | |
| audio_backend: Literal['torchaudio', 'nnaudio'] = 'torchaudio', | |
| sample_rate: int = 16000, | |
| n_fft: int = 2048, | |
| hop_length: int = 128, | |
| f_min: int = 50, # 20 Hz in the MT3 paper, but we can only use 20 Hz with nnAudio | |
| f_max: Optional[int] = 8000, | |
| n_mels: int = 512, | |
| eps: float = 1e-5, | |
| **kwargs, | |
| ): | |
| """ | |
| Log-Melspectrogram | |
| Args: | |
| audio_backend (str): 'torchaudio' or 'nnaudio' | |
| sample_rate (int): sample rate in Hz | |
| n_fft (int): FFT window size | |
| hop_length (int): hop length in samples | |
| f_min (int): minimum frequency in Hz | |
| f_max (int): maximum frequency in Hz | |
| n_mels (int): number of mel frequency bins | |
| eps (float): epsilon for numerical stability | |
| """ | |
| super(Melspectrogram, self).__init__() | |
| self.audio_backend = audio_backend.lower() | |
| if audio_backend.lower() == 'torchaudio': | |
| torchaudio = importlib.import_module('torchaudio') | |
| self.mel_stft = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=sample_rate, | |
| n_fft=n_fft, | |
| hop_length=hop_length, | |
| f_min=f_min, | |
| f_max=f_max, | |
| n_mels=n_mels, | |
| ) | |
| elif audio_backend.lower() == 'nnaudio': | |
| nnaudio = importlib.import_module('nnAudio.features') | |
| self.mel_stft_nnaudio = nnaudio.mel.MelSpectrogram( | |
| sr=sample_rate, | |
| win_length=n_fft, | |
| n_mels=n_mels, | |
| hop_length=hop_length, | |
| fmin=20, #f_min, | |
| fmax=f_max) | |
| else: | |
| raise NotImplementedError(audio_backend) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: # (B, 1, T) | |
| """ | |
| Args: | |
| x (torch.Tensor): (B, 1, T) | |
| Returns: | |
| torch.Tensor: (B, T, F) | |
| """ | |
| if self.audio_backend == 'torchaudio': | |
| x = self.mel_stft(x) # (B, 1, F, T) | |
| x = rearrange(x, 'b 1 f t -> b t f') | |
| x = minmax_normalize(torch.log(x + self.eps)) | |
| # some versions of torchaudio returns nan when input is all-zeros | |
| return torch.nan_to_num(x) | |
| elif self.audio_backend == 'nnaudio': | |
| x = self.mel_stft_nnaudio(x) # (B, F, T) | |
| x = rearrange(x, 'b f t -> b t f') | |
| x = minmax_normalize(torch.log(x + self.eps)) | |
| return x | |
| # ------------------------------------------------------------------------------------- | |
| # Log-spectrogram | |
| # ------------------------------------------------------------------------------------- | |
| class Spectrogram(nn.Module): | |
| def __init__( | |
| self, | |
| audio_backend: Literal['torchaudio', 'nnaudio'] = 'torchaudio', | |
| n_fft: int = 2048, | |
| hop_length: int = 128, | |
| eps: float = 1e-5, | |
| **kwargs, | |
| ): | |
| """ | |
| Log-Magnitude Spectrogram | |
| Args: | |
| audio_backend (str): 'torchaudio' or 'nnaudio' | |
| n_fft (int): FFT window size, creates n_fft // 2 + 1 freq-bins | |
| hop_length (int): hop length in samples | |
| eps (float): epsilon for numerical stability | |
| """ | |
| super(Spectrogram, self).__init__() | |
| self.audio_backend = audio_backend.lower() | |
| if audio_backend.lower() == 'torchaudio': | |
| torchaudio = importlib.import_module('torchaudio') | |
| self.stft = torchaudio.transforms.Spectrogram(n_fft=n_fft, | |
| hop_length=hop_length, | |
| window_fn=torch.hann_window, | |
| power=1.) # (B, 1, F, T), remove DC component | |
| elif audio_backend.lower() == 'nnaudio': | |
| # TODO: nnAudio spectrogram | |
| raise NotImplementedError(audio_backend) | |
| else: | |
| raise NotImplementedError(audio_backend) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: # (B, 1, T) | |
| """ | |
| Args: | |
| x (torch.Tensor): (B, 1, T) | |
| Returns: | |
| torch.Tensor: (B, T, F) | |
| """ | |
| if self.audio_backend == 'torchaudio': | |
| x = self.stft(x)[:, :, 1:, :] # (B, 1, F, T) remove DC component | |
| x = rearrange(x, 'b 1 f t -> b t f') | |
| x = minmax_normalize(torch.log(x + self.eps)) | |
| return torch.nan_to_num(x) # some versions of torchaudio returns nan when input is all-zeros | |
| elif self.audio_backend == 'nnaudio': | |
| raise NotImplementedError(self.audio_backend) | |
| def get_spectrogram_layer_from_audio_cfg(audio_cfg: Optional[Dict] = None) -> Tuple[nn.Module, Tuple[int]]: | |
| """Get mel-/spectrogram layer from config. | |
| - Used by 'ymt3' to create a spectrogram layer. | |
| - Returns output shape of spectrogram layer, which is used to determine input shape of model. | |
| Args: | |
| audio_cfg (dict): see config/config.py | |
| Returns: | |
| layer (nn.Module): mel-/spectrogram layer | |
| output_shape (tuple): inferred output shape of layer excluding batch dim. (T, F) | |
| """ | |
| if audio_cfg is None: | |
| audio_cfg = default_audio_cfg | |
| if audio_cfg['codec'] == 'melspec': | |
| layer = Melspectrogram(**audio_cfg) | |
| elif audio_cfg['codec'] == 'spec': | |
| layer = Spectrogram(**audio_cfg) | |
| else: | |
| raise NotImplementedError(audio_cfg['codec']) | |
| # Infer output shape of the spectrogram layer | |
| with torch.no_grad(): | |
| output_shape = layer(torch.randn(1, 1, audio_cfg['input_frames'])).shape[1:] | |
| return layer, output_shape | |