File size: 1,711 Bytes
18f2c0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import librosa
from transformers import WhisperForAudioClassification

# Load the trained model
model = WhisperForAudioClassification.from_pretrained("results/checkpoint-30")

# Load audio file
audio_path = "dataset/lisp/sample_01.wav"
audio, original_sr = librosa.load(audio_path, sr=44100)

# Resample to target sample rate (if needed)
target_sr = 16000
if original_sr != target_sr:
    audio = librosa.resample(audio, orig_sr=original_sr, target_sr=target_sr)

# Extract features
mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=target_sr, n_mels=80, hop_length=512)
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram)

import torch

 # Pad mel spectrogram to fixed length (assuming max_len is pre-defined)
max_len = 3000
pad_width = (0, max_len - mel_spectrogram_db.shape[1])  # Calculate padding width
mel_spectrogram_db_padded = torch.nn.functional.pad(torch.from_numpy(mel_spectrogram_db).float().unsqueeze(1), 
                                                    pad_width, mode='constant', value=0)

# print(mel_spectrogram_db_padded.shape)

input_features = mel_spectrogram_db_padded

# Permute dimensions to match expected format
input_features = input_features.permute(1, 0, 2)  # Permute dimensions to (batch_size, feature_dimension, sequence_length)

# print(input_features.shape)

# Create input dictionary with expected key
inputs = {'input_features': input_features}

# Make prediction
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_ids = torch.argmax(logits).item()
    predicted_label = model.config.id2label[predicted_class_ids]

print("Predicted label:", predicted_label)