Voice-guard / app /inference_wav2vec.py
varunkul's picture
Upload 6 files
e2c61ce verified
import os, json
import numpy as np
import torch
import torch.nn.functional as F
from .models.wav2vec_detector import Wav2VecClassifier
from .utils.audio import load_audio, pad_or_trim, TARGET_SR
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ---------- Thresholds & biases ----------
AI_THRESHOLD_DEFAULT = float(os.getenv("DETECTOR_AI_THRESHOLD", "0.60"))
MIC_THRESHOLD = float(os.getenv("DETECTOR_MIC_THRESHOLD", "0.68"))
UPLOAD_THRESHOLD = float(os.getenv("DETECTOR_UPLOAD_THRESHOLD", str(AI_THRESHOLD_DEFAULT)))
AI_LOGIT_BIAS = float(os.getenv("DETECTOR_AI_LOGIT_BIAS", "0.00")) # add to AI logit globally
# ---------- Decision rule ----------
# 'threshold' -> AI if ai_prob >= threshold
# 'argmax' -> AI if ai_prob > human_prob
# 'hybrid' -> threshold, but if replay_score >= T1 and ai_prob >= 0.50 -> AI
DECISION_RULE = os.getenv("DECISION_RULE", "threshold").lower()
# ---------- Replay-attack heuristic ----------
REPLAY_ENABLE = os.getenv("REPLAY_ENABLE", "1") != "0"
REPLAY_AI_BONUS = float(os.getenv("REPLAY_AI_BONUS", "1.2"))
REPLAY_FORCE_LABEL = os.getenv("REPLAY_FORCE_LABEL", "0") == "1"
REPLAY_T1 = float(os.getenv("REPLAY_T1", "0.35")) # soft start
REPLAY_T2 = float(os.getenv("REPLAY_T2", "0.55")) # strong replay
# ---------- DSP helpers ----------
def peak_normalize(y: np.ndarray, peak: float = 0.95, eps: float = 1e-9) -> np.ndarray:
m = float(np.max(np.abs(y)) + eps)
return (y / m) * peak if m > 0 else y
def rms_normalize(y: np.ndarray, target_rms: float = 0.03, eps: float = 1e-9) -> np.ndarray:
rms = float(np.sqrt(np.mean(y**2)) + eps)
g = target_rms / rms
return np.clip(y * g, -1.0, 1.0)
def trim_silence(y: np.ndarray, sr: int, thresh_db: float = 40.0, min_ms: int = 30) -> np.ndarray:
if y.size == 0: return y
win = max(1, int(sr * 0.02))
pad = max(1, int(sr * (min_ms / 1000.0)))
energy = np.convolve(y ** 2, np.ones(win) / win, mode="same")
ref = np.max(energy) + 1e-12
mask = 10.0 * np.log10(energy / ref + 1e-12) > -thresh_db
if not np.any(mask): return y
idx = np.where(mask)[0]
start = max(0, int(idx[0] - pad))
end = min(len(y), int(idx[-1] + pad))
return y[start:end]
def noise_gate(y, sr, gate_db=-42.0):
m = np.max(np.abs(y)) + 1e-9
thr = m * (10.0 ** (gate_db / 20.0))
y2 = y.copy()
y2[np.abs(y2) < thr] = 0.0
return y2
def bandpass_fft(y: np.ndarray, sr: int, low=100.0, high=3800.0):
n = int(2 ** np.ceil(np.log2(len(y) + 1)))
Y = np.fft.rfft(y, n=n)
freqs = np.fft.rfftfreq(n, d=1.0/sr)
mask = (freqs >= low) & (freqs <= high)
Y_filtered = Y * mask
y_filt = np.fft.irfft(Y_filtered, n=n)[:len(y)]
return y_filt.astype(np.float32, copy=False)
# ---------- Replay score ----------
def replay_score(y: np.ndarray, sr: int) -> float:
if len(y) < sr:
y = np.pad(y, (0, sr - len(y)))
N = 4096
if len(y) < N:
y = np.pad(y, (0, N - len(y)))
w = np.hanning(N)
seg = y[:N] * w
X = np.abs(np.fft.rfft(seg)) + 1e-9
cep = np.fft.irfft(np.log(X))
qmin = max(1, int(0.0003 * sr))
qmax = min(len(cep) - 1, int(0.0040 * sr))
cwin = np.abs(cep[qmin:qmax])
c_peak = float(np.max(cwin)); c_mean = float(np.mean(cwin) + 1e-9)
cep_score = np.clip((c_peak - c_mean) / (c_peak + c_mean), 0.0, 1.0)
F = np.fft.rfftfreq(N, 1.0 / sr)
total = float(np.sum(X))
hf = float(np.sum(X[F >= 5000.0]))
hf_ratio = hf / (total + 1e-9)
hf_term = np.clip((0.25 - hf_ratio) / 0.25, 0.0, 1.0)
return float(np.clip(0.6 * cep_score + 0.4 * hf_term, 0.0, 1.0))
# ---------- Detector ----------
class Detector:
def __init__(self, weights_path: str, encoder: str | None = None, unfreeze_last: int = 0):
cfg = None
js = weights_path.replace(".pth", ".json")
if os.path.exists(js):
try:
with open(js, "r", encoding="utf-8") as f:
cfg = json.load(f)
except Exception:
cfg = None
enc = encoder or (cfg.get("encoder") if cfg else "facebook/wav2vec2-base")
unf = unfreeze_last or int(cfg.get("unfreeze_last", 0)) if cfg else 0
self.model = Wav2VecClassifier(encoder=enc, unfreeze_last=unf).to(DEVICE)
if weights_path and os.path.exists(weights_path):
state = torch.load(weights_path, map_location=DEVICE)
self.model.load_state_dict(state, strict=False)
self.model.eval()
def _preprocess(self, y: np.ndarray, sr: int, source_hint: str | None):
y = trim_silence(y, sr, 40.0, 30)
y = bandpass_fft(y, sr, 100.0, 3800.0)
if source_hint and source_hint.lower().startswith("micro"):
y = noise_gate(y, sr, -42.0)
y = rms_normalize(y, 0.035)
y = peak_normalize(y, 0.95)
else:
y = rms_normalize(y, 0.03)
y = peak_normalize(y, 0.95)
y = pad_or_trim(y, duration_s=3.0, sr=sr)
return y
@torch.inference_mode()
def predict_proba(self, wav_bytes_or_path, source_hint: str | None = None):
y0, sr = load_audio(wav_bytes_or_path, target_sr=TARGET_SR)
rscore = replay_score(y0, sr) if REPLAY_ENABLE else 0.0
y = self._preprocess(y0, sr, source_hint)
x = torch.from_numpy(y).float().unsqueeze(0).to(DEVICE)
logits, _ = self.model(x)
logits = logits.clone()
logits[:, 1] += AI_LOGIT_BIAS
# Replay bonus on AI logit
if REPLAY_ENABLE and (source_hint and source_hint.lower().startswith("micro")) and (rscore >= REPLAY_T1):
ramp = np.clip((rscore - REPLAY_T1) / max(REPLAY_T2 - REPLAY_T1, 1e-6), 0.0, 1.0)
logits[:, 1] += REPLAY_AI_BONUS * ramp
probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
p_h, p_ai = float(probs[0]), float(probs[1])
thr_source = "mic" if (source_hint and source_hint.lower().startswith("micro")) else "upload"
thr = MIC_THRESHOLD if thr_source == "mic" else UPLOAD_THRESHOLD
# Labels by different rules
label_thresh = "ai" if p_ai >= thr else "human"
label_argmax = "ai" if p_ai > p_h else "human"
label_hybrid = label_thresh
if REPLAY_ENABLE and rscore >= REPLAY_T1 and p_ai >= 0.50:
label_hybrid = "ai"
if REPLAY_ENABLE and rscore >= REPLAY_T2 and (source_hint and source_hint.lower().startswith("micro")):
if REPLAY_FORCE_LABEL or p_ai >= (thr - 0.05):
label_hybrid = "ai"
if DECISION_RULE == "argmax":
label = label_argmax
rule_used = "argmax"
elif DECISION_RULE == "hybrid":
label = label_hybrid
rule_used = "hybrid(threshold+replay)"
else:
label = label_thresh
rule_used = "threshold"
return {
"human": p_h,
"ai": p_ai,
"label": label,
"threshold": float(thr),
"threshold_source": thr_source,
"backend": "wav2vec2",
"source_hint": (source_hint or "auto"),
"replay_score": float(rscore),
"decision": rule_used,
"decision_details": {
"ai_prob": p_ai,
"human_prob": p_h,
"prob_margin": p_ai - p_h,
"ai_vs_threshold_margin": p_ai - thr,
"replay_score": rscore,
"mic_threshold": MIC_THRESHOLD,
"upload_threshold": UPLOAD_THRESHOLD,
"force_label_AI": bool(REPLAY_FORCE_LABEL and rscore >= REPLAY_T2),
},
}
def explain(self, wav_bytes_or_path, source_hint: str | None = None):
self.model.eval()
y0, sr = load_audio(wav_bytes_or_path, target_sr=TARGET_SR)
y = self._preprocess(y0, sr, source_hint)
x = torch.from_numpy(y).float().unsqueeze(0).to(DEVICE)
x.requires_grad_(True)
logits, feats = self.model(x)
logits[:, 1].sum().backward(retain_graph=True)
if feats.grad is None:
s = x.grad.detach().abs().squeeze(0)
s = s / (s.max() + 1e-6)
H = 64
step = max(1, s.numel() // 256)
s_small = s[::step][:256].cpu().numpy()
cam = np.tile(s_small[None, :], (H, 1))
else:
g = feats.grad.detach().abs().sum(dim=-1).squeeze(0)
g = g / (g.max() + 1e-6)
H = 64
cam = np.tile(g.cpu().numpy()[None, :], (H, 1))
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-6)
return {"cam": cam.tolist(), "probs": None}