Spaces:
Runtime error
Runtime error
| """ | |
| Implementation of objective functions used in the task 'ITO-Master' | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import auraloss | |
| import os | |
| import sys | |
| currentdir = os.path.dirname(os.path.realpath(__file__)) | |
| sys.path.append(os.path.dirname(currentdir)) | |
| from modules.front_back_end import * | |
| # Root Mean Squared Loss | |
| # penalizes the volume factor with non-linearlity | |
| class RMSLoss(nn.Module): | |
| def __init__(self, reduce, loss_type="l2"): | |
| super(RMSLoss, self).__init__() | |
| self.weight_factor = 100. | |
| if loss_type=="l2": | |
| self.loss = nn.MSELoss(reduce=None) | |
| def forward(self, est_targets, targets): | |
| est_targets = est_targets.reshape(est_targets.shape[0]*est_targets.shape[1], est_targets.shape[2]) | |
| targets = targets.reshape(targets.shape[0]*targets.shape[1], targets.shape[2]) | |
| normalized_est = torch.sqrt(torch.mean(est_targets**2, dim=-1)) | |
| normalized_tgt = torch.sqrt(torch.mean(targets**2, dim=-1)) | |
| weight = torch.clamp(torch.abs(normalized_tgt-normalized_est), min=1/self.weight_factor) * self.weight_factor | |
| return torch.mean(weight**1.5 * self.loss(normalized_est, normalized_tgt)) | |
| # Multi-Scale Spectral Loss proposed at the paper "DDSP: DIFFERENTIABLE DIGITAL SIGNAL PROCESSING" (https://arxiv.org/abs/2001.04643) | |
| # we extend this loss by applying it to mid/side channels | |
| class MultiScale_Spectral_Loss_MidSide_DDSP(nn.Module): | |
| def __init__(self, mode='midside', \ | |
| reduce=True, \ | |
| n_filters=None, \ | |
| windows_size=None, \ | |
| hops_size=None, \ | |
| window="hann", \ | |
| eps=1e-7, \ | |
| device=torch.device("cpu")): | |
| super(MultiScale_Spectral_Loss_MidSide_DDSP, self).__init__() | |
| self.mode = mode | |
| self.eps = eps | |
| self.mid_weight = 0.5 # value in the range of 0.0 ~ 1.0 | |
| self.logmag_weight = 0.1 | |
| if n_filters is None: | |
| n_filters = [4096, 2048, 1024, 512] | |
| if windows_size is None: | |
| windows_size = [4096, 2048, 1024, 512] | |
| if hops_size is None: | |
| hops_size = [1024, 512, 256, 128] | |
| self.multiscales = [] | |
| for i in range(len(windows_size)): | |
| cur_scale = {'window_size' : float(windows_size[i])} | |
| if self.mode=='midside': | |
| cur_scale['front_end'] = FrontEnd(channel='mono', \ | |
| n_fft=n_filters[i], \ | |
| hop_length=hops_size[i], \ | |
| win_length=windows_size[i], \ | |
| window=window, \ | |
| device=device) | |
| elif self.mode=='ori': | |
| cur_scale['front_end'] = FrontEnd(channel='stereo', \ | |
| n_fft=n_filters[i], \ | |
| hop_length=hops_size[i], \ | |
| win_length=windows_size[i], \ | |
| window=window, \ | |
| device=device) | |
| self.multiscales.append(cur_scale) | |
| self.objective_l1 = nn.L1Loss(reduce=reduce) | |
| self.objective_l2 = nn.MSELoss(reduce=reduce) | |
| def forward(self, est_targets, targets): | |
| if self.mode=='midside': | |
| return self.forward_midside(est_targets, targets) | |
| elif self.mode=='ori': | |
| return self.forward_ori(est_targets, targets) | |
| def forward_ori(self, est_targets, targets): | |
| total_loss = 0.0 | |
| total_mag_loss = 0.0 | |
| total_logmag_loss = 0.0 | |
| for cur_scale in self.multiscales: | |
| est_mag = cur_scale['front_end'](est_targets, mode=["mag"]) | |
| tgt_mag = cur_scale['front_end'](targets, mode=["mag"]) | |
| mag_loss = self.magnitude_loss(est_mag, tgt_mag) | |
| logmag_loss = self.log_magnitude_loss(est_mag, tgt_mag) | |
| total_mag_loss += mag_loss | |
| total_logmag_loss += logmag_loss | |
| # return total_loss | |
| return (1-self.logmag_weight)*total_mag_loss + \ | |
| (self.logmag_weight)*total_logmag_loss | |
| def forward_midside(self, est_targets, targets): | |
| est_mid, est_side = self.to_mid_side(est_targets) | |
| tgt_mid, tgt_side = self.to_mid_side(targets) | |
| total_loss = 0.0 | |
| total_mag_loss = 0.0 | |
| total_logmag_loss = 0.0 | |
| for cur_scale in self.multiscales: | |
| est_mid_mag = cur_scale['front_end'](est_mid, mode=["mag"]) | |
| est_side_mag = cur_scale['front_end'](est_side, mode=["mag"]) | |
| tgt_mid_mag = cur_scale['front_end'](tgt_mid, mode=["mag"]) | |
| tgt_side_mag = cur_scale['front_end'](tgt_side, mode=["mag"]) | |
| mag_loss = self.mid_weight*self.magnitude_loss(est_mid_mag, tgt_mid_mag) + \ | |
| (1-self.mid_weight)*self.magnitude_loss(est_side_mag, tgt_side_mag) | |
| logmag_loss = self.mid_weight*self.log_magnitude_loss(est_mid_mag, tgt_mid_mag) + \ | |
| (1-self.mid_weight)*self.log_magnitude_loss(est_side_mag, tgt_side_mag) | |
| total_mag_loss += mag_loss | |
| total_logmag_loss += logmag_loss | |
| # return total_loss | |
| return (1-self.logmag_weight)*total_mag_loss + \ | |
| (self.logmag_weight)*total_logmag_loss | |
| def to_mid_side(self, stereo_in): | |
| mid = stereo_in[:,0] + stereo_in[:,1] | |
| side = stereo_in[:,0] - stereo_in[:,1] | |
| return mid, side | |
| def magnitude_loss(self, est_mag_spec, tgt_mag_spec): | |
| return torch.norm(self.objective_l1(est_mag_spec, tgt_mag_spec)) | |
| def log_magnitude_loss(self, est_mag_spec, tgt_mag_spec): | |
| est_log_mag_spec = torch.log10(est_mag_spec+self.eps) | |
| tgt_log_mag_spec = torch.log10(tgt_mag_spec+self.eps) | |
| return self.objective_l2(est_log_mag_spec, tgt_log_mag_spec) | |
| # Class of available loss functions | |
| class Loss: | |
| def __init__(self, args, reduce=True): | |
| device = torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| device = torch.device(f"cuda:{args.gpu}") | |
| self.l1 = nn.L1Loss(reduce=reduce) | |
| self.mse = nn.MSELoss(reduce=reduce) | |
| self.ce = nn.CrossEntropyLoss() | |
| self.triplet = nn.TripletMarginLoss(margin=1., p=2) | |
| self.cos = nn.CosineSimilarity(eps=args.eps) | |
| self.cosemb = nn.CosineEmbeddingLoss() | |
| self.multi_scale_spectral_midside = MultiScale_Spectral_Loss_MidSide_DDSP(mode='midside', eps=args.eps, device=device) | |
| self.multi_scale_spectral_ori = MultiScale_Spectral_Loss_MidSide_DDSP(mode='ori', eps=args.eps, device=device) | |
| self.gain = RMSLoss(reduce=reduce) | |
| # perceptual weighting with mel scaled spectrograms | |
| self.mrs_mel_perceptual = auraloss.freq.MultiResolutionSTFTLoss( | |
| fft_sizes=[1024, 2048, 8192], | |
| hop_sizes=[256, 512, 2048], | |
| win_lengths=[1024, 2048, 8192], | |
| scale="mel", | |
| n_bins=128, | |
| sample_rate=args.sample_rate, | |
| perceptual_weighting=True, | |
| ) | |
| """ | |
| Audio Feature Loss implementation | |
| copied from https://github.com/sai-soum/Diff-MST/blob/main/mst/loss.py | |
| """ | |
| import librosa | |
| from typing import List | |
| from modules.filter import barkscale_fbanks | |
| def compute_mid_side(x: torch.Tensor): | |
| x_mid = x[:, 0, :] + x[:, 1, :] | |
| x_side = x[:, 0, :] - x[:, 1, :] | |
| return x_mid, x_side | |
| def compute_melspectrum( | |
| x: torch.Tensor, | |
| sample_rate: int = 44100, | |
| fft_size: int = 32768, | |
| n_bins: int = 128, | |
| **kwargs, | |
| ): | |
| """Compute mel-spectrogram. | |
| Args: | |
| x: (bs, 2, seq_len) | |
| sample_rate: sample rate of audio | |
| fft_size: size of fft | |
| n_bins: number of mel bins | |
| Returns: | |
| X: (bs, n_bins) | |
| """ | |
| fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) | |
| fb = torch.tensor(fb).unsqueeze(0).type_as(x) | |
| x = x.mean(dim=1, keepdim=True) | |
| X = torch.fft.rfft(x, n=fft_size, dim=-1) | |
| X = torch.abs(X) | |
| X = torch.mean(X, dim=1, keepdim=True) # take mean over time | |
| X = X.permute(0, 2, 1) # swap time and freq dims | |
| X = torch.matmul(fb, X) | |
| X = torch.log(X + 1e-8) | |
| return X | |
| def compute_barkspectrum( | |
| x: torch.Tensor, | |
| fft_size: int = 32768, | |
| n_bands: int = 24, | |
| sample_rate: int = 44100, | |
| f_min: float = 20.0, | |
| f_max: float = 20000.0, | |
| mode: str = "mid-side", | |
| **kwargs, | |
| ): | |
| """Compute bark-spectrogram. | |
| Args: | |
| x: (bs, 2, seq_len) | |
| fft_size: size of fft | |
| n_bands: number of bark bins | |
| sample_rate: sample rate of audio | |
| f_min: minimum frequency | |
| f_max: maximum frequency | |
| mode: "mono", "stereo", or "mid-side" | |
| Returns: | |
| X: (bs, 24) | |
| """ | |
| # compute filterbank | |
| fb = barkscale_fbanks((fft_size // 2) + 1, f_min, f_max, n_bands, sample_rate) | |
| fb = fb.unsqueeze(0).type_as(x) | |
| fb = fb.permute(0, 2, 1) | |
| if mode == "mono": | |
| x = x.mean(dim=1) # average over channels | |
| signals = [x] | |
| elif mode == "stereo": | |
| signals = [x[:, 0, :], x[:, 1, :]] | |
| elif mode == "mid-side": | |
| x_mid = x[:, 0, :] + x[:, 1, :] | |
| x_side = x[:, 0, :] - x[:, 1, :] | |
| signals = [x_mid, x_side] | |
| else: | |
| raise ValueError(f"Invalid mode {mode}") | |
| outputs = [] | |
| for signal in signals: | |
| X = torch.stft( | |
| signal, | |
| n_fft=fft_size, | |
| hop_length=fft_size // 4, | |
| return_complex=True, | |
| window=torch.hann_window(fft_size).to(x.device), | |
| ) # compute stft | |
| X = torch.abs(X) # take magnitude | |
| X = torch.mean(X, dim=-1, keepdim=True) # take mean over time | |
| # X = X.permute(0, 2, 1) # swap time and freq dims | |
| X = torch.matmul(fb, X) # apply filterbank | |
| X = torch.log(X + 1e-8) | |
| # X = torch.cat([X, X_log], dim=-1) | |
| outputs.append(X) | |
| # stack into tensor | |
| X = torch.cat(outputs, dim=-1) | |
| return X | |
| def compute_rms(x: torch.Tensor, **kwargs): | |
| """Compute root mean square energy. | |
| Args: | |
| x: (bs, 1, seq_len) | |
| Returns: | |
| rms: (bs, ) | |
| """ | |
| rms = torch.sqrt(torch.mean(x**2, dim=-1).clamp(min=1e-8)) | |
| return rms | |
| def compute_crest_factor(x: torch.Tensor, **kwargs): | |
| """Compute crest factor as ratio of peak to rms energy in dB. | |
| Args: | |
| x: (bs, 2, seq_len) | |
| """ | |
| num = torch.max(torch.abs(x), dim=-1)[0] | |
| den = compute_rms(x).clamp(min=1e-8) | |
| cf = 20 * torch.log10((num / den).clamp(min=1e-8)) | |
| return cf | |
| def compute_stereo_width(x: torch.Tensor, **kwargs): | |
| """Compute stereo width as ratio of energy in sum and difference signals. | |
| Args: | |
| x: (bs, 2, seq_len) | |
| """ | |
| bs, chs, seq_len = x.size() | |
| assert chs == 2, "Input must be stereo" | |
| # compute sum and diff of stereo channels | |
| x_sum = x[:, 0, :] + x[:, 1, :] | |
| x_diff = x[:, 0, :] - x[:, 1, :] | |
| # compute power of sum and diff | |
| sum_energy = torch.mean(x_sum**2, dim=-1) | |
| diff_energy = torch.mean(x_diff**2, dim=-1) | |
| # compute stereo width as ratio | |
| stereo_width = diff_energy / sum_energy.clamp(min=1e-8) | |
| return stereo_width | |
| def compute_stereo_imbalance(x: torch.Tensor, **kwargs): | |
| """Compute stereo imbalance as ratio of energy in left and right channels. | |
| Args: | |
| x: (bs, 2, seq_len) | |
| Returns: | |
| stereo_imbalance: (bs, ) | |
| """ | |
| left_energy = torch.mean(x[:, 0, :] ** 2, dim=-1) | |
| right_energy = torch.mean(x[:, 1, :] ** 2, dim=-1) | |
| stereo_imbalance = (right_energy - left_energy) / ( | |
| right_energy + left_energy | |
| ).clamp(min=1e-8) | |
| return stereo_imbalance | |
| class AudioFeatureLoss(torch.nn.Module): | |
| def __init__( | |
| self, | |
| weights: List[float], | |
| sample_rate: int, | |
| stem_separation: bool = False, | |
| use_clap: bool = False, | |
| ) -> None: | |
| """Compute loss using a set of differentiable audio features. | |
| Args: | |
| weights: weights for each feature | |
| sample_rate: sample rate of audio | |
| stem_separation: whether to compute loss on stems or mix | |
| Based on features proposed in: | |
| Man, B. D., et al. | |
| "An analysis and evaluation of audio features for multitrack music mixtures." | |
| (2014). | |
| """ | |
| super().__init__() | |
| self.weights = weights | |
| self.sample_rate = sample_rate | |
| self.stem_separation = stem_separation | |
| self.sources_list = ["mix"] | |
| self.source_weights = [1.0] | |
| self.use_clap = use_clap | |
| self.transforms = [ | |
| compute_rms, | |
| compute_crest_factor, | |
| compute_stereo_width, | |
| compute_stereo_imbalance, | |
| compute_barkspectrum, | |
| ] | |
| assert len(self.transforms) == len(weights) | |
| def forward(self, input: torch.Tensor, target: torch.Tensor): | |
| losses = {} | |
| # reshape for example stem dim | |
| input_stems = input.unsqueeze(1) | |
| target_stems = target.unsqueeze(1) | |
| n_stems = input_stems.shape[1] | |
| # iterate over each stem compute loss for each transform | |
| for stem_idx in range(n_stems): | |
| input_stem = input_stems[:, stem_idx, ...] | |
| target_stem = target_stems[:, stem_idx, ...] | |
| for transform, weight in zip(self.transforms, self.weights): | |
| transform_name = "_".join(transform.__name__.split("_")[1:]) | |
| key = f"{self.sources_list[stem_idx]}-{transform_name}" | |
| input_transform = transform(input_stem, sample_rate=self.sample_rate) | |
| target_transform = transform(target_stem, sample_rate=self.sample_rate) | |
| val = torch.nn.functional.mse_loss(input_transform, target_transform) | |
| losses[key] = weight * val * self.source_weights[stem_idx] | |
| return losses |