Spaces:
Configuration error
Configuration error
| from fastai.vision import * | |
| from modules.model import Model | |
| class MultiLosses(nn.Module): | |
| def __init__(self, one_hot=True): | |
| super().__init__() | |
| self.ce = SoftCrossEntropyLoss() if one_hot else torch.nn.CrossEntropyLoss() | |
| self.bce = torch.nn.BCELoss() | |
| def last_losses(self): | |
| return self.losses | |
| def _flatten(self, sources, lengths): | |
| return torch.cat([t[:l] for t, l in zip(sources, lengths)]) | |
| def _merge_list(self, all_res): | |
| if not isinstance(all_res, (list, tuple)): | |
| return all_res | |
| def merge(items): | |
| if isinstance(items[0], torch.Tensor): return torch.cat(items, dim=0) | |
| else: return items[0] | |
| res = dict() | |
| for key in all_res[0].keys(): | |
| items = [r[key] for r in all_res] | |
| res[key] = merge(items) | |
| return res | |
| def _ce_loss(self, output, gt_labels, gt_lengths, idx=None, record=True): | |
| loss_name = output.get('name') | |
| pt_logits, weight = output['logits'], output['loss_weight'] | |
| assert pt_logits.shape[0] % gt_labels.shape[0] == 0 | |
| iter_size = pt_logits.shape[0] // gt_labels.shape[0] | |
| if iter_size > 1: | |
| gt_labels = gt_labels.repeat(3, 1, 1) | |
| gt_lengths = gt_lengths.repeat(3) | |
| flat_gt_labels = self._flatten(gt_labels, gt_lengths) | |
| flat_pt_logits = self._flatten(pt_logits, gt_lengths) | |
| nll = output.get('nll') | |
| if nll is not None: | |
| loss = self.ce(flat_pt_logits, flat_gt_labels, softmax=False) * weight | |
| else: | |
| loss = self.ce(flat_pt_logits, flat_gt_labels) * weight | |
| if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss | |
| return loss | |
| def forward(self, outputs, *args): | |
| self.losses = {} | |
| if isinstance(outputs, (tuple, list)): | |
| outputs = [self._merge_list(o) for o in outputs] | |
| return sum([self._ce_loss(o, *args) for o in outputs if o['loss_weight'] > 0.]) | |
| else: | |
| return self._ce_loss(outputs, *args, record=False) | |
| class SoftCrossEntropyLoss(nn.Module): | |
| def __init__(self, reduction="mean"): | |
| super().__init__() | |
| self.reduction = reduction | |
| def forward(self, input, target, softmax=True): | |
| if softmax: log_prob = F.log_softmax(input, dim=-1) | |
| else: log_prob = torch.log(input) | |
| loss = -(target * log_prob).sum(dim=-1) | |
| if self.reduction == "mean": return loss.mean() | |
| elif self.reduction == "sum": return loss.sum() | |
| else: return loss | |