lispdetector / detect.py
rawag's picture
Upload 9 files
18f2c0e verified
raw
history blame
1.71 kB
import librosa
from transformers import WhisperForAudioClassification
# Load the trained model
model = WhisperForAudioClassification.from_pretrained("results/checkpoint-30")
# Load audio file
audio_path = "dataset/lisp/sample_01.wav"
audio, original_sr = librosa.load(audio_path, sr=44100)
# Resample to target sample rate (if needed)
target_sr = 16000
if original_sr != target_sr:
audio = librosa.resample(audio, orig_sr=original_sr, target_sr=target_sr)
# Extract features
mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=target_sr, n_mels=80, hop_length=512)
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram)
import torch
# Pad mel spectrogram to fixed length (assuming max_len is pre-defined)
max_len = 3000
pad_width = (0, max_len - mel_spectrogram_db.shape[1]) # Calculate padding width
mel_spectrogram_db_padded = torch.nn.functional.pad(torch.from_numpy(mel_spectrogram_db).float().unsqueeze(1),
pad_width, mode='constant', value=0)
# print(mel_spectrogram_db_padded.shape)
input_features = mel_spectrogram_db_padded
# Permute dimensions to match expected format
input_features = input_features.permute(1, 0, 2) # Permute dimensions to (batch_size, feature_dimension, sequence_length)
# print(input_features.shape)
# Create input dictionary with expected key
inputs = {'input_features': input_features}
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_ids = torch.argmax(logits).item()
predicted_label = model.config.id2label[predicted_class_ids]
print("Predicted label:", predicted_label)