| from .dataset import Dataset, ValDataset, TestDataset | |
| from torch.utils.data import DataLoader | |
| def find_dataset_using_name(name): | |
| mapping = { | |
| "Video": Dataset, | |
| "VideoVal": ValDataset, | |
| "VideoTest": TestDataset, | |
| } | |
| cls = mapping.get(name, None) | |
| if cls is None: | |
| raise ValueError(f"Fail to find dataset {name}") | |
| return cls | |
| def create_dataset(metainfo, split): | |
| dataset_cls = find_dataset_using_name(split.type) | |
| dataset = dataset_cls(metainfo, split) | |
| return DataLoader( | |
| dataset, | |
| batch_size=split.batch_size, | |
| drop_last=split.drop_last, | |
| shuffle=split.shuffle, | |
| num_workers=split.worker, | |
| pin_memory=True | |
| ) |