Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import librosa | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import os | |
| from encoders.transformer import Wav2Vec2EmotionClassifier | |
| # Define the emotions | |
| emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"] | |
| label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)} | |
| # Load the trained model | |
| model_path = "lora_only_model.pth" | |
| cfg = { | |
| "model": { | |
| "encoder": "Wav2Vec2Classifier", | |
| "optimizer": { | |
| "name": "Adam", | |
| "lr": 0.0003, | |
| "weight_decay": 3e-4 | |
| }, | |
| "l1_lambda": 0.0 | |
| } | |
| } | |
| model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"]) | |
| state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| print(f"{name}: {param.data}") | |
| # Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors | |
| MIN_SAMPLES = 10 # or 16000 if you want at least 1 second | |
| # Preprocessing function | |
| def preprocess_audio(file_path, sample_rate=16000): | |
| """ | |
| Safely loads the file at file_path and returns a (1, samples) torch tensor. | |
| Returns None if the file is invalid or too short. | |
| """ | |
| if not file_path or (not os.path.exists(file_path)): | |
| # file_path could be None or an empty string if user didn't record properly | |
| return None | |
| # Load with librosa (which merges to mono by default if multi-channel) | |
| waveform, sr = librosa.load(file_path, sr=sample_rate) | |
| # Check length | |
| if len(waveform) < MIN_SAMPLES: | |
| return None | |
| # Convert to torch tensor, shape (1, samples) | |
| waveform_tensor = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0) | |
| return waveform_tensor | |
| # Prediction function | |
| def predict_emotion(audio_file): | |
| """ | |
| audio_file is a file path from Gradio (type='filepath'). | |
| """ | |
| # Preprocess | |
| waveform = preprocess_audio(audio_file, sample_rate=16000) | |
| # If invalid or too short, return an error-like message | |
| if waveform is None: | |
| return ( | |
| "Audio is too short or invalid. Please record/upload a longer clip.", | |
| "" | |
| ) | |
| # Perform inference | |
| with torch.no_grad(): | |
| logits = model(waveform) | |
| probabilities = F.softmax(logits, dim=-1).cpu().numpy()[0] | |
| # Get the predicted class | |
| predicted_class = np.argmax(probabilities) | |
| predicted_emotion = label_mapping[str(predicted_class)] | |
| # Format probabilities for visualization | |
| probabilities_output = [ | |
| f""" | |
| <div style='display: flex; align-items: center; margin: 5px 0;'> | |
| <div style='width: 20%; text-align: right; margin-right: 10px; font-weight: bold;'>{emotions[i]}</div> | |
| <div style='flex-grow: 1; background-color: #374151; border-radius: 4px; overflow: hidden;'> | |
| <div style='width: {probabilities[i]*100:.2f}%; background-color: #FFA500; height: 10px;'></div> | |
| </div> | |
| <div style='width: 10%; text-align: right; margin-left: 10px;'>{probabilities[i]*100:.2f}%</div> | |
| </div> | |
| """ | |
| for i in range(len(emotions)) | |
| ] | |
| return predicted_emotion, "\n".join(probabilities_output) | |
| # Create Gradio interface | |
| def gradio_interface(audio): | |
| detected_emotion, probabilities_html = predict_emotion(audio) | |
| return detected_emotion, gr.HTML(probabilities_html) | |
| # Define Gradio UI | |
| with gr.Blocks(css=""" | |
| body { | |
| background-color: #121212; | |
| color: white; | |
| font-family: Arial, sans-serif; | |
| } | |
| h1 { | |
| color: #FFA500; | |
| font-size: 48px; | |
| text-align: center; | |
| margin-bottom: 10px; | |
| } | |
| p { | |
| text-align: center; | |
| font-size: 18px; | |
| } | |
| .gradio-row { | |
| justify-content: center; | |
| align-items: center; | |
| } | |
| #submit_button { | |
| background-color: #FFA500 !important; | |
| color: black !important; | |
| font-size: 18px; | |
| padding: 10px 20px; | |
| margin-top: 20px; | |
| } | |
| #detected_emotion { | |
| font-size: 24px; | |
| font-weight: bold; | |
| text-align: center; | |
| } | |
| .probabilities-container { | |
| margin-top: 20px; | |
| padding: 10px; | |
| background-color: #1F2937; | |
| border-radius: 8px; | |
| } | |
| """) as demo: | |
| gr.Markdown( | |
| """ | |
| <div> | |
| <h1>Speech Emotion Recognition</h1> | |
| <p>๐ต Upload or record an audio file (max 1 minute) to detect emotions.</p> | |
| <p>Supported Emotions: ๐ Happy | ๐ญ Sad | ๐ก Angry | ๐ Neutral | ๐จ Fear | ๐คข Disgust | ๐ฎ Surprise</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_id="audio-block"): | |
| # type="filepath" means we get a temporary file path from Gradio | |
| audio_input = gr.Audio(label="๐ค Record or Upload Audio", type="filepath") | |
| submit_button = gr.Button("Submit", elem_id="submit_button") | |
| with gr.Column(scale=1): | |
| detected_emotion_label = gr.Label(label="Detected Emotion", elem_id="detected_emotion") | |
| probabilities_html = gr.HTML(label="Probabilities", elem_id="probabilities") | |
| submit_button.click( | |
| fn=gradio_interface, | |
| inputs=audio_input, | |
| outputs=[detected_emotion_label, probabilities_html] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |