Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Function | |
| from torch.autograd.function import once_differentiable | |
| from maskrcnn_benchmark import _C | |
| # TODO: Use JIT to replace CUDA implementation in the future. | |
| class _SigmoidFocalLoss(Function): | |
| def forward(ctx, logits, targets, gamma, alpha): | |
| ctx.save_for_backward(logits, targets) | |
| num_classes = logits.shape[1] | |
| ctx.num_classes = num_classes | |
| ctx.gamma = gamma | |
| ctx.alpha = alpha | |
| losses = _C.sigmoid_focalloss_forward( | |
| logits, targets, num_classes, gamma, alpha | |
| ) | |
| return losses | |
| def backward(ctx, d_loss): | |
| logits, targets = ctx.saved_tensors | |
| num_classes = ctx.num_classes | |
| gamma = ctx.gamma | |
| alpha = ctx.alpha | |
| d_loss = d_loss.contiguous() | |
| d_logits = _C.sigmoid_focalloss_backward( | |
| logits, targets, d_loss, num_classes, gamma, alpha | |
| ) | |
| return d_logits, None, None, None, None | |
| sigmoid_focal_loss_cuda = _SigmoidFocalLoss.apply | |
| def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha): | |
| num_classes = logits.shape[1] | |
| dtype = targets.dtype | |
| device = targets.device | |
| class_range = torch.arange(1, num_classes + 1, dtype=dtype, device=device).unsqueeze(0) | |
| t = targets.unsqueeze(1) | |
| p = torch.sigmoid(logits) | |
| term1 = (1 - p) ** gamma * torch.log(p) | |
| term2 = p ** gamma * torch.log(1 - p) | |
| return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha) | |
| class SigmoidFocalLoss(nn.Module): | |
| def __init__(self, gamma, alpha): | |
| super(SigmoidFocalLoss, self).__init__() | |
| self.gamma = gamma | |
| self.alpha = alpha | |
| def forward(self, logits, targets): | |
| if logits.is_cuda: | |
| loss_func = sigmoid_focal_loss_cuda | |
| else: | |
| loss_func = sigmoid_focal_loss_cpu | |
| loss = loss_func(logits, targets, self.gamma, self.alpha) | |
| return loss.sum() | |
| def __repr__(self): | |
| tmpstr = self.__class__.__name__ + "(" | |
| tmpstr += "gamma=" + str(self.gamma) | |
| tmpstr += ", alpha=" + str(self.alpha) | |
| tmpstr += ")" | |
| return tmpstr | |
| def token_sigmoid_softmax_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None): | |
| # Another modification is that because we use the cross entropy version, there is no frequent or not frequent class. | |
| # So we temporarily retired the design of alpha. | |
| assert (targets.dim() == 3) | |
| assert (pred_logits.dim() == 3) # batch x from x to | |
| # reprocess target to become probability map ready for softmax | |
| targets = targets.float() | |
| target_num = targets.sum(-1) + 1e-8 # numerical stability | |
| targets = targets / target_num.unsqueeze(-1) # T(x) | |
| if text_mask is not None: | |
| # reserve the last token for non object | |
| assert (text_mask.dim() == 2) | |
| text_mask[:, -1] = 1 | |
| text_mask = (text_mask > 0).unsqueeze(1).repeat(1, pred_logits.size(1), 1) # copy along the image channel | |
| pred_logits = pred_logits.masked_fill(~text_mask, -1000000) # softmax | |
| out_prob = pred_logits.softmax(-1) | |
| filled_targets = targets.clone() | |
| filled_targets[filled_targets == 0] = 1.0 | |
| weight = torch.clamp(targets - out_prob, min=0.001) / filled_targets | |
| weight = torch.pow(weight, gamma) # weight = torch.pow(torch.clamp(target - out_prob, min=0.01), gamma) | |
| loss_ce = - targets * weight * pred_logits.log_softmax( | |
| -1) # only those positives with positive target_sim will have losses. | |
| return loss_ce | |
| def token_sigmoid_binary_focal_loss_v2(pred_logits, targets, alpha, gamma, text_mask=None): | |
| assert (targets.dim() == 3) | |
| assert (pred_logits.dim() == 3) # batch x from x to | |
| if text_mask is not None: | |
| assert (text_mask.dim() == 2) | |
| # We convert everything into binary | |
| out_prob = pred_logits.sigmoid() | |
| out_prob_neg_pos = torch.stack([1 - out_prob, out_prob], dim=-1) + 1e-8 # batch x boxes x 256 x 2 | |
| weight = torch.pow(-out_prob_neg_pos + 1.0, gamma) | |
| focal_zero = - weight[:, :, :, 0] * torch.log(out_prob_neg_pos[:, :, :, 0]) * ( | |
| 1 - alpha) # negative class | |
| focal_one = - weight[:, :, :, 1] * torch.log(out_prob_neg_pos[:, :, :, 1]) * alpha # positive class | |
| focal = torch.stack([focal_zero, focal_one], dim=-1) | |
| loss_ce = torch.gather(focal, index=targets.long().unsqueeze(-1), dim=-1) | |
| return loss_ce | |
| def token_sigmoid_binary_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None): | |
| # binary version of focal loss | |
| # copied from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py | |
| """ | |
| Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| alpha: (optional) Weighting factor in range (0,1) to balance | |
| positive vs negative examples. Default = -1 (no weighting). | |
| gamma: Exponent of the modulating factor (1 - p_t) to | |
| balance easy vs hard examples. | |
| Returns: | |
| Loss tensor with the reduction option applied. | |
| """ | |
| assert (targets.dim() == 3) | |
| assert (pred_logits.dim() == 3) # batch x from x to | |
| bs, n, _ = pred_logits.shape | |
| if text_mask is not None: | |
| assert (text_mask.dim() == 2) | |
| text_mask = (text_mask > 0).unsqueeze(1) | |
| text_mask = text_mask.repeat(1, pred_logits.size(1), 1) # copy along the image channel dimension | |
| pred_logits = torch.masked_select(pred_logits, text_mask) | |
| targets = torch.masked_select(targets, text_mask) | |
| # print(pred_logits.shape) | |
| # print(targets.shape) | |
| p = torch.sigmoid(pred_logits) | |
| ce_loss = F.binary_cross_entropy_with_logits(pred_logits, targets, reduction="none") | |
| p_t = p * targets + (1 - p) * (1 - targets) | |
| loss = ce_loss * ((1 - p_t) ** gamma) | |
| if alpha >= 0: | |
| alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
| loss = alpha_t * loss | |
| return loss | |
| class TokenSigmoidFocalLoss(nn.Module): | |
| def __init__(self, alpha, gamma): | |
| super(TokenSigmoidFocalLoss, self).__init__() | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| def forward(self, logits, targets, text_masks=None, version="binary", **kwargs): | |
| if version == "binary": | |
| loss_func = token_sigmoid_binary_focal_loss | |
| elif version == "softmax": | |
| loss_func = token_sigmoid_softmax_focal_loss | |
| elif version == "binaryv2": | |
| loss_func = token_sigmoid_binary_focal_loss_v2 | |
| else: | |
| raise NotImplementedError | |
| loss = loss_func(logits, targets, self.alpha, self.gamma, text_masks, **kwargs) | |
| return loss.sum() | |
| def __repr__(self): | |
| tmpstr = self.__class__.__name__ + "(" | |
| tmpstr += "gamma=" + str(self.gamma) | |
| tmpstr += ", alpha=" + str(self.alpha) | |
| tmpstr += ")" | |
| return tmpstr | |