Spaces:
Runtime error
Runtime error
| from typing import Any, Tuple | |
| import torch | |
| import torchmetrics as tm | |
| from torchmetrics.audio.snr import SignalNoiseRatio | |
| from torchmetrics.functional.audio.snr import signal_noise_ratio, scale_invariant_signal_noise_ratio | |
| from torchmetrics.utilities.checks import _check_same_shape | |
| def safe_signal_noise_ratio( | |
| preds: torch.Tensor, target: torch.Tensor, zero_mean: bool = False | |
| ) -> torch.Tensor: | |
| return torch.nan_to_num( | |
| signal_noise_ratio(preds, target, zero_mean=zero_mean), nan=torch.nan, posinf=100.0, neginf=-100.0 | |
| ) | |
| def safe_scale_invariant_signal_noise_ratio( | |
| preds: torch.Tensor, target: torch.Tensor, | |
| zero_mean: bool = False | |
| ) -> torch.Tensor: | |
| """`Scale-invariant signal-to-distortion ratio`_ (SI-SDR). | |
| The SI-SDR value is in general considered an overall measure of how good a source sound. | |
| Args: | |
| preds: float tensor with shape ``(...,time)`` | |
| target: float tensor with shape ``(...,time)`` | |
| zero_mean: If to zero mean target and preds or not | |
| Returns: | |
| Float tensor with shape ``(...,)`` of SDR values per sample | |
| Raises: | |
| RuntimeError: | |
| If ``preds`` and ``target`` does not have the same shape | |
| """ | |
| return torch.nan_to_num( | |
| scale_invariant_signal_noise_ratio(preds, target), nan=torch.nan, posinf=100.0, | |
| neginf=-100.0 | |
| ) | |
| def decibels(x: torch.Tensor, threshold: float = 1e-6) -> torch.Tensor: | |
| mean_squared = torch.mean(torch.square(x), dim=-1) | |
| n_samples = x.shape[0] | |
| return torch.sum(10 * torch.log10(mean_squared + threshold)), n_samples | |
| class Decibels(tm.Metric): | |
| def __init__(self, threshold: float = 1e-6, **kwargs: Any) -> None: | |
| super().__init__(**kwargs) | |
| self.threshold = threshold | |
| self.add_state("running_mean", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
| self.add_state("running_count", default=torch.tensor(0), dist_reduce_fx="sum") | |
| def update(self, y): | |
| db, count = decibels(y, self.threshold) | |
| self.running_mean += db.cpu() | |
| self.running_count += count | |
| def compute(self) -> torch.Tensor: | |
| return self.running_mean / self.running_count | |
| # def reset(self) -> None: | |
| # self.running_mean = torch.tensor(0.0) | |
| # self.running_count = torch.tensor(0) | |
| class PredictedDecibels(Decibels): | |
| def update(self, ypred, ytrue) -> None: | |
| return super().update(ypred) | |
| class TargetDecibels(Decibels): | |
| def update(self, ypred, ytrue) -> None: | |
| return super().update(ytrue) | |
| class SafeSignalNoiseRatio(SignalNoiseRatio): | |
| def __init__( | |
| self, | |
| zero_mean: bool = False, | |
| threshold: float = 1e-6, | |
| fs: int = 44100, | |
| **kwargs: Any | |
| ) -> None: | |
| super().__init__(zero_mean, **kwargs) | |
| self.threshold = threshold | |
| self.fs = fs | |
| self.sample_mismatch_thresh_seconds = 0.1 | |
| self.add_state("snr_list", default=[], dist_reduce_fx="cat") | |
| def _fix_shape(self, preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| n_samples_preds = preds.shape[-1] | |
| n_samples_target = target.shape[-1] | |
| if n_samples_preds != n_samples_target: | |
| if ( | |
| abs(n_samples_preds - n_samples_target) / self.fs | |
| > self.sample_mismatch_thresh_seconds | |
| ): | |
| raise ValueError( | |
| "The difference between the number of samples of the predictions and the target is too large (100 ms)" | |
| ) | |
| min_samples = min(n_samples_preds, n_samples_target) | |
| preds = preds[..., :min_samples] | |
| target = target[..., :min_samples] | |
| return preds, target | |
| def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: | |
| """Update state with predictions and targets.""" | |
| preds, target = self._fix_shape(preds, target) | |
| snr_batch = safe_signal_noise_ratio( | |
| preds=preds, target=target, zero_mean=self.zero_mean | |
| ) | |
| self.snr_list.append(snr_batch) | |
| def compute(self) -> torch.Tensor: | |
| """Compute metric.""" | |
| if len(self.snr_list) == 0: | |
| return torch.tensor(float("nan")) | |
| return torch.nanmedian(torch.cat(self.snr_list)) | |
| class SafeScaleInvariantSignalNoiseRatio(SafeSignalNoiseRatio): | |
| def __init__( | |
| self, | |
| zero_mean: bool = False, | |
| threshold: float = 1e-6, | |
| fs: int = 44100, | |
| **kwargs: Any | |
| ) -> None: | |
| super().__init__(zero_mean, threshold, fs, **kwargs) | |
| def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: | |
| """Update state with predictions and targets.""" | |
| preds, target = self._fix_shape(preds, target) | |
| snr_batch = safe_scale_invariant_signal_noise_ratio( | |
| preds=preds, target=target, zero_mean=self.zero_mean | |
| ) | |
| self.snr_list.append(snr_batch) | |
| def compute(self) -> torch.Tensor: | |
| """Compute metric.""" | |
| if len(self.snr_list) == 0: | |
| return torch.tensor(float("nan")) | |
| return torch.nanmedian(torch.cat(self.snr_list)) | |