Diggz10's picture
Update app.py
7a18e70 verified
raw
history blame
3.08 kB
import gradio as gr
from transformers import pipeline
import soundfile as sf
import os
import base64
import tempfile
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn
# --- Load Model ---
try:
classifier = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er")
except Exception as e:
classifier = None
model_load_error = str(e)
else:
model_load_error = None
# --- Gradio Prediction Function ---
def predict_emotion(audio_file):
if classifier is None:
return {"error": f"Model load failed: {model_load_error}"}
if audio_file is None:
return {"error": "No audio input provided."}
try:
if isinstance(audio_file, str):
audio_path = audio_file
elif isinstance(audio_file, tuple):
sample_rate, audio_array = audio_file
temp_audio_path = "temp_audio.wav"
sf.write(temp_audio_path, audio_array, sample_rate)
audio_path = temp_audio_path
else:
return {"error": f"Unsupported input type: {type(audio_file)}"}
results = classifier(audio_path, top_k=5)
return {item['label']: round(item['score'], 3) for item in results}
except Exception as e:
return {"error": f"Prediction error: {str(e)}"}
finally:
if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
# --- FastAPI App for Base64 API ---
app = FastAPI()
@app.post("/api/predict/")
async def predict_emotion_api(request: Request):
if classifier is None:
return JSONResponse(content={"error": f"Model load failed: {model_load_error}"}, status_code=500)
try:
body = await request.json()
base64_audio = body.get("data")
if not base64_audio:
return JSONResponse(content={"error": "Missing 'data' field with base64 audio."}, status_code=400)
audio_data = base64.b64decode(base64_audio)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_file.write(audio_data)
temp_audio_path = temp_file.name
results = classifier(temp_audio_path, top_k=5)
os.unlink(temp_audio_path)
return {item['label']: round(item['score'], 3) for item in results}
except Exception as e:
return JSONResponse(content={"error": f"API prediction failed: {str(e)}"}, status_code=500)
# --- Gradio UI ---
gradio_interface = gr.Interface(
fn=predict_emotion,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"),
outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"),
title="Audio Emotion Detector",
description="Upload or record your voice to detect emotions.",
allow_flagging="never"
)
# --- Mount Gradio inside FastAPI ---
app = gr.mount_gradio_app(app, gradio_interface, path="/")
# --- Launch for local/dev use only ---
if __name__ == "__main__":
gradio_interface.queue()
uvicorn.run(app, host="0.0.0.0", port=7860)