Jihuai's picture
have to create an orphan branch to bypass large file history: cleanup .ipynb and create LFS
d572f56
import argparse
import torch
import torchaudio
from typing import Tuple
import museval
from museval.aggregate import TrackStore
from core.metrics.snr import safe_signal_noise_ratio, safe_scale_invariant_signal_noise_ratio
from torchmetrics.audio import SignalNoiseRatio, ScaleInvariantSignalNoiseRatio, SignalDistortionRatio, ScaleInvariantSignalDistortionRatio
def load_and_preprocess_audio(
audio_path: str,
target_sample_rate: int = 44100,
target_channels: int = 1,
verbose: bool = True
) -> torch.Tensor:
# εŠ θ½½ιŸ³ι’‘
waveform, orig_sample_rate = torchaudio.load(audio_path)
if verbose:
print(f"Loaded audio: {audio_path}")
print(f" Original shape: {waveform.shape}")
print(f" Original sample rate: {orig_sample_rate} Hz")
print(f" Duration: {waveform.shape[1] / orig_sample_rate:.2f} seconds")
print(f" Target sample rate: {target_sample_rate} Hz")
print(f" Target channels: {target_channels}")
# 重采样倄理
if orig_sample_rate != target_sample_rate:
if verbose:
print(f" Resampling: {orig_sample_rate} Hz -> {target_sample_rate} Hz")
resampler = torchaudio.transforms.Resample(
orig_sample_rate, target_sample_rate
)
waveform = resampler(waveform)
if verbose:
print(f" After resampling shape: {waveform.shape}")
print(f" New duration: {waveform.shape[1] / target_sample_rate:.2f} seconds")
# ι€šι“ε€„η†
current_channels = waveform.shape[0]
if current_channels > target_channels:
if verbose:
print(f" Downmixing: {current_channels} channels -> {target_channels} channel(s)")
print(f" Using mean averaging for downmixing")
assert target_channels == 1, "Downmixing only supported to mono"
waveform = waveform.mean(dim=0, keepdim=True)
if verbose:
print(f" After downmixing shape: {waveform.shape}")
elif current_channels < target_channels:
if verbose:
print(f" Upmixing: {current_channels} channel(s) -> {target_channels} channels")
print(f" Repeating single channel data")
assert waveform.shape[0] == 1, "Upmixing only supported from mono"
waveform = waveform.repeat(target_channels, 1)
if verbose:
print(f" After upmixing shape: {waveform.shape}")
else:
if verbose:
print(f" No channel conversion needed (already {target_channels} channels)")
# ζœ€η»ˆδΏ‘ζ―
if verbose:
print(f" Final shape: {waveform.shape}")
print(f" Final sample rate: {target_sample_rate} Hz")
print(f" Final duration: {waveform.shape[1] / target_sample_rate:.2f} seconds")
print("-" * 50)
return waveform
def align_audio_length(
audio1: torch.Tensor,
audio2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
min_length = min(audio1.shape[1], audio2.shape[1])
audio1_aligned = audio1[:, :min_length]
audio2_aligned = audio2[:, :min_length]
return audio1_aligned, audio2_aligned
def calculate_audio_snr(
pred_audio_path: str,
target_audio_path: str,
):
target_audio, target_sr = torchaudio.load(target_audio_path)
target_sample_rate = target_sr
target_channels = target_audio.shape[0]
pred_audio = load_and_preprocess_audio(
pred_audio_path, target_sample_rate, target_channels
)
target_audio = load_and_preprocess_audio(
target_audio_path, target_sample_rate, target_channels
)
pred_audio, target_audio = align_audio_length(pred_audio, target_audio)
SDR, ISR, SIR, SAR = museval.evaluate(pred_audio.T.unsqueeze(0), target_audio.T.unsqueeze(0))
data = TrackStore(track_name="test")
data.add_target(target_name="target", values={"SDR": SDR[0].tolist(), "SIR": SIR[0].tolist(), "SAR": SAR[0].tolist(), "ISR": ISR[0].tolist()})
print("MusEval results: ")
print(data)
snr_value = safe_signal_noise_ratio(pred_audio, target_audio)
print("SNR", snr_value)
sisnr_value = safe_scale_invariant_signal_noise_ratio(pred_audio, target_audio)
print("SI-SNR", sisnr_value)
torch_snr = SignalNoiseRatio()
torch_snr_value = torch_snr(pred_audio, target_audio)
print("Torch SNR", torch_snr_value)
torch_sisnr = ScaleInvariantSignalNoiseRatio()
torch_sisnr_value = torch_sisnr(pred_audio, target_audio)
print("Torch SI-SNR", torch_sisnr_value)
torch_sdr = SignalDistortionRatio()
torch_sdr_value = torch_sdr(pred_audio, target_audio)
print("Torch SDR", torch_sdr_value)
torch_sisdr = ScaleInvariantSignalDistortionRatio()
torch_sisdr_value = torch_sisdr(pred_audio, target_audio)
print("Torch SI-SDR", torch_sisdr_value)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="metric")
parser.add_argument("a", help="a.wav")
parser.add_argument("b", help="b.wav")
args = parser.parse_args()
calculate_audio_snr(args.a, args.b)