File size: 2,611 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Resample index to recover the original dataset length."""

from __future__ import annotations

import numpy as np
from torch.utils.data import Dataset

from vis4d.common.logging import rank_zero_info

from .reference import MultiViewDataset
from .typing import DictDataOrList


class ResampleDataset(Dataset[DictDataOrList]):
    """Dataset wrapper to recover the filtered samples through resampling.

    In MMEngine and Detectron2, the dataset might return None when the sample
    has no valid annotations. They will resample the index and try to get the
    valid training data. The length of dataset will be different depends on
    whether filtering the empty samples first.

    This dataset wrapper resamples the index to recover the original dataset
    length (before filter empty frames) to align with the other codebases'
    implementation.

    https://github.com/open-mmlab/mmengine/blob/main/mmengine/dataset/base_dataset.py#L411
    https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/common.py#L96
    """

    def __init__(self, dataset: Dataset[DictDataOrList]) -> None:
        """Creates an instance of the class."""
        super().__init__()
        self.dataset = dataset
        self.has_reference = isinstance(dataset, MultiViewDataset)
        self.valid_len = len(dataset)  # type: ignore

        # Handle the case that dataset is already wrapped.
        if hasattr(self.dataset, "dataset"):
            _dataset = self.dataset.dataset
        else:
            _dataset = self.dataset

        assert hasattr(_dataset, "original_len"), (
            "The dataset must have the attribute `original_len` to resample "
            + "index to recover the original length."
        )
        self.original_len = _dataset.original_len

        rank_zero_info(
            f"Recover {_dataset} to {self.original_len} samples by resampling "
            + "index."
        )

    def __len__(self) -> int:
        """Return the length of dataset.

        Returns:
            int: Length of dataset.
        """
        return self.original_len

    def __getitem__(self, idx: int) -> DictDataOrList:
        """Get original dataset idx according to the given index.

        Resample index to recover the original dataset length.

        Args:
            idx (int): The index of original dataset length.

        Returns:
            DictDataOrList: Data of the corresponding index.
        """
        if idx < self.valid_len:
            index = idx
        else:
            index = np.random.randint(0, self.valid_len)
        return self.dataset[index]