Voice-guard / app /app.py
varunkul's picture
Upload 6 files
e2c61ce verified
raw
history blame
5.02 kB
import os
import json
import numpy as np
import gradio as gr
from dotenv import load_dotenv
from matplotlib import cm
load_dotenv()
# -------------------------
# 0) Env & defaults
# -------------------------
BACKEND = os.getenv("DETECTOR_BACKEND", "wav2vec2").strip().lower() # "wav2vec2" or "cnn"
DEFAULT_W2V_WEIGHTS = "app/models/weights/wav2vec2_classifier.pth"
DEFAULT_CNN_WEIGHTS = "app/models/weights/cnn_melspec.pth"
DEFAULT_WEIGHTS = DEFAULT_W2V_WEIGHTS if BACKEND == "wav2vec2" else DEFAULT_CNN_WEIGHTS
MODEL_WEIGHTS_PATH = os.getenv("MODEL_WEIGHTS_PATH", DEFAULT_WEIGHTS).strip()
# -------------------------
# 1) Import your Detector
# -------------------------
def _import_detector(backend):
"""
Import the correct Detector class depending on backend and package layout.
Works both when run as a module ('.inference_*') and as a script ('app.inference_*').
"""
try:
if backend == "wav2vec2":
from .inference_wav2vec import Detector # type: ignore
else:
from .inference import Detector # type: ignore
except Exception:
if backend == "wav2vec2":
from app.inference_wav2vec import Detector # type: ignore
else:
from app.inference import Detector # type: ignore
return Detector
try:
Detector = _import_detector(BACKEND)
except Exception as e:
# Fallback dummy to keep the UI alive even if import fails,
# so you can see the error in the JSON panel.
class Detector: # type: ignore
def __init__(self, *args, **kwargs):
self._err = f"Detector import failed: {e}"
def predict_proba(self, *args, **kwargs):
return {"error": self._err}
def explain(self, *args, **kwargs):
return {"cam": np.zeros((128, 128), dtype=np.float32).tolist()}
# Single, shared detector (created lazily so startup is fast on Spaces)
_DET = None
def _get_detector():
global _DET
if _DET is None:
_DET = Detector(weights_path=MODEL_WEIGHTS_PATH)
return _DET
# -------------------------
# 2) Core functions
# -------------------------
def predict_and_explain(audio_path: str | None, source_hint: str):
"""
audio_path: filepath from Gradio (since type='filepath')
source_hint: "Auto", "Microphone", "Upload"
"""
source = (source_hint or "Auto").strip().lower()
if not audio_path or not os.path.exists(audio_path):
return {"error": "No audio received. Record or upload a 2–4s clip."}, None
det = _get_detector()
# Your Detector is expected to accept a file path and optional source hint
proba = det.predict_proba(audio_path, source_hint=source)
exp = det.explain(audio_path, source_hint=source)
# Explanation to heatmap (float [0,1] -> magma RGB uint8)
cam = np.array(exp.get("cam", []), dtype=np.float32)
if cam.ndim == 1:
# if model returned a 1D vector, tile to square-ish map
side = int(np.sqrt(cam.size))
side = max(side, 2)
cam = cam[: side * side].reshape(side, side)
cam = np.clip(cam, 0.0, 1.0)
cam_rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
# Ensure proba is JSON-serializable
if not isinstance(proba, dict):
proba = {"result": proba}
return proba, cam_rgb
def provenance(audio_path: str | None):
# Stub (you can wire a provenance model or checksum here)
return {"ok": True, "note": "Provenance check not wired in this app.py."}
# -------------------------
# 3) UI
# -------------------------
with gr.Blocks(title=f"AI Voice Detector Β· {BACKEND.upper()}") as demo:
gr.Markdown(f"# πŸ”Ž AI Voice Detector β€” Backend: **{BACKEND.upper()}**")
gr.Markdown(
"Record or upload a short clip (~3s). Get probabilities, a label, and an explanation heatmap."
)
with gr.Row():
audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
with gr.Column():
src = gr.Radio(choices=["Auto", "Microphone", "Upload"], value="Auto", label="Source")
btn_predict = gr.Button("Analyze", variant="primary")
btn_prov = gr.Button("Provenance Check (optional)")
with gr.Row():
json_out = gr.JSON(label="Prediction (probabilities + label)")
cam_out = gr.Image(label="Explanation Heatmap (saliency)")
prov_out = gr.JSON(label="Provenance Result (if available)")
btn_predict.click(predict_and_explain, inputs=[audio_in, src], outputs=[json_out, cam_out])
btn_prov.click(provenance, inputs=audio_in, outputs=prov_out)
# -------------------------
# 4) Launch (Spaces-friendly)
# -------------------------
if __name__ == "__main__":
# queue() keeps UI responsive under load; host/port are Spaces-safe and local-friendly
demo.queue().launch(server_name="0.0.0.0", server_port=7860)