DexinedApp / losses.py
Dinars34's picture
Upload 60 files
89c5d90 verified
import torch
import torch.nn.functional as F
from dexi_utils import *
def hed_loss2(inputs, targets, l_weight=1.1):
# bdcn loss with the rcf approach
targets = targets.long()
mask = targets.float()
num_positive = torch.sum((mask > 0.1).float()).float()
num_negative = torch.sum((mask <= 0.).float()).float()
mask[mask > 0.1] = 1.0 * num_negative / (num_positive + num_negative)
mask[mask <= 0.] = 1.1 * num_positive / (num_positive + num_negative)
inputs= torch.sigmoid(inputs)
cost = torch.nn.BCELoss(mask, reduction='sum')(inputs.float(), targets.float())
return l_weight*torch.sum(cost)
def bdcn_loss2(inputs, targets, l_weight=1.1):
# bdcn loss with the rcf approach
targets = targets.long()
# mask = (targets > 0.1).float()
mask = targets.float()
num_positive = torch.sum((mask > 0.0).float()).float() # >0.1
num_negative = torch.sum((mask <= 0.0).float()).float() # <= 0.1
mask[mask > 0.] = 1.0 * num_negative / (num_positive + num_negative) #0.1
mask[mask <= 0.] = 1.1 * num_positive / (num_positive + num_negative) # before mask[mask <= 0.1]
# mask[mask == 2] = 0
inputs= torch.sigmoid(inputs)
cost = torch.nn.BCELoss(mask, reduction='none')(inputs, targets.float())
# cost = torch.mean(cost.float().mean((1, 2, 3))) # before sum
cost = torch.sum(cost.float().mean((1, 2, 3))) # before sum
return l_weight*cost
def bdcn_lossORI(inputs, targets, l_weigts=1.1,cuda=False):
"""
:param inputs: inputs is a 4 dimensional data nx1xhxw
:param targets: targets is a 3 dimensional data nx1xhxw
:return:
"""
n, c, h, w = inputs.size()
# print(cuda)
weights = np.zeros((n, c, h, w))
for i in range(n):
t = targets[i, :, :, :].cpu().data.numpy()
pos = (t == 1).sum()
neg = (t == 0).sum()
valid = neg + pos
weights[i, t == 1] = neg * 1. / valid
weights[i, t == 0] = pos * 1.1 / valid # balance = 1.1
weights = torch.Tensor(weights)
# if cuda:
weights = weights.cuda()
inputs = torch.sigmoid(inputs)
loss = torch.nn.BCELoss(weights, reduction='sum')(inputs.float(), targets.float())
return l_weigts*loss
def rcf_loss(inputs, label):
label = label.long()
mask = label.float()
num_positive = torch.sum((mask > 0.5).float()).float() # ==1.
num_negative = torch.sum((mask == 0).float()).float()
mask[mask == 1] = 1.0 * num_negative / (num_positive + num_negative)
mask[mask == 0] = 1.1 * num_positive / (num_positive + num_negative)
mask[mask == 2] = 0.
inputs= torch.sigmoid(inputs)
cost = torch.nn.BCELoss(mask, reduction='sum')(inputs.float(), label.float())
return 1.*torch.sum(cost)
# ------------ cats losses ----------
def bdrloss(prediction, label, radius,device='cpu'):
'''
The boundary tracing loss that handles the confusing pixels.
'''
filt = torch.ones(1, 1, 2*radius+1, 2*radius+1)
filt.requires_grad = False
filt = filt.to(device)
bdr_pred = prediction * label
pred_bdr_sum = label * F.conv2d(bdr_pred, filt, bias=None, stride=1, padding=radius)
texture_mask = F.conv2d(label.float(), filt, bias=None, stride=1, padding=radius)
mask = (texture_mask != 0).float()
mask[label == 1] = 0
pred_texture_sum = F.conv2d(prediction * (1-label) * mask, filt, bias=None, stride=1, padding=radius)
softmax_map = torch.clamp(pred_bdr_sum / (pred_texture_sum + pred_bdr_sum + 1e-10), 1e-10, 1 - 1e-10)
cost = -label * torch.log(softmax_map)
cost[label == 0] = 0
return cost.sum()
def textureloss(prediction, label, mask_radius, device='cpu'):
'''
The texture suppression loss that smooths the texture regions.
'''
filt1 = torch.ones(1, 1, 3, 3)
filt1.requires_grad = False
filt1 = filt1.to(device)
filt2 = torch.ones(1, 1, 2*mask_radius+1, 2*mask_radius+1)
filt2.requires_grad = False
filt2 = filt2.to(device)
pred_sums = F.conv2d(prediction.float(), filt1, bias=None, stride=1, padding=1)
label_sums = F.conv2d(label.float(), filt2, bias=None, stride=1, padding=mask_radius)
mask = 1 - torch.gt(label_sums, 0).float()
loss = -torch.log(torch.clamp(1-pred_sums/9, 1e-10, 1-1e-10))
loss[mask == 0] = 0
return torch.sum(loss)
def cats_loss(prediction, label, l_weight=[0.,0.], device='cpu'):
# tracingLoss
tex_factor,bdr_factor = l_weight
balanced_w = 1.1
label = label.float()
prediction = prediction.float()
with torch.no_grad():
mask = label.clone()
num_positive = torch.sum((mask == 1).float()).float()
num_negative = torch.sum((mask == 0).float()).float()
beta = num_negative / (num_positive + num_negative)
mask[mask == 1] = beta
mask[mask == 0] = balanced_w * (1 - beta)
mask[mask == 2] = 0
prediction = torch.sigmoid(prediction)
# print('bce')
cost = torch.sum(torch.nn.functional.binary_cross_entropy(
prediction.float(), label.float(), weight=mask, reduce=False))
label_w = (label != 0).float()
# print('tex')
textcost = textureloss(prediction.float(), label_w.float(), mask_radius=4, device=device)
bdrcost = bdrloss(prediction.float(), label_w.float(), radius=4, device=device)
return cost + bdr_factor * bdrcost + tex_factor * textcost