import os import argparse import random from pathlib import Path from contextlib import nullcontext import importlib.util import numpy as np import torch import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.utils.data import Dataset, DataLoader # ---------- Local imports ---------- try: from .models.cnn_melspec import TinyMelCNN from .utils.audio import load_audio, pad_or_trim, logmel, TARGET_SR except ImportError: from app.models.cnn_melspec import TinyMelCNN from app.utils.audio import load_audio, pad_or_trim, logmel, TARGET_SR # ---------- Augmentations (robust across versions) ---------- from audiomentations import ( Compose, AddGaussianNoise, TimeStretch, PitchShift, BandPassFilter ) def make_gain(min_db, max_db, p): """Handle both min_gain_in_db/max_gain_in_db and min_gain_db/max_gain_db.""" from audiomentations import Gain as _Gain try: return _Gain(min_gain_in_db=min_db, max_gain_in_db=max_db, p=p) except TypeError: return _Gain(min_gain_db=min_db, max_gain_db=max_db, p=p) def make_clipping(p=0.3): """ Build ClippingDistortion across versions. Newer: min_percent/max_percent (0..20 typical) Older: min_percentile_threshold/max_percentile_threshold in [0..100] Returns None if not available. """ try: from audiomentations import ClippingDistortion as _Clip except Exception: return None # Try newer signature for kwargs in ( dict(min_percent=0.0, max_percent=20.0, p=p), dict(min_percent=5.0, max_percent=30.0, p=p), ): try: return _Clip(**kwargs) except Exception: pass # Try older signature for kwargs in ( dict(min_percentile_threshold=95, max_percentile_threshold=100, p=p), dict(min_percentile_threshold=90, max_percentile_threshold=99, p=p), ): try: return _Clip(**kwargs) except Exception: pass return None def have_fast_mp3(): return importlib.util.find_spec("fast_mp3_augment") is not None def make_mp3_compression(min_bitrate=48, max_bitrate=96, p=0.6): """ Only enable Mp3Compression when the fast backend is present. On Windows without the extra package this often breaks; we skip it. """ if not have_fast_mp3(): return None try: from audiomentations import Mp3Compression as _Mp3 # Prefer the fast backend; if API lacks backend arg, constructor still works. try: return _Mp3(min_bitrate=min_bitrate, max_bitrate=max_bitrate, p=p, backend="fast_mp3_augment") except TypeError: return _Mp3(min_bitrate=min_bitrate, max_bitrate=max_bitrate, p=p) except Exception: return None # ---------- Repro ---------- def set_seed(seed: int = 42): random.seed(seed); np.random.seed(seed) torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) # ---------- Dataset ---------- class FolderDataset(Dataset): """ data_dir/ human/*.wav ai/*.wav """ def __init__(self, root: str, split: str = "train", val_ratio: float = 0.15, seed: int = 42, clip_seconds: float = 3.0): self.root = Path(root) self.clip_seconds = float(clip_seconds) human = sorted((self.root / "human").glob("*.wav")) ai = sorted((self.root / "ai").glob("*.wav")) pairs = [(p, 0) for p in human] + [(p, 1) for p in ai] rng = random.Random(seed) rng.shuffle(pairs) n_val = int(len(pairs) * val_ratio) self.items = pairs[n_val:] if split == "train" else pairs[:n_val] self.is_train = split == "train" self._len_h = sum(1 for _, y in self.items if y == 0) self._len_a = sum(1 for _, y in self.items if y == 1) # Human: mild, natural perturbations self.aug_human = Compose([ AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.01, p=0.4), TimeStretch(min_rate=0.96, max_rate=1.04, p=0.3), PitchShift(min_semitones=-1, max_semitones=1, p=0.2), make_gain(-4, 4, p=0.3), ]) # AI: replay-aware chain (speaker/room/mic simulation) ai_transforms = [ BandPassFilter(min_center_freq=200.0, max_center_freq=3500.0, p=0.5), AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.01, p=0.3), TimeStretch(min_rate=0.95, max_rate=1.05, p=0.25), make_gain(-6, 6, p=0.3), ] clip = make_clipping(p=0.3) if clip is not None: ai_transforms.insert(1, clip) mp3 = make_mp3_compression() if mp3 is not None: ai_transforms.insert(0, mp3) self.aug_ai = Compose(ai_transforms) def __len__(self): return len(self.items) def __getitem__(self, idx: int): path, label = self.items[idx] y, sr = load_audio(str(path), TARGET_SR) y = pad_or_trim(y, duration_s=self.clip_seconds, sr=sr) if self.is_train: if label == 1: y = self.aug_ai(samples=y, sample_rate=sr) else: y = self.aug_human(samples=y, sample_rate=sr) mel = logmel(y, sr) # (n_mels, T) x = torch.from_numpy(mel).unsqueeze(0) # (1, n_mels, T) y_t = torch.tensor(label, dtype=torch.long) return x, y_t # ---------- Dataloaders ---------- def make_dataloaders(args): ds_tr = FolderDataset(args.data_dir, split="train", val_ratio=args.val_ratio, seed=args.seed, clip_seconds=args.clip_seconds) ds_va = FolderDataset(args.data_dir, split="val", val_ratio=args.val_ratio, seed=args.seed, clip_seconds=args.clip_seconds) # Windows is happier with workers=0; keep configurable workers = args.workers if args.workers >= 0 else (0 if os.name == "nt" else max(1, (os.cpu_count() or 4)//2)) pin = (not args.cpu) and torch.cuda.is_available() dl_tr = DataLoader( ds_tr, batch_size=args.batch_size, shuffle=True, num_workers=workers, pin_memory=pin, persistent_workers=(workers > 0), drop_last=True, ) dl_va = DataLoader( ds_va, batch_size=max(1, args.batch_size // 2), shuffle=False, num_workers=workers, pin_memory=pin, persistent_workers=(workers > 0), ) return ds_tr, ds_va, dl_tr, dl_va def class_weights_from_dataset(ds: FolderDataset, eps: float = 1e-6): n_h, n_a = max(ds._len_h, eps), max(ds._len_a, eps) w_h = (n_h + n_a) / (2 * n_h) w_a = (n_h + n_a) / (2 * n_a) return torch.tensor([w_h, w_a], dtype=torch.float32) # ---------- Training / Eval ---------- def train_one_epoch(model, dl, device, opt, scaler, autocast_ctx, loss_fn, grad_accum=1): model.train() total_loss = 0.0 correct = 0 seen = 0 opt.zero_grad(set_to_none=True) for step, (x, y) in enumerate(dl): x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) with autocast_ctx: logits = model(x) loss = loss_fn(logits, y) loss = loss / grad_accum if getattr(scaler, "is_enabled", lambda: False)(): scaler.scale(loss).backward() else: loss.backward() if (step + 1) % grad_accum == 0: if getattr(scaler, "is_enabled", lambda: False)(): scaler.step(opt) scaler.update() else: opt.step() opt.zero_grad(set_to_none=True) total_loss += float(loss) * x.size(0) * grad_accum correct += int((logits.argmax(1) == y).sum().item()) seen += x.size(0) return total_loss / max(seen, 1), correct / max(seen, 1) @torch.no_grad() def evaluate(model, dl, device, loss_fn): model.eval() total_loss = 0.0 correct = 0 seen = 0 for x, y in dl: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) logits = model(x) loss = loss_fn(logits, y) total_loss += float(loss) * x.size(0) correct += int((logits.argmax(1) == y).sum().item()) seen += x.size(0) return total_loss / max(seen, 1), correct / max(seen, 1) def main(args): set_seed(args.seed) device = "cuda" if (torch.cuda.is_available() and not args.cpu) else "cpu" cudnn.benchmark = True ds_tr, ds_va, dl_tr, dl_va = make_dataloaders(args) print(f"Train items: {len(ds_tr)} (human={ds_tr._len_h}, ai={ds_tr._len_a})") print(f"Val items: {len(ds_va)}") model = TinyMelCNN().to(device) weights = class_weights_from_dataset(ds_tr).to(device) loss_fn = torch.nn.CrossEntropyLoss(weight=weights) opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) # AMP (use new torch.amp if available, else fallback) try: from torch.amp import GradScaler, autocast as amp_autocast scaler = GradScaler("cuda", enabled=(device == "cuda" and args.amp)) autocast_ctx = amp_autocast("cuda") if (device == "cuda" and args.amp) else nullcontext() except Exception: from torch.cuda.amp import GradScaler, autocast as amp_autocast # deprecated but works scaler = GradScaler(enabled=(device == "cuda" and args.amp)) autocast_ctx = amp_autocast() if (device == "cuda" and args.amp) else nullcontext() best_va = -1.0 patience_counter = 0 Path(args.out).parent.mkdir(parents=True, exist_ok=True) for epoch in range(args.epochs): tr_loss, tr_acc = train_one_epoch( model, dl_tr, device, opt, scaler, autocast_ctx, loss_fn, grad_accum=args.grad_accum ) va_loss, va_acc = evaluate(model, dl_va, device, loss_fn) print(f"epoch {epoch+1:02d}/{args.epochs} | train {tr_loss:.3f}/{tr_acc:.3f} | val {va_loss:.3f}/{va_acc:.3f}") # Save "last" every epoch torch.save(model.state_dict(), args.out.replace(".pth", ".last.pth")) if va_acc > best_va + 1e-4: best_va = va_acc torch.save(model.state_dict(), args.out) patience_counter = 0 print(f"✅ Saved best to {args.out} (val_acc={best_va:.3f})") else: patience_counter += 1 if args.early_stop > 0 and patience_counter >= args.early_stop: print(f"⏹️ Early stopping at epoch {epoch+1} (best val_acc={best_va:.3f})") break print("Done.") if __name__ == "__main__": p = argparse.ArgumentParser(description="Train AI Voice Detector (replay-aware, version-robust, no fast_mp3 required)") p.add_argument("--data_dir", type=str, required=True, help="Folder with subfolders human/ and ai/") p.add_argument("--out", type=str, default="app/models/weights/cnn_melspec.pth") p.add_argument("--epochs", type=int, default=10) p.add_argument("--batch_size", type=int, default=32) p.add_argument("--grad_accum", type=int, default=2) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--val_ratio", type=float, default=0.15) p.add_argument("--clip_seconds", type=float, default=3.0) p.add_argument("--workers", type=int, default=-1) # try --workers 0 on Windows if you see issues p.add_argument("--amp", action="store_true", default=True) p.add_argument("--cpu", action="store_true") p.add_argument("--early_stop", type=int, default=0) p.add_argument("--seed", type=int, default=42) args = p.parse_args() main(args)