File size: 2,279 Bytes
e2c61ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os, io, base64
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Dict, Any
from PIL import Image
from matplotlib import cm

BACKEND = os.getenv("DETECTOR_BACKEND", "wav2vec2").lower()
try:
    if BACKEND == "wav2vec2":
        from .inference_wav2vec import Detector  # type: ignore
    else:
        from .inference import Detector  # type: ignore
except Exception:
    if BACKEND == "wav2vec2":
        from app.inference_wav2vec import Detector  # type: ignore
    else:
        from app.inference import Detector  # type: ignore

DEFAULT_WEIGHTS = "app/models/weights/wav2vec2_classifier.pth" if BACKEND=="wav2vec2" else "app/models/weights/cnn_melspec.pth"
WEIGHTS = os.getenv("MODEL_WEIGHTS_PATH", DEFAULT_WEIGHTS)
det = Detector(weights_path=WEIGHTS)

app = FastAPI(title="Voice Guard API", version="1.1.0")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # tighten in prod
    allow_methods=["*"],
    allow_headers=["*"],
)

class AnalyzeResponse(BaseModel):
    human: float
    ai: float
    label: str
    threshold: float
    threshold_source: Optional[str] = None
    backend: str
    source_hint: str
    replay_score: Optional[float] = None
    decision: Optional[str] = None
    decision_details: Optional[Dict[str, Any]] = None
    heatmap_b64: str

def heatmap_png_b64(cam: np.ndarray) -> str:
    cam = np.clip(cam, 0.0, 1.0).astype(np.float32)
    rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
    im = Image.fromarray(rgb)
    buf = io.BytesIO(); im.save(buf, format="PNG")
    return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode("ascii")

@app.post("/analyze", response_model=AnalyzeResponse)
async def analyze(file: UploadFile = File(...), source_hint: str = Form("auto")):
    raw = await file.read()
    proba = det.predict_proba(raw, source_hint=source_hint)
    cam = np.array(det.explain(raw, source_hint=source_hint)["cam"], dtype=np.float32)
    return {
        **proba,
        "heatmap_b64": heatmap_png_b64(cam),
    }

@app.get("/health")
def health(): return {"ok": True, "backend": BACKEND}