Spaces:
Runtime error
Runtime error
| from lightning import LightningDataModule | |
| from torch.utils.data import Dataset, DataLoader | |
| class Datamodule(LightningDataModule): | |
| def __init__( | |
| self, | |
| train_dataset: Dataset, | |
| eval_dataset: Dataset, | |
| batch_train_size: int, | |
| num_workers: int, | |
| eval_batch_size: int = None, | |
| ): | |
| super().__init__() | |
| self.train_dataset = train_dataset | |
| self.eval_dataset = eval_dataset | |
| self.batch_train_size = batch_train_size | |
| self.eval_batch_size = ( | |
| eval_batch_size if eval_batch_size is not None else batch_train_size | |
| ) | |
| self.num_workers = num_workers | |
| def train_dataloader(self) -> DataLoader: | |
| """Load train set loader.""" | |
| persistent_workers = True if self.num_workers > 0 else False | |
| dataloader = DataLoader( | |
| self.train_dataset, | |
| batch_size=self.batch_train_size, | |
| num_workers=self.num_workers, | |
| pin_memory=True, | |
| persistent_workers=persistent_workers, | |
| ) | |
| return dataloader | |
| def val_dataloader(self) -> DataLoader: | |
| """Load val set loader.""" | |
| persistent_workers = True if self.num_workers > 0 else False | |
| dataloader = DataLoader( | |
| self.eval_dataset, | |
| batch_size=self.eval_batch_size, | |
| num_workers=self.num_workers, | |
| pin_memory=True, | |
| persistent_workers=persistent_workers, | |
| ) | |
| return dataloader | |
| def predict_dataloader(self) -> DataLoader: | |
| """Load predict set loader.""" | |
| dataloader = DataLoader( | |
| self.eval_dataset, | |
| batch_size=self.eval_batch_size, | |
| num_workers=self.num_workers, | |
| ) | |
| return dataloader | |
| def test_dataloader(self) -> DataLoader: | |
| """Load test set loader.""" | |
| dataloader = DataLoader( | |
| self.eval_dataset, | |
| batch_size=self.eval_batch_size, | |
| num_workers=self.num_workers, | |
| ) | |
| return dataloader | |