Spaces:
Running
on
Zero
Running
on
Zero
| """Post process after transformation.""" | |
| from __future__ import annotations | |
| import torch | |
| from vis4d.common.typing import NDArrayF32, NDArrayI64 | |
| from vis4d.data.const import CommonKeys as K | |
| from vis4d.op.box.box2d import bbox_area, bbox_clip | |
| from .base import Transform | |
| class PostProcessBoxes2D: | |
| """Post process after transformation.""" | |
| def __init__( | |
| self, min_area: float = 7.0 * 7.0, clip_bboxes_to_image: bool = True | |
| ) -> None: | |
| """Creates an instance of the class. | |
| Args: | |
| min_area (float): Minimum area of the bounding box. Defaults to | |
| 7.0 * 7.0. | |
| clip_bboxes_to_image (bool): Whether to clip the bounding boxes to | |
| the image size. Defaults to True. | |
| """ | |
| self.min_area = min_area | |
| self.clip_bboxes_to_image = clip_bboxes_to_image | |
| def __call__( | |
| self, | |
| boxes_list: list[NDArrayF32], | |
| classes_list: list[NDArrayI64], | |
| track_ids_list: list[NDArrayI64] | None, | |
| input_hw_list: list[tuple[int, int]], | |
| boxes3d_list: list[NDArrayF32] | None, | |
| boxes3d_classes_list: list[NDArrayI64] | None, | |
| boxes3d_track_ids_list: list[NDArrayI64] | None, | |
| ) -> tuple[ | |
| list[NDArrayF32], | |
| list[NDArrayI64], | |
| list[NDArrayI64] | None, | |
| list[NDArrayF32] | None, | |
| list[NDArrayI64] | None, | |
| list[NDArrayI64] | None, | |
| ]: | |
| """Post process according to boxes2D after transformation. | |
| Args: | |
| boxes_list (list[NDArrayF32]): The bounding boxes to be post | |
| processed. | |
| classes_list (list[NDArrayF32]): The classes of the bounding boxes. | |
| track_ids_list (list[NDArrayI64] | None): The track ids of the | |
| bounding boxes. | |
| input_hw_list (list[tuple[int, int]]): The height and width of the | |
| input image. | |
| boxes3d_list (list[NDArrayF32] | None): The 3D bounding boxes to be | |
| post processed. | |
| boxes3d_classes_list (list[NDArrayI64] | None): The classes of the | |
| 3D bounding boxes. | |
| boxes3d_track_ids_list (list[NDArrayI64] | None): The track ids of | |
| the 3D bounding boxes. | |
| Returns: | |
| tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None, | |
| list[NDArrayF32] | None, list[NDArrayI64] | None, | |
| list[NDArrayI64] | None]: The post processed results. | |
| """ | |
| new_track_ids: list[NDArrayI64] | None = ( | |
| [] if track_ids_list is not None else None | |
| ) | |
| new_boxes3d: list[NDArrayF32] | None = ( | |
| [] if boxes3d_list is not None else None | |
| ) | |
| new_boxes3d_classes: list[NDArrayI64] | None = ( | |
| [] if boxes3d_classes_list is not None else None | |
| ) | |
| new_boxes3d_track_ids: list[NDArrayI64] | None = ( | |
| [] if boxes3d_track_ids_list is not None else None | |
| ) | |
| for i, (boxes, classes) in enumerate(zip(boxes_list, classes_list)): | |
| boxes_ = torch.from_numpy(boxes) | |
| if self.clip_bboxes_to_image: | |
| boxes_ = bbox_clip(boxes_, input_hw_list[i]) | |
| keep = (bbox_area(boxes_) >= self.min_area).numpy() | |
| boxes_list[i] = boxes[keep] | |
| classes_list[i] = classes[keep] | |
| if track_ids_list is not None: | |
| assert new_track_ids is not None | |
| new_track_ids.append(track_ids_list[i][keep]) | |
| if boxes3d_list is not None: | |
| assert new_boxes3d is not None | |
| new_boxes3d.append(boxes3d_list[i][keep]) | |
| if boxes3d_classes_list is not None: | |
| assert new_boxes3d_classes is not None | |
| new_boxes3d_classes.append(boxes3d_classes_list[i][keep]) | |
| if boxes3d_track_ids_list is not None: | |
| assert new_boxes3d_track_ids is not None | |
| new_boxes3d_track_ids.append(boxes3d_track_ids_list[i][keep]) | |
| return ( | |
| boxes_list, | |
| classes_list, | |
| new_track_ids, | |
| new_boxes3d, | |
| new_boxes3d_classes, | |
| new_boxes3d_track_ids, | |
| ) | |
| class RescaleTrackIDs: | |
| """Rescale track ids.""" | |
| def __call__(self, track_ids_list: list[NDArrayI64]) -> list[NDArrayI64]: | |
| """Rescale the track ids. | |
| Args: | |
| track_ids_list (list[NDArrayI64]): The track ids to be | |
| rescaled. | |
| Returns: | |
| list[NDArrayI64]: The rescaled track ids. | |
| """ | |
| track_ids_all: dict[int, int] = {} | |
| for track_ids in track_ids_list: | |
| for track_id in track_ids: | |
| if track_id not in track_ids_all: | |
| track_ids_all[track_id] = len(track_ids_all) | |
| for track_ids in track_ids_list: | |
| for i, track_id in enumerate(track_ids): | |
| track_ids[i] = track_ids_all[track_id] | |
| return track_ids_list | |