Spaces:
Runtime error
Runtime error
| import os, errno, numpy, torch, csv, re, shutil, os, zipfile | |
| from collections import OrderedDict | |
| from torchvision.datasets.folder import default_loader | |
| from torchvision import transforms | |
| from scipy import ndimage | |
| from urllib.request import urlopen | |
| class BrodenDataset(torch.utils.data.Dataset): | |
| ''' | |
| A multicategory segmentation data set. | |
| Returns three streams: | |
| (1) The image (3, h, w). | |
| (2) The multicategory segmentation (labelcount, h, w). | |
| (3) A bincount of pixels in the segmentation (labelcount). | |
| Net dissect also assumes that the dataset object has three properties | |
| with human-readable labels: | |
| ds.labels = ['red', 'black', 'car', 'tree', 'grid', ...] | |
| ds.categories = ['color', 'part', 'object', 'texture'] | |
| ds.label_category = [0, 0, 2, 2, 3, ...] # The category for each label | |
| ''' | |
| def __init__(self, directory='dataset/broden', resolution=384, | |
| split='train', categories=None, | |
| transform=None, transform_segment=None, | |
| download=False, size=None, include_bincount=True, | |
| broden_version=1, max_segment_depth=6): | |
| assert resolution in [224, 227, 384] | |
| if download: | |
| ensure_broden_downloaded(directory, resolution, broden_version) | |
| self.directory = directory | |
| self.resolution = resolution | |
| self.resdir = os.path.join(directory, 'broden%d_%d' % | |
| (broden_version, resolution)) | |
| self.loader = default_loader | |
| self.transform = transform | |
| self.transform_segment = transform_segment | |
| self.include_bincount = include_bincount | |
| # The maximum number of multilabel layers that coexist at an image. | |
| self.max_segment_depth = max_segment_depth | |
| with open(os.path.join(self.resdir, 'category.csv'), | |
| encoding='utf-8') as f: | |
| self.category_info = OrderedDict() | |
| for row in csv.DictReader(f): | |
| self.category_info[row['name']] = row | |
| if categories is not None: | |
| # Filter out unused categories | |
| categories = set([c for c in categories if c in self.category_info]) | |
| for cat in list(self.category_info.keys()): | |
| if cat not in categories: | |
| del self.category_info[cat] | |
| categories = list(self.category_info.keys()) | |
| self.categories = categories | |
| # Filter out unneeded images. | |
| with open(os.path.join(self.resdir, 'index.csv'), | |
| encoding='utf-8') as f: | |
| all_images = [decode_index_dict(r) for r in csv.DictReader(f)] | |
| self.image = [row for row in all_images | |
| if index_has_any_data(row, categories) and row['split'] == split] | |
| if size is not None: | |
| self.image = self.image[:size] | |
| with open(os.path.join(self.resdir, 'label.csv'), | |
| encoding='utf-8') as f: | |
| self.label_info = build_dense_label_array([ | |
| decode_label_dict(r) for r in csv.DictReader(f)]) | |
| self.labels = [l['name'] for l in self.label_info] | |
| # Build dense remapping arrays for labels, so that you can | |
| # get dense ranges of labels for each category. | |
| self.category_map = {} | |
| self.category_unmap = {} | |
| self.category_label = {} | |
| for cat in self.categories: | |
| with open(os.path.join(self.resdir, 'c_%s.csv' % cat), | |
| encoding='utf-8') as f: | |
| c_data = [decode_label_dict(r) for r in csv.DictReader(f)] | |
| self.category_unmap[cat], self.category_map[cat] = ( | |
| build_numpy_category_map(c_data)) | |
| self.category_label[cat] = build_dense_label_array( | |
| c_data, key='code') | |
| self.num_labels = len(self.labels) | |
| # Primary categories for each label is the category in which it | |
| # appears with the maximum coverage. | |
| self.label_category = numpy.zeros(self.num_labels, dtype=int) | |
| for i in range(self.num_labels): | |
| maxcoverage, self.label_category[i] = max( | |
| (self.category_label[cat][self.category_map[cat][i]]['coverage'] | |
| if i < len(self.category_map[cat]) | |
| and self.category_map[cat][i] else 0, ic) | |
| for ic, cat in enumerate(categories)) | |
| def __len__(self): | |
| return len(self.image) | |
| def __getitem__(self, idx): | |
| record = self.image[idx] | |
| # example record: { | |
| # 'image': 'opensurfaces/25605.jpg', 'split': 'train', | |
| # 'ih': 384, 'iw': 384, 'sh': 192, 'sw': 192, | |
| # 'color': ['opensurfaces/25605_color.png'], | |
| # 'object': [], 'part': [], | |
| # 'material': ['opensurfaces/25605_material.png'], | |
| # 'scene': [], 'texture': []} | |
| image = self.loader(os.path.join(self.resdir, 'images', | |
| record['image'])) | |
| segment = numpy.zeros(shape=(self.max_segment_depth, | |
| record['sh'], record['sw']), dtype=int) | |
| if self.include_bincount: | |
| bincount = numpy.zeros(shape=(self.num_labels,), dtype=int) | |
| depth = 0 | |
| for cat in self.categories: | |
| for layer in record[cat]: | |
| if isinstance(layer, int): | |
| segment[depth,:,:] = layer | |
| if self.include_bincount: | |
| bincount[layer] += segment.shape[1] * segment.shape[2] | |
| else: | |
| png = numpy.asarray(self.loader(os.path.join( | |
| self.resdir, 'images', layer))) | |
| segment[depth,:,:] = png[:,:,0] + png[:,:,1] * 256 | |
| if self.include_bincount: | |
| bincount += numpy.bincount(segment[depth,:,:].flatten(), | |
| minlength=self.num_labels) | |
| depth += 1 | |
| if self.transform: | |
| image = self.transform(image) | |
| if self.transform_segment: | |
| segment = self.transform_segment(segment) | |
| if self.include_bincount: | |
| bincount[0] = 0 | |
| return (image, segment, bincount) | |
| else: | |
| return (image, segment) | |
| def build_dense_label_array(label_data, key='number', allow_none=False): | |
| ''' | |
| Input: set of rows with 'number' fields (or another field name key). | |
| Output: array such that a[number] = the row with the given number. | |
| ''' | |
| result = [None] * (max([d[key] for d in label_data]) + 1) | |
| for d in label_data: | |
| result[d[key]] = d | |
| # Fill in none | |
| if not allow_none: | |
| example = label_data[0] | |
| def make_empty(k): | |
| return dict((c, k if c is key else type(v)()) | |
| for c, v in example.items()) | |
| for i, d in enumerate(result): | |
| if d is None: | |
| result[i] = dict(make_empty(i)) | |
| return result | |
| def build_numpy_category_map(map_data, key1='code', key2='number'): | |
| ''' | |
| Input: set of rows with 'number' fields (or another field name key). | |
| Output: array such that a[number] = the row with the given number. | |
| ''' | |
| results = list(numpy.zeros((max([d[key] for d in map_data]) + 1), | |
| dtype=numpy.int16) for key in (key1, key2)) | |
| for d in map_data: | |
| results[0][d[key1]] = d[key2] | |
| results[1][d[key2]] = d[key1] | |
| return results | |
| def index_has_any_data(row, categories): | |
| for c in categories: | |
| for data in row[c]: | |
| if data: return True | |
| return False | |
| def decode_label_dict(row): | |
| result = {} | |
| for key, val in row.items(): | |
| if key == 'category': | |
| result[key] = dict((c, int(n)) | |
| for c, n in [re.match('^([^(]*)\(([^)]*)\)$', f).groups() | |
| for f in val.split(';')]) | |
| elif key == 'name': | |
| result[key] = val | |
| elif key == 'syns': | |
| result[key] = val.split(';') | |
| elif re.match('^\d+$', val): | |
| result[key] = int(val) | |
| elif re.match('^\d+\.\d*$', val): | |
| result[key] = float(val) | |
| else: | |
| result[key] = val | |
| return result | |
| def decode_index_dict(row): | |
| result = {} | |
| for key, val in row.items(): | |
| if key in ['image', 'split']: | |
| result[key] = val | |
| elif key in ['sw', 'sh', 'iw', 'ih']: | |
| result[key] = int(val) | |
| else: | |
| item = [s for s in val.split(';') if s] | |
| for i, v in enumerate(item): | |
| if re.match('^\d+$', v): | |
| item[i] = int(v) | |
| result[key] = item | |
| return result | |
| class ScaleSegmentation: | |
| ''' | |
| Utility for scaling segmentations, using nearest-neighbor zooming. | |
| ''' | |
| def __init__(self, target_height, target_width): | |
| self.target_height = target_height | |
| self.target_width = target_width | |
| def __call__(self, seg): | |
| ratio = (1, self.target_height / float(seg.shape[1]), | |
| self.target_width / float(seg.shape[2])) | |
| return ndimage.zoom(seg, ratio, order=0) | |
| def scatter_batch(seg, num_labels, omit_zero=True, dtype=torch.uint8): | |
| ''' | |
| Utility for scattering semgentations into a one-hot representation. | |
| ''' | |
| result = torch.zeros(*((seg.shape[0], num_labels,) + seg.shape[2:]), | |
| dtype=dtype, device=seg.device) | |
| result.scatter_(1, seg, 1) | |
| if omit_zero: | |
| result[:,0] = 0 | |
| return result | |
| def ensure_broden_downloaded(directory, resolution, broden_version=1): | |
| assert resolution in [224, 227, 384] | |
| baseurl = 'http://netdissect.csail.mit.edu/data/' | |
| dirname = 'broden%d_%d' % (broden_version, resolution) | |
| if os.path.isfile(os.path.join(directory, dirname, 'index.csv')): | |
| return # Already downloaded | |
| zipfilename = 'broden1_%d.zip' % resolution | |
| download_dir = os.path.join(directory, 'download') | |
| os.makedirs(download_dir, exist_ok=True) | |
| full_zipfilename = os.path.join(download_dir, zipfilename) | |
| if not os.path.exists(full_zipfilename): | |
| url = '%s/%s' % (baseurl, zipfilename) | |
| print('Downloading %s' % url) | |
| data = urlopen(url) | |
| with open(full_zipfilename, 'wb') as f: | |
| f.write(data.read()) | |
| print('Unzipping %s' % zipfilename) | |
| with zipfile.ZipFile(full_zipfilename, 'r') as zip_ref: | |
| zip_ref.extractall(directory) | |
| assert os.path.isfile(os.path.join(directory, dirname, 'index.csv')) | |
| def test_broden_dataset(): | |
| ''' | |
| Testing code. | |
| ''' | |
| bds = BrodenDataset('dataset/broden', resolution=384, | |
| transform=transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.ToTensor()]), | |
| transform_segment=transforms.Compose([ | |
| ScaleSegmentation(224, 224) | |
| ]), | |
| include_bincount=True) | |
| loader = torch.utils.data.DataLoader(bds, batch_size=100, num_workers=24) | |
| for i in range(1,20): | |
| print(bds.label[i]['name'], | |
| list(bds.category.keys())[bds.primary_category[i]]) | |
| for i, (im, seg, bc) in enumerate(loader): | |
| print(i, im.shape, seg.shape, seg.max(), bc.shape) | |
| if __name__ == '__main__': | |
| test_broden_dataset() | |