Spaces:
Running
Running
| import gradio as gr | |
| from transformers import pipeline | |
| import soundfile as sf | |
| import os | |
| import base64 | |
| import tempfile | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse | |
| import uvicorn | |
| from fastapi.middleware.cors import CORSMiddleware # <--- 1. ADD THIS IMPORT | |
| # --- Load Model --- | |
| try: | |
| classifier = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er") | |
| except Exception as e: | |
| classifier = None | |
| model_load_error = str(e) | |
| else: | |
| model_load_error = None | |
| # --- FastAPI App --- | |
| app = FastAPI() | |
| # --- 2. ADD THIS ENTIRE BLOCK --- | |
| # This block adds the CORS middleware to allow your WordPress site to make requests. | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["https://tknassetshub.io"], # This gives your specific domain permission. | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --------------------------------- | |
| # This is our dedicated, robust API endpoint | |
| async def predict_emotion_api(request: Request): | |
| if classifier is None: | |
| return JSONResponse(content={"error": f"Model is not loaded: {model_load_error}"}, status_code=503) | |
| try: | |
| body = await request.json() | |
| # The JS FileReader sends a string like "data:audio/wav;base64,AABBCC..." | |
| base64_with_prefix = body.get("data") | |
| if not base64_with_prefix: | |
| return JSONResponse(content={"error": "Missing 'data' field in request body."}, status_code=400) | |
| # Robustly strip the prefix to get the pure base64 data | |
| try: | |
| header, encoded = base64_with_prefix.split(",", 1) | |
| audio_data = base64.b64decode(encoded) | |
| except (ValueError, TypeError): | |
| return JSONResponse(content={"error": "Invalid base64 data format. Please send the full data URI."}, status_code=400) | |
| # Write to a temporary file for the pipeline to process | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: | |
| temp_file.write(audio_data) | |
| temp_audio_path = temp_file.name | |
| results = classifier(temp_audio_path) | |
| os.unlink(temp_audio_path) # Clean up the temp file | |
| # Return a successful response with the data | |
| return JSONResponse(content={"data": results}) | |
| except Exception as e: | |
| # Clean up the temp file if it exists even after an error | |
| if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path): | |
| os.unlink(temp_audio_path) | |
| return JSONResponse(content={"error": f"Internal server error during prediction: {str(e)}"}, status_code=500) | |
| # --- Gradio UI (for demonstration on the Space's page) --- | |
| def gradio_predict_wrapper(audio_file_path): | |
| if classifier is None: return {"error": f"Model is not loaded: {model_load_error}"} | |
| if audio_file_path is None: return {"error": "Please provide an audio file."} | |
| try: | |
| results = classifier(audio_file_path, top_k=5) | |
| return {item['label']: item['score'] for item in results} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| gradio_interface = gr.Interface( | |
| fn=gradio_predict_wrapper, | |
| inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"), | |
| outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"), | |
| title="Audio Emotion Detector", | |
| description="This UI is for direct demonstration. The primary API for websites is at /api/predict/", | |
| allow_flagging="never" | |
| ) | |
| # Mount the Gradio UI onto a subpath of our FastAPI app | |
| app = gr.mount_gradio_app(app, gradio_interface, path="/gradio") | |
| # The Uvicorn server launch command (used by Hugging Face Spaces) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |