|
|
import librosa
|
|
|
from transformers import WhisperForAudioClassification
|
|
|
|
|
|
|
|
|
model = WhisperForAudioClassification.from_pretrained("results/checkpoint-30")
|
|
|
|
|
|
|
|
|
audio_path = "dataset/lisp/sample_01.wav"
|
|
|
audio, original_sr = librosa.load(audio_path, sr=44100)
|
|
|
|
|
|
|
|
|
target_sr = 16000
|
|
|
if original_sr != target_sr:
|
|
|
audio = librosa.resample(audio, orig_sr=original_sr, target_sr=target_sr)
|
|
|
|
|
|
|
|
|
mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=target_sr, n_mels=80, hop_length=512)
|
|
|
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram)
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
max_len = 3000
|
|
|
pad_width = (0, max_len - mel_spectrogram_db.shape[1])
|
|
|
mel_spectrogram_db_padded = torch.nn.functional.pad(torch.from_numpy(mel_spectrogram_db).float().unsqueeze(1),
|
|
|
pad_width, mode='constant', value=0)
|
|
|
|
|
|
|
|
|
|
|
|
input_features = mel_spectrogram_db_padded
|
|
|
|
|
|
|
|
|
input_features = input_features.permute(1, 0, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = {'input_features': input_features}
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = model(**inputs)
|
|
|
logits = outputs.logits
|
|
|
predicted_class_ids = torch.argmax(logits).item()
|
|
|
predicted_label = model.config.id2label[predicted_class_ids]
|
|
|
|
|
|
print("Predicted label:", predicted_label)
|
|
|
|