| import torch | |
| import torchvision.transforms as transforms | |
| import torchvision.datasets as dset | |
| class Invert: | |
| def __call__(self, x): | |
| return 1 - x | |
| class Gray: | |
| def __call__(self, x): | |
| return x[0:1] | |
| def load_dataset(dataset_name, split='full'): | |
| if dataset_name == 'mnist': | |
| dataset = dset.MNIST( | |
| root='data/mnist', | |
| download=True, | |
| transform=transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| ) | |
| return dataset | |
| elif dataset_name == 'coco': | |
| dataset = dset.ImageFolder(root='data/coco', | |
| transform=transforms.Compose([ | |
| transforms.Scale(64), | |
| transforms.CenterCrop(64), | |
| transforms.ToTensor(), | |
| ])) | |
| return dataset | |
| elif dataset_name == 'quickdraw': | |
| X = (np.load('data/quickdraw/teapot.npy')) | |
| X = X.reshape((X.shape[0], 28, 28)) | |
| X = X / 255. | |
| X = X.astype(np.float32) | |
| X = torch.from_numpy(X) | |
| dataset = TensorDataset(X, X) | |
| return dataset | |
| elif dataset_name == 'shoes': | |
| dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images/Shoes', | |
| transform=transforms.Compose([ | |
| transforms.Scale(64), | |
| transforms.CenterCrop(64), | |
| transforms.ToTensor(), | |
| ])) | |
| return dataset | |
| elif dataset_name == 'footwear': | |
| dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images', | |
| transform=transforms.Compose([ | |
| transforms.Scale(64), | |
| transforms.CenterCrop(64), | |
| transforms.ToTensor(), | |
| ])) | |
| return dataset | |
| elif dataset_name == 'celeba': | |
| dataset = dset.ImageFolder(root='data/celeba', | |
| transform=transforms.Compose([ | |
| transforms.Scale(32), | |
| transforms.CenterCrop(32), | |
| transforms.ToTensor(), | |
| ])) | |
| return dataset | |
| elif dataset_name == 'birds': | |
| dataset = dset.ImageFolder(root='data/birds/'+split, | |
| transform=transforms.Compose([ | |
| transforms.Scale(32), | |
| transforms.CenterCrop(32), | |
| transforms.ToTensor(), | |
| ])) | |
| return dataset | |
| elif dataset_name == 'sketchy': | |
| dataset = dset.ImageFolder(root='data/sketchy/'+split, | |
| transform=transforms.Compose([ | |
| transforms.Scale(64), | |
| transforms.CenterCrop(64), | |
| transforms.ToTensor(), | |
| Gray() | |
| ])) | |
| return dataset | |
| elif dataset_name == 'fonts': | |
| dataset = dset.ImageFolder(root='data/fonts/'+split, | |
| transform=transforms.Compose([ | |
| transforms.ToTensor(), | |
| Invert(), | |
| Gray(), | |
| ])) | |
| return dataset | |
| else: | |
| raise ValueError('Error : unknown dataset') | |