Spaces:
Runtime error
Runtime error
| from copy import deepcopy as dp | |
| from pathlib import Path | |
| from torch.utils.data import Dataset | |
| class MultimodalDataset(Dataset): | |
| def __init__( | |
| self, | |
| name, | |
| dataset_name, | |
| dataset_dir, | |
| trajectory, | |
| feature_type, | |
| num_rawfeats, | |
| num_feats, | |
| num_cams, | |
| num_cond_feats, | |
| standardization, | |
| augmentation=None, | |
| **modalities, | |
| ): | |
| self.dataset_dir = Path(dataset_dir) | |
| self.name = name | |
| self.dataset_name = dataset_name | |
| self.feature_type = feature_type | |
| self.num_rawfeats = num_rawfeats | |
| self.num_feats = num_feats | |
| self.num_cams = num_cams | |
| self.trajectory_dataset = trajectory | |
| self.standardization = standardization | |
| self.modality_datasets = modalities | |
| if augmentation is not None: | |
| self.augmentation = True | |
| self.augmentation_rate = augmentation.rate | |
| self.trajectory_dataset.set_augmentation(augmentation.trajectory) | |
| if hasattr(augmentation, "modalities"): | |
| for modality, augments in augmentation.modalities: | |
| self.modality_datasets[modality].set_augmentation(augments) | |
| else: | |
| self.augmentation = False | |
| # --------------------------------------------------------------------------------- # | |
| def set_split(self, split: str, train_rate: float = 1.0): | |
| self.split = split | |
| # Get trajectory split | |
| self.trajectory_dataset = dp(self.trajectory_dataset).set_split( | |
| split, train_rate | |
| ) | |
| self.root_filenames = self.trajectory_dataset.filenames | |
| # Get modality split | |
| for modality_name in self.modality_datasets.keys(): | |
| self.modality_datasets[modality_name].filenames = self.root_filenames | |
| self.get_feature = self.trajectory_dataset.get_feature | |
| self.get_matrix = self.trajectory_dataset.get_matrix | |
| return self | |
| # --------------------------------------------------------------------------------- # | |
| def __getitem__(self, index): | |
| traj_out = self.trajectory_dataset[index] | |
| traj_filename, traj_feature, padding_mask, intrinsics = traj_out | |
| out = { | |
| "traj_filename": traj_filename, | |
| "traj_feat": traj_feature, | |
| "padding_mask": padding_mask, | |
| "intrinsics": intrinsics, | |
| } | |
| for modality_name, modality_dataset in self.modality_datasets.items(): | |
| modality_filename, modality_feature, modality_raw = modality_dataset[index] | |
| assert traj_filename.split(".")[0] == modality_filename.split(".")[0] | |
| out[f"{modality_name}_filename"] = modality_filename | |
| out[f"{modality_name}_feat"] = modality_feature | |
| out[f"{modality_name}_raw"] = modality_raw | |
| out[f"{modality_name}_padding_mask"] = padding_mask | |
| return out | |
| def __len__(self): | |
| return len(self.trajectory_dataset) | |