Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| """ | |
| Implements the Generalized R-CNN framework | |
| """ | |
| import torch | |
| from torch import nn | |
| from maskrcnn_benchmark.structures.image_list import to_image_list | |
| from ..backbone import build_backbone | |
| from ..rpn import build_rpn | |
| from ..roi_heads import build_roi_heads | |
| import timeit | |
| class GeneralizedRCNN(nn.Module): | |
| """ | |
| Main class for Generalized R-CNN. Currently supports boxes and masks. | |
| It consists of three main parts: | |
| - backbone | |
| - rpn | |
| - heads: takes the features + the proposals from the RPN and computes | |
| detections / masks from it. | |
| """ | |
| def __init__(self, cfg): | |
| super(GeneralizedRCNN, self).__init__() | |
| self.backbone = build_backbone(cfg) | |
| self.rpn = build_rpn(cfg) | |
| self.roi_heads = build_roi_heads(cfg) | |
| self.DEBUG = cfg.MODEL.DEBUG | |
| self.ONNX = cfg.MODEL.ONNX | |
| self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE | |
| self.freeze_fpn = cfg.MODEL.FPN.FREEZE | |
| self.freeze_rpn = cfg.MODEL.RPN.FREEZE | |
| if cfg.MODEL.LINEAR_PROB: | |
| assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!" | |
| if hasattr(self.backbone, 'fpn'): | |
| assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!" | |
| self.linear_prob = cfg.MODEL.LINEAR_PROB | |
| def train(self, mode=True): | |
| """Convert the model into training mode while keep layers freezed.""" | |
| super(GeneralizedRCNN, self).train(mode) | |
| if self.freeze_backbone: | |
| self.backbone.body.eval() | |
| for p in self.backbone.body.parameters(): | |
| p.requires_grad = False | |
| if self.freeze_fpn: | |
| self.backbone.fpn.eval() | |
| for p in self.backbone.fpn.parameters(): | |
| p.requires_grad = False | |
| if self.freeze_rpn: | |
| self.rpn.eval() | |
| for p in self.rpn.parameters(): | |
| p.requires_grad = False | |
| if self.linear_prob: | |
| if self.rpn is not None: | |
| for key, value in self.rpn.named_parameters(): | |
| if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key): | |
| value.requires_grad = False | |
| if self.roi_heads is not None: | |
| for key, value in self.roi_heads.named_parameters(): | |
| if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key): | |
| value.requires_grad = False | |
| def forward(self, images, targets=None): | |
| """ | |
| Arguments: | |
| images (list[Tensor] or ImageList): images to be processed | |
| targets (list[BoxList]): ground-truth boxes present in the image (optional) | |
| Returns: | |
| result (list[BoxList] or dict[Tensor]): the output from the model. | |
| During training, it returns a dict[Tensor] which contains the losses. | |
| During testing, it returns list[BoxList] contains additional fields | |
| like `scores`, `labels` and `mask` (for Mask R-CNN models). | |
| """ | |
| if self.training and targets is None: | |
| raise ValueError("In training mode, targets should be passed") | |
| if self.DEBUG: debug_info = {} | |
| if self.DEBUG: debug_info['input_size'] = images[0].size() | |
| if self.DEBUG: tic = timeit.time.perf_counter() | |
| if self.ONNX: | |
| features = self.backbone(images) | |
| else: | |
| images = to_image_list(images) | |
| features = self.backbone(images.tensors) | |
| if self.DEBUG: debug_info['feat_time'] = timeit.time.perf_counter() - tic | |
| if self.DEBUG: debug_info['feat_size'] = [feat.size() for feat in features] | |
| if self.DEBUG: tic = timeit.time.perf_counter() | |
| proposals, proposal_losses = self.rpn(images, features, targets) | |
| if self.DEBUG: debug_info['rpn_time'] = timeit.time.perf_counter() - tic | |
| if self.DEBUG: debug_info['#rpn'] = [prop for prop in proposals] | |
| if self.DEBUG: tic = timeit.time.perf_counter() | |
| if self.roi_heads: | |
| x, result, detector_losses = self.roi_heads(features, proposals, targets) | |
| else: | |
| # RPN-only models don't have roi_heads | |
| x = features | |
| result = proposals | |
| detector_losses = {} | |
| if self.DEBUG: debug_info['rcnn_time'] = timeit.time.perf_counter() - tic | |
| if self.DEBUG: debug_info['#rcnn'] = result | |
| if self.DEBUG: return result, debug_info | |
| if self.training: | |
| losses = {} | |
| losses.update(detector_losses) | |
| losses.update(proposal_losses) | |
| return losses | |
| return result |