Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import os | |
| from collections import defaultdict | |
| from collections.abc import Callable | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import wandb | |
| from torch.cuda.amp import GradScaler, autocast | |
| from tqdm import tqdm | |
| from utmosv2.utils import calc_metrics, print_metrics | |
| def _train_1epoch( | |
| cfg, | |
| model: torch.nn.Module, | |
| train_dataloader: torch.utils.data.DataLoader, | |
| criterion: torch.nn.Module, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: torch.optim.lr_scheduler.LRScheduler, | |
| device: torch.device, | |
| ) -> dict[str, float]: | |
| model.train() | |
| train_loss = defaultdict(float) | |
| scaler = GradScaler() | |
| print(f" (lr: {scheduler.get_last_lr()[0]:.6f})") | |
| pbar = tqdm(train_dataloader, total=len(train_dataloader)) | |
| for i, t in enumerate(pbar): | |
| x, y = t[:-1], t[-1] | |
| x = [t.to(device, non_blocking=True) for t in x] | |
| y = y.to(device, non_blocking=True) | |
| if cfg.run.mixup: | |
| lmd = np.random.beta(cfg.run.mixup_alpha, cfg.run.mixup_alpha) | |
| perm = torch.randperm(x[0].shape[0]).to(device) | |
| x2 = [t[perm, :] for t in x] | |
| y2 = y[perm] | |
| optimizer.zero_grad() | |
| with autocast(): | |
| if cfg.run.mixup: | |
| output = model( | |
| *[lmd * t + (1 - lmd) * t2 for t, t2 in zip(x, x2)] | |
| ).squeeze(1) | |
| if isinstance(cfg.loss, list): | |
| loss = [ | |
| (w1, lmd * l1 + (1 - lmd) * l2) | |
| for (w1, l1), (_, l2) in zip( | |
| criterion(output, y), criterion(output, y2) | |
| ) | |
| ] | |
| else: | |
| loss = lmd * criterion(output, y) + (1 - lmd) * criterion( | |
| output, y2 | |
| ) | |
| else: | |
| output = model(*x).squeeze(1) | |
| loss = criterion(output, y) | |
| if isinstance(loss, list): | |
| loss_total = sum(w * ls for w, ls in loss) | |
| else: | |
| loss_total = loss | |
| scaler.scale(loss_total).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler.step() | |
| train_loss["loss"] += loss_total.detach().float().cpu().item() | |
| if isinstance(loss, list): | |
| for (cl, _), (_, ls) in zip(cfg.loss, loss): | |
| train_loss[cl.name] += ls.detach().float().cpu().item() | |
| pbar.set_description( | |
| f' loss: {train_loss["loss"] / (i + 1):.4f}' | |
| + ( | |
| f' ({", ".join([f"{cl.name}: {train_loss[cl.name] / (i + 1):.4f}" for cl, _ in cfg.loss])})' | |
| if isinstance(loss, list) | |
| else "" | |
| ) | |
| ) | |
| return {name: v / len(train_dataloader) for name, v in train_loss.items()} | |
| def _validate_1epoch( | |
| cfg, | |
| model: torch.nn.Module, | |
| valid_dataloader: torch.utils.data.DataLoader, | |
| criterion: torch.nn.Module, | |
| metrics: dict[str, Callable[[np.ndarray, np.ndarray], float]], | |
| device: torch.device, | |
| ) -> tuple[dict[str, float], dict[str, float], np.ndarray]: | |
| model.eval() | |
| valid_loss = defaultdict(float) | |
| valid_metrics = {name: 0.0 for name in metrics} | |
| valid_preds = [] | |
| pbar = tqdm(valid_dataloader, total=len(valid_dataloader)) | |
| with torch.no_grad(): | |
| for i, t in enumerate(pbar): | |
| x, y = t[:-1], t[-1] | |
| x = [t.to(device, non_blocking=True) for t in x] | |
| y_cpu = y | |
| y = y.to(device, non_blocking=True) | |
| with autocast(): | |
| output = model(*x).squeeze(1) | |
| loss = criterion(output, y) | |
| if isinstance(loss, list): | |
| loss_total = sum(w * ls for w, ls in loss) | |
| else: | |
| loss_total = loss | |
| valid_loss["loss"] += loss_total.detach().float().cpu().item() | |
| if isinstance(loss, list): | |
| for (cl, _), (_, ls) in zip(cfg.loss, loss): | |
| valid_loss[cl.name] += ls.detach().float().cpu().item() | |
| output = output.cpu().numpy() | |
| for name, metric in metrics.items(): | |
| valid_metrics[name] += metric(output, y_cpu.numpy()) | |
| pbar.set_description( | |
| f' val_loss: {valid_loss["loss"] / (i + 1):.4f} ' | |
| + ( | |
| f'({", ".join([f"{cl.name}: {valid_loss[cl.name] / (i + 1):.4f}" for cl, _ in cfg.loss])}) ' | |
| if isinstance(loss, list) | |
| else "" | |
| ) | |
| + " - ".join( | |
| [ | |
| f"val_{name}: {v / (i + 1):.4f}" | |
| for name, v in valid_metrics.items() | |
| ] | |
| ) | |
| ) | |
| valid_preds.append(output) | |
| valid_loss = {name: v / len(valid_dataloader) for name, v in valid_loss.items()} | |
| valid_metrics = { | |
| name: v / len(valid_dataloader) for name, v in valid_metrics.items() | |
| } | |
| valid_preds = np.concatenate(valid_preds) | |
| return valid_loss, valid_metrics, valid_preds | |
| def run_train( | |
| cfg, | |
| model: torch.nn.Module, | |
| train_dataloader: torch.utils.data.DataLoader, | |
| valid_dataloader: torch.utils.data.DataLoader, | |
| valid_data: pd.DataFrame, | |
| oof_preds: np.ndarray, | |
| now_fold: int, | |
| criterion: torch.nn.Module, | |
| metrics: dict[str, Callable[[np.ndarray, np.ndarray], float]], | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: torch.optim.lr_scheduler.LRScheduler, | |
| device: torch.device, | |
| ) -> None: | |
| best_metric = 0.0 | |
| os.makedirs(cfg.save_path, exist_ok=True) | |
| for epoch in range(cfg.run.num_epochs): | |
| print(f"[Epoch {epoch + 1}/{cfg.run.num_epochs}]") | |
| train_loss = _train_1epoch( | |
| cfg, model, train_dataloader, criterion, optimizer, scheduler, device | |
| ) | |
| valid_loss, _, valid_preds = _validate_1epoch( | |
| cfg, model, valid_dataloader, criterion, metrics, device | |
| ) | |
| print(f"Validation dataset: {cfg.validation_dataset}") | |
| if cfg.validation_dataset == "each": | |
| dataset = valid_data["dataset"].unique() | |
| val_metrics = [ | |
| calc_metrics( | |
| valid_data[valid_data["dataset"] == ds], | |
| valid_preds[valid_data["dataset"] == ds], | |
| ) | |
| for ds in dataset | |
| ] | |
| val_metrics = { | |
| name: sum([m[name] for m in val_metrics]) / len(val_metrics) | |
| for name in val_metrics[0].keys() | |
| } | |
| if cfg.validation_dataset == "all": | |
| print("Validation dataset: ALL") | |
| val_metrics = calc_metrics(valid_data, valid_preds) | |
| else: | |
| val_metrics = calc_metrics( | |
| valid_data[valid_data["dataset"] == cfg.validation_dataset], | |
| valid_preds[valid_data["dataset"] == cfg.validation_dataset], | |
| ) | |
| print_metrics(val_metrics) | |
| if val_metrics[cfg.main_metric] > best_metric: | |
| new_metric = val_metrics[cfg.main_metric] | |
| print(f"(Found best metric: {best_metric:.4f} -> {new_metric:.4f})") | |
| best_metric = new_metric | |
| save_path = ( | |
| cfg.save_path / f"fold{now_fold}_s{cfg.split.seed}_best_model.pth" | |
| ) | |
| torch.save(model.state_dict(), save_path) | |
| print(f"Save best model: {save_path}") | |
| oof_preds[valid_data.index] = valid_preds | |
| save_path = cfg.save_path / f"fold{now_fold}_s{cfg.split.seed}_last_model.pth" | |
| torch.save(model.state_dict(), save_path) | |
| print() | |
| val_metrics["train_loss"] = train_loss["loss"] | |
| val_metrics["val_loss"] = valid_loss["loss"] | |
| for cl, _ in cfg.loss: | |
| val_metrics[f"train_loss_{cl.name}"] = train_loss[cl.name] | |
| val_metrics[f"val_loss_{cl.name}"] = valid_loss[cl.name] | |
| if cfg.wandb: | |
| wandb.log(val_metrics) | |