Spaces:
Running
Running
| import os | |
| import json | |
| import numpy as np | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from matplotlib import cm | |
| load_dotenv() | |
| # ------------------------- | |
| # 0) Env & defaults | |
| # ------------------------- | |
| BACKEND = os.getenv("DETECTOR_BACKEND", "wav2vec2").strip().lower() # "wav2vec2" or "cnn" | |
| DEFAULT_W2V_WEIGHTS = "app/models/weights/wav2vec2_classifier.pth" | |
| DEFAULT_CNN_WEIGHTS = "app/models/weights/cnn_melspec.pth" | |
| DEFAULT_WEIGHTS = DEFAULT_W2V_WEIGHTS if BACKEND == "wav2vec2" else DEFAULT_CNN_WEIGHTS | |
| MODEL_WEIGHTS_PATH = os.getenv("MODEL_WEIGHTS_PATH", DEFAULT_WEIGHTS).strip() | |
| # ------------------------- | |
| # 1) Import your Detector | |
| # ------------------------- | |
| def _import_detector(backend): | |
| """ | |
| Import the correct Detector class depending on backend and package layout. | |
| Works both when run as a module ('.inference_*') and as a script ('app.inference_*'). | |
| """ | |
| try: | |
| if backend == "wav2vec2": | |
| from .inference_wav2vec import Detector # type: ignore | |
| else: | |
| from .inference import Detector # type: ignore | |
| except Exception: | |
| if backend == "wav2vec2": | |
| from app.inference_wav2vec import Detector # type: ignore | |
| else: | |
| from app.inference import Detector # type: ignore | |
| return Detector | |
| try: | |
| Detector = _import_detector(BACKEND) | |
| except Exception as e: | |
| # Fallback dummy to keep the UI alive even if import fails, | |
| # so you can see the error in the JSON panel. | |
| class Detector: # type: ignore | |
| def __init__(self, *args, **kwargs): | |
| self._err = f"Detector import failed: {e}" | |
| def predict_proba(self, *args, **kwargs): | |
| return {"error": self._err} | |
| def explain(self, *args, **kwargs): | |
| return {"cam": np.zeros((128, 128), dtype=np.float32).tolist()} | |
| # Single, shared detector (created lazily so startup is fast on Spaces) | |
| _DET = None | |
| def _get_detector(): | |
| global _DET | |
| if _DET is None: | |
| _DET = Detector(weights_path=MODEL_WEIGHTS_PATH) | |
| return _DET | |
| # ------------------------- | |
| # 2) Core functions | |
| # ------------------------- | |
| def predict_and_explain(audio_path: str | None, source_hint: str): | |
| """ | |
| audio_path: filepath from Gradio (since type='filepath') | |
| source_hint: "Auto", "Microphone", "Upload" | |
| """ | |
| source = (source_hint or "Auto").strip().lower() | |
| if not audio_path or not os.path.exists(audio_path): | |
| return {"error": "No audio received. Record or upload a 2β4s clip."}, None | |
| det = _get_detector() | |
| # Your Detector is expected to accept a file path and optional source hint | |
| proba = det.predict_proba(audio_path, source_hint=source) | |
| exp = det.explain(audio_path, source_hint=source) | |
| # Explanation to heatmap (float [0,1] -> magma RGB uint8) | |
| cam = np.array(exp.get("cam", []), dtype=np.float32) | |
| if cam.ndim == 1: | |
| # if model returned a 1D vector, tile to square-ish map | |
| side = int(np.sqrt(cam.size)) | |
| side = max(side, 2) | |
| cam = cam[: side * side].reshape(side, side) | |
| cam = np.clip(cam, 0.0, 1.0) | |
| cam_rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8) | |
| # Ensure proba is JSON-serializable | |
| if not isinstance(proba, dict): | |
| proba = {"result": proba} | |
| return proba, cam_rgb | |
| def provenance(audio_path: str | None): | |
| # Stub (you can wire a provenance model or checksum here) | |
| return {"ok": True, "note": "Provenance check not wired in this app.py."} | |
| # ------------------------- | |
| # 3) UI | |
| # ------------------------- | |
| with gr.Blocks(title=f"AI Voice Detector Β· {BACKEND.upper()}") as demo: | |
| gr.Markdown(f"# π AI Voice Detector β Backend: **{BACKEND.upper()}**") | |
| gr.Markdown( | |
| "Record or upload a short clip (~3s). Get probabilities, a label, and an explanation heatmap." | |
| ) | |
| with gr.Row(): | |
| audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio") | |
| with gr.Column(): | |
| src = gr.Radio(choices=["Auto", "Microphone", "Upload"], value="Auto", label="Source") | |
| btn_predict = gr.Button("Analyze", variant="primary") | |
| btn_prov = gr.Button("Provenance Check (optional)") | |
| with gr.Row(): | |
| json_out = gr.JSON(label="Prediction (probabilities + label)") | |
| cam_out = gr.Image(label="Explanation Heatmap (saliency)") | |
| prov_out = gr.JSON(label="Provenance Result (if available)") | |
| btn_predict.click(predict_and_explain, inputs=[audio_in, src], outputs=[json_out, cam_out]) | |
| btn_prov.click(provenance, inputs=audio_in, outputs=prov_out) | |
| # ------------------------- | |
| # 4) Launch (Spaces-friendly) | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| # queue() keeps UI responsive under load; host/port are Spaces-safe and local-friendly | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) |