Spaces:
Running
Running
| #-*- encoding:utf-8 -*- | |
| from pytorch_lightning.callbacks import Callback | |
| class DatasetCallback(Callback): | |
| def __init__(self): | |
| self.sampler_pos_start = 0 | |
| self.preload_used_idx_flag = False | |
| def on_train_start(self, trainer, pl_module): | |
| if not self.preload_used_idx_flag: | |
| self.preload_used_idx_flag = True | |
| trainer.train_dataloader.batch_sampler.sampler_pos_reload = self.sampler_pos_start | |
| def on_save_checkpoint(self, trainer, pl_module, checkpoint): | |
| if trainer.train_dataloader is not None: | |
| # Save sampler_pos_start parameters in the checkpoint | |
| checkpoint['sampler_pos_start'] = trainer.train_dataloader.batch_sampler.sampler_pos_start | |
| def on_load_checkpoint(self, trainer, pl_module, checkpoint): | |
| # Restore sampler_pos_start parameters from the checkpoint | |
| if 'sampler_pos_start' in checkpoint: | |
| self.sampler_pos_start = checkpoint.get('sampler_pos_start', 0) | |
| print('Load sampler_pos_start from checkpoint, sampler_pos_start = %d' % self.sampler_pos_start) | |
| else: | |
| print('The sampler_pos_start is not in checkpoint') |