Spaces:
Runtime error
Runtime error
| # from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import functools | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| import wandb | |
| from typing import Any, Dict, Union | |
| import torch | |
| from .distributed import get_rank, is_main_process | |
| from termcolor import colored | |
| def log_dict_to_wandb(log_dict, step, prefix=""): | |
| """include a separator `/` at the end of `prefix`""" | |
| if not is_main_process(): | |
| return | |
| log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()} | |
| wandb.log(log_dict, step) | |
| def setup_wandb(config): | |
| if not (config.wandb.enable and is_main_process()): | |
| return | |
| run = wandb.init( | |
| config=config, | |
| project=config.wandb.project, | |
| entity=config.wandb.entity, | |
| name=os.path.basename(config.output_dir), | |
| reinit=True | |
| ) | |
| return run | |
| def setup_output_folder(save_dir: str, folder_only: bool = False): | |
| """Sets up and returns the output file where the logs will be placed | |
| based on the configuration passed. Usually "save_dir/logs/log_<timestamp>.txt". | |
| If env.log_dir is passed, logs will be directly saved in this folder. | |
| Args: | |
| folder_only (bool, optional): If folder should be returned and not the file. | |
| Defaults to False. | |
| Returns: | |
| str: folder or file path depending on folder_only flag | |
| """ | |
| log_filename = "train_" | |
| log_filename += time.strftime("%Y_%m_%dT%H_%M_%S") | |
| log_filename += ".log" | |
| log_folder = os.path.join(save_dir, "logs") | |
| if not os.path.exists(log_folder): | |
| os.path.mkdirs(log_folder) | |
| if folder_only: | |
| return log_folder | |
| log_filename = os.path.join(log_folder, log_filename) | |
| return log_filename | |
| def setup_logger( | |
| output: str = None, | |
| color: bool = True, | |
| name: str = "mmf", | |
| disable: bool = False, | |
| clear_handlers=True, | |
| *args, | |
| **kwargs, | |
| ): | |
| """ | |
| Initialize the MMF logger and set its verbosity level to "INFO". | |
| Outside libraries shouldn't call this in case they have set there | |
| own logging handlers and setup. If they do, and don't want to | |
| clear handlers, pass clear_handlers options. | |
| The initial version of this function was taken from D2 and adapted | |
| for MMF. | |
| Args: | |
| output (str): a file name or a directory to save log. | |
| If ends with ".txt" or ".log", assumed to be a file name. | |
| Default: Saved to file <save_dir/logs/log_[timestamp].txt> | |
| color (bool): If false, won't log colored logs. Default: true | |
| name (str): the root module name of this logger. Defaults to "mmf". | |
| disable: do not use | |
| clear_handlers (bool): If false, won't clear existing handlers. | |
| Returns: | |
| logging.Logger: a logger | |
| """ | |
| if disable: | |
| return None | |
| logger = logging.getLogger(name) | |
| logger.propagate = False | |
| logging.captureWarnings(True) | |
| warnings_logger = logging.getLogger("py.warnings") | |
| plain_formatter = logging.Formatter( | |
| "%(asctime)s | %(levelname)s | %(name)s : %(message)s", | |
| datefmt="%Y-%m-%dT%H:%M:%S", | |
| ) | |
| distributed_rank = get_rank() | |
| handlers = [] | |
| logging_level = logging.INFO | |
| # logging_level = logging.DEBUG | |
| if distributed_rank == 0: | |
| logger.setLevel(logging_level) | |
| ch = logging.StreamHandler(stream=sys.stdout) | |
| ch.setLevel(logging_level) | |
| if color: | |
| formatter = ColorfulFormatter( | |
| colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", | |
| datefmt="%Y-%m-%dT%H:%M:%S", | |
| ) | |
| else: | |
| formatter = plain_formatter | |
| ch.setFormatter(formatter) | |
| logger.addHandler(ch) | |
| warnings_logger.addHandler(ch) | |
| handlers.append(ch) | |
| # file logging: all workers | |
| if output is None: | |
| output = setup_output_folder() | |
| if output is not None: | |
| if output.endswith(".txt") or output.endswith(".log"): | |
| filename = output | |
| else: | |
| filename = os.path.join(output, "train.log") | |
| if distributed_rank > 0: | |
| filename = filename + f".rank{distributed_rank}" | |
| os.makedirs(os.path.dirname(filename), exist_ok=True) | |
| fh = logging.StreamHandler(_cached_log_stream(filename)) | |
| fh.setLevel(logging_level) | |
| fh.setFormatter(plain_formatter) | |
| logger.addHandler(fh) | |
| warnings_logger.addHandler(fh) | |
| handlers.append(fh) | |
| # Slurm/FB output, only log the main process | |
| # save_dir = get_mmf_env(key="save_dir") | |
| if "train.log" not in filename and distributed_rank == 0: | |
| filename = os.path.join(output, "train.log") | |
| sh = logging.StreamHandler(_cached_log_stream(filename)) | |
| sh.setLevel(logging_level) | |
| sh.setFormatter(plain_formatter) | |
| logger.addHandler(sh) | |
| warnings_logger.addHandler(sh) | |
| handlers.append(sh) | |
| logger.info(f"Logging to: {filename}") | |
| # Remove existing handlers to add MMF specific handlers | |
| if clear_handlers: | |
| for handler in logging.root.handlers[:]: | |
| logging.root.removeHandler(handler) | |
| # Now, add our handlers. | |
| logging.basicConfig(level=logging_level, handlers=handlers) | |
| return logger | |
| def setup_very_basic_config(color=True): | |
| plain_formatter = logging.Formatter( | |
| "%(asctime)s | %(levelname)s | %(name)s : %(message)s", | |
| datefmt="%Y-%m-%dT%H:%M:%S", | |
| ) | |
| ch = logging.StreamHandler(stream=sys.stdout) | |
| ch.setLevel(logging.INFO) | |
| if color: | |
| formatter = ColorfulFormatter( | |
| colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", | |
| datefmt="%Y-%m-%dT%H:%M:%S", | |
| ) | |
| else: | |
| formatter = plain_formatter | |
| ch.setFormatter(formatter) | |
| # Setup a minimal configuration for logging in case something tries to | |
| # log a message even before logging is setup by MMF. | |
| logging.basicConfig(level=logging.INFO, handlers=[ch]) | |
| # cache the opened file object, so that different calls to `setup_logger` | |
| # with the same file name can safely write to the same file. | |
| def _cached_log_stream(filename): | |
| return open(filename, "a") | |
| # ColorfulFormatter is adopted from Detectron2 and adapted for MMF | |
| class ColorfulFormatter(logging.Formatter): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def formatMessage(self, record): | |
| log = super().formatMessage(record) | |
| if record.levelno == logging.WARNING: | |
| prefix = colored("WARNING", "red", attrs=["blink"]) | |
| elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: | |
| prefix = colored("ERROR", "red", attrs=["blink", "underline"]) | |
| else: | |
| return log | |
| return prefix + " " + log | |
| class TensorboardLogger: | |
| def __init__(self, log_folder="./logs", iteration=0): | |
| # This would handle warning of missing tensorboard | |
| from torch.utils.tensorboard import SummaryWriter | |
| self.summary_writer = None | |
| self._is_master = is_main_process() | |
| # self.timer = Timer() | |
| self.log_folder = log_folder | |
| if self._is_master: | |
| # current_time = self.timer.get_time_hhmmss(None, format=self.time_format) | |
| current_time = time.strftime("%Y-%m-%dT%H:%M:%S") | |
| # self.timer.get_time_hhmmss(None, format=self.time_format) | |
| tensorboard_folder = os.path.join( | |
| self.log_folder, f"tensorboard_{current_time}" | |
| ) | |
| self.summary_writer = SummaryWriter(tensorboard_folder) | |
| def __del__(self): | |
| if getattr(self, "summary_writer", None) is not None: | |
| self.summary_writer.close() | |
| def _should_log_tensorboard(self): | |
| if self.summary_writer is None or not self._is_master: | |
| return False | |
| else: | |
| return True | |
| def add_scalar(self, key, value, iteration): | |
| if not self._should_log_tensorboard(): | |
| return | |
| self.summary_writer.add_scalar(key, value, iteration) | |
| def add_scalars(self, scalar_dict, iteration): | |
| if not self._should_log_tensorboard(): | |
| return | |
| for key, val in scalar_dict.items(): | |
| self.summary_writer.add_scalar(key, val, iteration) | |
| def add_histogram_for_model(self, model, iteration): | |
| if not self._should_log_tensorboard(): | |
| return | |
| for name, param in model.named_parameters(): | |
| np_param = param.clone().cpu().data.numpy() | |
| self.summary_writer.add_histogram(name, np_param, iteration) | |