Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| import torch | |
| from detectron2.layers import cat | |
| def get_point_coords_from_point_annotation(instances): | |
| """ | |
| Load point coords and their corresponding labels from point annotation. | |
| Args: | |
| instances (list[Instances]): A list of N Instances, where N is the number of images | |
| in the batch. These instances are in 1:1 | |
| correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask, | |
| ...) associated with each instance are stored in fields. | |
| Returns: | |
| point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P | |
| sampled points. | |
| point_labels (Tensor): A tensor of shape (N, P) that contains the labels of P | |
| sampled points. `point_labels` takes 3 possible values: | |
| - 0: the point belongs to background | |
| - 1: the point belongs to the object | |
| - -1: the point is ignored during training | |
| """ | |
| point_coords_list = [] | |
| point_labels_list = [] | |
| for instances_per_image in instances: | |
| if len(instances_per_image) == 0: | |
| continue | |
| point_coords = instances_per_image.gt_point_coords.to(torch.float32) | |
| point_labels = instances_per_image.gt_point_labels.to(torch.float32).clone() | |
| proposal_boxes_per_image = instances_per_image.proposal_boxes.tensor | |
| # Convert point coordinate system, ground truth points are in image coord. | |
| point_coords_wrt_box = get_point_coords_wrt_box(proposal_boxes_per_image, point_coords) | |
| # Ignore points that are outside predicted boxes. | |
| point_ignores = ( | |
| (point_coords_wrt_box[:, :, 0] < 0) | |
| | (point_coords_wrt_box[:, :, 0] > 1) | |
| | (point_coords_wrt_box[:, :, 1] < 0) | |
| | (point_coords_wrt_box[:, :, 1] > 1) | |
| ) | |
| point_labels[point_ignores] = -1 | |
| point_coords_list.append(point_coords_wrt_box) | |
| point_labels_list.append(point_labels) | |
| return ( | |
| cat(point_coords_list, dim=0), | |
| cat(point_labels_list, dim=0), | |
| ) | |
| def get_point_coords_wrt_box(boxes_coords, point_coords): | |
| """ | |
| Convert image-level absolute coordinates to box-normalized [0, 1] x [0, 1] point cooordinates. | |
| Args: | |
| boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes. | |
| coordinates. | |
| point_coords (Tensor): A tensor of shape (R, P, 2) that contains | |
| image-normalized coordinates of P sampled points. | |
| Returns: | |
| point_coords_wrt_box (Tensor): A tensor of shape (R, P, 2) that contains | |
| [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. | |
| """ | |
| with torch.no_grad(): | |
| point_coords_wrt_box = point_coords.clone() | |
| point_coords_wrt_box[:, :, 0] -= boxes_coords[:, None, 0] | |
| point_coords_wrt_box[:, :, 1] -= boxes_coords[:, None, 1] | |
| point_coords_wrt_box[:, :, 0] = point_coords_wrt_box[:, :, 0] / ( | |
| boxes_coords[:, None, 2] - boxes_coords[:, None, 0] | |
| ) | |
| point_coords_wrt_box[:, :, 1] = point_coords_wrt_box[:, :, 1] / ( | |
| boxes_coords[:, None, 3] - boxes_coords[:, None, 1] | |
| ) | |
| return point_coords_wrt_box | |