3D-MOOD / vis4d /data /cbgs.py
RoyYang0714's picture
feat: Try to build everything locally.
9b33fca
"""Class-balanced Grouping and Sampling for 3D Object Detection.
Implementation of `Class-balanced Grouping and Sampling for Point Cloud 3D
Object Detection <https://arxiv.org/abs/1908.09492>`_.
"""
from __future__ import annotations
import numpy as np
from torch.utils.data import Dataset
from vis4d.common.distributed import broadcast, rank_zero_only
from vis4d.common.logging import rank_zero_info
from vis4d.common.time import Timer
from .datasets.util import print_class_histogram
from .reference import MultiViewDataset
from .typing import DictDataOrList
# TODO: Support sensor selection.
class CBGSDataset(Dataset[DictDataOrList]):
"""Balance the number of scenes under different classes."""
def __init__(
self,
dataset: Dataset[DictDataOrList],
class_map: dict[str, int],
ignore: int = -1,
) -> None:
"""Creates an instance of the class."""
super().__init__()
self.dataset = dataset
self.has_reference = isinstance(dataset, MultiViewDataset)
self.cat2id = dict(sorted(class_map.items(), key=lambda x: x[1]))
self.ignore = ignore
rank_zero_info("Wrapping dataset with CBGS...")
sample_indices = self._get_sample_indices()
self.sample_indices = broadcast(sample_indices)
def _show_histogram(
self,
sample_indices: list[int],
sample_frequencies: list[dict[str, int]],
) -> None:
"""Show class histogram."""
frequencies = {cat: 0 for cat in self.cat2id.keys()}
for idx in sample_indices:
freq = sample_frequencies[idx]
for box3d_class in freq:
frequencies[box3d_class] += freq[box3d_class]
print_class_histogram(frequencies)
def _get_class_sample_indices(
self,
) -> tuple[dict[int, list[int]], list[dict[str, int]]]:
"""Get sample indices."""
class_sample_indices: dict[int, list[int]] = {
cat_id: [] for cat_id in self.cat2id.values()
}
sample_frequencies = []
inv_class_map = {v: k for k, v in self.cat2id.items()}
# Handle the case that dataset is already wrapped.
if hasattr(self.dataset, "dataset"):
dataset = self.dataset.dataset
else:
dataset = self.dataset
for idx in range(len(dataset)):
assert hasattr(
dataset, "get_cat_ids"
), "The dataset must have a method `get_cat_ids` to get cat ids."
cat_ids = dataset.get_cat_ids(idx)
cur_cats = {}
frequencies = {cat: 0 for cat in self.cat2id.keys()}
for cat_id in cat_ids:
if cat_id != self.ignore:
cur_cats[cat_id] = [idx]
frequencies[inv_class_map[cat_id]] += 1
sample_frequencies.append(frequencies)
for cat_id, index in cur_cats.items():
class_sample_indices[cat_id] += index
return class_sample_indices, sample_frequencies
@rank_zero_only
def _get_sample_indices(self) -> list[int]:
"""Load sample indices.
Returns:
list[int]: List of indices after class sampling.
"""
t = Timer()
(
class_sample_indices,
sample_frequencies,
) = self._get_class_sample_indices()
duplicated_samples = sum(
len(v) for _, v in class_sample_indices.items()
)
class_distribution = {
k: len(v) / duplicated_samples
for k, v in class_sample_indices.items()
}
sample_indices = []
frac = 1.0 / len(self.cat2id)
ratios = [
frac / v if v > 0 else 1 for v in class_distribution.values()
]
for cls_inds, ratio in zip(
list(class_sample_indices.values()), ratios
):
sample_indices += np.random.choice(
cls_inds, int(len(cls_inds) * ratio)
).tolist()
self._show_histogram(sample_indices, sample_frequencies)
rank_zero_info(
f"Generating {len(sample_indices)} CBGS samples takes "
+ f"{t.time():.2f} seconds."
)
return sample_indices
def __len__(self) -> int:
"""Return the length of sample indices.
Returns:
int: Length of sample indices.
"""
return len(self.sample_indices)
def __getitem__(self, idx: int) -> DictDataOrList:
"""Get original dataset idx according to the given index.
Args:
idx (int): The index of self.sample_indices.
Returns:
DictDataOrList: Data of the corresponding index.
"""
ori_index = self.sample_indices[idx]
return self.dataset[ori_index]