Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from maskrcnn_benchmark.layers import ROIAlign, ROIAlignV2 | |
| from .utils import cat | |
| class LevelMapper(object): | |
| """Determine which FPN level each RoI in a set of RoIs should map to based | |
| on the heuristic in the FPN paper. | |
| """ | |
| def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): | |
| """ | |
| Arguments: | |
| k_min (int) | |
| k_max (int) | |
| canonical_scale (int) | |
| canonical_level (int) | |
| eps (float) | |
| """ | |
| self.k_min = k_min | |
| self.k_max = k_max | |
| self.s0 = canonical_scale | |
| self.lvl0 = canonical_level | |
| self.eps = eps | |
| def __call__(self, boxlists): | |
| """ | |
| Arguments: | |
| boxlists (list[BoxList]) | |
| """ | |
| # Compute level ids | |
| s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists])) | |
| # Eqn.(1) in FPN paper | |
| target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps)) | |
| target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max) | |
| return target_lvls.to(torch.int64) - self.k_min | |
| class Pooler(nn.Module): | |
| """ | |
| Pooler for Detection with or without FPN. | |
| It currently hard-code ROIAlign in the implementation, | |
| but that can be made more generic later on. | |
| Also, the requirement of passing the scales is not strictly necessary, as they | |
| can be inferred from the size of the feature map / size of original image, | |
| which is available thanks to the BoxList. | |
| """ | |
| def __init__(self, output_size, scales, sampling_ratio, use_v2=False): | |
| """ | |
| Arguments: | |
| output_size (list[tuple[int]] or list[int]): output size for the pooled region | |
| scales (list[float]): scales for each Pooler | |
| sampling_ratio (int): sampling ratio for ROIAlign | |
| """ | |
| super(Pooler, self).__init__() | |
| poolers = [] | |
| for scale in scales: | |
| poolers.append( | |
| ROIAlignV2( | |
| output_size, spatial_scale=scale, sampling_ratio=sampling_ratio | |
| ) | |
| if use_v2 else | |
| ROIAlign( | |
| output_size, spatial_scale=scale, sampling_ratio=sampling_ratio | |
| ) | |
| ) | |
| self.poolers = nn.ModuleList(poolers) | |
| self.output_size = output_size | |
| # get the levels in the feature map by leveraging the fact that the network always | |
| # downsamples by a factor of 2 at each level. | |
| lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item() | |
| lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item() | |
| self.map_levels = LevelMapper(lvl_min, lvl_max) | |
| def convert_to_roi_format(self, boxes): | |
| concat_boxes = cat([b.bbox for b in boxes], dim=0) | |
| device, dtype = concat_boxes.device, concat_boxes.dtype | |
| ids = cat( | |
| [ | |
| torch.full((len(b), 1), i, dtype=dtype, device=device) | |
| for i, b in enumerate(boxes) | |
| ], | |
| dim=0, | |
| ) | |
| rois = torch.cat([ids, concat_boxes], dim=1) | |
| return rois | |
| def forward(self, x, boxes): | |
| """ | |
| Arguments: | |
| x (list[Tensor]): feature maps for each level | |
| boxes (list[BoxList]): boxes to be used to perform the pooling operation. | |
| Returns: | |
| result (Tensor) | |
| """ | |
| num_levels = len(self.poolers) | |
| rois = self.convert_to_roi_format(boxes) | |
| if num_levels == 1: | |
| return self.poolers[0](x[0], rois) | |
| levels = self.map_levels(boxes) | |
| num_rois = len(rois) | |
| num_channels = x[0].shape[1] | |
| output_size = self.output_size[0] | |
| dtype, device = x[0].dtype, x[0].device | |
| result = torch.zeros( | |
| (num_rois, num_channels, output_size, output_size), | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)): | |
| idx_in_level = torch.nonzero(levels == level).squeeze(1) | |
| rois_per_level = rois[idx_in_level] | |
| result[idx_in_level] = pooler(per_level_feature, rois_per_level) | |
| return result | |