RoyYang0714's picture
feat: Try to build everything locally.
9b33fca
"""Object detection and visual grounding dataset."""
from __future__ import annotations
import json
import os.path as osp
import numpy as np
from tqdm import tqdm
from vis4d.common.logging import rank_zero_info
from vis4d.common.typing import ArgsType, DictStrAny
from vis4d.data.const import CommonKeys as K
from vis4d.data.datasets.base import Dataset
from vis4d.data.datasets.util import (
CacheMappingMixin,
im_decode,
print_class_histogram,
)
from vis4d.data.typing import DictData
class ODVGDataset(CacheMappingMixin, Dataset):
"""Object detection and visual grounding dataset."""
def __init__(
self,
data_root: str,
ann_file: str,
label_map_file: str | None = None,
dataset_type: str = "VG",
dataset_prefix: str | None = None,
remove_empty: bool = False,
cache_as_binary: bool = False,
cached_file_path: str | None = None,
**kwargs: ArgsType,
) -> None:
"""Create an object detection and visual grounding dataset."""
super().__init__(**kwargs)
self.data_root = data_root
self.ann_file = ann_file
self.dataset_type = dataset_type
self.dataset_prefix = dataset_prefix
self.remove_empty = remove_empty
if label_map_file is not None:
label_map_file = osp.join(self.data_root, label_map_file)
with open(label_map_file, "r") as file:
# dict[class_id (str): class_name (str)]
self.label_map = json.load(file)
self.dataset_type = "OD"
self.det_map = {v: int(k) for k, v in self.label_map.items()}
self.categories = sorted(self.det_map, key=self.det_map.get)
else:
self.label_map = None
self.dataset_type = "VG"
# Load annotations
self.samples, _ = self._load_mapping(
self._generate_data_mapping,
self._filter_data,
cache_as_binary=cache_as_binary,
cached_file_path=cached_file_path,
)
def __repr__(self) -> str:
"""Concise representation of the dataset."""
return f"ODVGDataset({self.ann_file})"
def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]:
"""Remove empty samples."""
samples = []
if self.dataset_type == "OD":
frequencies = {cat: 0 for _, cat in self.label_map.items()}
empty_samples = 0
for sample in data:
if self.remove_empty and len(sample["anns"]) == 0:
empty_samples += 1
continue
if self.dataset_type == "OD":
for ann in sample["anns"]:
frequencies[ann["category"]] += 1
samples.append(sample)
rank_zero_info(f"Propocessing {self} with {len(samples)} samples.")
rank_zero_info(f"Filtered {empty_samples} empty samples")
if self.dataset_type == "OD":
frequencies = dict(sorted(frequencies.items()))
print_class_histogram(frequencies)
return samples
def _generate_data_mapping(self) -> list[DictStrAny]:
"""Generates the data mapping."""
with open(osp.join(self.data_root, self.ann_file), "r") as f:
data_list = [json.loads(line) for line in f]
if self.with_camera:
with open(osp.join(self.data_root, "cam_info.json"), "r") as f:
cameras = json.load(f)
samples = []
for data in tqdm(data_list):
data_info = {}
if self.dataset_prefix is not None:
img_path = osp.join(
self.data_root, self.dataset_prefix, data["filename"]
)
else:
img_path = osp.join(self.data_root, data["filename"])
data_info["img_path"] = img_path
# Pseudo K
if self.with_camera:
data_info["K"] = cameras[img_path][0]
# Pseudo Depth Path
if self.dataset_prefix is not None:
depth_path = osp.join(
self.data_root,
f"{self.dataset_prefix}_depth",
data["filename"].replace(".jpg", "_depth.png"),
)
else:
depth_path = osp.join(
self.data_root,
data["filename"].replace(".jpg", "_depth.png"),
)
data_info["depth_path"] = depth_path
data_info["height"] = data["height"]
data_info["width"] = data["width"]
valid_anns = []
boxes = []
class_ids = np.empty((0,), dtype=np.int64)[1:]
if self.dataset_type == "OD":
instances = data.get("detection", {}).get("instances", [])
for ann in instances:
bbox = ann["bbox"]
# Box 2D
x1, y1, x2, y2 = bbox
inter_w = max(0, min(x2, data["width"]) - max(x1, 0))
inter_h = max(0, min(y2, data["height"]) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if (x2 - x1) < 1 or (y2 - y1) < 1:
continue
boxes.append(bbox)
# Class
class_ids = np.concatenate(
[class_ids, np.array([ann["label"]], dtype=np.int64)]
)
valid_anns.append(ann)
else:
anno = data["grounding"]
caption = anno["caption"].lower().strip()
if not caption.endswith("."):
caption = caption + ". "
data_info["caption"] = caption
regions = anno["regions"]
phrases = []
positive_positions = []
for i, region in enumerate(regions):
bboxes = region["bbox"]
if not isinstance(bboxes[0], list):
bboxes = [bboxes]
for bbox in bboxes:
x1, y1, x2, y2 = bbox
inter_w = max(0, min(x2, data["width"]) - max(x1, 0))
inter_h = max(0, min(y2, data["height"]) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if (x2 - x1) < 1 or (y2 - y1) < 1:
continue
boxes.append(bbox)
phrases.append(region["phrase"])
positive_positions.append(region["tokens_positive"])
valid_anns.append(region)
class_ids = np.concatenate(
[class_ids, np.array([i], dtype=np.int64)]
)
data_info["phrases"] = phrases
data_info["positive_positions"] = positive_positions
boxes2d = (
np.empty((0, 4), dtype=np.float32)
if not boxes
else np.array(boxes, dtype=np.float32)
)
data_info["boxes2d"] = boxes2d
data_info["class_ids"] = class_ids
data_info["anns"] = valid_anns
samples.append(data_info)
del data_list
return samples
def get_cat_ids(self, idx: int) -> list[int]:
"""Return the samples."""
return self.samples[idx]["class_ids"].tolist()
def __len__(self) -> int:
"""Total number of samples of data."""
return len(self.samples)
def __getitem__(self, idx: int) -> DictData:
"""Get single sample.
Args:
idx (int): Index of sample.
Returns:
DictData: sample at index in Vis4D input format.
"""
sample = self.samples[idx]
data_dict: DictData = {}
# Get image info
sample_name = sample["img_path"].split("/")[-1]
data_dict[K.sample_names] = sample_name
im_bytes = self.data_backend.get(sample["img_path"])
image = np.ascontiguousarray(
im_decode(im_bytes, mode=self.image_channel_mode),
dtype=np.float32,
)[None]
data_dict[K.images] = image
data_dict[K.input_hw] = (image.shape[1], image.shape[2])
data_dict[K.original_images] = image
data_dict[K.original_hw] = (image.shape[1], image.shape[2])
data_dict[K.boxes2d] = sample["boxes2d"]
data_dict[K.boxes2d_classes] = sample["class_ids"]
if self.dataset_type == "OD":
data_dict[K.boxes2d_names] = self.categories
data_dict["phrases"] = None
data_dict["positive_positions"] = None
else:
data_dict[K.boxes2d_names] = sample["caption"]
data_dict["phrases"] = sample["phrases"]
data_dict["positive_positions"] = sample["positive_positions"]
data_dict["dataset_type"] = self.dataset_type
data_dict["label_map"] = self.label_map
self.data_backend.close()
return data_dict