Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class LabelSmoothingCrossEntropy(nn.Module): | |
| """ | |
| NLL loss with label smoothing. | |
| """ | |
| def __init__(self, smoothing=0.1): | |
| """ | |
| Constructor for the LabelSmoothing module. | |
| :param smoothing: label smoothing factor | |
| """ | |
| super(LabelSmoothingCrossEntropy, self).__init__() | |
| assert smoothing < 1.0 | |
| self.smoothing = smoothing | |
| self.confidence = 1. - smoothing | |
| def forward(self, x, target): | |
| logprobs = F.log_softmax(x, dim=-1) | |
| nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) | |
| nll_loss = nll_loss.squeeze(1) | |
| smooth_loss = -logprobs.mean(dim=-1) | |
| loss = self.confidence * nll_loss + self.smoothing * smooth_loss | |
| return loss.mean() | |
| class SoftTargetCrossEntropy(nn.Module): | |
| def __init__(self): | |
| super(SoftTargetCrossEntropy, self).__init__() | |
| def forward(self, x, target): | |
| loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) | |
| return loss.mean() | |