import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import wandb from tqdm import tqdm from torch.optim.lr_scheduler import OneCycleLR from torch.cuda.amp import GradScaler, autocast import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from dataset2 import MedicalImageDatasetBalancedIntensity3D, TransformationMedicalImageDatasetBalancedIntensity3D from model import Backbone, SingleScanModel, Classifier from utils import BaseConfig import numpy as np from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score def calculate_metrics(pred_probs, pred_labels, true_labels): """ Multi-class classification metrics. Args: pred_probs (numpy.ndarray): Predicted probabilities for each class pred_labels (numpy.ndarray): Predicted labels true_labels (numpy.ndarray): Ground truth labels Returns: dict: Dictionary containing accuracy, precision, recall, F1, and AUC """ accuracy = accuracy_score(true_labels, pred_labels) precision = precision_score(true_labels, pred_labels, average='weighted') recall = recall_score(true_labels, pred_labels, average='weighted') f1 = f1_score(true_labels, pred_labels, average='weighted') auc = roc_auc_score(true_labels, pred_probs, multi_class='ovr') return { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc } #============================ # TRAINER CLASS #============================ class SequenceTrainer(BaseConfig): """ Trainer class for sequence classification """ def __init__(self): super().__init__() self.setup_wandb() self.setup_model() self.setup_data() self.setup_training() def setup_wandb(self): config = self.get_config() wandb.init( project=config['logger']['project_name'], name=config['logger']['run_name'], config=config ) def setup_model(self): self.backbone = Backbone() # Change classifier to output 4 values for multi-class classification self.classifier = Classifier(d_model=2048, num_classes=4) self.model = SingleScanModel(self.backbone, self.classifier) # Load weights from brainiac config = self.get_config() if config["train"]["finetune"] == "yes": checkpoint = torch.load(config["train"]["weights"], map_location=self.device) state_dict = checkpoint["state_dict"] filtered_state_dict = {} for key, value in state_dict.items(): new_key = key.replace("module.", "backbone.") if key.startswith("module.") else key filtered_state_dict[new_key] = value self.model.backbone.load_state_dict(filtered_state_dict, strict=False) print("Pretrained weights loaded!") if config["train"]["freeze"] == "yes": for param in self.model.backbone.parameters(): param.requires_grad = False print("Backbone weights frozen!") self.model = self.model.to(self.device) ## spinup dataloaders def setup_data(self): config = self.get_config() self.train_dataset = TransformationMedicalImageDatasetBalancedIntensity3D( csv_path=config['data']['train_csv'], root_dir=config["data"]["root_dir"] ) self.val_dataset = MedicalImageDatasetBalancedIntensity3D( csv_path=config['data']['val_csv'], root_dir=config["data"]["root_dir"] ) self.train_loader = DataLoader( self.train_dataset, batch_size=config["data"]["batch_size"], shuffle=True, collate_fn=self.custom_collate, num_workers=config["data"]["num_workers"] ) self.val_loader = DataLoader( self.val_dataset, batch_size=1, shuffle=False, collate_fn=self.custom_collate, num_workers=1 ) def setup_training(self): """ training setup """ config = self.get_config() # Cross Entropy Loss for multi-class classification self.criterion = nn.CrossEntropyLoss().to(self.device) self.optimizer = optim.AdamW( self.model.parameters(), lr=config['optim']['lr'], weight_decay=config["optim"]["weight_decay"] ) self.scheduler = OneCycleLR( self.optimizer, max_lr=config['optim']['lr'], epochs=config['optim']['max_epochs'], steps_per_epoch=len(self.train_loader) ) self.scaler = GradScaler() ## main training loop def train(self): config = self.get_config() max_epochs = config['optim']['max_epochs'] best_metrics = { 'val_loss': float('inf'), 'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0, 'auc': 0 } for epoch in range(max_epochs): train_loss = self.train_epoch(epoch, max_epochs) val_loss, metrics = self.validate_epoch(epoch, max_epochs) # save model based on auc if metrics['auc'] > best_metrics['auc']: print(f"New best model found!") print(f"Improved Val Loss from {best_metrics['val_loss']:.4f} to {val_loss:.4f}") print(f"Improved F1 from {best_metrics['f1']:.4f} to {metrics['f1']:.4f}") best_metrics.update(metrics) best_metrics['val_loss'] = val_loss self.save_checkpoint(epoch, val_loss, metrics) wandb.finish() ## training pass def train_epoch(self, epoch, max_epochs): self.model.train() train_loss = 0.0 for sample in tqdm(self.train_loader, desc=f"Training Epoch {epoch}/{max_epochs-1}"): inputs = sample['image'].to(self.device) labels = sample['label'].to(self.device) # No need for float() conversion self.optimizer.zero_grad(set_to_none=True) with autocast(): outputs = self.model(inputs) loss = self.criterion(outputs, labels) # CrossEntropyLoss expects raw logits self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.scaler.step(self.optimizer) self.scaler.update() self.scheduler.step() train_loss += loss.item() * inputs.size(0) train_loss = train_loss / len(self.train_loader.dataset) wandb.log({"Train Loss": train_loss}) return train_loss ## validation pass def validate_epoch(self, epoch, max_epochs): self.model.eval() val_loss = 0.0 all_labels = [] all_preds = [] all_probs = [] with torch.no_grad(): for sample in tqdm(self.val_loader, desc=f"Validation Epoch {epoch}/{max_epochs-1}"): inputs = sample['image'].to(self.device) labels = sample['label'].to(self.device) # No need for float() conversion outputs = self.model(inputs) loss = self.criterion(outputs, labels) # CrossEntropyLoss expects raw logits # Get probabilities and predictions for multi-class probs = torch.softmax(outputs, dim=1).cpu().numpy() preds = np.argmax(probs, axis=1) val_loss += loss.item() * inputs.size(0) all_labels.extend(labels.cpu().numpy()) all_preds.extend(preds) all_probs.extend(probs) val_loss = val_loss / len(self.val_loader.dataset) metrics = calculate_metrics( np.array(all_probs), np.array(all_preds), np.array(all_labels) ) wandb.log({ "Val Loss": val_loss, "Accuracy": metrics['accuracy'], "Precision": metrics['precision'], "Recall": metrics['recall'], "F1 Score": metrics['f1'], "AUC": metrics['auc'] }) print(f"Epoch {epoch}/{max_epochs-1}") print(f"Val Loss: {val_loss:.4f}") print(f"Accuracy: {metrics['accuracy']:.4f}") print(f"Precision: {metrics['precision']:.4f}") print(f"Recall: {metrics['recall']:.4f}") print(f"F1 Score: {metrics['f1']:.4f}") print(f"AUC: {metrics['auc']:.4f}") return val_loss, metrics ## save checkpoint def save_checkpoint(self, epoch, loss, metrics): config = self.get_config() checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'metrics': metrics } save_path = os.path.join( config['logger']['save_dir'], config['logger']['save_name'].format(epoch=epoch, loss=loss, metric=metrics['f1']) ) torch.save(checkpoint, save_path) if __name__ == "__main__": trainer = SequenceTrainer() trainer.train()