Voice-guard / app /train_wav2vec.py
varunkul's picture
Upload 6 files
e2c61ce verified
import os, json, argparse, random
from pathlib import Path
from contextlib import nullcontext
import numpy as np
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
try:
from .models.wav2vec_detector import Wav2VecClassifier
from .utils.audio import load_audio, pad_or_trim, TARGET_SR
except ImportError:
from app.models.wav2vec_detector import Wav2VecClassifier
from app.utils.audio import load_audio, pad_or_trim, TARGET_SR
from audiomentations import Compose, AddGaussianNoise, BandPassFilter
def make_gain(min_db, max_db, p):
from audiomentations import Gain as _Gain
try: return _Gain(min_gain_in_db=min_db, max_gain_in_db=max_db, p=p)
except TypeError: return _Gain(min_gain_db=min_db, max_gain_db=max_db, p=p)
def set_seed(s=42):
random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
def peak_normalize(y: np.ndarray, peak: float = 0.95, eps: float = 1e-9) -> np.ndarray:
m = float(np.max(np.abs(y)) + eps)
return (y / m) * peak if m > 0 else y
def trim_silence(y: np.ndarray, sr: int, thresh_db: float = 40.0, min_ms: int = 30) -> np.ndarray:
if y.size == 0: return y
win = max(1, int(sr * 0.02)); pad = max(1, int(sr * (min_ms / 1000.0)))
energy = np.convolve(y**2, np.ones(win)/win, mode="same")
ref = np.max(energy) + 1e-12
mask = 10*np.log10(energy/ref + 1e-12) > -thresh_db
if not np.any(mask): return y
idx = np.where(mask)[0]; start=max(0,int(idx[0]-pad)); end=min(len(y),int(idx[-1]+pad))
return y[start:end]
class WavDataset(Dataset):
"""data_dir/{human,ai}/*.wav"""
def __init__(self, root, split="train", val_ratio=0.15, seed=42, clip_seconds=3.0):
self.root = Path(root); self.clip = float(clip_seconds)
human = sorted((self.root/"human").glob("*.wav"))
ai = sorted((self.root/"ai").glob("*.wav"))
items = [(p,0) for p in human] + [(p,1) for p in ai]
rng = random.Random(seed); rng.shuffle(items)
n_val = int(len(items)*val_ratio)
self.items = items[n_val:] if split=="train" else items[:n_val]
self.is_train = split=="train"
self.nh = sum(1 for _,y in self.items if y==0)
self.na = sum(1 for _,y in self.items if y==1)
self.aug_h = Compose([AddGaussianNoise(0.001,0.01,p=0.3), make_gain(-4,4,p=0.3)])
self.aug_a = Compose([BandPassFilter(200.0,3500.0,p=0.5), AddGaussianNoise(0.001,0.01,p=0.3), make_gain(-6,6,p=0.3)])
def __len__(self): return len(self.items)
def __getitem__(self, idx):
path, label = self.items[idx]
y, sr = load_audio(str(path), TARGET_SR)
y = trim_silence(y, sr, 40.0, 30)
y = peak_normalize(y, 0.95)
y = pad_or_trim(y, duration_s=self.clip, sr=sr)
if self.is_train:
y = (self.aug_a if label==1 else self.aug_h)(samples=y, sample_rate=sr)
return torch.from_numpy(y).float(), torch.tensor(label, dtype=torch.long)
def make_loaders(args):
ds_tr = WavDataset(args.data_dir, "train", args.val_ratio, args.seed, args.clip_seconds)
ds_va = WavDataset(args.data_dir, "val", args.val_ratio, args.seed, args.clip_seconds)
# Weighted sampler to balance classes
labels = [y for _, y in ds_tr.items]
n0 = max(1, labels.count(0)); n1 = max(1, labels.count(1))
w0 = (n0 + n1) / (2 * n0); w1 = (n0 + n1) / (2 * n1)
sample_weights = [w0 if y == 0 else w1 for y in labels]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(labels), replacement=True)
workers = args.workers if args.workers >= 0 else (0 if os.name=="nt" else max(1,(os.cpu_count() or 4)//2))
pin = (not args.cpu) and torch.cuda.is_available()
dl_tr = DataLoader(ds_tr, batch_size=args.batch_size, sampler=sampler,
num_workers=workers, pin_memory=pin, drop_last=True)
dl_va = DataLoader(ds_va, batch_size=max(1,args.batch_size//2), shuffle=False,
num_workers=workers, pin_memory=pin)
return ds_tr, ds_va, dl_tr, dl_va
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=None, gamma=1.5):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.ce = torch.nn.CrossEntropyLoss(weight=alpha)
def forward(self, logits, target):
ce = self.ce(logits, target)
with torch.no_grad():
pt = torch.exp(-ce)
return ((1 - pt) ** self.gamma) * ce
def train_one_epoch(model, dl, device, opt, scaler, autocast_ctx, loss_fn, grad_accum=1):
model.train(); total=0.0; correct=0; seen=0
opt.zero_grad(set_to_none=True)
for step,(x,y) in enumerate(dl):
x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True)
with autocast_ctx:
logits,_=model(x); loss=loss_fn(logits,y)
loss=loss/grad_accum
if getattr(scaler,"is_enabled",lambda:False)(): scaler.scale(loss).backward()
else: loss.backward()
if (step+1)%grad_accum==0:
if getattr(scaler,"is_enabled",lambda:False)():
scaler.step(opt); scaler.update()
else:
opt.step()
opt.zero_grad(set_to_none=True)
total += float(loss) * x.size(0) * grad_accum
correct += int((logits.argmax(1)==y).sum().item()); seen += x.size(0)
return total/max(seen,1), correct/max(seen,1)
@torch.no_grad()
def evaluate(model, dl, device, loss_fn):
model.eval(); total=0.0; correct=0; seen=0
for x,y in dl:
x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True)
logits,_=model(x); loss=loss_fn(logits,y)
total += float(loss) * x.size(0); correct += int((logits.argmax(1)==y).sum().item()); seen += x.size(0)
return total/max(seen,1), correct/max(seen,1)
def main(args):
set_seed(args.seed)
device = "cuda" if (torch.cuda.is_available() and not args.cpu) else "cpu"
cudnn.benchmark = True
ds_tr, ds_va, dl_tr, dl_va = make_loaders(args)
print(f"Train items: {len(ds_tr)} (human={ds_tr.nh}, ai={ds_tr.na})")
print(f"Val items: {len(ds_va)}")
model = Wav2VecClassifier(
encoder=args.encoder,
unfreeze_last=args.unfreeze_last,
dropout=args.dropout,
hidden=args.hidden
).to(device)
# Focal loss with class weights
nh, na = ds_tr.nh, ds_tr.na
w = torch.tensor([(nh+na)/(2*nh+1e-6), (nh+na)/(2*na+1e-6)], dtype=torch.float32).to(device)
loss_fn = FocalLoss(alpha=w, gamma=1.5)
head_params = list(model.head.parameters())
enc_params = [p for p in model.encoder.parameters() if p.requires_grad]
param_groups = [{"params": head_params, "lr": args.lr_head}]
if enc_params:
param_groups.append({"params": enc_params, "lr": args.lr_encoder})
opt = torch.optim.AdamW(param_groups, weight_decay=1e-4)
try:
from torch.amp import GradScaler, autocast as amp_autocast
scaler = GradScaler("cuda", enabled=(device=="cuda" and args.amp))
autocast_ctx = amp_autocast("cuda") if (device=="cuda" and args.amp) else nullcontext()
except Exception:
from torch.cuda.amp import GradScaler, autocast as amp_autocast
scaler = GradScaler(enabled=(device=="cuda" and args.amp))
autocast_ctx = amp_autocast() if (device=="cuda" and args.amp) else nullcontext()
best=-1.0; patience=0
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
with open(args.out.replace(".pth",".json"), "w", encoding="utf-8") as f:
json.dump({"encoder": args.encoder, "unfreeze_last": args.unfreeze_last}, f)
for epoch in range(args.epochs):
tr_loss, tr_acc = train_one_epoch(model, dl_tr, device, opt, scaler, autocast_ctx, loss_fn, args.grad_accum)
va_loss, va_acc = evaluate(model, dl_va, device, loss_fn)
print(f"epoch {epoch+1:02d}/{args.epochs} | train {tr_loss:.3f}/{tr_acc:.3f} | val {va_loss:.3f}/{va_acc:.3f}")
torch.save(model.state_dict(), args.out.replace(".pth",".last.pth"))
if va_acc > best + 1e-4:
best = va_acc; patience=0
torch.save(model.state_dict(), args.out)
print(f"✅ Saved best to {args.out} (val_acc={best:.3f})")
else:
patience += 1
if args.early_stop>0 and patience>=args.early_stop:
print(f"⏹️ Early stopping at epoch {epoch+1} (best={best:.3f})")
break
print("Done.")
if __name__ == "__main__":
ap = argparse.ArgumentParser(description="Train Wav2Vec2-based AI Voice Detector (balanced)")
ap.add_argument("--data_dir", required=True, help="Folder with human/ and ai/ WAVs")
ap.add_argument("--out", default="app/models/weights/wav2vec2_classifier.pth")
ap.add_argument("--encoder", default="facebook/wav2vec2-base")
ap.add_argument("--unfreeze_last", type=int, default=0)
ap.add_argument("--epochs", type=int, default=8)
ap.add_argument("--batch_size", type=int, default=16)
ap.add_argument("--grad_accum", type=int, default=2)
ap.add_argument("--lr_head", type=float, default=1e-3)
ap.add_argument("--lr_encoder", type=float, default=1e-5)
ap.add_argument("--val_ratio", type=float, default=0.15)
ap.add_argument("--clip_seconds", type=float, default=3.0)
ap.add_argument("--workers", type=int, default=-1)
ap.add_argument("--amp", action="store_true", default=True)
ap.add_argument("--cpu", action="store_true")
ap.add_argument("--dropout", type=float, default=0.2)
ap.add_argument("--hidden", type=int, default=256)
ap.add_argument("--early_stop", type=int, default=0)
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
main(args)