Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from torch import nn | |
| from scipy.optimize import linear_sum_assignment | |
| from torch.cuda.amp import custom_fwd, custom_bwd | |
| def box_area(boxes): | |
| return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |
| # modified from torchvision to also return the union | |
| def box_iou(boxes1, boxes2): | |
| area1 = box_area(boxes1) | |
| area2 = box_area(boxes2) | |
| lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] | |
| rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] | |
| wh = (rb - lt).clamp(min=0) # [N,M,2] | |
| inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] | |
| union = area1[:, None] + area2 - inter | |
| iou = inter / union | |
| return iou, union | |
| def generalized_box_iou(boxes1, boxes2): | |
| """ | |
| Generalized IoU from https://giou.stanford.edu/ | |
| The boxes should be in [x0, y0, x1, y1] format | |
| Returns a [N, M] pairwise matrix, where N = len(boxes1) | |
| and M = len(boxes2) | |
| """ | |
| # degenerate boxes gives inf / nan results | |
| # so do an early check | |
| #assert (boxes1[:, 2:] >= boxes1[:, :2]).all() | |
| #assert (boxes2[:, 2:] >= boxes2[:, :2]).all() | |
| iou, union = box_iou(boxes1, boxes2) | |
| lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) | |
| rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) | |
| wh = (rb - lt).clamp(min=0) # [N,M,2] | |
| area = wh[:, :, 0] * wh[:, :, 1] | |
| return iou - (area - union) / area | |
| def dice_loss(inputs, targets, num_boxes): | |
| """ | |
| Compute the DICE loss, similar to generalized IOU for masks | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| """ | |
| inputs = inputs.sigmoid() | |
| inputs = inputs.flatten(1) | |
| numerator = 2 * (inputs * targets).sum(1) | |
| denominator = inputs.sum(-1) + targets.sum(-1) | |
| loss = 1 - (numerator + 1) / (denominator + 1) | |
| return loss.sum() / num_boxes | |
| def sigmoid_focal_loss(inputs: torch.Tensor, targets: torch.Tensor, alpha: float = -1, gamma: float = 2, reduction: str = "none"): | |
| """ | |
| Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| alpha: (optional) Weighting factor in range (0,1) to balance | |
| positive vs negative examples. Default = -1 (no weighting). | |
| gamma: Exponent of the modulating factor (1 - p_t) to | |
| balance easy vs hard examples. | |
| reduction: 'none' | 'mean' | 'sum' | |
| 'none': No reduction will be applied to the output. | |
| 'mean': The output will be averaged. | |
| 'sum': The output will be summed. | |
| Returns: | |
| Loss tensor with the reduction option applied. | |
| """ | |
| p = torch.sigmoid(inputs) | |
| ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
| p_t = p * targets + (1 - p) * (1 - targets) | |
| loss = ce_loss * ((1 - p_t) ** gamma) | |
| if alpha >= 0: | |
| alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
| loss = alpha_t * loss | |
| if reduction == "mean": | |
| loss = loss.mean() | |
| elif reduction == "sum": | |
| loss = loss.sum() | |
| return loss | |
| sigmoid_focal_loss_jit = torch.jit.script( | |
| sigmoid_focal_loss | |
| ) # type: torch.jit.ScriptModule | |
| class HungarianMatcher(nn.Module): | |
| """This class computes an assignment between the targets and the predictions of the network | |
| For efficiency reasons, the targets don't include the no_object. Because of this, in general, | |
| there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, | |
| while the others are un-matched (and thus treated as non-objects). | |
| """ | |
| def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, | |
| use_focal: bool = False, focal_loss_alpha: float = 0.25, focal_loss_gamma: float = 2.0, | |
| **kwargs): | |
| """Creates the matcher | |
| Params: | |
| cost_class: This is the relative weight of the classification error in the matching cost | |
| cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost | |
| cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost | |
| """ | |
| super().__init__() | |
| self.cost_class = cost_class | |
| self.cost_bbox = cost_bbox | |
| self.cost_giou = cost_giou | |
| self.use_focal = use_focal | |
| if self.use_focal: | |
| self.focal_loss_alpha = focal_loss_alpha | |
| self.focal_loss_gamma = focal_loss_gamma | |
| assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" | |
| def forward(self, outputs, targets): | |
| """ Performs the matching | |
| Params: | |
| outputs: This is a dict that contains at least these entries: | |
| "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits | |
| "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates | |
| targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: | |
| "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth | |
| objects in the target) containing the class labels | |
| "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates | |
| Returns: | |
| A list of size batch_size, containing tuples of (index_i, index_j) where: | |
| - index_i is the indices of the selected predictions (in order) | |
| - index_j is the indices of the corresponding selected targets (in order) | |
| For each batch element, it holds: | |
| len(index_i) = len(index_j) = min(num_queries, num_target_boxes) | |
| """ | |
| bs, num_queries = outputs["pred_logits"].shape[:2] | |
| # We flatten to compute the cost matrices in a batch | |
| if self.use_focal: | |
| out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] | |
| out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] | |
| else: | |
| out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] | |
| out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] | |
| # Also concat the target labels and boxes | |
| tgt_ids = torch.cat([v["labels"] for v in targets]) | |
| tgt_bbox = torch.cat([v["boxes_xyxy"] for v in targets]) | |
| # Compute the classification cost. Contrary to the loss, we don't use the NLL, | |
| # but approximate it in 1 - proba[target class]. | |
| # The 1 is a constant that doesn't change the matching, it can be ommitted. | |
| if self.use_focal: | |
| # Compute the classification cost. | |
| alpha = self.focal_loss_alpha | |
| gamma = self.focal_loss_gamma | |
| neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) | |
| pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) | |
| cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] | |
| else: | |
| cost_class = -out_prob[:, tgt_ids] | |
| # Compute the L1 cost between boxes | |
| image_size_out = torch.cat([v["image_size_xyxy"].unsqueeze(0) for v in targets]) | |
| image_size_out = image_size_out.unsqueeze(1).repeat(1, num_queries, 1).flatten(0, 1) | |
| image_size_tgt = torch.cat([v["image_size_xyxy_tgt"] for v in targets]) | |
| out_bbox_ = out_bbox / image_size_out | |
| tgt_bbox_ = tgt_bbox / image_size_tgt | |
| cost_bbox = torch.cdist(out_bbox_, tgt_bbox_, p=1) | |
| # Compute the giou cost betwen boxes | |
| # cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) | |
| cost_giou = -generalized_box_iou(out_bbox, tgt_bbox) | |
| # Final cost matrix | |
| C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou | |
| C = C.view(bs, num_queries, -1).cpu() | |
| C[torch.isnan(C)] = 0.0 | |
| C[torch.isinf(C)] = 0.0 | |
| sizes = [len(v["boxes"]) for v in targets] | |
| indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] | |
| return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] | |
| class SetCriterion(nn.Module): | |
| """ | |
| The process happens in two steps: | |
| 1) we compute hungarian assignment between ground truth boxes and the outputs of the model | |
| 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) | |
| """ | |
| def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, | |
| use_focal, focal_loss_alpha=0.25, focal_loss_gamma=2.0): | |
| """ Create the criterion. | |
| Parameters: | |
| num_classes: number of object categories, omitting the special no-object category | |
| matcher: module able to compute a matching between targets and proposals | |
| weight_dict: dict containing as key the names of the losses and as values their relative weight. | |
| eos_coef: relative classification weight applied to the no-object category | |
| losses: list of all the losses to be applied. See get_loss for list of available losses. | |
| """ | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.matcher = matcher | |
| self.weight_dict = weight_dict | |
| self.eos_coef = eos_coef | |
| self.losses = losses | |
| self.use_focal = use_focal | |
| if self.use_focal: | |
| self.focal_loss_alpha = focal_loss_alpha | |
| self.focal_loss_gamma = focal_loss_gamma | |
| else: | |
| empty_weight = torch.ones(self.num_classes + 1) | |
| empty_weight[-1] = self.eos_coef | |
| self.register_buffer('empty_weight', empty_weight) | |
| def loss_labels(self, outputs, targets, indices, num_boxes, log=False): | |
| """Classification loss (NLL) | |
| targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] | |
| """ | |
| assert 'pred_logits' in outputs | |
| src_logits = outputs['pred_logits'] | |
| idx = self._get_src_permutation_idx(indices) | |
| target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) | |
| target_classes = torch.full(src_logits.shape[:2], self.num_classes, | |
| dtype=torch.int64, device=src_logits.device) | |
| target_classes[idx] = target_classes_o | |
| if self.use_focal: | |
| src_logits = src_logits.flatten(0, 1) | |
| # prepare one_hot target. | |
| target_classes = target_classes.flatten(0, 1) | |
| pos_inds = torch.nonzero(target_classes != self.num_classes, as_tuple=True)[0] | |
| labels = torch.zeros_like(src_logits) | |
| labels[pos_inds, target_classes[pos_inds]] = 1 | |
| # comp focal loss. | |
| class_loss = sigmoid_focal_loss_jit( | |
| src_logits, | |
| labels, | |
| alpha=self.focal_loss_alpha, | |
| gamma=self.focal_loss_gamma, | |
| reduction="sum", | |
| ) / num_boxes | |
| losses = {'loss_ce': class_loss} | |
| else: | |
| loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) | |
| losses = {'loss_ce': loss_ce} | |
| return losses | |
| def loss_boxes(self, outputs, targets, indices, num_boxes): | |
| """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss | |
| targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] | |
| The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. | |
| """ | |
| assert 'pred_boxes' in outputs | |
| idx = self._get_src_permutation_idx(indices) | |
| src_boxes = outputs['pred_boxes'][idx] | |
| target_boxes = torch.cat([t['boxes_xyxy'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
| losses = {} | |
| loss_giou = 1 - torch.diag(generalized_box_iou(src_boxes, target_boxes)) | |
| losses['loss_giou'] = loss_giou.sum() / num_boxes | |
| image_size = torch.cat([v["image_size_xyxy_tgt"] for v in targets]) | |
| src_boxes_ = src_boxes / image_size | |
| target_boxes_ = target_boxes / image_size | |
| loss_bbox = F.l1_loss(src_boxes_, target_boxes_, reduction='none') | |
| losses['loss_bbox'] = loss_bbox.sum() / num_boxes | |
| return losses | |
| def _get_src_permutation_idx(self, indices): | |
| # permute predictions following indices | |
| batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
| src_idx = torch.cat([src for (src, _) in indices]) | |
| return batch_idx, src_idx | |
| def _get_tgt_permutation_idx(self, indices): | |
| # permute targets following indices | |
| batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |
| tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |
| return batch_idx, tgt_idx | |
| def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): | |
| loss_map = { | |
| 'labels': self.loss_labels, | |
| 'boxes': self.loss_boxes, | |
| } | |
| assert loss in loss_map, f'do you really want to compute {loss} loss?' | |
| return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) | |
| def forward(self, outputs, targets, *argrs, **kwargs): | |
| """ This performs the loss computation. | |
| Parameters: | |
| outputs: dict of tensors, see the output specification of the model for the format | |
| targets: list of dicts, such that len(targets) == batch_size. | |
| The expected keys in each dict depends on the losses applied, see each loss' doc | |
| """ | |
| outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} | |
| # Retrieve the matching between the outputs of the last layer and the targets | |
| indices = self.matcher(outputs_without_aux, targets) | |
| # Compute the average number of target boxes accross all nodes, for normalization purposes | |
| num_boxes = sum(len(t["labels"]) for t in targets) | |
| num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) | |
| if dist.is_available() and dist.is_initialized(): | |
| torch.distributed.all_reduce(num_boxes) | |
| word_size = dist.get_world_size() | |
| else: | |
| word_size = 1 | |
| num_boxes = torch.clamp(num_boxes / word_size, min=1).item() | |
| # Compute all the requested losses | |
| losses = {} | |
| for loss in self.losses: | |
| losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) | |
| # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
| if 'aux_outputs' in outputs: | |
| for i, aux_outputs in enumerate(outputs['aux_outputs']): | |
| indices = self.matcher(aux_outputs, targets) | |
| for loss in self.losses: | |
| if loss == 'masks': | |
| # Intermediate masks losses are too costly to compute, we ignore them. | |
| continue | |
| kwargs = {} | |
| if loss == 'labels': | |
| # Logging is enabled only for the last layer | |
| kwargs = {'log': False} | |
| l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) | |
| l_dict = {k + f'_{i}': v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| return losses | |