Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data import Dataset | |
| from termcolor import colored | |
| from transformers.optimization import AdamW | |
| from itertools import chain | |
| import sys | |
| sys.path.append("..") | |
| from transformers.optimization import get_linear_schedule_with_warmup | |
| import os | |
| import math | |
| import time | |
| from datetime import datetime as dt | |
| from torch.utils.data import DataLoader | |
| from params import * | |
| from utils.logger import get_logger | |
| from models.model import ModelWrapper | |
| from models.sampler import RandomBatchSampler, BucketBatchSampler | |
| from utils.metrics import get_metric_for_tfm | |
| from accelerate import Accelerator | |
| from dataset.autocorrect_dataset import SpellCorrectDataset | |
| from dataset.util import load_epoch_dataset | |
| class Trainer(): | |
| def __init__(self, model_wrapper: ModelWrapper, data_path, dataset_name, valid_dataset: Dataset): | |
| self.model_wrapper = model_wrapper | |
| self.model = model_wrapper.model | |
| self.model_name = model_wrapper.model_name | |
| self.data_path = data_path | |
| self.incorrect_file = f'{dataset_name}.train.noise' | |
| self.correct_file = f'{dataset_name}.train' | |
| self.length_file = f'{dataset_name}.length.train' | |
| train_dataset = load_epoch_dataset(data_path, self.correct_file, \ | |
| self.incorrect_file, self.length_file, 1, EPOCHS) | |
| train_dataset = SpellCorrectDataset(dataset=train_dataset) | |
| self.train_dataset = train_dataset | |
| self.valid_dataset = valid_dataset | |
| if not BUCKET_SAMPLING: | |
| self.train_sampler = RandomBatchSampler(train_dataset, TRAIN_BATCH_SIZE) | |
| self.valid_sampler = RandomBatchSampler(valid_dataset, VALID_BATCH_SIZE, shuffle = False) | |
| else: | |
| self.train_sampler = BucketBatchSampler(train_dataset) | |
| self.valid_sampler = BucketBatchSampler(valid_dataset, shuffle = False) | |
| self.train_data = DataLoader(dataset=train_dataset, batch_sampler=self.train_sampler, | |
| collate_fn=model_wrapper.collator.collate, num_workers=2, pin_memory=True) | |
| self.valid_data = DataLoader(dataset=valid_dataset, batch_sampler=self.valid_sampler, | |
| collate_fn=model_wrapper.collator.collate, num_workers=2, pin_memory=True) | |
| self.total_training_steps = len(self.train_dataset) * EPOCHS | |
| self.checkpoint_cycle = math.ceil((len(self.train_data) * EPOCHS / CHECKPOINT_FREQ) / PRINT_PER_ITER) * PRINT_PER_ITER | |
| self.print_every = PRINT_PER_ITER | |
| self.iter = 0 | |
| self.scratch_iter = 0 | |
| self.start_epoch = 1 | |
| self.best_F1 = -1 | |
| self.current_epoch = 1 | |
| self.progress_epoch = None | |
| self.max_epochs = EPOCHS | |
| self.learning_rate = MAX_LR | |
| self.optimizer = AdamW(self.model.parameters(), | |
| lr=self.learning_rate, | |
| weight_decay=0.01, | |
| correct_bias=False) | |
| self.num_warmup_steps = WARMUP_PERCENT * self.total_training_steps | |
| self.scheduler = get_linear_schedule_with_warmup( | |
| self.optimizer, num_warmup_steps=self.num_warmup_steps, num_training_steps=self.total_training_steps) | |
| self.train_losses = [] | |
| self.accelerator = Accelerator(cpu= True if DEVICE == "cpu" else False) | |
| self.device = self.accelerator.device | |
| self.total_fw_time = 0 | |
| log_path = LOG + \ | |
| f'/pytorch.{self.model_name}.lr.{self.learning_rate}.train.log' | |
| if log_path: | |
| self.logger = get_logger(log_path) | |
| self.logger.log(f'DEVICE is: {self.device}') | |
| self.logger.log( | |
| f"============TOTAL TRAINING STEPS===========\n{self.total_training_steps}") | |
| self.logger.log(f"CHECKPOINT CYCLE: {self.checkpoint_cycle} ITER") | |
| def load_lazy_dataset(self, epoch): | |
| train_dataset = load_epoch_dataset(self.data_path, self.correct_file,\ | |
| self.incorrect_file, self.length_file, epoch, EPOCHS) | |
| self.train_dataset = SpellCorrectDataset(dataset=train_dataset) | |
| if not BUCKET_SAMPLING: | |
| self.train_sampler = RandomBatchSampler(self.train_dataset, TRAIN_BATCH_SIZE) | |
| else: | |
| self.train_sampler = BucketBatchSampler(self.train_dataset) | |
| self.train_data = DataLoader(dataset=self.train_dataset, batch_sampler=self.train_sampler, | |
| collate_fn=self.model_wrapper.collator.collate,\ | |
| num_workers=2, pin_memory=True) | |
| def step(self, batch, training=True): | |
| if training: | |
| self.model.train() | |
| start = time.time() | |
| outputs = self.model(batch['batch_src'], batch['attn_masks'], batch['batch_tgt']) # outputs.logits , outputs.loss | |
| self.total_fw_time += time.time() - start | |
| loss = outputs['loss'] | |
| batch_loss = outputs['loss'].cpu().detach().numpy() | |
| self.optimizer.zero_grad() | |
| self.accelerator.backward(loss) | |
| # Gradient clipping is not in AdamW anymore (so you can use amp without issue) | |
| torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), max_norm=1.0) | |
| self.optimizer.step() | |
| self.scheduler.step(self.iter) | |
| return batch_loss | |
| else: | |
| self.model.eval() | |
| outputs = self.model(batch['batch_src'], batch['attn_masks'], batch['batch_tgt']) | |
| return outputs['loss'], outputs['preds'], \ | |
| batch['batch_tgt'].cpu().detach().numpy(), batch['lengths'] | |
| def train(self): | |
| self.logger.log("Loading model to device") | |
| self.model, self.optimizer, self.scheduler = self.accelerator.prepare( | |
| self.model, self.optimizer, self.scheduler) | |
| self.logger.log(f"Begin training from epoch: {self.start_epoch}") | |
| total_time = 0 | |
| total_loss = 0 | |
| overall_loss, overall_iter = 0, 0 | |
| patience = 0 | |
| for epoch_id in range(self.start_epoch, self.max_epochs + 1): | |
| self.current_epoch = epoch_id | |
| if self.progress_epoch and self.progress_epoch == epoch_id: | |
| self.progress_epoch = None | |
| elif self.current_epoch != 1: | |
| self.load_lazy_dataset(epoch_id) | |
| self.logger.log(f"Loaded lazy dataset {epoch_id} / {self.max_epochs}") | |
| else: | |
| pass | |
| self.logger.log(f"START OF EPOCH {epoch_id}") | |
| for step, batch in enumerate(self.train_data): | |
| start = time.time() | |
| self.iter += batch['batch_tgt'].size(0) | |
| self.scratch_iter += batch['batch_tgt'].size(0) | |
| overall_iter += batch['batch_tgt'].size(0) | |
| batch_loss = self.step(batch) | |
| total_time += time.time() - start | |
| total_loss += batch_loss | |
| overall_loss += batch_loss | |
| if step % self.print_every == 0: | |
| info = '{} - epoch: {} - step: {} - iter: {:08d}/{:08d} - train loss: {:.5f} - lr: {:.5e} - {} time: {:.2f}s'.format( | |
| colored(str(dt.now()),"green"), | |
| epoch_id, | |
| step, | |
| self.iter, | |
| self.total_training_steps, | |
| total_loss / self.print_every, | |
| self.optimizer.param_groups[0]['lr'], | |
| self.device, | |
| total_time) | |
| total_loss = 0 | |
| total_time = 0 | |
| self.logger.log(info) | |
| if step % self.checkpoint_cycle == 0: | |
| torch.cuda.empty_cache() | |
| if step == 0: | |
| continue | |
| # <---- validate -----> | |
| val_loss, val_accu, val_mean_time = self.validate() | |
| info = '{} - epoch: {} - valid loss: {:.5f} - valid accuracy: {:.4f}'.format( | |
| colored(str(dt.now()),"green"), epoch_id, val_loss, val_accu) | |
| self.logger.log(info) | |
| if overall_iter != 0 and overall_loss != 0: | |
| self.logger.log(f"Overall trainning loss between two checkpoints: {overall_loss / overall_iter}") | |
| overall_loss, overall_iter = 0, 0 | |
| if val_accu > self.best_F1: | |
| self.best_F1 = val_accu | |
| info = 'Saving weights to disk......' | |
| self.logger.log(info) | |
| self.save_weights(self.checkpoint_dir, epoch_id, self.best_F1) | |
| info = 'Saving checkpoint to disk......' | |
| self.logger.log(info) | |
| self.save_checkpoint( | |
| self.checkpoint_dir, epoch_id, self.best_F1) | |
| patience = 0 | |
| else: | |
| patience += 1 | |
| self.logger.log("Mean forward time: {:.5f}".format( | |
| self.total_fw_time / VALID_BATCH_SIZE)) | |
| self.total_fw_time = 0 | |
| if patience >= PATIENCE: | |
| break | |
| torch.cuda.empty_cache() | |
| ## Validation before next epoch | |
| torch.cuda.empty_cache() | |
| val_loss, val_accu, val_mean_time = self.validate() | |
| info = '{} - epoch: {} - valid loss: {:.5f} - valid accuracy: {:.4f}'.format( | |
| colored(str(dt.now()),"green"), epoch_id, val_loss, val_accu) | |
| self.logger.log(info) | |
| if overall_iter != 0 and overall_loss != 0: | |
| self.logger.log(f"Overall trainning loss between two checkpoints: {overall_loss / overall_iter}") | |
| overall_loss, overall_iter = 0, 0 | |
| if val_accu > self.best_F1: | |
| self.best_F1 = val_accu | |
| info = 'Saving weights to disk......' | |
| self.logger.log(info) | |
| self.save_weights(self.checkpoint_dir, epoch_id, self.best_F1) | |
| info = 'Saving checkpoint to disk......' | |
| self.logger.log(info) | |
| self.save_checkpoint( | |
| self.checkpoint_dir, epoch_id, self.best_F1) | |
| patience = 0 | |
| else: | |
| patience += 1 | |
| self.logger.log("Mean forward time: {:.5f}".format( | |
| self.total_fw_time / VALID_BATCH_SIZE)) | |
| self.total_fw_time = 0 | |
| if patience >= PATIENCE: | |
| break | |
| torch.cuda.empty_cache() | |
| self.scratch_iter = 0 | |
| self.logger.log(f"END OF EPOCH {epoch_id}") | |
| self.logger.log("Train complete!") | |
| def validate(self): | |
| total_loss = 0 | |
| valid_loss = 0 | |
| valid_time = 0 | |
| total_time = 0 | |
| total_examples = 0 | |
| num_correct, num_wrong = 0, 0 | |
| with torch.no_grad(): | |
| for step, batch in enumerate(self.valid_data): | |
| start = time.time() | |
| total_examples += batch['batch_tgt'].size(0) | |
| batch_loss, batch_predictions, \ | |
| batch_label_ids, batch_lengths = self.step( | |
| batch, training=False) | |
| valid_time += time.time() - start | |
| batch_token_lens = batch['lengths'] | |
| batch_label_ids = batch['batch_tgt'].cpu().detach().numpy() | |
| _num_correct, _num_wrong = get_metric_for_tfm(batch_predictions, batch_label_ids, batch_token_lens) | |
| num_correct += _num_correct | |
| num_wrong += _num_wrong | |
| valid_loss += batch_loss | |
| total_loss += batch_loss | |
| if step % self.print_every == 0: | |
| info = '{} Validation - iter: {:08d}/{:08d} - valid loss: {:.5f} - {} time: {:.2f}s'.format( | |
| colored(str(dt.now()),"green"), | |
| step, | |
| len(self.valid_data), | |
| valid_loss / self.print_every, | |
| self.device, | |
| valid_time / self.print_every) | |
| valid_loss = 0 | |
| total_time += valid_time | |
| valid_time = 0 | |
| self.logger.log(info) | |
| del batch_loss | |
| avg_loss = total_loss / len(self.valid_data) | |
| avg_accu = num_correct / (num_correct + num_wrong) | |
| avg_time = total_time / total_examples | |
| return avg_loss, avg_accu, avg_time | |
| def load_checkpoint(self, checkpoint_dir, dataset_name, start_epoch=0): | |
| self.checkpoint_dir = checkpoint_dir | |
| self.dataset_name = dataset_name | |
| checkpoint_path = checkpoint_dir + \ | |
| f'/{dataset_name}.model.epoch_{start_epoch - 1}.pth' | |
| if start_epoch > 0 and os.path.exists(checkpoint_path): | |
| checkpoint = torch.load( | |
| checkpoint_path, map_location=torch.device('cpu')) | |
| assert EPOCHS == checkpoint['num_epochs'] | |
| self.optimizer.load_state_dict(checkpoint['optimizer']) | |
| self.scheduler.load_state_dict(checkpoint['scheduler']) | |
| self.optimizer.base_lrs = [MAX_LR] | |
| self.scheduler.base_lrs = [MAX_LR] | |
| self.model.load_state_dict(checkpoint['state_dict']) | |
| self.iter = checkpoint['iter'] | |
| self.remained_indies = checkpoint['remained_indies'] | |
| self.start_epoch = checkpoint['epoch'] | |
| self.progress_epoch = self.start_epoch | |
| self.scratch_iter = checkpoint['scratch_iter'] | |
| train_dataset = load_epoch_dataset(self.data_path, self.correct_file,\ | |
| self.incorrect_file, self.length_file, self.start_epoch, EPOCHS) | |
| self.train_dataset = SpellCorrectDataset(dataset=train_dataset) | |
| if not BUCKET_SAMPLING: | |
| assert checkpoint['strategy'] == "random_sampling" | |
| self.train_sampler = RandomBatchSampler(self.train_dataset, TRAIN_BATCH_SIZE) | |
| self.train_sampler.load_checkpoints(self.scratch_iter) | |
| else: | |
| assert checkpoint['strategy'] == "bucket_sampling" | |
| self.train_sampler = BucketBatchSampler(self.train_dataset) | |
| self.train_sampler.load_checkpoints(self.remained_indies) | |
| self.train_data = DataLoader(dataset=self.train_dataset, batch_sampler=self.train_sampler, | |
| collate_fn=self.model_wrapper.collator.collate,\ | |
| num_workers=2, pin_memory=True) | |
| self.best_F1 = checkpoint['best_F1'] | |
| def save_checkpoint(self, checkpoint_dir, epoch, best_F1): | |
| checkpoint_path = checkpoint_dir + \ | |
| f'/{self.dataset_name}.model.epoch_{epoch}.pth' | |
| flatten_iterator_indies = list(chain.from_iterable(self.train_sampler.seq)) | |
| remained_indies = flatten_iterator_indies[self.scratch_iter:None] | |
| self.logger.log(f"Traversed iter from beginning: {self.scratch_iter}") | |
| state = { | |
| 'epoch': epoch, | |
| 'iter': self.iter, 'state_dict': self.model.state_dict(), 'scratch_iter': self.scratch_iter, | |
| 'optimizer': self.optimizer.state_dict(), | |
| 'scheduler': self.scheduler.state_dict(), | |
| 'best_F1': best_F1, | |
| 'remained_indies': remained_indies, | |
| 'strategy': 'bucket_sampling' if BUCKET_SAMPLING else 'random_sampling', | |
| 'num_epochs': EPOCHS | |
| } | |
| if not os.path.exists(checkpoint_dir): | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| info = f'Saving model checkpoint to: {checkpoint_path}' | |
| self.logger.log(info) | |
| torch.save(state, checkpoint_path) | |
| def save_weights(self, checkpoint_dir, epoch, best_F1): | |
| weight_path = checkpoint_dir + \ | |
| f'/{self.dataset_name}.weights.pth' | |
| if not os.path.exists(checkpoint_dir): | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| state = { | |
| 'epoch': epoch, | |
| 'state_dict': self.model.state_dict(), | |
| 'best_F1': best_F1 | |
| } | |
| info = f'Saving model weights to: {weight_path}' | |
| self.logger.log(info) | |
| torch.save(state, weight_path) | |