Spaces:
Runtime error
Runtime error
| import os | |
| from torchvision.datasets import VisionDataset | |
| from PIL import Image | |
| from sklearn.model_selection import train_test_split | |
| class CustomDataset(VisionDataset): | |
| def __init__(self, root_path, subset="train", transform=None, target_transform=None, split_ratios=(0.7, 0.15, 0.15), seed=42): | |
| super(CustomDataset, self).__init__(root_path, transform=transform, target_transform=target_transform) | |
| self.root = root_path | |
| self.subset = subset # Can be "train", "val", or "test" | |
| self.split_ratios = split_ratios | |
| self.seed = seed | |
| self.classes, self.class_idx = self._find_classes() | |
| self.samples = self._make_dataset() | |
| def _find_classes(self): | |
| classes = [d.name for d in os.scandir(self.root) if d.is_dir()] | |
| classes.sort() | |
| class_idx = {cls_name: i for i, cls_name in enumerate(classes)} | |
| return classes, class_idx | |
| def _make_dataset(self): | |
| samples = [] | |
| for target_class in sorted(self.class_idx.keys()): | |
| class_index = self.class_idx[target_class] | |
| target_dir = os.path.join(self.root, target_class) | |
| for root, _, fnames in sorted(os.walk(target_dir)): | |
| for fname in sorted(fnames): | |
| path = os.path.join(root, fname) | |
| samples.append((path, class_index)) | |
| # Split into train, val, and test sets | |
| train_samples, test_samples = train_test_split( | |
| samples, test_size=1 - self.split_ratios[0], random_state=self.seed, stratify=[s[1] for s in samples] | |
| ) | |
| val_samples, test_samples = train_test_split( | |
| test_samples, test_size=self.split_ratios[2] / (self.split_ratios[1] + self.split_ratios[2]), | |
| random_state=self.seed, stratify=[s[1] for s in test_samples] | |
| ) | |
| if self.subset == "train": | |
| return train_samples | |
| elif self.subset == "val": | |
| return val_samples | |
| elif self.subset == "test": | |
| return test_samples | |
| else: | |
| raise ValueError(f"Unknown subset: {self.subset}") | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, index): | |
| path, target = self.samples[index] | |
| img = Image.open(path).convert("RGB") | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| return img, target | |