Yilin0601's picture
Update app.py
c076438 verified
raw
history blame
2.63 kB
# app.py
import gradio as gr
import torch
import numpy as np
import librosa
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
# ------------------------------------------------
# 1. Load base Wav2Vec2 model + classification head
# ------------------------------------------------
model_name = "facebook/wav2vec2-base-960h"
# Specify num_labels=8 to create a random classification head on top.
model = Wav2Vec2ForSequenceClassification.from_pretrained(
model_name,
num_labels=8
)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
model.eval()
# ------------------------------------------------
# 2. Define inference function
# ------------------------------------------------
def classify_accuracy(audio):
"""
Receives a tuple (sample_rate, data) from Gradio when type='numpy'.
Resamples if needed, runs a forward pass, and returns a 'level'.
"""
if audio is None:
return "No audio provided."
sample_rate, data = audio
# Ensure data is a NumPy array.
if not isinstance(data, np.ndarray):
data = np.array(data)
# Resample to 16kHz if needed.
target_sr = 16000
if sample_rate != target_sr:
data = librosa.resample(data, orig_sr=sample_rate, target_sr=target_sr)
sample_rate = target_sr
# Extract features from the audio data.
inputs = feature_extractor(
data,
sampling_rate=sample_rate,
return_tensors="pt",
padding=True
)
# Run model inference.
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_id = torch.argmax(logits, dim=-1).item()
# Map predicted id (0..7) to the final level (3..10).
accuracy_level = predicted_id + 3
return f"Predicted Accuracy Level: {accuracy_level}"
# ------------------------------------------------
# 3. Build Gradio interface
# ------------------------------------------------
title = "Speech Accuracy Classifier (Base Wav2Vec2)"
description = (
"Record audio using your microphone or upload an audio file (left). "
"The model (not fine-tuned) will classify the audio into an accuracy level (right)."
)
# Using source="microphone" allows for direct recording, while recent versions also enable file upload.
demo = gr.Interface(
fn=classify_accuracy,
inputs=gr.Audio(source="microphone", type="numpy", label="Record/Upload Audio"),
outputs=gr.Textbox(label="Classification Result"),
title=title,
description=description,
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()