Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import bisect | |
| import time | |
| from collections import OrderedDict | |
| from typing import Dict, Optional | |
| try: | |
| import torch | |
| def type_as(a, b): | |
| if torch.is_tensor(a) and torch.is_tensor(b): | |
| return a.to(b) | |
| else: | |
| return a | |
| except ImportError: | |
| torch = None | |
| def type_as(a, b): | |
| return a | |
| try: | |
| import numpy as np | |
| except ImportError: | |
| np = None | |
| class Meter(object): | |
| """Base class for Meters.""" | |
| def __init__(self): | |
| pass | |
| def state_dict(self): | |
| return {} | |
| def load_state_dict(self, state_dict): | |
| pass | |
| def reset(self): | |
| raise NotImplementedError | |
| def smoothed_value(self) -> float: | |
| """Smoothed value used for logging.""" | |
| raise NotImplementedError | |
| def safe_round(number, ndigits): | |
| if hasattr(number, "__round__"): | |
| return round(number, ndigits) | |
| elif torch is not None and torch.is_tensor(number) and number.numel() == 1: | |
| return safe_round(number.item(), ndigits) | |
| elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"): | |
| return safe_round(number.item(), ndigits) | |
| else: | |
| return number | |
| class AverageMeter(Meter): | |
| """Computes and stores the average and current value""" | |
| def __init__(self, round: Optional[int] = None): | |
| self.round = round | |
| self.reset() | |
| def reset(self): | |
| self.val = None # most recent update | |
| self.sum = 0 # sum from all updates | |
| self.count = 0 # total n from all updates | |
| def update(self, val, n=1): | |
| if val is not None: | |
| self.val = val | |
| if n > 0: | |
| self.sum = type_as(self.sum, val) + (val * n) | |
| self.count = type_as(self.count, n) + n | |
| def state_dict(self): | |
| return { | |
| "val": self.val, | |
| "sum": self.sum, | |
| "count": self.count, | |
| "round": self.round, | |
| } | |
| def load_state_dict(self, state_dict): | |
| self.val = state_dict["val"] | |
| self.sum = state_dict["sum"] | |
| self.count = state_dict["count"] | |
| self.round = state_dict.get("round", None) | |
| def avg(self): | |
| return self.sum / self.count if self.count > 0 else self.val | |
| def smoothed_value(self) -> float: | |
| val = self.avg | |
| if self.round is not None and val is not None: | |
| val = safe_round(val, self.round) | |
| return val | |
| class SumMeter(Meter): | |
| """Computes and stores the sum""" | |
| def __init__(self, round: Optional[int] = None): | |
| self.round = round | |
| self.reset() | |
| def reset(self): | |
| self.sum = 0 # sum from all updates | |
| def update(self, val): | |
| if val is not None: | |
| self.sum = type_as(self.sum, val) + val | |
| def state_dict(self): | |
| return { | |
| "sum": self.sum, | |
| "round": self.round, | |
| } | |
| def load_state_dict(self, state_dict): | |
| self.sum = state_dict["sum"] | |
| self.round = state_dict.get("round", None) | |
| def smoothed_value(self) -> float: | |
| val = self.sum | |
| if self.round is not None and val is not None: | |
| val = safe_round(val, self.round) | |
| return val | |
| class ConcatTensorMeter(Meter): | |
| """Concatenates tensors""" | |
| def __init__(self, dim=0): | |
| super().__init__() | |
| self.reset() | |
| self.dim = dim | |
| def reset(self): | |
| self.tensor = None | |
| def update(self, val): | |
| if self.tensor is None: | |
| self.tensor = val | |
| else: | |
| self.tensor = torch.cat([self.tensor, val], dim=self.dim) | |
| def state_dict(self): | |
| return { | |
| "tensor": self.tensor, | |
| } | |
| def load_state_dict(self, state_dict): | |
| self.tensor = state_dict["tensor"] | |
| def smoothed_value(self) -> float: | |
| return [] # return a dummy value | |
| class TimeMeter(Meter): | |
| """Computes the average occurrence of some event per second""" | |
| def __init__( | |
| self, | |
| init: int = 0, | |
| n: int = 0, | |
| round: Optional[int] = None, | |
| ): | |
| self.round = round | |
| self.reset(init, n) | |
| def reset(self, init=0, n=0): | |
| self.init = init | |
| self.start = time.perf_counter() | |
| self.n = n | |
| self.i = 0 | |
| def update(self, val=1): | |
| self.n = type_as(self.n, val) + val | |
| self.i += 1 | |
| def state_dict(self): | |
| return { | |
| "init": self.elapsed_time, | |
| "n": self.n, | |
| "round": self.round, | |
| } | |
| def load_state_dict(self, state_dict): | |
| if "start" in state_dict: | |
| # backwards compatibility for old state_dicts | |
| self.reset(init=state_dict["init"]) | |
| else: | |
| self.reset(init=state_dict["init"], n=state_dict["n"]) | |
| self.round = state_dict.get("round", None) | |
| def avg(self): | |
| return self.n / self.elapsed_time | |
| def elapsed_time(self): | |
| return self.init + (time.perf_counter() - self.start) | |
| def smoothed_value(self) -> float: | |
| val = self.avg | |
| if self.round is not None and val is not None: | |
| val = safe_round(val, self.round) | |
| return val | |
| class StopwatchMeter(Meter): | |
| """Computes the sum/avg duration of some event in seconds""" | |
| def __init__(self, round: Optional[int] = None): | |
| self.round = round | |
| self.sum = 0 | |
| self.n = 0 | |
| self.start_time = None | |
| def start(self): | |
| self.start_time = time.perf_counter() | |
| def stop(self, n=1, prehook=None): | |
| if self.start_time is not None: | |
| if prehook is not None: | |
| prehook() | |
| delta = time.perf_counter() - self.start_time | |
| self.sum = self.sum + delta | |
| self.n = type_as(self.n, n) + n | |
| def reset(self): | |
| self.sum = 0 # cumulative time during which stopwatch was active | |
| self.n = 0 # total n across all start/stop | |
| self.start() | |
| def state_dict(self): | |
| return { | |
| "sum": self.sum, | |
| "n": self.n, | |
| "round": self.round, | |
| } | |
| def load_state_dict(self, state_dict): | |
| self.sum = state_dict["sum"] | |
| self.n = state_dict["n"] | |
| self.start_time = None | |
| self.round = state_dict.get("round", None) | |
| def avg(self): | |
| return self.sum / self.n if self.n > 0 else self.sum | |
| def elapsed_time(self): | |
| if self.start_time is None: | |
| return 0.0 | |
| return time.perf_counter() - self.start_time | |
| def smoothed_value(self) -> float: | |
| val = self.avg if self.sum > 0 else self.elapsed_time | |
| if self.round is not None and val is not None: | |
| val = safe_round(val, self.round) | |
| return val | |
| class MetersDict(OrderedDict): | |
| """A sorted dictionary of :class:`Meters`. | |
| Meters are sorted according to a priority that is given when the | |
| meter is first added to the dictionary. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.priorities = [] | |
| def __setitem__(self, key, value): | |
| assert key not in self, "MetersDict doesn't support reassignment" | |
| priority, value = value | |
| bisect.insort(self.priorities, (priority, len(self.priorities), key)) | |
| super().__setitem__(key, value) | |
| for _, _, key in self.priorities: # reorder dict to match priorities | |
| self.move_to_end(key) | |
| def add_meter(self, key, meter, priority): | |
| self.__setitem__(key, (priority, meter)) | |
| def state_dict(self): | |
| return [ | |
| (pri, key, self[key].__class__.__name__, self[key].state_dict()) | |
| for pri, _, key in self.priorities | |
| # can't serialize DerivedMeter instances | |
| if not isinstance(self[key], MetersDict._DerivedMeter) | |
| ] | |
| def load_state_dict(self, state_dict): | |
| self.clear() | |
| self.priorities.clear() | |
| for pri, key, meter_cls, meter_state in state_dict: | |
| meter = globals()[meter_cls]() | |
| meter.load_state_dict(meter_state) | |
| self.add_meter(key, meter, pri) | |
| def get_smoothed_value(self, key: str) -> float: | |
| """Get a single smoothed value.""" | |
| meter = self[key] | |
| if isinstance(meter, MetersDict._DerivedMeter): | |
| return meter.fn(self) | |
| else: | |
| return meter.smoothed_value | |
| def get_smoothed_values(self) -> Dict[str, float]: | |
| """Get all smoothed values.""" | |
| return OrderedDict( | |
| [ | |
| (key, self.get_smoothed_value(key)) | |
| for key in self.keys() | |
| if not key.startswith("_") | |
| ] | |
| ) | |
| def reset(self): | |
| """Reset Meter instances.""" | |
| for meter in self.values(): | |
| if isinstance(meter, MetersDict._DerivedMeter): | |
| continue | |
| meter.reset() | |
| class _DerivedMeter(Meter): | |
| """A Meter whose values are derived from other Meters.""" | |
| def __init__(self, fn): | |
| self.fn = fn | |
| def reset(self): | |
| pass | |