|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
import librosa |
|
|
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "facebook/wav2vec2-base-960h" |
|
|
|
|
|
|
|
|
model = Wav2Vec2ForSequenceClassification.from_pretrained( |
|
|
model_name, |
|
|
num_labels=8 |
|
|
) |
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_accuracy(audio): |
|
|
""" |
|
|
Receives a tuple (sample_rate, data) from Gradio when type='numpy'. |
|
|
Resamples if needed, runs a forward pass, and returns a 'level'. |
|
|
""" |
|
|
if audio is None: |
|
|
return "No audio provided." |
|
|
|
|
|
sample_rate, data = audio |
|
|
|
|
|
|
|
|
if not isinstance(data, np.ndarray): |
|
|
data = np.array(data) |
|
|
|
|
|
|
|
|
target_sr = 16000 |
|
|
if sample_rate != target_sr: |
|
|
data = librosa.resample(data, orig_sr=sample_rate, target_sr=target_sr) |
|
|
sample_rate = target_sr |
|
|
|
|
|
|
|
|
inputs = feature_extractor( |
|
|
data, |
|
|
sampling_rate=sample_rate, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
predicted_id = torch.argmax(logits, dim=-1).item() |
|
|
|
|
|
|
|
|
accuracy_level = predicted_id + 3 |
|
|
return f"Predicted Accuracy Level: {accuracy_level}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
title = "Speech Accuracy Classifier (Base Wav2Vec2)" |
|
|
description = ( |
|
|
"Record audio using your microphone or upload an audio file (left). " |
|
|
"The model (not fine-tuned) will classify the audio into an accuracy level (right)." |
|
|
) |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=classify_accuracy, |
|
|
inputs=gr.Audio(source="microphone", type="numpy", label="Record/Upload Audio"), |
|
|
outputs=gr.Textbox(label="Classification Result"), |
|
|
title=title, |
|
|
description=description, |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|