import os, json import numpy as np import torch import torch.nn.functional as F from .models.wav2vec_detector import Wav2VecClassifier from .utils.audio import load_audio, pad_or_trim, TARGET_SR DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ---------- Thresholds & biases ---------- AI_THRESHOLD_DEFAULT = float(os.getenv("DETECTOR_AI_THRESHOLD", "0.60")) MIC_THRESHOLD = float(os.getenv("DETECTOR_MIC_THRESHOLD", "0.68")) UPLOAD_THRESHOLD = float(os.getenv("DETECTOR_UPLOAD_THRESHOLD", str(AI_THRESHOLD_DEFAULT))) AI_LOGIT_BIAS = float(os.getenv("DETECTOR_AI_LOGIT_BIAS", "0.00")) # add to AI logit globally # ---------- Decision rule ---------- # 'threshold' -> AI if ai_prob >= threshold # 'argmax' -> AI if ai_prob > human_prob # 'hybrid' -> threshold, but if replay_score >= T1 and ai_prob >= 0.50 -> AI DECISION_RULE = os.getenv("DECISION_RULE", "threshold").lower() # ---------- Replay-attack heuristic ---------- REPLAY_ENABLE = os.getenv("REPLAY_ENABLE", "1") != "0" REPLAY_AI_BONUS = float(os.getenv("REPLAY_AI_BONUS", "1.2")) REPLAY_FORCE_LABEL = os.getenv("REPLAY_FORCE_LABEL", "0") == "1" REPLAY_T1 = float(os.getenv("REPLAY_T1", "0.35")) # soft start REPLAY_T2 = float(os.getenv("REPLAY_T2", "0.55")) # strong replay # ---------- DSP helpers ---------- def peak_normalize(y: np.ndarray, peak: float = 0.95, eps: float = 1e-9) -> np.ndarray: m = float(np.max(np.abs(y)) + eps) return (y / m) * peak if m > 0 else y def rms_normalize(y: np.ndarray, target_rms: float = 0.03, eps: float = 1e-9) -> np.ndarray: rms = float(np.sqrt(np.mean(y**2)) + eps) g = target_rms / rms return np.clip(y * g, -1.0, 1.0) def trim_silence(y: np.ndarray, sr: int, thresh_db: float = 40.0, min_ms: int = 30) -> np.ndarray: if y.size == 0: return y win = max(1, int(sr * 0.02)) pad = max(1, int(sr * (min_ms / 1000.0))) energy = np.convolve(y ** 2, np.ones(win) / win, mode="same") ref = np.max(energy) + 1e-12 mask = 10.0 * np.log10(energy / ref + 1e-12) > -thresh_db if not np.any(mask): return y idx = np.where(mask)[0] start = max(0, int(idx[0] - pad)) end = min(len(y), int(idx[-1] + pad)) return y[start:end] def noise_gate(y, sr, gate_db=-42.0): m = np.max(np.abs(y)) + 1e-9 thr = m * (10.0 ** (gate_db / 20.0)) y2 = y.copy() y2[np.abs(y2) < thr] = 0.0 return y2 def bandpass_fft(y: np.ndarray, sr: int, low=100.0, high=3800.0): n = int(2 ** np.ceil(np.log2(len(y) + 1))) Y = np.fft.rfft(y, n=n) freqs = np.fft.rfftfreq(n, d=1.0/sr) mask = (freqs >= low) & (freqs <= high) Y_filtered = Y * mask y_filt = np.fft.irfft(Y_filtered, n=n)[:len(y)] return y_filt.astype(np.float32, copy=False) # ---------- Replay score ---------- def replay_score(y: np.ndarray, sr: int) -> float: if len(y) < sr: y = np.pad(y, (0, sr - len(y))) N = 4096 if len(y) < N: y = np.pad(y, (0, N - len(y))) w = np.hanning(N) seg = y[:N] * w X = np.abs(np.fft.rfft(seg)) + 1e-9 cep = np.fft.irfft(np.log(X)) qmin = max(1, int(0.0003 * sr)) qmax = min(len(cep) - 1, int(0.0040 * sr)) cwin = np.abs(cep[qmin:qmax]) c_peak = float(np.max(cwin)); c_mean = float(np.mean(cwin) + 1e-9) cep_score = np.clip((c_peak - c_mean) / (c_peak + c_mean), 0.0, 1.0) F = np.fft.rfftfreq(N, 1.0 / sr) total = float(np.sum(X)) hf = float(np.sum(X[F >= 5000.0])) hf_ratio = hf / (total + 1e-9) hf_term = np.clip((0.25 - hf_ratio) / 0.25, 0.0, 1.0) return float(np.clip(0.6 * cep_score + 0.4 * hf_term, 0.0, 1.0)) # ---------- Detector ---------- class Detector: def __init__(self, weights_path: str, encoder: str | None = None, unfreeze_last: int = 0): cfg = None js = weights_path.replace(".pth", ".json") if os.path.exists(js): try: with open(js, "r", encoding="utf-8") as f: cfg = json.load(f) except Exception: cfg = None enc = encoder or (cfg.get("encoder") if cfg else "facebook/wav2vec2-base") unf = unfreeze_last or int(cfg.get("unfreeze_last", 0)) if cfg else 0 self.model = Wav2VecClassifier(encoder=enc, unfreeze_last=unf).to(DEVICE) if weights_path and os.path.exists(weights_path): state = torch.load(weights_path, map_location=DEVICE) self.model.load_state_dict(state, strict=False) self.model.eval() def _preprocess(self, y: np.ndarray, sr: int, source_hint: str | None): y = trim_silence(y, sr, 40.0, 30) y = bandpass_fft(y, sr, 100.0, 3800.0) if source_hint and source_hint.lower().startswith("micro"): y = noise_gate(y, sr, -42.0) y = rms_normalize(y, 0.035) y = peak_normalize(y, 0.95) else: y = rms_normalize(y, 0.03) y = peak_normalize(y, 0.95) y = pad_or_trim(y, duration_s=3.0, sr=sr) return y @torch.inference_mode() def predict_proba(self, wav_bytes_or_path, source_hint: str | None = None): y0, sr = load_audio(wav_bytes_or_path, target_sr=TARGET_SR) rscore = replay_score(y0, sr) if REPLAY_ENABLE else 0.0 y = self._preprocess(y0, sr, source_hint) x = torch.from_numpy(y).float().unsqueeze(0).to(DEVICE) logits, _ = self.model(x) logits = logits.clone() logits[:, 1] += AI_LOGIT_BIAS # Replay bonus on AI logit if REPLAY_ENABLE and (source_hint and source_hint.lower().startswith("micro")) and (rscore >= REPLAY_T1): ramp = np.clip((rscore - REPLAY_T1) / max(REPLAY_T2 - REPLAY_T1, 1e-6), 0.0, 1.0) logits[:, 1] += REPLAY_AI_BONUS * ramp probs = F.softmax(logits, dim=-1).cpu().numpy()[0] p_h, p_ai = float(probs[0]), float(probs[1]) thr_source = "mic" if (source_hint and source_hint.lower().startswith("micro")) else "upload" thr = MIC_THRESHOLD if thr_source == "mic" else UPLOAD_THRESHOLD # Labels by different rules label_thresh = "ai" if p_ai >= thr else "human" label_argmax = "ai" if p_ai > p_h else "human" label_hybrid = label_thresh if REPLAY_ENABLE and rscore >= REPLAY_T1 and p_ai >= 0.50: label_hybrid = "ai" if REPLAY_ENABLE and rscore >= REPLAY_T2 and (source_hint and source_hint.lower().startswith("micro")): if REPLAY_FORCE_LABEL or p_ai >= (thr - 0.05): label_hybrid = "ai" if DECISION_RULE == "argmax": label = label_argmax rule_used = "argmax" elif DECISION_RULE == "hybrid": label = label_hybrid rule_used = "hybrid(threshold+replay)" else: label = label_thresh rule_used = "threshold" return { "human": p_h, "ai": p_ai, "label": label, "threshold": float(thr), "threshold_source": thr_source, "backend": "wav2vec2", "source_hint": (source_hint or "auto"), "replay_score": float(rscore), "decision": rule_used, "decision_details": { "ai_prob": p_ai, "human_prob": p_h, "prob_margin": p_ai - p_h, "ai_vs_threshold_margin": p_ai - thr, "replay_score": rscore, "mic_threshold": MIC_THRESHOLD, "upload_threshold": UPLOAD_THRESHOLD, "force_label_AI": bool(REPLAY_FORCE_LABEL and rscore >= REPLAY_T2), }, } def explain(self, wav_bytes_or_path, source_hint: str | None = None): self.model.eval() y0, sr = load_audio(wav_bytes_or_path, target_sr=TARGET_SR) y = self._preprocess(y0, sr, source_hint) x = torch.from_numpy(y).float().unsqueeze(0).to(DEVICE) x.requires_grad_(True) logits, feats = self.model(x) logits[:, 1].sum().backward(retain_graph=True) if feats.grad is None: s = x.grad.detach().abs().squeeze(0) s = s / (s.max() + 1e-6) H = 64 step = max(1, s.numel() // 256) s_small = s[::step][:256].cpu().numpy() cam = np.tile(s_small[None, :], (H, 1)) else: g = feats.grad.detach().abs().sum(dim=-1).squeeze(0) g = g / (g.max() + 1e-6) H = 64 cam = np.tile(g.cpu().numpy()[None, :], (H, 1)) cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-6) return {"cam": cam.tolist(), "probs": None}