Spaces:
Running
Running
| import os, json | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from .models.wav2vec_detector import Wav2VecClassifier | |
| from .utils.audio import load_audio, pad_or_trim, TARGET_SR | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------- Thresholds & biases ---------- | |
| AI_THRESHOLD_DEFAULT = float(os.getenv("DETECTOR_AI_THRESHOLD", "0.60")) | |
| MIC_THRESHOLD = float(os.getenv("DETECTOR_MIC_THRESHOLD", "0.68")) | |
| UPLOAD_THRESHOLD = float(os.getenv("DETECTOR_UPLOAD_THRESHOLD", str(AI_THRESHOLD_DEFAULT))) | |
| AI_LOGIT_BIAS = float(os.getenv("DETECTOR_AI_LOGIT_BIAS", "0.00")) # add to AI logit globally | |
| # ---------- Decision rule ---------- | |
| # 'threshold' -> AI if ai_prob >= threshold | |
| # 'argmax' -> AI if ai_prob > human_prob | |
| # 'hybrid' -> threshold, but if replay_score >= T1 and ai_prob >= 0.50 -> AI | |
| DECISION_RULE = os.getenv("DECISION_RULE", "threshold").lower() | |
| # ---------- Replay-attack heuristic ---------- | |
| REPLAY_ENABLE = os.getenv("REPLAY_ENABLE", "1") != "0" | |
| REPLAY_AI_BONUS = float(os.getenv("REPLAY_AI_BONUS", "1.2")) | |
| REPLAY_FORCE_LABEL = os.getenv("REPLAY_FORCE_LABEL", "0") == "1" | |
| REPLAY_T1 = float(os.getenv("REPLAY_T1", "0.35")) # soft start | |
| REPLAY_T2 = float(os.getenv("REPLAY_T2", "0.55")) # strong replay | |
| # ---------- DSP helpers ---------- | |
| def peak_normalize(y: np.ndarray, peak: float = 0.95, eps: float = 1e-9) -> np.ndarray: | |
| m = float(np.max(np.abs(y)) + eps) | |
| return (y / m) * peak if m > 0 else y | |
| def rms_normalize(y: np.ndarray, target_rms: float = 0.03, eps: float = 1e-9) -> np.ndarray: | |
| rms = float(np.sqrt(np.mean(y**2)) + eps) | |
| g = target_rms / rms | |
| return np.clip(y * g, -1.0, 1.0) | |
| def trim_silence(y: np.ndarray, sr: int, thresh_db: float = 40.0, min_ms: int = 30) -> np.ndarray: | |
| if y.size == 0: return y | |
| win = max(1, int(sr * 0.02)) | |
| pad = max(1, int(sr * (min_ms / 1000.0))) | |
| energy = np.convolve(y ** 2, np.ones(win) / win, mode="same") | |
| ref = np.max(energy) + 1e-12 | |
| mask = 10.0 * np.log10(energy / ref + 1e-12) > -thresh_db | |
| if not np.any(mask): return y | |
| idx = np.where(mask)[0] | |
| start = max(0, int(idx[0] - pad)) | |
| end = min(len(y), int(idx[-1] + pad)) | |
| return y[start:end] | |
| def noise_gate(y, sr, gate_db=-42.0): | |
| m = np.max(np.abs(y)) + 1e-9 | |
| thr = m * (10.0 ** (gate_db / 20.0)) | |
| y2 = y.copy() | |
| y2[np.abs(y2) < thr] = 0.0 | |
| return y2 | |
| def bandpass_fft(y: np.ndarray, sr: int, low=100.0, high=3800.0): | |
| n = int(2 ** np.ceil(np.log2(len(y) + 1))) | |
| Y = np.fft.rfft(y, n=n) | |
| freqs = np.fft.rfftfreq(n, d=1.0/sr) | |
| mask = (freqs >= low) & (freqs <= high) | |
| Y_filtered = Y * mask | |
| y_filt = np.fft.irfft(Y_filtered, n=n)[:len(y)] | |
| return y_filt.astype(np.float32, copy=False) | |
| # ---------- Replay score ---------- | |
| def replay_score(y: np.ndarray, sr: int) -> float: | |
| if len(y) < sr: | |
| y = np.pad(y, (0, sr - len(y))) | |
| N = 4096 | |
| if len(y) < N: | |
| y = np.pad(y, (0, N - len(y))) | |
| w = np.hanning(N) | |
| seg = y[:N] * w | |
| X = np.abs(np.fft.rfft(seg)) + 1e-9 | |
| cep = np.fft.irfft(np.log(X)) | |
| qmin = max(1, int(0.0003 * sr)) | |
| qmax = min(len(cep) - 1, int(0.0040 * sr)) | |
| cwin = np.abs(cep[qmin:qmax]) | |
| c_peak = float(np.max(cwin)); c_mean = float(np.mean(cwin) + 1e-9) | |
| cep_score = np.clip((c_peak - c_mean) / (c_peak + c_mean), 0.0, 1.0) | |
| F = np.fft.rfftfreq(N, 1.0 / sr) | |
| total = float(np.sum(X)) | |
| hf = float(np.sum(X[F >= 5000.0])) | |
| hf_ratio = hf / (total + 1e-9) | |
| hf_term = np.clip((0.25 - hf_ratio) / 0.25, 0.0, 1.0) | |
| return float(np.clip(0.6 * cep_score + 0.4 * hf_term, 0.0, 1.0)) | |
| # ---------- Detector ---------- | |
| class Detector: | |
| def __init__(self, weights_path: str, encoder: str | None = None, unfreeze_last: int = 0): | |
| cfg = None | |
| js = weights_path.replace(".pth", ".json") | |
| if os.path.exists(js): | |
| try: | |
| with open(js, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| except Exception: | |
| cfg = None | |
| enc = encoder or (cfg.get("encoder") if cfg else "facebook/wav2vec2-base") | |
| unf = unfreeze_last or int(cfg.get("unfreeze_last", 0)) if cfg else 0 | |
| self.model = Wav2VecClassifier(encoder=enc, unfreeze_last=unf).to(DEVICE) | |
| if weights_path and os.path.exists(weights_path): | |
| state = torch.load(weights_path, map_location=DEVICE) | |
| self.model.load_state_dict(state, strict=False) | |
| self.model.eval() | |
| def _preprocess(self, y: np.ndarray, sr: int, source_hint: str | None): | |
| y = trim_silence(y, sr, 40.0, 30) | |
| y = bandpass_fft(y, sr, 100.0, 3800.0) | |
| if source_hint and source_hint.lower().startswith("micro"): | |
| y = noise_gate(y, sr, -42.0) | |
| y = rms_normalize(y, 0.035) | |
| y = peak_normalize(y, 0.95) | |
| else: | |
| y = rms_normalize(y, 0.03) | |
| y = peak_normalize(y, 0.95) | |
| y = pad_or_trim(y, duration_s=3.0, sr=sr) | |
| return y | |
| def predict_proba(self, wav_bytes_or_path, source_hint: str | None = None): | |
| y0, sr = load_audio(wav_bytes_or_path, target_sr=TARGET_SR) | |
| rscore = replay_score(y0, sr) if REPLAY_ENABLE else 0.0 | |
| y = self._preprocess(y0, sr, source_hint) | |
| x = torch.from_numpy(y).float().unsqueeze(0).to(DEVICE) | |
| logits, _ = self.model(x) | |
| logits = logits.clone() | |
| logits[:, 1] += AI_LOGIT_BIAS | |
| # Replay bonus on AI logit | |
| if REPLAY_ENABLE and (source_hint and source_hint.lower().startswith("micro")) and (rscore >= REPLAY_T1): | |
| ramp = np.clip((rscore - REPLAY_T1) / max(REPLAY_T2 - REPLAY_T1, 1e-6), 0.0, 1.0) | |
| logits[:, 1] += REPLAY_AI_BONUS * ramp | |
| probs = F.softmax(logits, dim=-1).cpu().numpy()[0] | |
| p_h, p_ai = float(probs[0]), float(probs[1]) | |
| thr_source = "mic" if (source_hint and source_hint.lower().startswith("micro")) else "upload" | |
| thr = MIC_THRESHOLD if thr_source == "mic" else UPLOAD_THRESHOLD | |
| # Labels by different rules | |
| label_thresh = "ai" if p_ai >= thr else "human" | |
| label_argmax = "ai" if p_ai > p_h else "human" | |
| label_hybrid = label_thresh | |
| if REPLAY_ENABLE and rscore >= REPLAY_T1 and p_ai >= 0.50: | |
| label_hybrid = "ai" | |
| if REPLAY_ENABLE and rscore >= REPLAY_T2 and (source_hint and source_hint.lower().startswith("micro")): | |
| if REPLAY_FORCE_LABEL or p_ai >= (thr - 0.05): | |
| label_hybrid = "ai" | |
| if DECISION_RULE == "argmax": | |
| label = label_argmax | |
| rule_used = "argmax" | |
| elif DECISION_RULE == "hybrid": | |
| label = label_hybrid | |
| rule_used = "hybrid(threshold+replay)" | |
| else: | |
| label = label_thresh | |
| rule_used = "threshold" | |
| return { | |
| "human": p_h, | |
| "ai": p_ai, | |
| "label": label, | |
| "threshold": float(thr), | |
| "threshold_source": thr_source, | |
| "backend": "wav2vec2", | |
| "source_hint": (source_hint or "auto"), | |
| "replay_score": float(rscore), | |
| "decision": rule_used, | |
| "decision_details": { | |
| "ai_prob": p_ai, | |
| "human_prob": p_h, | |
| "prob_margin": p_ai - p_h, | |
| "ai_vs_threshold_margin": p_ai - thr, | |
| "replay_score": rscore, | |
| "mic_threshold": MIC_THRESHOLD, | |
| "upload_threshold": UPLOAD_THRESHOLD, | |
| "force_label_AI": bool(REPLAY_FORCE_LABEL and rscore >= REPLAY_T2), | |
| }, | |
| } | |
| def explain(self, wav_bytes_or_path, source_hint: str | None = None): | |
| self.model.eval() | |
| y0, sr = load_audio(wav_bytes_or_path, target_sr=TARGET_SR) | |
| y = self._preprocess(y0, sr, source_hint) | |
| x = torch.from_numpy(y).float().unsqueeze(0).to(DEVICE) | |
| x.requires_grad_(True) | |
| logits, feats = self.model(x) | |
| logits[:, 1].sum().backward(retain_graph=True) | |
| if feats.grad is None: | |
| s = x.grad.detach().abs().squeeze(0) | |
| s = s / (s.max() + 1e-6) | |
| H = 64 | |
| step = max(1, s.numel() // 256) | |
| s_small = s[::step][:256].cpu().numpy() | |
| cam = np.tile(s_small[None, :], (H, 1)) | |
| else: | |
| g = feats.grad.detach().abs().sum(dim=-1).squeeze(0) | |
| g = g / (g.max() + 1e-6) | |
| H = 64 | |
| cam = np.tile(g.cpu().numpy()[None, :], (H, 1)) | |
| cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-6) | |
| return {"cam": cam.tolist(), "probs": None} |