File size: 3,083 Bytes
cae86cc
 
 
 
5cde27f
 
7a18e70
 
 
cae86cc
7a18e70
0ea75df
6608513
0ea75df
6608513
7a18e70
 
 
cae86cc
7a18e70
cae86cc
7a18e70
 
 
5cde27f
7a18e70
cae86cc
7a18e70
 
 
 
 
 
 
 
 
 
6608513
 
7a18e70
 
cae86cc
7a18e70
5cde27f
 
7a18e70
 
 
 
 
5cde27f
7a18e70
5cde27f
 
7a18e70
 
 
 
 
 
 
5cde27f
 
7a18e70
5cde27f
 
7a18e70
5cde27f
 
7a18e70
cae86cc
7a18e70
 
cae86cc
7a18e70
 
 
 
 
cae86cc
 
7a18e70
 
 
 
cae86cc
7a18e70
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from transformers import pipeline
import soundfile as sf
import os
import base64
import tempfile
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn

# --- Load Model ---
try:
    classifier = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er")
except Exception as e:
    classifier = None
    model_load_error = str(e)
else:
    model_load_error = None

# --- Gradio Prediction Function ---
def predict_emotion(audio_file):
    if classifier is None:
        return {"error": f"Model load failed: {model_load_error}"}
    if audio_file is None:
        return {"error": "No audio input provided."}

    try:
        if isinstance(audio_file, str):
            audio_path = audio_file
        elif isinstance(audio_file, tuple):
            sample_rate, audio_array = audio_file
            temp_audio_path = "temp_audio.wav"
            sf.write(temp_audio_path, audio_array, sample_rate)
            audio_path = temp_audio_path
        else:
            return {"error": f"Unsupported input type: {type(audio_file)}"}

        results = classifier(audio_path, top_k=5)
        return {item['label']: round(item['score'], 3) for item in results}
    except Exception as e:
        return {"error": f"Prediction error: {str(e)}"}
    finally:
        if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path):
            os.remove(temp_audio_path)

# --- FastAPI App for Base64 API ---
app = FastAPI()

@app.post("/api/predict/")
async def predict_emotion_api(request: Request):
    if classifier is None:
        return JSONResponse(content={"error": f"Model load failed: {model_load_error}"}, status_code=500)
    
    try:
        body = await request.json()
        base64_audio = body.get("data")
        if not base64_audio:
            return JSONResponse(content={"error": "Missing 'data' field with base64 audio."}, status_code=400)

        audio_data = base64.b64decode(base64_audio)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
            temp_file.write(audio_data)
            temp_audio_path = temp_file.name

        results = classifier(temp_audio_path, top_k=5)
        os.unlink(temp_audio_path)

        return {item['label']: round(item['score'], 3) for item in results}
    except Exception as e:
        return JSONResponse(content={"error": f"API prediction failed: {str(e)}"}, status_code=500)

# --- Gradio UI ---
gradio_interface = gr.Interface(
    fn=predict_emotion,
    inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"),
    outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"),
    title="Audio Emotion Detector",
    description="Upload or record your voice to detect emotions.",
    allow_flagging="never"
)

# --- Mount Gradio inside FastAPI ---
app = gr.mount_gradio_app(app, gradio_interface, path="/")

# --- Launch for local/dev use only ---
if __name__ == "__main__":
    gradio_interface.queue()
    uvicorn.run(app, host="0.0.0.0", port=7860)