Spaces:
Running
Running
| import os, json, argparse, random | |
| from pathlib import Path | |
| from contextlib import nullcontext | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.backends.cudnn as cudnn | |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler | |
| try: | |
| from .models.wav2vec_detector import Wav2VecClassifier | |
| from .utils.audio import load_audio, pad_or_trim, TARGET_SR | |
| except ImportError: | |
| from app.models.wav2vec_detector import Wav2VecClassifier | |
| from app.utils.audio import load_audio, pad_or_trim, TARGET_SR | |
| from audiomentations import Compose, AddGaussianNoise, BandPassFilter | |
| def make_gain(min_db, max_db, p): | |
| from audiomentations import Gain as _Gain | |
| try: return _Gain(min_gain_in_db=min_db, max_gain_in_db=max_db, p=p) | |
| except TypeError: return _Gain(min_gain_db=min_db, max_gain_db=max_db, p=p) | |
| def set_seed(s=42): | |
| random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s) | |
| def peak_normalize(y: np.ndarray, peak: float = 0.95, eps: float = 1e-9) -> np.ndarray: | |
| m = float(np.max(np.abs(y)) + eps) | |
| return (y / m) * peak if m > 0 else y | |
| def trim_silence(y: np.ndarray, sr: int, thresh_db: float = 40.0, min_ms: int = 30) -> np.ndarray: | |
| if y.size == 0: return y | |
| win = max(1, int(sr * 0.02)); pad = max(1, int(sr * (min_ms / 1000.0))) | |
| energy = np.convolve(y**2, np.ones(win)/win, mode="same") | |
| ref = np.max(energy) + 1e-12 | |
| mask = 10*np.log10(energy/ref + 1e-12) > -thresh_db | |
| if not np.any(mask): return y | |
| idx = np.where(mask)[0]; start=max(0,int(idx[0]-pad)); end=min(len(y),int(idx[-1]+pad)) | |
| return y[start:end] | |
| class WavDataset(Dataset): | |
| """data_dir/{human,ai}/*.wav""" | |
| def __init__(self, root, split="train", val_ratio=0.15, seed=42, clip_seconds=3.0): | |
| self.root = Path(root); self.clip = float(clip_seconds) | |
| human = sorted((self.root/"human").glob("*.wav")) | |
| ai = sorted((self.root/"ai").glob("*.wav")) | |
| items = [(p,0) for p in human] + [(p,1) for p in ai] | |
| rng = random.Random(seed); rng.shuffle(items) | |
| n_val = int(len(items)*val_ratio) | |
| self.items = items[n_val:] if split=="train" else items[:n_val] | |
| self.is_train = split=="train" | |
| self.nh = sum(1 for _,y in self.items if y==0) | |
| self.na = sum(1 for _,y in self.items if y==1) | |
| self.aug_h = Compose([AddGaussianNoise(0.001,0.01,p=0.3), make_gain(-4,4,p=0.3)]) | |
| self.aug_a = Compose([BandPassFilter(200.0,3500.0,p=0.5), AddGaussianNoise(0.001,0.01,p=0.3), make_gain(-6,6,p=0.3)]) | |
| def __len__(self): return len(self.items) | |
| def __getitem__(self, idx): | |
| path, label = self.items[idx] | |
| y, sr = load_audio(str(path), TARGET_SR) | |
| y = trim_silence(y, sr, 40.0, 30) | |
| y = peak_normalize(y, 0.95) | |
| y = pad_or_trim(y, duration_s=self.clip, sr=sr) | |
| if self.is_train: | |
| y = (self.aug_a if label==1 else self.aug_h)(samples=y, sample_rate=sr) | |
| return torch.from_numpy(y).float(), torch.tensor(label, dtype=torch.long) | |
| def make_loaders(args): | |
| ds_tr = WavDataset(args.data_dir, "train", args.val_ratio, args.seed, args.clip_seconds) | |
| ds_va = WavDataset(args.data_dir, "val", args.val_ratio, args.seed, args.clip_seconds) | |
| # Weighted sampler to balance classes | |
| labels = [y for _, y in ds_tr.items] | |
| n0 = max(1, labels.count(0)); n1 = max(1, labels.count(1)) | |
| w0 = (n0 + n1) / (2 * n0); w1 = (n0 + n1) / (2 * n1) | |
| sample_weights = [w0 if y == 0 else w1 for y in labels] | |
| sampler = WeightedRandomSampler(sample_weights, num_samples=len(labels), replacement=True) | |
| workers = args.workers if args.workers >= 0 else (0 if os.name=="nt" else max(1,(os.cpu_count() or 4)//2)) | |
| pin = (not args.cpu) and torch.cuda.is_available() | |
| dl_tr = DataLoader(ds_tr, batch_size=args.batch_size, sampler=sampler, | |
| num_workers=workers, pin_memory=pin, drop_last=True) | |
| dl_va = DataLoader(ds_va, batch_size=max(1,args.batch_size//2), shuffle=False, | |
| num_workers=workers, pin_memory=pin) | |
| return ds_tr, ds_va, dl_tr, dl_va | |
| class FocalLoss(torch.nn.Module): | |
| def __init__(self, alpha=None, gamma=1.5): | |
| super().__init__() | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.ce = torch.nn.CrossEntropyLoss(weight=alpha) | |
| def forward(self, logits, target): | |
| ce = self.ce(logits, target) | |
| with torch.no_grad(): | |
| pt = torch.exp(-ce) | |
| return ((1 - pt) ** self.gamma) * ce | |
| def train_one_epoch(model, dl, device, opt, scaler, autocast_ctx, loss_fn, grad_accum=1): | |
| model.train(); total=0.0; correct=0; seen=0 | |
| opt.zero_grad(set_to_none=True) | |
| for step,(x,y) in enumerate(dl): | |
| x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True) | |
| with autocast_ctx: | |
| logits,_=model(x); loss=loss_fn(logits,y) | |
| loss=loss/grad_accum | |
| if getattr(scaler,"is_enabled",lambda:False)(): scaler.scale(loss).backward() | |
| else: loss.backward() | |
| if (step+1)%grad_accum==0: | |
| if getattr(scaler,"is_enabled",lambda:False)(): | |
| scaler.step(opt); scaler.update() | |
| else: | |
| opt.step() | |
| opt.zero_grad(set_to_none=True) | |
| total += float(loss) * x.size(0) * grad_accum | |
| correct += int((logits.argmax(1)==y).sum().item()); seen += x.size(0) | |
| return total/max(seen,1), correct/max(seen,1) | |
| def evaluate(model, dl, device, loss_fn): | |
| model.eval(); total=0.0; correct=0; seen=0 | |
| for x,y in dl: | |
| x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True) | |
| logits,_=model(x); loss=loss_fn(logits,y) | |
| total += float(loss) * x.size(0); correct += int((logits.argmax(1)==y).sum().item()); seen += x.size(0) | |
| return total/max(seen,1), correct/max(seen,1) | |
| def main(args): | |
| set_seed(args.seed) | |
| device = "cuda" if (torch.cuda.is_available() and not args.cpu) else "cpu" | |
| cudnn.benchmark = True | |
| ds_tr, ds_va, dl_tr, dl_va = make_loaders(args) | |
| print(f"Train items: {len(ds_tr)} (human={ds_tr.nh}, ai={ds_tr.na})") | |
| print(f"Val items: {len(ds_va)}") | |
| model = Wav2VecClassifier( | |
| encoder=args.encoder, | |
| unfreeze_last=args.unfreeze_last, | |
| dropout=args.dropout, | |
| hidden=args.hidden | |
| ).to(device) | |
| # Focal loss with class weights | |
| nh, na = ds_tr.nh, ds_tr.na | |
| w = torch.tensor([(nh+na)/(2*nh+1e-6), (nh+na)/(2*na+1e-6)], dtype=torch.float32).to(device) | |
| loss_fn = FocalLoss(alpha=w, gamma=1.5) | |
| head_params = list(model.head.parameters()) | |
| enc_params = [p for p in model.encoder.parameters() if p.requires_grad] | |
| param_groups = [{"params": head_params, "lr": args.lr_head}] | |
| if enc_params: | |
| param_groups.append({"params": enc_params, "lr": args.lr_encoder}) | |
| opt = torch.optim.AdamW(param_groups, weight_decay=1e-4) | |
| try: | |
| from torch.amp import GradScaler, autocast as amp_autocast | |
| scaler = GradScaler("cuda", enabled=(device=="cuda" and args.amp)) | |
| autocast_ctx = amp_autocast("cuda") if (device=="cuda" and args.amp) else nullcontext() | |
| except Exception: | |
| from torch.cuda.amp import GradScaler, autocast as amp_autocast | |
| scaler = GradScaler(enabled=(device=="cuda" and args.amp)) | |
| autocast_ctx = amp_autocast() if (device=="cuda" and args.amp) else nullcontext() | |
| best=-1.0; patience=0 | |
| Path(args.out).parent.mkdir(parents=True, exist_ok=True) | |
| with open(args.out.replace(".pth",".json"), "w", encoding="utf-8") as f: | |
| json.dump({"encoder": args.encoder, "unfreeze_last": args.unfreeze_last}, f) | |
| for epoch in range(args.epochs): | |
| tr_loss, tr_acc = train_one_epoch(model, dl_tr, device, opt, scaler, autocast_ctx, loss_fn, args.grad_accum) | |
| va_loss, va_acc = evaluate(model, dl_va, device, loss_fn) | |
| print(f"epoch {epoch+1:02d}/{args.epochs} | train {tr_loss:.3f}/{tr_acc:.3f} | val {va_loss:.3f}/{va_acc:.3f}") | |
| torch.save(model.state_dict(), args.out.replace(".pth",".last.pth")) | |
| if va_acc > best + 1e-4: | |
| best = va_acc; patience=0 | |
| torch.save(model.state_dict(), args.out) | |
| print(f"✅ Saved best to {args.out} (val_acc={best:.3f})") | |
| else: | |
| patience += 1 | |
| if args.early_stop>0 and patience>=args.early_stop: | |
| print(f"⏹️ Early stopping at epoch {epoch+1} (best={best:.3f})") | |
| break | |
| print("Done.") | |
| if __name__ == "__main__": | |
| ap = argparse.ArgumentParser(description="Train Wav2Vec2-based AI Voice Detector (balanced)") | |
| ap.add_argument("--data_dir", required=True, help="Folder with human/ and ai/ WAVs") | |
| ap.add_argument("--out", default="app/models/weights/wav2vec2_classifier.pth") | |
| ap.add_argument("--encoder", default="facebook/wav2vec2-base") | |
| ap.add_argument("--unfreeze_last", type=int, default=0) | |
| ap.add_argument("--epochs", type=int, default=8) | |
| ap.add_argument("--batch_size", type=int, default=16) | |
| ap.add_argument("--grad_accum", type=int, default=2) | |
| ap.add_argument("--lr_head", type=float, default=1e-3) | |
| ap.add_argument("--lr_encoder", type=float, default=1e-5) | |
| ap.add_argument("--val_ratio", type=float, default=0.15) | |
| ap.add_argument("--clip_seconds", type=float, default=3.0) | |
| ap.add_argument("--workers", type=int, default=-1) | |
| ap.add_argument("--amp", action="store_true", default=True) | |
| ap.add_argument("--cpu", action="store_true") | |
| ap.add_argument("--dropout", type=float, default=0.2) | |
| ap.add_argument("--hidden", type=int, default=256) | |
| ap.add_argument("--early_stop", type=int, default=0) | |
| ap.add_argument("--seed", type=int, default=42) | |
| args = ap.parse_args() | |
| main(args) |