Jihuai's picture
have to create an orphan branch to bypass large file history: cleanup .ipynb and create LFS
d572f56
raw
history blame
5.16 kB
import argparse
import torch
import torchaudio
from typing import Tuple
import museval
from museval.aggregate import TrackStore
from core.metrics.snr import safe_signal_noise_ratio, safe_scale_invariant_signal_noise_ratio
from torchmetrics.audio import SignalNoiseRatio, ScaleInvariantSignalNoiseRatio, SignalDistortionRatio, ScaleInvariantSignalDistortionRatio
def load_and_preprocess_audio(
audio_path: str,
target_sample_rate: int = 44100,
target_channels: int = 1,
verbose: bool = True
) -> torch.Tensor:
# εŠ θ½½ιŸ³ι’‘
waveform, orig_sample_rate = torchaudio.load(audio_path)
if verbose:
print(f"Loaded audio: {audio_path}")
print(f" Original shape: {waveform.shape}")
print(f" Original sample rate: {orig_sample_rate} Hz")
print(f" Duration: {waveform.shape[1] / orig_sample_rate:.2f} seconds")
print(f" Target sample rate: {target_sample_rate} Hz")
print(f" Target channels: {target_channels}")
# 重采样倄理
if orig_sample_rate != target_sample_rate:
if verbose:
print(f" Resampling: {orig_sample_rate} Hz -> {target_sample_rate} Hz")
resampler = torchaudio.transforms.Resample(
orig_sample_rate, target_sample_rate
)
waveform = resampler(waveform)
if verbose:
print(f" After resampling shape: {waveform.shape}")
print(f" New duration: {waveform.shape[1] / target_sample_rate:.2f} seconds")
# ι€šι“ε€„η†
current_channels = waveform.shape[0]
if current_channels > target_channels:
if verbose:
print(f" Downmixing: {current_channels} channels -> {target_channels} channel(s)")
print(f" Using mean averaging for downmixing")
assert target_channels == 1, "Downmixing only supported to mono"
waveform = waveform.mean(dim=0, keepdim=True)
if verbose:
print(f" After downmixing shape: {waveform.shape}")
elif current_channels < target_channels:
if verbose:
print(f" Upmixing: {current_channels} channel(s) -> {target_channels} channels")
print(f" Repeating single channel data")
assert waveform.shape[0] == 1, "Upmixing only supported from mono"
waveform = waveform.repeat(target_channels, 1)
if verbose:
print(f" After upmixing shape: {waveform.shape}")
else:
if verbose:
print(f" No channel conversion needed (already {target_channels} channels)")
# ζœ€η»ˆδΏ‘ζ―
if verbose:
print(f" Final shape: {waveform.shape}")
print(f" Final sample rate: {target_sample_rate} Hz")
print(f" Final duration: {waveform.shape[1] / target_sample_rate:.2f} seconds")
print("-" * 50)
return waveform
def align_audio_length(
audio1: torch.Tensor,
audio2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
min_length = min(audio1.shape[1], audio2.shape[1])
audio1_aligned = audio1[:, :min_length]
audio2_aligned = audio2[:, :min_length]
return audio1_aligned, audio2_aligned
def calculate_audio_snr(
pred_audio_path: str,
target_audio_path: str,
):
target_audio, target_sr = torchaudio.load(target_audio_path)
target_sample_rate = target_sr
target_channels = target_audio.shape[0]
pred_audio = load_and_preprocess_audio(
pred_audio_path, target_sample_rate, target_channels
)
target_audio = load_and_preprocess_audio(
target_audio_path, target_sample_rate, target_channels
)
pred_audio, target_audio = align_audio_length(pred_audio, target_audio)
SDR, ISR, SIR, SAR = museval.evaluate(pred_audio.T.unsqueeze(0), target_audio.T.unsqueeze(0))
data = TrackStore(track_name="test")
data.add_target(target_name="target", values={"SDR": SDR[0].tolist(), "SIR": SIR[0].tolist(), "SAR": SAR[0].tolist(), "ISR": ISR[0].tolist()})
print("MusEval results: ")
print(data)
snr_value = safe_signal_noise_ratio(pred_audio, target_audio)
print("SNR", snr_value)
sisnr_value = safe_scale_invariant_signal_noise_ratio(pred_audio, target_audio)
print("SI-SNR", sisnr_value)
torch_snr = SignalNoiseRatio()
torch_snr_value = torch_snr(pred_audio, target_audio)
print("Torch SNR", torch_snr_value)
torch_sisnr = ScaleInvariantSignalNoiseRatio()
torch_sisnr_value = torch_sisnr(pred_audio, target_audio)
print("Torch SI-SNR", torch_sisnr_value)
torch_sdr = SignalDistortionRatio()
torch_sdr_value = torch_sdr(pred_audio, target_audio)
print("Torch SDR", torch_sdr_value)
torch_sisdr = ScaleInvariantSignalDistortionRatio()
torch_sisdr_value = torch_sisdr(pred_audio, target_audio)
print("Torch SI-SDR", torch_sisdr_value)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="metric")
parser.add_argument("a", help="a.wav")
parser.add_argument("b", help="b.wav")
args = parser.parse_args()
calculate_audio_snr(args.a, args.b)