Voice-guard / app /train.py
varunkul's picture
Upload 6 files
e2c61ce verified
import os
import argparse
import random
from pathlib import Path
from contextlib import nullcontext
import importlib.util
import numpy as np
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
# ---------- Local imports ----------
try:
from .models.cnn_melspec import TinyMelCNN
from .utils.audio import load_audio, pad_or_trim, logmel, TARGET_SR
except ImportError:
from app.models.cnn_melspec import TinyMelCNN
from app.utils.audio import load_audio, pad_or_trim, logmel, TARGET_SR
# ---------- Augmentations (robust across versions) ----------
from audiomentations import (
Compose, AddGaussianNoise, TimeStretch, PitchShift, BandPassFilter
)
def make_gain(min_db, max_db, p):
"""Handle both min_gain_in_db/max_gain_in_db and min_gain_db/max_gain_db."""
from audiomentations import Gain as _Gain
try:
return _Gain(min_gain_in_db=min_db, max_gain_in_db=max_db, p=p)
except TypeError:
return _Gain(min_gain_db=min_db, max_gain_db=max_db, p=p)
def make_clipping(p=0.3):
"""
Build ClippingDistortion across versions.
Newer: min_percent/max_percent (0..20 typical)
Older: min_percentile_threshold/max_percentile_threshold in [0..100]
Returns None if not available.
"""
try:
from audiomentations import ClippingDistortion as _Clip
except Exception:
return None
# Try newer signature
for kwargs in (
dict(min_percent=0.0, max_percent=20.0, p=p),
dict(min_percent=5.0, max_percent=30.0, p=p),
):
try:
return _Clip(**kwargs)
except Exception:
pass
# Try older signature
for kwargs in (
dict(min_percentile_threshold=95, max_percentile_threshold=100, p=p),
dict(min_percentile_threshold=90, max_percentile_threshold=99, p=p),
):
try:
return _Clip(**kwargs)
except Exception:
pass
return None
def have_fast_mp3():
return importlib.util.find_spec("fast_mp3_augment") is not None
def make_mp3_compression(min_bitrate=48, max_bitrate=96, p=0.6):
"""
Only enable Mp3Compression when the fast backend is present.
On Windows without the extra package this often breaks; we skip it.
"""
if not have_fast_mp3():
return None
try:
from audiomentations import Mp3Compression as _Mp3
# Prefer the fast backend; if API lacks backend arg, constructor still works.
try:
return _Mp3(min_bitrate=min_bitrate, max_bitrate=max_bitrate, p=p, backend="fast_mp3_augment")
except TypeError:
return _Mp3(min_bitrate=min_bitrate, max_bitrate=max_bitrate, p=p)
except Exception:
return None
# ---------- Repro ----------
def set_seed(seed: int = 42):
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
# ---------- Dataset ----------
class FolderDataset(Dataset):
"""
data_dir/
human/*.wav
ai/*.wav
"""
def __init__(self, root: str, split: str = "train", val_ratio: float = 0.15,
seed: int = 42, clip_seconds: float = 3.0):
self.root = Path(root)
self.clip_seconds = float(clip_seconds)
human = sorted((self.root / "human").glob("*.wav"))
ai = sorted((self.root / "ai").glob("*.wav"))
pairs = [(p, 0) for p in human] + [(p, 1) for p in ai]
rng = random.Random(seed)
rng.shuffle(pairs)
n_val = int(len(pairs) * val_ratio)
self.items = pairs[n_val:] if split == "train" else pairs[:n_val]
self.is_train = split == "train"
self._len_h = sum(1 for _, y in self.items if y == 0)
self._len_a = sum(1 for _, y in self.items if y == 1)
# Human: mild, natural perturbations
self.aug_human = Compose([
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.01, p=0.4),
TimeStretch(min_rate=0.96, max_rate=1.04, p=0.3),
PitchShift(min_semitones=-1, max_semitones=1, p=0.2),
make_gain(-4, 4, p=0.3),
])
# AI: replay-aware chain (speaker/room/mic simulation)
ai_transforms = [
BandPassFilter(min_center_freq=200.0, max_center_freq=3500.0, p=0.5),
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.01, p=0.3),
TimeStretch(min_rate=0.95, max_rate=1.05, p=0.25),
make_gain(-6, 6, p=0.3),
]
clip = make_clipping(p=0.3)
if clip is not None:
ai_transforms.insert(1, clip)
mp3 = make_mp3_compression()
if mp3 is not None:
ai_transforms.insert(0, mp3)
self.aug_ai = Compose(ai_transforms)
def __len__(self):
return len(self.items)
def __getitem__(self, idx: int):
path, label = self.items[idx]
y, sr = load_audio(str(path), TARGET_SR)
y = pad_or_trim(y, duration_s=self.clip_seconds, sr=sr)
if self.is_train:
if label == 1:
y = self.aug_ai(samples=y, sample_rate=sr)
else:
y = self.aug_human(samples=y, sample_rate=sr)
mel = logmel(y, sr) # (n_mels, T)
x = torch.from_numpy(mel).unsqueeze(0) # (1, n_mels, T)
y_t = torch.tensor(label, dtype=torch.long)
return x, y_t
# ---------- Dataloaders ----------
def make_dataloaders(args):
ds_tr = FolderDataset(args.data_dir, split="train", val_ratio=args.val_ratio,
seed=args.seed, clip_seconds=args.clip_seconds)
ds_va = FolderDataset(args.data_dir, split="val", val_ratio=args.val_ratio,
seed=args.seed, clip_seconds=args.clip_seconds)
# Windows is happier with workers=0; keep configurable
workers = args.workers if args.workers >= 0 else (0 if os.name == "nt" else max(1, (os.cpu_count() or 4)//2))
pin = (not args.cpu) and torch.cuda.is_available()
dl_tr = DataLoader(
ds_tr, batch_size=args.batch_size, shuffle=True,
num_workers=workers, pin_memory=pin,
persistent_workers=(workers > 0), drop_last=True,
)
dl_va = DataLoader(
ds_va, batch_size=max(1, args.batch_size // 2), shuffle=False,
num_workers=workers, pin_memory=pin,
persistent_workers=(workers > 0),
)
return ds_tr, ds_va, dl_tr, dl_va
def class_weights_from_dataset(ds: FolderDataset, eps: float = 1e-6):
n_h, n_a = max(ds._len_h, eps), max(ds._len_a, eps)
w_h = (n_h + n_a) / (2 * n_h)
w_a = (n_h + n_a) / (2 * n_a)
return torch.tensor([w_h, w_a], dtype=torch.float32)
# ---------- Training / Eval ----------
def train_one_epoch(model, dl, device, opt, scaler, autocast_ctx, loss_fn, grad_accum=1):
model.train()
total_loss = 0.0
correct = 0
seen = 0
opt.zero_grad(set_to_none=True)
for step, (x, y) in enumerate(dl):
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
with autocast_ctx:
logits = model(x)
loss = loss_fn(logits, y)
loss = loss / grad_accum
if getattr(scaler, "is_enabled", lambda: False)():
scaler.scale(loss).backward()
else:
loss.backward()
if (step + 1) % grad_accum == 0:
if getattr(scaler, "is_enabled", lambda: False)():
scaler.step(opt)
scaler.update()
else:
opt.step()
opt.zero_grad(set_to_none=True)
total_loss += float(loss) * x.size(0) * grad_accum
correct += int((logits.argmax(1) == y).sum().item())
seen += x.size(0)
return total_loss / max(seen, 1), correct / max(seen, 1)
@torch.no_grad()
def evaluate(model, dl, device, loss_fn):
model.eval()
total_loss = 0.0
correct = 0
seen = 0
for x, y in dl:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
logits = model(x)
loss = loss_fn(logits, y)
total_loss += float(loss) * x.size(0)
correct += int((logits.argmax(1) == y).sum().item())
seen += x.size(0)
return total_loss / max(seen, 1), correct / max(seen, 1)
def main(args):
set_seed(args.seed)
device = "cuda" if (torch.cuda.is_available() and not args.cpu) else "cpu"
cudnn.benchmark = True
ds_tr, ds_va, dl_tr, dl_va = make_dataloaders(args)
print(f"Train items: {len(ds_tr)} (human={ds_tr._len_h}, ai={ds_tr._len_a})")
print(f"Val items: {len(ds_va)}")
model = TinyMelCNN().to(device)
weights = class_weights_from_dataset(ds_tr).to(device)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
# AMP (use new torch.amp if available, else fallback)
try:
from torch.amp import GradScaler, autocast as amp_autocast
scaler = GradScaler("cuda", enabled=(device == "cuda" and args.amp))
autocast_ctx = amp_autocast("cuda") if (device == "cuda" and args.amp) else nullcontext()
except Exception:
from torch.cuda.amp import GradScaler, autocast as amp_autocast # deprecated but works
scaler = GradScaler(enabled=(device == "cuda" and args.amp))
autocast_ctx = amp_autocast() if (device == "cuda" and args.amp) else nullcontext()
best_va = -1.0
patience_counter = 0
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
for epoch in range(args.epochs):
tr_loss, tr_acc = train_one_epoch(
model, dl_tr, device, opt, scaler, autocast_ctx, loss_fn,
grad_accum=args.grad_accum
)
va_loss, va_acc = evaluate(model, dl_va, device, loss_fn)
print(f"epoch {epoch+1:02d}/{args.epochs} | train {tr_loss:.3f}/{tr_acc:.3f} | val {va_loss:.3f}/{va_acc:.3f}")
# Save "last" every epoch
torch.save(model.state_dict(), args.out.replace(".pth", ".last.pth"))
if va_acc > best_va + 1e-4:
best_va = va_acc
torch.save(model.state_dict(), args.out)
patience_counter = 0
print(f"✅ Saved best to {args.out} (val_acc={best_va:.3f})")
else:
patience_counter += 1
if args.early_stop > 0 and patience_counter >= args.early_stop:
print(f"⏹️ Early stopping at epoch {epoch+1} (best val_acc={best_va:.3f})")
break
print("Done.")
if __name__ == "__main__":
p = argparse.ArgumentParser(description="Train AI Voice Detector (replay-aware, version-robust, no fast_mp3 required)")
p.add_argument("--data_dir", type=str, required=True, help="Folder with subfolders human/ and ai/")
p.add_argument("--out", type=str, default="app/models/weights/cnn_melspec.pth")
p.add_argument("--epochs", type=int, default=10)
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--grad_accum", type=int, default=2)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--val_ratio", type=float, default=0.15)
p.add_argument("--clip_seconds", type=float, default=3.0)
p.add_argument("--workers", type=int, default=-1) # try --workers 0 on Windows if you see issues
p.add_argument("--amp", action="store_true", default=True)
p.add_argument("--cpu", action="store_true")
p.add_argument("--early_stop", type=int, default=0)
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()
main(args)