Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Wrapper around various loggers and progress bars (e.g., tqdm). | |
| """ | |
| import atexit | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| from collections import OrderedDict | |
| from contextlib import contextmanager | |
| from numbers import Number | |
| from typing import Optional | |
| import torch | |
| from .meters import AverageMeter, StopwatchMeter, TimeMeter | |
| logger = logging.getLogger(__name__) | |
| def progress_bar( | |
| iterator, | |
| log_format: Optional[str] = None, | |
| log_interval: int = 100, | |
| log_file: Optional[str] = None, | |
| epoch: Optional[int] = None, | |
| prefix: Optional[str] = None, | |
| aim_repo: Optional[str] = None, | |
| aim_run_hash: Optional[str] = None, | |
| aim_param_checkpoint_dir: Optional[str] = None, | |
| tensorboard_logdir: Optional[str] = None, | |
| default_log_format: str = "tqdm", | |
| wandb_project: Optional[str] = None, | |
| wandb_run_name: Optional[str] = None, | |
| azureml_logging: Optional[bool] = False, | |
| ): | |
| if log_format is None: | |
| log_format = default_log_format | |
| if log_file is not None: | |
| handler = logging.FileHandler(filename=log_file) | |
| logger.addHandler(handler) | |
| if log_format == "tqdm" and not sys.stderr.isatty(): | |
| log_format = "simple" | |
| if log_format == "json": | |
| bar = JsonProgressBar(iterator, epoch, prefix, log_interval) | |
| elif log_format == "none": | |
| bar = NoopProgressBar(iterator, epoch, prefix) | |
| elif log_format == "simple": | |
| bar = SimpleProgressBar(iterator, epoch, prefix, log_interval) | |
| elif log_format == "tqdm": | |
| bar = TqdmProgressBar(iterator, epoch, prefix) | |
| else: | |
| raise ValueError("Unknown log format: {}".format(log_format)) | |
| if aim_repo: | |
| bar = AimProgressBarWrapper( | |
| bar, | |
| aim_repo=aim_repo, | |
| aim_run_hash=aim_run_hash, | |
| aim_param_checkpoint_dir=aim_param_checkpoint_dir, | |
| ) | |
| if tensorboard_logdir: | |
| try: | |
| # [FB only] custom wrapper for TensorBoard | |
| import palaas # noqa | |
| from .fb_tbmf_wrapper import FbTbmfWrapper | |
| bar = FbTbmfWrapper(bar, log_interval) | |
| except ImportError: | |
| bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) | |
| if wandb_project: | |
| bar = WandBProgressBarWrapper(bar, wandb_project, run_name=wandb_run_name) | |
| if azureml_logging: | |
| bar = AzureMLProgressBarWrapper(bar) | |
| return bar | |
| def build_progress_bar( | |
| args, | |
| iterator, | |
| epoch: Optional[int] = None, | |
| prefix: Optional[str] = None, | |
| default: str = "tqdm", | |
| no_progress_bar: str = "none", | |
| ): | |
| """Legacy wrapper that takes an argparse.Namespace.""" | |
| if getattr(args, "no_progress_bar", False): | |
| default = no_progress_bar | |
| if getattr(args, "distributed_rank", 0) == 0: | |
| tensorboard_logdir = getattr(args, "tensorboard_logdir", None) | |
| else: | |
| tensorboard_logdir = None | |
| return progress_bar( | |
| iterator, | |
| log_format=args.log_format, | |
| log_interval=args.log_interval, | |
| epoch=epoch, | |
| prefix=prefix, | |
| tensorboard_logdir=tensorboard_logdir, | |
| default_log_format=default, | |
| ) | |
| def format_stat(stat): | |
| if isinstance(stat, Number): | |
| stat = "{:g}".format(stat) | |
| elif isinstance(stat, AverageMeter): | |
| stat = "{:.3f}".format(stat.avg) | |
| elif isinstance(stat, TimeMeter): | |
| stat = "{:g}".format(round(stat.avg)) | |
| elif isinstance(stat, StopwatchMeter): | |
| stat = "{:g}".format(round(stat.sum)) | |
| elif torch.is_tensor(stat): | |
| stat = stat.tolist() | |
| return stat | |
| class BaseProgressBar(object): | |
| """Abstract class for progress bars.""" | |
| def __init__(self, iterable, epoch=None, prefix=None): | |
| self.iterable = iterable | |
| self.n = getattr(iterable, "n", 0) | |
| self.epoch = epoch | |
| self.prefix = "" | |
| if epoch is not None: | |
| self.prefix += "epoch {:03d}".format(epoch) | |
| if prefix is not None: | |
| self.prefix += (" | " if self.prefix != "" else "") + prefix | |
| def __len__(self): | |
| return len(self.iterable) | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, *exc): | |
| return False | |
| def __iter__(self): | |
| raise NotImplementedError | |
| def log(self, stats, tag=None, step=None): | |
| """Log intermediate stats according to log_interval.""" | |
| raise NotImplementedError | |
| def print(self, stats, tag=None, step=None): | |
| """Print end-of-epoch stats.""" | |
| raise NotImplementedError | |
| def update_config(self, config): | |
| """Log latest configuration.""" | |
| pass | |
| def _str_commas(self, stats): | |
| return ", ".join(key + "=" + stats[key].strip() for key in stats.keys()) | |
| def _str_pipes(self, stats): | |
| return " | ".join(key + " " + stats[key].strip() for key in stats.keys()) | |
| def _format_stats(self, stats): | |
| postfix = OrderedDict(stats) | |
| # Preprocess stats according to datatype | |
| for key in postfix.keys(): | |
| postfix[key] = str(format_stat(postfix[key])) | |
| return postfix | |
| def rename_logger(logger, new_name): | |
| old_name = logger.name | |
| if new_name is not None: | |
| logger.name = new_name | |
| yield logger | |
| logger.name = old_name | |
| class JsonProgressBar(BaseProgressBar): | |
| """Log output in JSON format.""" | |
| def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): | |
| super().__init__(iterable, epoch, prefix) | |
| self.log_interval = log_interval | |
| self.i = None | |
| self.size = None | |
| def __iter__(self): | |
| self.size = len(self.iterable) | |
| for i, obj in enumerate(self.iterable, start=self.n): | |
| self.i = i | |
| yield obj | |
| def log(self, stats, tag=None, step=None): | |
| """Log intermediate stats according to log_interval.""" | |
| step = step or self.i or 0 | |
| if step > 0 and self.log_interval is not None and step % self.log_interval == 0: | |
| update = ( | |
| self.epoch - 1 + (self.i + 1) / float(self.size) | |
| if self.epoch is not None | |
| else None | |
| ) | |
| stats = self._format_stats(stats, epoch=self.epoch, update=update) | |
| with rename_logger(logger, tag): | |
| logger.info(json.dumps(stats)) | |
| def print(self, stats, tag=None, step=None): | |
| """Print end-of-epoch stats.""" | |
| self.stats = stats | |
| if tag is not None: | |
| self.stats = OrderedDict( | |
| [(tag + "_" + k, v) for k, v in self.stats.items()] | |
| ) | |
| stats = self._format_stats(self.stats, epoch=self.epoch) | |
| with rename_logger(logger, tag): | |
| logger.info(json.dumps(stats)) | |
| def _format_stats(self, stats, epoch=None, update=None): | |
| postfix = OrderedDict() | |
| if epoch is not None: | |
| postfix["epoch"] = epoch | |
| if update is not None: | |
| postfix["update"] = round(update, 3) | |
| # Preprocess stats according to datatype | |
| for key in stats.keys(): | |
| postfix[key] = format_stat(stats[key]) | |
| return postfix | |
| class NoopProgressBar(BaseProgressBar): | |
| """No logging.""" | |
| def __init__(self, iterable, epoch=None, prefix=None): | |
| super().__init__(iterable, epoch, prefix) | |
| def __iter__(self): | |
| for obj in self.iterable: | |
| yield obj | |
| def log(self, stats, tag=None, step=None): | |
| """Log intermediate stats according to log_interval.""" | |
| pass | |
| def print(self, stats, tag=None, step=None): | |
| """Print end-of-epoch stats.""" | |
| pass | |
| class SimpleProgressBar(BaseProgressBar): | |
| """A minimal logger for non-TTY environments.""" | |
| def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): | |
| super().__init__(iterable, epoch, prefix) | |
| self.log_interval = log_interval | |
| self.i = None | |
| self.size = None | |
| def __iter__(self): | |
| self.size = len(self.iterable) | |
| for i, obj in enumerate(self.iterable, start=self.n): | |
| self.i = i | |
| yield obj | |
| def log(self, stats, tag=None, step=None): | |
| """Log intermediate stats according to log_interval.""" | |
| step = step or self.i or 0 | |
| if step > 0 and self.log_interval is not None and step % self.log_interval == 0: | |
| stats = self._format_stats(stats) | |
| postfix = self._str_commas(stats) | |
| with rename_logger(logger, tag): | |
| logger.info( | |
| "{}: {:5d} / {:d} {}".format( | |
| self.prefix, self.i + 1, self.size, postfix | |
| ) | |
| ) | |
| def print(self, stats, tag=None, step=None): | |
| """Print end-of-epoch stats.""" | |
| postfix = self._str_pipes(self._format_stats(stats)) | |
| with rename_logger(logger, tag): | |
| logger.info("{} | {}".format(self.prefix, postfix)) | |
| class TqdmProgressBar(BaseProgressBar): | |
| """Log to tqdm.""" | |
| def __init__(self, iterable, epoch=None, prefix=None): | |
| super().__init__(iterable, epoch, prefix) | |
| from tqdm import tqdm | |
| self.tqdm = tqdm( | |
| iterable, | |
| self.prefix, | |
| leave=False, | |
| disable=(logger.getEffectiveLevel() > logging.INFO), | |
| ) | |
| def __iter__(self): | |
| return iter(self.tqdm) | |
| def log(self, stats, tag=None, step=None): | |
| """Log intermediate stats according to log_interval.""" | |
| self.tqdm.set_postfix(self._format_stats(stats), refresh=False) | |
| def print(self, stats, tag=None, step=None): | |
| """Print end-of-epoch stats.""" | |
| postfix = self._str_pipes(self._format_stats(stats)) | |
| with rename_logger(logger, tag): | |
| logger.info("{} | {}".format(self.prefix, postfix)) | |
| try: | |
| import functools | |
| from aim import Repo as AimRepo | |
| def get_aim_run(repo, run_hash): | |
| from aim import Run | |
| return Run(run_hash=run_hash, repo=repo) | |
| except ImportError: | |
| get_aim_run = None | |
| AimRepo = None | |
| class AimProgressBarWrapper(BaseProgressBar): | |
| """Log to Aim.""" | |
| def __init__(self, wrapped_bar, aim_repo, aim_run_hash, aim_param_checkpoint_dir): | |
| self.wrapped_bar = wrapped_bar | |
| if get_aim_run is None: | |
| self.run = None | |
| logger.warning("Aim not found, please install with: pip install aim") | |
| else: | |
| logger.info(f"Storing logs at Aim repo: {aim_repo}") | |
| if not aim_run_hash: | |
| # Find run based on save_dir parameter | |
| query = f"run.checkpoint.save_dir == '{aim_param_checkpoint_dir}'" | |
| try: | |
| runs_generator = AimRepo(aim_repo).query_runs(query) | |
| run = next(runs_generator.iter_runs()) | |
| aim_run_hash = run.run.hash | |
| except Exception: | |
| pass | |
| if aim_run_hash: | |
| logger.info(f"Appending to run: {aim_run_hash}") | |
| self.run = get_aim_run(aim_repo, aim_run_hash) | |
| def __iter__(self): | |
| return iter(self.wrapped_bar) | |
| def log(self, stats, tag=None, step=None): | |
| """Log intermediate stats to Aim.""" | |
| self._log_to_aim(stats, tag, step) | |
| self.wrapped_bar.log(stats, tag=tag, step=step) | |
| def print(self, stats, tag=None, step=None): | |
| """Print end-of-epoch stats.""" | |
| self._log_to_aim(stats, tag, step) | |
| self.wrapped_bar.print(stats, tag=tag, step=step) | |
| def update_config(self, config): | |
| """Log latest configuration.""" | |
| if self.run is not None: | |
| for key in config: | |
| self.run.set(key, config[key], strict=False) | |
| self.wrapped_bar.update_config(config) | |
| def _log_to_aim(self, stats, tag=None, step=None): | |
| if self.run is None: | |
| return | |
| if step is None: | |
| step = stats["num_updates"] | |
| if "train" in tag: | |
| context = {"tag": tag, "subset": "train"} | |
| elif "val" in tag: | |
| context = {"tag": tag, "subset": "val"} | |
| else: | |
| context = {"tag": tag} | |
| for key in stats.keys() - {"num_updates"}: | |
| self.run.track(stats[key], name=key, step=step, context=context) | |
| try: | |
| _tensorboard_writers = {} | |
| from torch.utils.tensorboard import SummaryWriter | |
| except ImportError: | |
| try: | |
| from tensorboardX import SummaryWriter | |
| except ImportError: | |
| SummaryWriter = None | |
| def _close_writers(): | |
| for w in _tensorboard_writers.values(): | |
| w.close() | |
| atexit.register(_close_writers) | |
| class TensorboardProgressBarWrapper(BaseProgressBar): | |
| """Log to tensorboard.""" | |
| def __init__(self, wrapped_bar, tensorboard_logdir): | |
| self.wrapped_bar = wrapped_bar | |
| self.tensorboard_logdir = tensorboard_logdir | |
| if SummaryWriter is None: | |
| logger.warning( | |
| "tensorboard not found, please install with: pip install tensorboard" | |
| ) | |
| def _writer(self, key): | |
| if SummaryWriter is None: | |
| return None | |
| _writers = _tensorboard_writers | |
| if key not in _writers: | |
| _writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key)) | |
| _writers[key].add_text("sys.argv", " ".join(sys.argv)) | |
| return _writers[key] | |
| def __iter__(self): | |
| return iter(self.wrapped_bar) | |
| def log(self, stats, tag=None, step=None): | |
| """Log intermediate stats to tensorboard.""" | |
| self._log_to_tensorboard(stats, tag, step) | |
| self.wrapped_bar.log(stats, tag=tag, step=step) | |
| def print(self, stats, tag=None, step=None): | |
| """Print end-of-epoch stats.""" | |
| self._log_to_tensorboard(stats, tag, step) | |
| self.wrapped_bar.print(stats, tag=tag, step=step) | |
| def update_config(self, config): | |
| """Log latest configuration.""" | |
| # TODO add hparams to Tensorboard | |
| self.wrapped_bar.update_config(config) | |
| def _log_to_tensorboard(self, stats, tag=None, step=None): | |
| writer = self._writer(tag or "") | |
| if writer is None: | |
| return | |
| if step is None: | |
| step = stats["num_updates"] | |
| for key in stats.keys() - {"num_updates"}: | |
| if isinstance(stats[key], AverageMeter): | |
| writer.add_scalar(key, stats[key].val, step) | |
| elif isinstance(stats[key], Number): | |
| writer.add_scalar(key, stats[key], step) | |
| elif torch.is_tensor(stats[key]) and stats[key].numel() == 1: | |
| writer.add_scalar(key, stats[key].item(), step) | |
| writer.flush() | |
| try: | |
| import wandb | |
| except ImportError: | |
| wandb = None | |
| class WandBProgressBarWrapper(BaseProgressBar): | |
| """Log to Weights & Biases.""" | |
| def __init__(self, wrapped_bar, wandb_project, run_name=None): | |
| self.wrapped_bar = wrapped_bar | |
| if wandb is None: | |
| logger.warning("wandb not found, pip install wandb") | |
| return | |
| # reinit=False to ensure if wandb.init() is called multiple times | |
| # within one process it still references the same run | |
| wandb.init(project=wandb_project, reinit=False, name=run_name) | |
| def __iter__(self): | |
| return iter(self.wrapped_bar) | |
| def log(self, stats, tag=None, step=None): | |
| """Log intermediate stats to tensorboard.""" | |
| self._log_to_wandb(stats, tag, step) | |
| self.wrapped_bar.log(stats, tag=tag, step=step) | |
| def print(self, stats, tag=None, step=None): | |
| """Print end-of-epoch stats.""" | |
| self._log_to_wandb(stats, tag, step) | |
| self.wrapped_bar.print(stats, tag=tag, step=step) | |
| def update_config(self, config): | |
| """Log latest configuration.""" | |
| if wandb is not None: | |
| wandb.config.update(config) | |
| self.wrapped_bar.update_config(config) | |
| def _log_to_wandb(self, stats, tag=None, step=None): | |
| if wandb is None: | |
| return | |
| if step is None: | |
| step = stats["num_updates"] | |
| prefix = "" if tag is None else tag + "/" | |
| for key in stats.keys() - {"num_updates"}: | |
| if isinstance(stats[key], AverageMeter): | |
| wandb.log({prefix + key: stats[key].val}, step=step) | |
| elif isinstance(stats[key], Number): | |
| wandb.log({prefix + key: stats[key]}, step=step) | |
| try: | |
| from azureml.core import Run | |
| except ImportError: | |
| Run = None | |
| class AzureMLProgressBarWrapper(BaseProgressBar): | |
| """Log to Azure ML""" | |
| def __init__(self, wrapped_bar): | |
| self.wrapped_bar = wrapped_bar | |
| if Run is None: | |
| logger.warning("azureml.core not found, pip install azureml-core") | |
| return | |
| self.run = Run.get_context() | |
| def __exit__(self, *exc): | |
| if Run is not None: | |
| self.run.complete() | |
| return False | |
| def __iter__(self): | |
| return iter(self.wrapped_bar) | |
| def log(self, stats, tag=None, step=None): | |
| """Log intermediate stats to AzureML""" | |
| self._log_to_azureml(stats, tag, step) | |
| self.wrapped_bar.log(stats, tag=tag, step=step) | |
| def print(self, stats, tag=None, step=None): | |
| """Print end-of-epoch stats""" | |
| self._log_to_azureml(stats, tag, step) | |
| self.wrapped_bar.print(stats, tag=tag, step=step) | |
| def update_config(self, config): | |
| """Log latest configuration.""" | |
| self.wrapped_bar.update_config(config) | |
| def _log_to_azureml(self, stats, tag=None, step=None): | |
| if Run is None: | |
| return | |
| if step is None: | |
| step = stats["num_updates"] | |
| prefix = "" if tag is None else tag + "/" | |
| for key in stats.keys() - {"num_updates"}: | |
| name = prefix + key | |
| if isinstance(stats[key], AverageMeter): | |
| self.run.log_row(name=name, **{"step": step, key: stats[key].val}) | |
| elif isinstance(stats[key], Number): | |
| self.run.log_row(name=name, **{"step": step, key: stats[key]}) | |