Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn.functional as F | |
| from dexi_utils import * | |
| def hed_loss2(inputs, targets, l_weight=1.1): | |
| # bdcn loss with the rcf approach | |
| targets = targets.long() | |
| mask = targets.float() | |
| num_positive = torch.sum((mask > 0.1).float()).float() | |
| num_negative = torch.sum((mask <= 0.).float()).float() | |
| mask[mask > 0.1] = 1.0 * num_negative / (num_positive + num_negative) | |
| mask[mask <= 0.] = 1.1 * num_positive / (num_positive + num_negative) | |
| inputs= torch.sigmoid(inputs) | |
| cost = torch.nn.BCELoss(mask, reduction='sum')(inputs.float(), targets.float()) | |
| return l_weight*torch.sum(cost) | |
| def bdcn_loss2(inputs, targets, l_weight=1.1): | |
| # bdcn loss with the rcf approach | |
| targets = targets.long() | |
| # mask = (targets > 0.1).float() | |
| mask = targets.float() | |
| num_positive = torch.sum((mask > 0.0).float()).float() # >0.1 | |
| num_negative = torch.sum((mask <= 0.0).float()).float() # <= 0.1 | |
| mask[mask > 0.] = 1.0 * num_negative / (num_positive + num_negative) #0.1 | |
| mask[mask <= 0.] = 1.1 * num_positive / (num_positive + num_negative) # before mask[mask <= 0.1] | |
| # mask[mask == 2] = 0 | |
| inputs= torch.sigmoid(inputs) | |
| cost = torch.nn.BCELoss(mask, reduction='none')(inputs, targets.float()) | |
| # cost = torch.mean(cost.float().mean((1, 2, 3))) # before sum | |
| cost = torch.sum(cost.float().mean((1, 2, 3))) # before sum | |
| return l_weight*cost | |
| def bdcn_lossORI(inputs, targets, l_weigts=1.1,cuda=False): | |
| """ | |
| :param inputs: inputs is a 4 dimensional data nx1xhxw | |
| :param targets: targets is a 3 dimensional data nx1xhxw | |
| :return: | |
| """ | |
| n, c, h, w = inputs.size() | |
| # print(cuda) | |
| weights = np.zeros((n, c, h, w)) | |
| for i in range(n): | |
| t = targets[i, :, :, :].cpu().data.numpy() | |
| pos = (t == 1).sum() | |
| neg = (t == 0).sum() | |
| valid = neg + pos | |
| weights[i, t == 1] = neg * 1. / valid | |
| weights[i, t == 0] = pos * 1.1 / valid # balance = 1.1 | |
| weights = torch.Tensor(weights) | |
| # if cuda: | |
| weights = weights.cuda() | |
| inputs = torch.sigmoid(inputs) | |
| loss = torch.nn.BCELoss(weights, reduction='sum')(inputs.float(), targets.float()) | |
| return l_weigts*loss | |
| def rcf_loss(inputs, label): | |
| label = label.long() | |
| mask = label.float() | |
| num_positive = torch.sum((mask > 0.5).float()).float() # ==1. | |
| num_negative = torch.sum((mask == 0).float()).float() | |
| mask[mask == 1] = 1.0 * num_negative / (num_positive + num_negative) | |
| mask[mask == 0] = 1.1 * num_positive / (num_positive + num_negative) | |
| mask[mask == 2] = 0. | |
| inputs= torch.sigmoid(inputs) | |
| cost = torch.nn.BCELoss(mask, reduction='sum')(inputs.float(), label.float()) | |
| return 1.*torch.sum(cost) | |
| # ------------ cats losses ---------- | |
| def bdrloss(prediction, label, radius,device='cpu'): | |
| ''' | |
| The boundary tracing loss that handles the confusing pixels. | |
| ''' | |
| filt = torch.ones(1, 1, 2*radius+1, 2*radius+1) | |
| filt.requires_grad = False | |
| filt = filt.to(device) | |
| bdr_pred = prediction * label | |
| pred_bdr_sum = label * F.conv2d(bdr_pred, filt, bias=None, stride=1, padding=radius) | |
| texture_mask = F.conv2d(label.float(), filt, bias=None, stride=1, padding=radius) | |
| mask = (texture_mask != 0).float() | |
| mask[label == 1] = 0 | |
| pred_texture_sum = F.conv2d(prediction * (1-label) * mask, filt, bias=None, stride=1, padding=radius) | |
| softmax_map = torch.clamp(pred_bdr_sum / (pred_texture_sum + pred_bdr_sum + 1e-10), 1e-10, 1 - 1e-10) | |
| cost = -label * torch.log(softmax_map) | |
| cost[label == 0] = 0 | |
| return cost.sum() | |
| def textureloss(prediction, label, mask_radius, device='cpu'): | |
| ''' | |
| The texture suppression loss that smooths the texture regions. | |
| ''' | |
| filt1 = torch.ones(1, 1, 3, 3) | |
| filt1.requires_grad = False | |
| filt1 = filt1.to(device) | |
| filt2 = torch.ones(1, 1, 2*mask_radius+1, 2*mask_radius+1) | |
| filt2.requires_grad = False | |
| filt2 = filt2.to(device) | |
| pred_sums = F.conv2d(prediction.float(), filt1, bias=None, stride=1, padding=1) | |
| label_sums = F.conv2d(label.float(), filt2, bias=None, stride=1, padding=mask_radius) | |
| mask = 1 - torch.gt(label_sums, 0).float() | |
| loss = -torch.log(torch.clamp(1-pred_sums/9, 1e-10, 1-1e-10)) | |
| loss[mask == 0] = 0 | |
| return torch.sum(loss) | |
| def cats_loss(prediction, label, l_weight=[0.,0.], device='cpu'): | |
| # tracingLoss | |
| tex_factor,bdr_factor = l_weight | |
| balanced_w = 1.1 | |
| label = label.float() | |
| prediction = prediction.float() | |
| with torch.no_grad(): | |
| mask = label.clone() | |
| num_positive = torch.sum((mask == 1).float()).float() | |
| num_negative = torch.sum((mask == 0).float()).float() | |
| beta = num_negative / (num_positive + num_negative) | |
| mask[mask == 1] = beta | |
| mask[mask == 0] = balanced_w * (1 - beta) | |
| mask[mask == 2] = 0 | |
| prediction = torch.sigmoid(prediction) | |
| # print('bce') | |
| cost = torch.sum(torch.nn.functional.binary_cross_entropy( | |
| prediction.float(), label.float(), weight=mask, reduce=False)) | |
| label_w = (label != 0).float() | |
| # print('tex') | |
| textcost = textureloss(prediction.float(), label_w.float(), mask_radius=4, device=device) | |
| bdrcost = bdrloss(prediction.float(), label_w.float(), radius=4, device=device) | |
| return cost + bdr_factor * bdrcost + tex_factor * textcost |