Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,816 Bytes
9b33fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""Class-balanced Grouping and Sampling for 3D Object Detection.
Implementation of `Class-balanced Grouping and Sampling for Point Cloud 3D
Object Detection <https://arxiv.org/abs/1908.09492>`_.
"""
from __future__ import annotations
import numpy as np
from torch.utils.data import Dataset
from vis4d.common.distributed import broadcast, rank_zero_only
from vis4d.common.logging import rank_zero_info
from vis4d.common.time import Timer
from .datasets.util import print_class_histogram
from .reference import MultiViewDataset
from .typing import DictDataOrList
# TODO: Support sensor selection.
class CBGSDataset(Dataset[DictDataOrList]):
"""Balance the number of scenes under different classes."""
def __init__(
self,
dataset: Dataset[DictDataOrList],
class_map: dict[str, int],
ignore: int = -1,
) -> None:
"""Creates an instance of the class."""
super().__init__()
self.dataset = dataset
self.has_reference = isinstance(dataset, MultiViewDataset)
self.cat2id = dict(sorted(class_map.items(), key=lambda x: x[1]))
self.ignore = ignore
rank_zero_info("Wrapping dataset with CBGS...")
sample_indices = self._get_sample_indices()
self.sample_indices = broadcast(sample_indices)
def _show_histogram(
self,
sample_indices: list[int],
sample_frequencies: list[dict[str, int]],
) -> None:
"""Show class histogram."""
frequencies = {cat: 0 for cat in self.cat2id.keys()}
for idx in sample_indices:
freq = sample_frequencies[idx]
for box3d_class in freq:
frequencies[box3d_class] += freq[box3d_class]
print_class_histogram(frequencies)
def _get_class_sample_indices(
self,
) -> tuple[dict[int, list[int]], list[dict[str, int]]]:
"""Get sample indices."""
class_sample_indices: dict[int, list[int]] = {
cat_id: [] for cat_id in self.cat2id.values()
}
sample_frequencies = []
inv_class_map = {v: k for k, v in self.cat2id.items()}
# Handle the case that dataset is already wrapped.
if hasattr(self.dataset, "dataset"):
dataset = self.dataset.dataset
else:
dataset = self.dataset
for idx in range(len(dataset)):
assert hasattr(
dataset, "get_cat_ids"
), "The dataset must have a method `get_cat_ids` to get cat ids."
cat_ids = dataset.get_cat_ids(idx)
cur_cats = {}
frequencies = {cat: 0 for cat in self.cat2id.keys()}
for cat_id in cat_ids:
if cat_id != self.ignore:
cur_cats[cat_id] = [idx]
frequencies[inv_class_map[cat_id]] += 1
sample_frequencies.append(frequencies)
for cat_id, index in cur_cats.items():
class_sample_indices[cat_id] += index
return class_sample_indices, sample_frequencies
@rank_zero_only
def _get_sample_indices(self) -> list[int]:
"""Load sample indices.
Returns:
list[int]: List of indices after class sampling.
"""
t = Timer()
(
class_sample_indices,
sample_frequencies,
) = self._get_class_sample_indices()
duplicated_samples = sum(
len(v) for _, v in class_sample_indices.items()
)
class_distribution = {
k: len(v) / duplicated_samples
for k, v in class_sample_indices.items()
}
sample_indices = []
frac = 1.0 / len(self.cat2id)
ratios = [
frac / v if v > 0 else 1 for v in class_distribution.values()
]
for cls_inds, ratio in zip(
list(class_sample_indices.values()), ratios
):
sample_indices += np.random.choice(
cls_inds, int(len(cls_inds) * ratio)
).tolist()
self._show_histogram(sample_indices, sample_frequencies)
rank_zero_info(
f"Generating {len(sample_indices)} CBGS samples takes "
+ f"{t.time():.2f} seconds."
)
return sample_indices
def __len__(self) -> int:
"""Return the length of sample indices.
Returns:
int: Length of sample indices.
"""
return len(self.sample_indices)
def __getitem__(self, idx: int) -> DictDataOrList:
"""Get original dataset idx according to the given index.
Args:
idx (int): The index of self.sample_indices.
Returns:
DictDataOrList: Data of the corresponding index.
"""
ori_index = self.sample_indices[idx]
return self.dataset[ori_index]
|