File size: 4,816 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Class-balanced Grouping and Sampling for 3D Object Detection.

Implementation of `Class-balanced Grouping and Sampling for Point Cloud 3D
Object Detection <https://arxiv.org/abs/1908.09492>`_.
"""

from __future__ import annotations

import numpy as np
from torch.utils.data import Dataset

from vis4d.common.distributed import broadcast, rank_zero_only
from vis4d.common.logging import rank_zero_info
from vis4d.common.time import Timer

from .datasets.util import print_class_histogram
from .reference import MultiViewDataset
from .typing import DictDataOrList


# TODO: Support sensor selection.
class CBGSDataset(Dataset[DictDataOrList]):
    """Balance the number of scenes under different classes."""

    def __init__(
        self,
        dataset: Dataset[DictDataOrList],
        class_map: dict[str, int],
        ignore: int = -1,
    ) -> None:
        """Creates an instance of the class."""
        super().__init__()
        self.dataset = dataset
        self.has_reference = isinstance(dataset, MultiViewDataset)
        self.cat2id = dict(sorted(class_map.items(), key=lambda x: x[1]))
        self.ignore = ignore

        rank_zero_info("Wrapping dataset with CBGS...")
        sample_indices = self._get_sample_indices()
        self.sample_indices = broadcast(sample_indices)

    def _show_histogram(
        self,
        sample_indices: list[int],
        sample_frequencies: list[dict[str, int]],
    ) -> None:
        """Show class histogram."""
        frequencies = {cat: 0 for cat in self.cat2id.keys()}

        for idx in sample_indices:
            freq = sample_frequencies[idx]
            for box3d_class in freq:
                frequencies[box3d_class] += freq[box3d_class]

        print_class_histogram(frequencies)

    def _get_class_sample_indices(
        self,
    ) -> tuple[dict[int, list[int]], list[dict[str, int]]]:
        """Get sample indices."""
        class_sample_indices: dict[int, list[int]] = {
            cat_id: [] for cat_id in self.cat2id.values()
        }
        sample_frequencies = []
        inv_class_map = {v: k for k, v in self.cat2id.items()}

        # Handle the case that dataset is already wrapped.
        if hasattr(self.dataset, "dataset"):
            dataset = self.dataset.dataset
        else:
            dataset = self.dataset

        for idx in range(len(dataset)):
            assert hasattr(
                dataset, "get_cat_ids"
            ), "The dataset must have a method `get_cat_ids` to get cat ids."
            cat_ids = dataset.get_cat_ids(idx)
            cur_cats = {}
            frequencies = {cat: 0 for cat in self.cat2id.keys()}

            for cat_id in cat_ids:
                if cat_id != self.ignore:
                    cur_cats[cat_id] = [idx]
                    frequencies[inv_class_map[cat_id]] += 1

            sample_frequencies.append(frequencies)
            for cat_id, index in cur_cats.items():
                class_sample_indices[cat_id] += index

        return class_sample_indices, sample_frequencies

    @rank_zero_only
    def _get_sample_indices(self) -> list[int]:
        """Load sample indices.

        Returns:
            list[int]: List of indices after class sampling.
        """
        t = Timer()
        (
            class_sample_indices,
            sample_frequencies,
        ) = self._get_class_sample_indices()

        duplicated_samples = sum(
            len(v) for _, v in class_sample_indices.items()
        )
        class_distribution = {
            k: len(v) / duplicated_samples
            for k, v in class_sample_indices.items()
        }

        sample_indices = []

        frac = 1.0 / len(self.cat2id)
        ratios = [
            frac / v if v > 0 else 1 for v in class_distribution.values()
        ]
        for cls_inds, ratio in zip(
            list(class_sample_indices.values()), ratios
        ):
            sample_indices += np.random.choice(
                cls_inds, int(len(cls_inds) * ratio)
            ).tolist()

        self._show_histogram(sample_indices, sample_frequencies)

        rank_zero_info(
            f"Generating {len(sample_indices)} CBGS samples takes "
            + f"{t.time():.2f} seconds."
        )

        return sample_indices

    def __len__(self) -> int:
        """Return the length of sample indices.

        Returns:
            int: Length of sample indices.
        """
        return len(self.sample_indices)

    def __getitem__(self, idx: int) -> DictDataOrList:
        """Get original dataset idx according to the given index.

        Args:
            idx (int): The index of self.sample_indices.

        Returns:
            DictDataOrList: Data of the corresponding index.
        """
        ori_index = self.sample_indices[idx]
        return self.dataset[ori_index]