Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| import numpy as np | |
| from typing import Any, List | |
| from detectron2.modeling import ROI_MASK_HEAD_REGISTRY | |
| from detectron2.modeling.roi_heads.mask_head import MaskRCNNConvUpsampleHead, mask_rcnn_inference | |
| from detectron2.projects.point_rend import ImplicitPointRendMaskHead | |
| from detectron2.projects.point_rend.point_features import point_sample | |
| from detectron2.projects.point_rend.point_head import roi_mask_point_loss | |
| from detectron2.structures import Instances | |
| from .point_utils import get_point_coords_from_point_annotation | |
| __all__ = [ | |
| "ImplicitPointRendPointSupHead", | |
| "MaskRCNNConvUpsamplePointSupHead", | |
| ] | |
| class MaskRCNNConvUpsamplePointSupHead(MaskRCNNConvUpsampleHead): | |
| """ | |
| A mask head with several conv layers, plus an upsample layer (with `ConvTranspose2d`). | |
| Predictions are made with a final 1x1 conv layer. | |
| The difference with `MaskRCNNConvUpsampleHead` is that this head is trained | |
| with point supervision. Please use the `MaskRCNNConvUpsampleHead` if you want | |
| to train the model with mask supervision. | |
| """ | |
| def forward(self, x, instances: List[Instances]) -> Any: | |
| """ | |
| Args: | |
| x: input region feature(s) provided by :class:`ROIHeads`. | |
| instances (list[Instances]): contains the boxes & labels corresponding | |
| to the input features. | |
| Exact format is up to its caller to decide. | |
| Typically, this is the foreground instances in training, with | |
| "proposal_boxes" field and other gt annotations. | |
| In inference, it contains boxes that are already predicted. | |
| Returns: | |
| A dict of losses in training. The predicted "instances" in inference. | |
| """ | |
| x = self.layers(x) | |
| if self.training: | |
| N, C, H, W = x.shape | |
| assert H == W | |
| proposal_boxes = [x.proposal_boxes for x in instances] | |
| assert N == np.sum(len(x) for x in proposal_boxes) | |
| if N == 0: | |
| return {"loss_mask": x.sum() * 0} | |
| # Training with point supervision | |
| point_coords, point_labels = get_point_coords_from_point_annotation(instances) | |
| mask_logits = point_sample( | |
| x, | |
| point_coords, | |
| align_corners=False, | |
| ) | |
| return {"loss_mask": roi_mask_point_loss(mask_logits, instances, point_labels)} | |
| else: | |
| mask_rcnn_inference(x, instances) | |
| return instances | |
| class ImplicitPointRendPointSupHead(ImplicitPointRendMaskHead): | |
| def _uniform_sample_train_points(self, instances): | |
| assert self.training | |
| # Please keep in mind that "gt_masks" is not used in this mask head. | |
| point_coords, point_labels = get_point_coords_from_point_annotation(instances) | |
| return point_coords, point_labels | |