# app.py import gradio as gr import torch import numpy as np import librosa from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification # ------------------------------------------------ # 1. Load base Wav2Vec2 model + classification head # ------------------------------------------------ model_name = "facebook/wav2vec2-base-960h" # Specify num_labels=8 to create a random classification head on top. model = Wav2Vec2ForSequenceClassification.from_pretrained( model_name, num_labels=8 ) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) model.eval() # ------------------------------------------------ # 2. Define inference function # ------------------------------------------------ def classify_accuracy(audio): """ Receives a tuple (sample_rate, data) from Gradio when type='numpy'. Resamples if needed, runs a forward pass, and returns a 'level'. """ if audio is None: return "No audio provided." sample_rate, data = audio # Ensure data is a NumPy array. if not isinstance(data, np.ndarray): data = np.array(data) # Resample to 16kHz if needed. target_sr = 16000 if sample_rate != target_sr: data = librosa.resample(data, orig_sr=sample_rate, target_sr=target_sr) sample_rate = target_sr # Extract features from the audio data. inputs = feature_extractor( data, sampling_rate=sample_rate, return_tensors="pt", padding=True ) # Run model inference. with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_id = torch.argmax(logits, dim=-1).item() # Map predicted id (0..7) to the final level (3..10). accuracy_level = predicted_id + 3 return f"Predicted Accuracy Level: {accuracy_level}" # ------------------------------------------------ # 3. Build Gradio interface # ------------------------------------------------ title = "Speech Accuracy Classifier (Base Wav2Vec2)" description = ( "Record audio using your microphone or upload an audio file (left). " "The model (not fine-tuned) will classify the audio into an accuracy level (right)." ) # Using source="microphone" allows for direct recording, while recent versions also enable file upload. demo = gr.Interface( fn=classify_accuracy, inputs=gr.Audio(source="microphone", type="numpy", label="Record/Upload Audio"), outputs=gr.Textbox(label="Classification Result"), title=title, description=description, allow_flagging="never" ) if __name__ == "__main__": demo.launch()