Spaces:
Sleeping
Sleeping
File size: 2,279 Bytes
e2c61ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import os, io, base64
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Dict, Any
from PIL import Image
from matplotlib import cm
BACKEND = os.getenv("DETECTOR_BACKEND", "wav2vec2").lower()
try:
if BACKEND == "wav2vec2":
from .inference_wav2vec import Detector # type: ignore
else:
from .inference import Detector # type: ignore
except Exception:
if BACKEND == "wav2vec2":
from app.inference_wav2vec import Detector # type: ignore
else:
from app.inference import Detector # type: ignore
DEFAULT_WEIGHTS = "app/models/weights/wav2vec2_classifier.pth" if BACKEND=="wav2vec2" else "app/models/weights/cnn_melspec.pth"
WEIGHTS = os.getenv("MODEL_WEIGHTS_PATH", DEFAULT_WEIGHTS)
det = Detector(weights_path=WEIGHTS)
app = FastAPI(title="Voice Guard API", version="1.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # tighten in prod
allow_methods=["*"],
allow_headers=["*"],
)
class AnalyzeResponse(BaseModel):
human: float
ai: float
label: str
threshold: float
threshold_source: Optional[str] = None
backend: str
source_hint: str
replay_score: Optional[float] = None
decision: Optional[str] = None
decision_details: Optional[Dict[str, Any]] = None
heatmap_b64: str
def heatmap_png_b64(cam: np.ndarray) -> str:
cam = np.clip(cam, 0.0, 1.0).astype(np.float32)
rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
im = Image.fromarray(rgb)
buf = io.BytesIO(); im.save(buf, format="PNG")
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode("ascii")
@app.post("/analyze", response_model=AnalyzeResponse)
async def analyze(file: UploadFile = File(...), source_hint: str = Form("auto")):
raw = await file.read()
proba = det.predict_proba(raw, source_hint=source_hint)
cam = np.array(det.explain(raw, source_hint=source_hint)["cam"], dtype=np.float32)
return {
**proba,
"heatmap_b64": heatmap_png_b64(cam),
}
@app.get("/health")
def health(): return {"ok": True, "backend": BACKEND} |