Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| class IOULoss(nn.Module): | |
| def __init__(self, loss_type="iou"): | |
| super(IOULoss, self).__init__() | |
| self.loss_type = loss_type | |
| def forward(self, pred, target, weight=None): | |
| pred_left = pred[:, 0] | |
| pred_top = pred[:, 1] | |
| pred_right = pred[:, 2] | |
| pred_bottom = pred[:, 3] | |
| target_left = target[:, 0] | |
| target_top = target[:, 1] | |
| target_right = target[:, 2] | |
| target_bottom = target[:, 3] | |
| target_area = (target_left + target_right) * \ | |
| (target_top + target_bottom) | |
| pred_area = (pred_left + pred_right) * \ | |
| (pred_top + pred_bottom) | |
| w_intersect = torch.min(pred_left, target_left) + torch.min(pred_right, target_right) | |
| g_w_intersect = torch.max(pred_left, target_left) + torch.max( | |
| pred_right, target_right) | |
| h_intersect = torch.min(pred_bottom, target_bottom) + torch.min(pred_top, target_top) | |
| g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max(pred_top, target_top) | |
| ac_uion = g_w_intersect * g_h_intersect + 1e-7 | |
| area_intersect = w_intersect * h_intersect | |
| area_union = target_area + pred_area - area_intersect | |
| ious = (area_intersect + 1.0) / (area_union + 1.0) | |
| gious = ious - (ac_uion - area_union) / ac_uion | |
| if self.loss_type == 'iou': | |
| losses = -torch.log(ious) | |
| elif self.loss_type == 'linear_iou': | |
| losses = 1 - ious | |
| elif self.loss_type == 'giou': | |
| losses = 1 - gious | |
| else: | |
| raise NotImplementedError | |
| if weight is not None and weight.sum() > 0: | |
| return (losses * weight).sum() | |
| else: | |
| assert losses.numel() != 0 | |
| return losses.sum() | |
| class IOUWHLoss(nn.Module): # used for anchor guiding | |
| def __init__(self, reduction='none'): | |
| super(IOUWHLoss, self).__init__() | |
| self.reduction = reduction | |
| def forward(self, pred, target): | |
| orig_shape = pred.shape | |
| pred = pred.view(-1, 4) | |
| target = target.view(-1, 4) | |
| target[:, :2] = 0 | |
| tl = torch.max((target[:, :2] - pred[:, 2:] / 2), | |
| (target[:, :2] - target[:, 2:] / 2)) | |
| br = torch.min((target[:, :2] + pred[:, 2:] / 2), | |
| (target[:, :2] + target[:, 2:] / 2)) | |
| area_p = torch.prod(pred[:, 2:], 1) | |
| area_g = torch.prod(target[:, 2:], 1) | |
| en = (tl < br).type(tl.type()).prod(dim=1) | |
| area_i = torch.prod(br - tl, 1) * en | |
| U = area_p + area_g - area_i + 1e-16 | |
| iou = area_i / U | |
| loss = 1 - iou ** 2 | |
| if self.reduction == 'mean': | |
| loss = loss.mean() | |
| elif self.reduction == 'sum': | |
| loss = loss.sum() | |
| return loss | |