File size: 5,210 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""DataPipe wraps datasets to share the prepossessing pipeline."""

from __future__ import annotations

import random
from collections.abc import Callable, Iterable

from torch.utils.data import ConcatDataset, Dataset

from .reference import MultiViewDataset
from .transforms.base import TFunctor
from .typing import DictData, DictDataOrList


class DataPipe(ConcatDataset[DictDataOrList]):
    """DataPipe class.

    This class wraps one or multiple instances of a PyTorch Dataset so that the
    preprocessing steps can be shared across those datasets. Composes dataset
    and the preprocessing pipeline.
    """

    def __init__(
        self,
        datasets: Dataset[DictDataOrList] | Iterable[Dataset[DictDataOrList]],
        preprocess_fn: Callable[
            [list[DictData]], list[DictData]
        ] = lambda x: x,
    ):
        """Creates an instance of the class.

        Args:
            datasets (Dataset | Iterable[Dataset]): Dataset(s) to be wrapped by
                this data pipeline.
            preprocess_fn (Callable[[list[DictData]], list[DictData]]):
                Preprocessing function of a single sample. It takes a list of
                samples and returns a list of samples. Defaults to identity
                function.
        """
        if isinstance(datasets, Dataset):
            datasets = [datasets]
        super().__init__(datasets)
        self.preprocess_fn = preprocess_fn

        self.has_reference = any(
            _check_reference(dataset) for dataset in datasets
        )

        if self.has_reference and not all(
            _check_reference(dataset) for dataset in datasets
        ):
            raise ValueError(
                "All datasets must be MultiViewDataset / has reference if "
                + "one of them is."
            )

    def __getitem__(self, idx: int) -> DictDataOrList:
        """Wrap getitem to apply augmentations."""
        samples = super().__getitem__(idx)
        if isinstance(samples, list):
            return self.preprocess_fn(samples)

        return self.preprocess_fn([samples])[0]


class MultiSampleDataPipe(DataPipe):
    """MultiSampleDataPipe class.

    This class wraps DataPipe to support augmentations that require multiple
    images (e.g., Mosaic and Mixup) by sampling additional indices for each
    image. NUM_SAMPLES needs to be defined as a class attribute for transforms
    that require multi-sample augmentation.
    """

    def __init__(
        self,
        datasets: Dataset[DictDataOrList] | Iterable[Dataset[DictDataOrList]],
        preprocess_fn: list[list[TFunctor]],
    ):
        """Creates an instance of the class.

        Args:
            datasets (Dataset | Iterable[Dataset]): Dataset(s) to be wrapped by
                this data pipeline.
            preprocess_fn (list[list[TFunctor]]): Preprocessing functions of a
                single sample. Different than DataPipe, this is a list of lists
                of transformation functions. The inner list is for transforms
                that needs to share the same sampled indices (e.g.,
                GenMosaicParameters and MosaicImages), and the outer list is
                for different transforms.
        """
        super().__init__(datasets)
        self.preprocess_fns = preprocess_fn

    def _sample_indices(self, idx: int, num_samples: int) -> list[int]:
        """Sample additional indices for multi-sample augmentation."""
        indices = [idx]
        for _ in range(1, num_samples):
            indices.append(random.randint(0, len(self) - 1))
        return indices

    def __getitem__(self, idx: int) -> DictDataOrList:
        """Wrap getitem to apply augmentations."""
        samples = super(DataPipe, self).__getitem__(idx)
        if not isinstance(samples, list):
            samples = [samples]
            single_view = True
        else:
            single_view = False

        for preprocess_fn in self.preprocess_fns:
            if hasattr(preprocess_fn[0], "NUM_SAMPLES"):
                num_samples = preprocess_fn[0].NUM_SAMPLES
                aug_inds = self._sample_indices(idx, num_samples)
                add_samples = [
                    super(DataPipe, self).__getitem__(ind)
                    for ind in aug_inds[1:]
                ]
                prep_samples = []
                for i, samp in enumerate(samples):
                    prep_samples.append(samp)
                    prep_samples += [
                        s[i] if isinstance(s, list) else s for s in add_samples
                    ]
            else:
                num_samples = 1
                prep_samples = samples
            for prep_fn in preprocess_fn:
                prep_samples = prep_fn.apply_to_data(prep_samples)  # type: ignore # pylint: disable=line-too-long
            samples = prep_samples[::num_samples]
        return samples[0] if single_view else samples


def _check_reference(dataset: Dataset[DictDataOrList]) -> bool:
    """Check if the datasets have reference."""
    has_reference = (
        dataset.has_reference if hasattr(dataset, "has_reference") else False
    )
    return has_reference or isinstance(dataset, MultiViewDataset)