Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class PairwizeDiffLoss(nn.Module): | |
| def __init__(self, margin: float = 0.2, norm: str = "l1"): | |
| super().__init__() | |
| self.margin = margin | |
| self.norm = norm | |
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| s = input.unsqueeze(1) - input.unsqueeze(0) | |
| t = target.unsqueeze(1) - target.unsqueeze(0) | |
| if self.norm not in ["l1", "l2_squared"]: | |
| raise ValueError( | |
| f'Unknown norm: {self.norm}. Must be one of ["l1", "l2_squared"]' | |
| ) | |
| norm_fn = { | |
| "l1": torch.abs, | |
| "l2_squared": lambda x: x**2, | |
| }[self.norm] | |
| loss = F.relu(norm_fn(s - t) - self.margin) | |
| return loss.mean().div(2) | |
| class CombinedLoss(nn.Module): | |
| def __init__(self, weighted_losses: list[tuple[nn.Module, float]]): | |
| super().__init__() | |
| self.weighted_losses = weighted_losses | |
| def forward( | |
| self, input: torch.Tensor, target: torch.Tensor | |
| ) -> list[tuple[float, torch.Tensor]]: | |
| return [(w, loss(input, target)) for loss, w in self.weighted_losses] | |