Spaces:
Runtime error
Runtime error
File size: 1,853 Bytes
d572f56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from typing import Dict, Optional
from torch import nn
import torchmetrics as tm
from core.types import BatchedInputOutput, OperationMode
class BaseMetricHandler(nn.Module):
def __init__(
self, stem: str, metric: tm.Metric, modality: str, name: Optional[str] = None
) -> None:
super().__init__()
self.metric = metric
self.modality = modality
self.stem = stem
if name is None or name == "__auto__":
name = self.metric.__class__.__name__
self.name = name
def update(self, batch: BatchedInputOutput):
y_true = batch.sources[self.stem]
y_pred = batch.estimates[self.stem]
self.metric.update(y_pred[self.modality].cuda(), y_true[self.modality].cuda())
def compute(self) -> Dict[str, float]:
metric = self.metric.compute()
if isinstance(metric, dict):
return {f"{self.name}/{k}": v for k, v in metric.items()}
return {self.name: self.metric.compute()}
def reset(self):
self.metric.reset()
class MultiModeMetricHandler(nn.Module):
def __init__(
self,
train_metrics: Dict[str, BaseMetricHandler],
val_metrics: Dict[str, BaseMetricHandler],
test_metrics: Dict[str, BaseMetricHandler],
):
super().__init__()
self.train_metrics = nn.ModuleDict(train_metrics)
self.val_metrics = nn.ModuleDict(val_metrics)
self.test_metrics = nn.ModuleDict(test_metrics)
def get_mode(self, mode: OperationMode) -> BaseMetricHandler:
if mode == OperationMode.TRAIN:
return self.train_metrics
elif mode == OperationMode.VAL:
return self.val_metrics
elif mode == OperationMode.TEST:
return self.test_metrics
else:
raise ValueError(f"Unknown mode: {mode}")
|