Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import os | |
| import pickle | |
| import sys | |
| import unittest | |
| from functools import partial | |
| import torch | |
| from iopath.common.file_io import LazyPath | |
| from detectron2 import model_zoo | |
| from detectron2.config import get_cfg, instantiate | |
| from detectron2.data import ( | |
| DatasetCatalog, | |
| DatasetFromList, | |
| MapDataset, | |
| ToIterableDataset, | |
| build_batch_data_loader, | |
| build_detection_test_loader, | |
| build_detection_train_loader, | |
| ) | |
| from detectron2.data.common import ( | |
| AspectRatioGroupedDataset, | |
| set_default_dataset_from_list_serialize_method, | |
| ) | |
| from detectron2.data.samplers import InferenceSampler, TrainingSampler | |
| def _a_slow_func(x): | |
| return "path/{}".format(x) | |
| class TestDatasetFromList(unittest.TestCase): | |
| # Failing for py3.6, likely due to pickle | |
| def test_using_lazy_path(self): | |
| dataset = [] | |
| for i in range(10): | |
| dataset.append({"file_name": LazyPath(partial(_a_slow_func, i))}) | |
| dataset = DatasetFromList(dataset) | |
| for i in range(10): | |
| path = dataset[i]["file_name"] | |
| self.assertTrue(isinstance(path, LazyPath)) | |
| self.assertEqual(os.fspath(path), _a_slow_func(i)) | |
| def test_alternative_serialize_method(self): | |
| dataset = [1, 2, 3] | |
| dataset = DatasetFromList(dataset, serialize=torch.tensor) | |
| self.assertEqual(dataset[2], torch.tensor(3)) | |
| def test_change_default_serialize_method(self): | |
| dataset = [1, 2, 3] | |
| with set_default_dataset_from_list_serialize_method(torch.tensor): | |
| dataset_1 = DatasetFromList(dataset, serialize=True) | |
| self.assertEqual(dataset_1[2], torch.tensor(3)) | |
| dataset_2 = DatasetFromList(dataset, serialize=True) | |
| self.assertEqual(dataset_2[2], 3) | |
| class TestMapDataset(unittest.TestCase): | |
| def map_func(x): | |
| if x == 2: | |
| return None | |
| return x * 2 | |
| def test_map_style(self): | |
| ds = DatasetFromList([1, 2, 3]) | |
| ds = MapDataset(ds, TestMapDataset.map_func) | |
| self.assertEqual(ds[0], 2) | |
| self.assertEqual(ds[2], 6) | |
| self.assertIn(ds[1], [2, 6]) | |
| def test_iter_style(self): | |
| class DS(torch.utils.data.IterableDataset): | |
| def __iter__(self): | |
| yield from [1, 2, 3] | |
| ds = DS() | |
| ds = MapDataset(ds, TestMapDataset.map_func) | |
| self.assertIsInstance(ds, torch.utils.data.IterableDataset) | |
| data = list(iter(ds)) | |
| self.assertEqual(data, [2, 6]) | |
| def test_pickleability(self): | |
| ds = DatasetFromList([1, 2, 3]) | |
| ds = MapDataset(ds, lambda x: x * 2) | |
| ds = pickle.loads(pickle.dumps(ds)) | |
| self.assertEqual(ds[0], 2) | |
| class TestAspectRatioGrouping(unittest.TestCase): | |
| def test_reiter_leak(self): | |
| data = [(1, 0), (0, 1), (1, 0), (0, 1)] | |
| data = [{"width": a, "height": b} for (a, b) in data] | |
| batchsize = 2 | |
| dataset = AspectRatioGroupedDataset(data, batchsize) | |
| for _ in range(5): | |
| for idx, __ in enumerate(dataset): | |
| if idx == 1: | |
| # manually break, so the iterator does not stop by itself | |
| break | |
| # check that bucket sizes are valid | |
| for bucket in dataset._buckets: | |
| self.assertLess(len(bucket), batchsize) | |
| class _MyData(torch.utils.data.IterableDataset): | |
| def __iter__(self): | |
| while True: | |
| yield 1 | |
| class TestDataLoader(unittest.TestCase): | |
| def _get_kwargs(self): | |
| # get kwargs of build_detection_train_loader | |
| cfg = model_zoo.get_config("common/data/coco.py").dataloader.train | |
| cfg.dataset.names = "coco_2017_val_100" | |
| cfg.pop("_target_") | |
| kwargs = {k: instantiate(v) for k, v in cfg.items()} | |
| return kwargs | |
| def test_build_dataloader_train(self): | |
| kwargs = self._get_kwargs() | |
| dl = build_detection_train_loader(**kwargs) | |
| next(iter(dl)) | |
| def test_build_iterable_dataloader_train(self): | |
| kwargs = self._get_kwargs() | |
| ds = DatasetFromList(kwargs.pop("dataset")) | |
| ds = ToIterableDataset(ds, TrainingSampler(len(ds))) | |
| dl = build_detection_train_loader(dataset=ds, **kwargs) | |
| next(iter(dl)) | |
| def test_build_iterable_dataloader_from_cfg(self): | |
| cfg = get_cfg() | |
| cfg.DATASETS.TRAIN = ["iter_data"] | |
| DatasetCatalog.register("iter_data", lambda: _MyData()) | |
| dl = build_detection_train_loader(cfg, mapper=lambda x: x, aspect_ratio_grouping=False) | |
| next(iter(dl)) | |
| dl = build_detection_test_loader(cfg, "iter_data", mapper=lambda x: x) | |
| next(iter(dl)) | |
| def _check_is_range(self, data_loader, N): | |
| # check that data_loader produces range(N) | |
| data = list(iter(data_loader)) | |
| data = [x for batch in data for x in batch] # flatten the batches | |
| self.assertEqual(len(data), N) | |
| self.assertEqual(set(data), set(range(N))) | |
| def test_build_batch_dataloader_inference(self): | |
| # Test that build_batch_data_loader can be used for inference | |
| N = 96 | |
| ds = DatasetFromList(list(range(N))) | |
| sampler = InferenceSampler(len(ds)) | |
| dl = build_batch_data_loader(ds, sampler, 8, num_workers=3) | |
| self._check_is_range(dl, N) | |
| def test_build_batch_dataloader_inference_incomplete_batch(self): | |
| # Test that build_batch_data_loader works when dataset size is not multiple of | |
| # batch size or num_workers | |
| def _test(N, batch_size, num_workers): | |
| ds = DatasetFromList(list(range(N))) | |
| sampler = InferenceSampler(len(ds)) | |
| dl = build_batch_data_loader(ds, sampler, batch_size, num_workers=num_workers) | |
| data = list(iter(dl)) | |
| self.assertEqual(len(data), len(dl)) # floor(N / batch_size) | |
| self._check_is_range(dl, N // batch_size * batch_size) | |
| dl = build_batch_data_loader( | |
| ds, sampler, batch_size, num_workers=num_workers, drop_last=False | |
| ) | |
| data = list(iter(dl)) | |
| self.assertEqual(len(data), len(dl)) # ceil(N / batch_size) | |
| self._check_is_range(dl, N) | |
| _test(48, batch_size=8, num_workers=3) | |
| _test(47, batch_size=8, num_workers=3) | |
| _test(46, batch_size=8, num_workers=3) | |
| _test(40, batch_size=8, num_workers=3) | |
| _test(39, batch_size=8, num_workers=3) | |
| def test_build_dataloader_inference(self): | |
| N = 50 | |
| ds = DatasetFromList(list(range(N))) | |
| sampler = InferenceSampler(len(ds)) | |
| # test that parallel loader works correctly | |
| dl = build_detection_test_loader( | |
| dataset=ds, sampler=sampler, mapper=lambda x: x, num_workers=3 | |
| ) | |
| self._check_is_range(dl, N) | |
| # test that batch_size works correctly | |
| dl = build_detection_test_loader( | |
| dataset=ds, sampler=sampler, mapper=lambda x: x, batch_size=4, num_workers=0 | |
| ) | |
| self._check_is_range(dl, N) | |
| def test_build_iterable_dataloader_inference(self): | |
| # Test that build_detection_test_loader supports iterable dataset | |
| N = 50 | |
| ds = DatasetFromList(list(range(N))) | |
| ds = ToIterableDataset(ds, InferenceSampler(len(ds))) | |
| dl = build_detection_test_loader(dataset=ds, mapper=lambda x: x, num_workers=3) | |
| self._check_is_range(dl, N) | |