import os, json, argparse, random from pathlib import Path from contextlib import nullcontext import numpy as np import torch import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler try: from .models.wav2vec_detector import Wav2VecClassifier from .utils.audio import load_audio, pad_or_trim, TARGET_SR except ImportError: from app.models.wav2vec_detector import Wav2VecClassifier from app.utils.audio import load_audio, pad_or_trim, TARGET_SR from audiomentations import Compose, AddGaussianNoise, BandPassFilter def make_gain(min_db, max_db, p): from audiomentations import Gain as _Gain try: return _Gain(min_gain_in_db=min_db, max_gain_in_db=max_db, p=p) except TypeError: return _Gain(min_gain_db=min_db, max_gain_db=max_db, p=p) def set_seed(s=42): random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s) def peak_normalize(y: np.ndarray, peak: float = 0.95, eps: float = 1e-9) -> np.ndarray: m = float(np.max(np.abs(y)) + eps) return (y / m) * peak if m > 0 else y def trim_silence(y: np.ndarray, sr: int, thresh_db: float = 40.0, min_ms: int = 30) -> np.ndarray: if y.size == 0: return y win = max(1, int(sr * 0.02)); pad = max(1, int(sr * (min_ms / 1000.0))) energy = np.convolve(y**2, np.ones(win)/win, mode="same") ref = np.max(energy) + 1e-12 mask = 10*np.log10(energy/ref + 1e-12) > -thresh_db if not np.any(mask): return y idx = np.where(mask)[0]; start=max(0,int(idx[0]-pad)); end=min(len(y),int(idx[-1]+pad)) return y[start:end] class WavDataset(Dataset): """data_dir/{human,ai}/*.wav""" def __init__(self, root, split="train", val_ratio=0.15, seed=42, clip_seconds=3.0): self.root = Path(root); self.clip = float(clip_seconds) human = sorted((self.root/"human").glob("*.wav")) ai = sorted((self.root/"ai").glob("*.wav")) items = [(p,0) for p in human] + [(p,1) for p in ai] rng = random.Random(seed); rng.shuffle(items) n_val = int(len(items)*val_ratio) self.items = items[n_val:] if split=="train" else items[:n_val] self.is_train = split=="train" self.nh = sum(1 for _,y in self.items if y==0) self.na = sum(1 for _,y in self.items if y==1) self.aug_h = Compose([AddGaussianNoise(0.001,0.01,p=0.3), make_gain(-4,4,p=0.3)]) self.aug_a = Compose([BandPassFilter(200.0,3500.0,p=0.5), AddGaussianNoise(0.001,0.01,p=0.3), make_gain(-6,6,p=0.3)]) def __len__(self): return len(self.items) def __getitem__(self, idx): path, label = self.items[idx] y, sr = load_audio(str(path), TARGET_SR) y = trim_silence(y, sr, 40.0, 30) y = peak_normalize(y, 0.95) y = pad_or_trim(y, duration_s=self.clip, sr=sr) if self.is_train: y = (self.aug_a if label==1 else self.aug_h)(samples=y, sample_rate=sr) return torch.from_numpy(y).float(), torch.tensor(label, dtype=torch.long) def make_loaders(args): ds_tr = WavDataset(args.data_dir, "train", args.val_ratio, args.seed, args.clip_seconds) ds_va = WavDataset(args.data_dir, "val", args.val_ratio, args.seed, args.clip_seconds) # Weighted sampler to balance classes labels = [y for _, y in ds_tr.items] n0 = max(1, labels.count(0)); n1 = max(1, labels.count(1)) w0 = (n0 + n1) / (2 * n0); w1 = (n0 + n1) / (2 * n1) sample_weights = [w0 if y == 0 else w1 for y in labels] sampler = WeightedRandomSampler(sample_weights, num_samples=len(labels), replacement=True) workers = args.workers if args.workers >= 0 else (0 if os.name=="nt" else max(1,(os.cpu_count() or 4)//2)) pin = (not args.cpu) and torch.cuda.is_available() dl_tr = DataLoader(ds_tr, batch_size=args.batch_size, sampler=sampler, num_workers=workers, pin_memory=pin, drop_last=True) dl_va = DataLoader(ds_va, batch_size=max(1,args.batch_size//2), shuffle=False, num_workers=workers, pin_memory=pin) return ds_tr, ds_va, dl_tr, dl_va class FocalLoss(torch.nn.Module): def __init__(self, alpha=None, gamma=1.5): super().__init__() self.alpha = alpha self.gamma = gamma self.ce = torch.nn.CrossEntropyLoss(weight=alpha) def forward(self, logits, target): ce = self.ce(logits, target) with torch.no_grad(): pt = torch.exp(-ce) return ((1 - pt) ** self.gamma) * ce def train_one_epoch(model, dl, device, opt, scaler, autocast_ctx, loss_fn, grad_accum=1): model.train(); total=0.0; correct=0; seen=0 opt.zero_grad(set_to_none=True) for step,(x,y) in enumerate(dl): x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True) with autocast_ctx: logits,_=model(x); loss=loss_fn(logits,y) loss=loss/grad_accum if getattr(scaler,"is_enabled",lambda:False)(): scaler.scale(loss).backward() else: loss.backward() if (step+1)%grad_accum==0: if getattr(scaler,"is_enabled",lambda:False)(): scaler.step(opt); scaler.update() else: opt.step() opt.zero_grad(set_to_none=True) total += float(loss) * x.size(0) * grad_accum correct += int((logits.argmax(1)==y).sum().item()); seen += x.size(0) return total/max(seen,1), correct/max(seen,1) @torch.no_grad() def evaluate(model, dl, device, loss_fn): model.eval(); total=0.0; correct=0; seen=0 for x,y in dl: x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True) logits,_=model(x); loss=loss_fn(logits,y) total += float(loss) * x.size(0); correct += int((logits.argmax(1)==y).sum().item()); seen += x.size(0) return total/max(seen,1), correct/max(seen,1) def main(args): set_seed(args.seed) device = "cuda" if (torch.cuda.is_available() and not args.cpu) else "cpu" cudnn.benchmark = True ds_tr, ds_va, dl_tr, dl_va = make_loaders(args) print(f"Train items: {len(ds_tr)} (human={ds_tr.nh}, ai={ds_tr.na})") print(f"Val items: {len(ds_va)}") model = Wav2VecClassifier( encoder=args.encoder, unfreeze_last=args.unfreeze_last, dropout=args.dropout, hidden=args.hidden ).to(device) # Focal loss with class weights nh, na = ds_tr.nh, ds_tr.na w = torch.tensor([(nh+na)/(2*nh+1e-6), (nh+na)/(2*na+1e-6)], dtype=torch.float32).to(device) loss_fn = FocalLoss(alpha=w, gamma=1.5) head_params = list(model.head.parameters()) enc_params = [p for p in model.encoder.parameters() if p.requires_grad] param_groups = [{"params": head_params, "lr": args.lr_head}] if enc_params: param_groups.append({"params": enc_params, "lr": args.lr_encoder}) opt = torch.optim.AdamW(param_groups, weight_decay=1e-4) try: from torch.amp import GradScaler, autocast as amp_autocast scaler = GradScaler("cuda", enabled=(device=="cuda" and args.amp)) autocast_ctx = amp_autocast("cuda") if (device=="cuda" and args.amp) else nullcontext() except Exception: from torch.cuda.amp import GradScaler, autocast as amp_autocast scaler = GradScaler(enabled=(device=="cuda" and args.amp)) autocast_ctx = amp_autocast() if (device=="cuda" and args.amp) else nullcontext() best=-1.0; patience=0 Path(args.out).parent.mkdir(parents=True, exist_ok=True) with open(args.out.replace(".pth",".json"), "w", encoding="utf-8") as f: json.dump({"encoder": args.encoder, "unfreeze_last": args.unfreeze_last}, f) for epoch in range(args.epochs): tr_loss, tr_acc = train_one_epoch(model, dl_tr, device, opt, scaler, autocast_ctx, loss_fn, args.grad_accum) va_loss, va_acc = evaluate(model, dl_va, device, loss_fn) print(f"epoch {epoch+1:02d}/{args.epochs} | train {tr_loss:.3f}/{tr_acc:.3f} | val {va_loss:.3f}/{va_acc:.3f}") torch.save(model.state_dict(), args.out.replace(".pth",".last.pth")) if va_acc > best + 1e-4: best = va_acc; patience=0 torch.save(model.state_dict(), args.out) print(f"✅ Saved best to {args.out} (val_acc={best:.3f})") else: patience += 1 if args.early_stop>0 and patience>=args.early_stop: print(f"⏹️ Early stopping at epoch {epoch+1} (best={best:.3f})") break print("Done.") if __name__ == "__main__": ap = argparse.ArgumentParser(description="Train Wav2Vec2-based AI Voice Detector (balanced)") ap.add_argument("--data_dir", required=True, help="Folder with human/ and ai/ WAVs") ap.add_argument("--out", default="app/models/weights/wav2vec2_classifier.pth") ap.add_argument("--encoder", default="facebook/wav2vec2-base") ap.add_argument("--unfreeze_last", type=int, default=0) ap.add_argument("--epochs", type=int, default=8) ap.add_argument("--batch_size", type=int, default=16) ap.add_argument("--grad_accum", type=int, default=2) ap.add_argument("--lr_head", type=float, default=1e-3) ap.add_argument("--lr_encoder", type=float, default=1e-5) ap.add_argument("--val_ratio", type=float, default=0.15) ap.add_argument("--clip_seconds", type=float, default=3.0) ap.add_argument("--workers", type=int, default=-1) ap.add_argument("--amp", action="store_true", default=True) ap.add_argument("--cpu", action="store_true") ap.add_argument("--dropout", type=float, default=0.2) ap.add_argument("--hidden", type=int, default=256) ap.add_argument("--early_stop", type=int, default=0) ap.add_argument("--seed", type=int, default=42) args = ap.parse_args() main(args)