|
|
import logging |
|
|
import os |
|
|
from typing import TYPE_CHECKING, Union |
|
|
|
|
|
from .constants import FINETRAINERS_LOG_LEVEL |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from .parallel import ParallelBackendType |
|
|
|
|
|
|
|
|
class FinetrainersLoggerAdapter(logging.LoggerAdapter): |
|
|
def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None: |
|
|
super().__init__(logger, {}) |
|
|
self.parallel_backend = parallel_backend |
|
|
self._log_freq = {} |
|
|
self._log_freq_counter = {} |
|
|
|
|
|
def log( |
|
|
self, |
|
|
level, |
|
|
msg, |
|
|
*args, |
|
|
main_process_only: bool = False, |
|
|
local_main_process_only: bool = True, |
|
|
in_order: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
kwargs.setdefault("stacklevel", 2) |
|
|
|
|
|
if not self.isEnabledFor(level): |
|
|
return |
|
|
|
|
|
if self.parallel_backend is None: |
|
|
if int(os.environ.get("RANK", 0)) == 0: |
|
|
msg, kwargs = self.process(msg, kwargs) |
|
|
self.logger.log(level, msg, *args, **kwargs) |
|
|
return |
|
|
|
|
|
if (main_process_only or local_main_process_only) and in_order: |
|
|
raise ValueError( |
|
|
"Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True." |
|
|
) |
|
|
|
|
|
if (main_process_only and self.parallel_backend.is_main_process) or ( |
|
|
local_main_process_only and self.parallel_backend.is_local_main_process |
|
|
): |
|
|
msg, kwargs = self.process(msg, kwargs) |
|
|
self.logger.log(level, msg, *args, **kwargs) |
|
|
return |
|
|
|
|
|
if in_order: |
|
|
for i in range(self.parallel_backend.world_size): |
|
|
if self.rank == i: |
|
|
msg, kwargs = self.process(msg, kwargs) |
|
|
self.logger.log(level, msg, *args, **kwargs) |
|
|
self.parallel_backend.wait_for_everyone() |
|
|
return |
|
|
|
|
|
if not main_process_only and not local_main_process_only: |
|
|
msg, kwargs = self.process(msg, kwargs) |
|
|
self.logger.log(level, msg, *args, **kwargs) |
|
|
return |
|
|
|
|
|
def log_freq( |
|
|
self, |
|
|
level: str, |
|
|
name: str, |
|
|
msg: str, |
|
|
frequency: int, |
|
|
*, |
|
|
main_process_only: bool = False, |
|
|
local_main_process_only: bool = True, |
|
|
in_order: bool = False, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
if frequency <= 0: |
|
|
return |
|
|
if name not in self._log_freq_counter: |
|
|
self._log_freq[name] = frequency |
|
|
self._log_freq_counter[name] = 0 |
|
|
if self._log_freq_counter[name] % self._log_freq[name] == 0: |
|
|
self.log( |
|
|
level, |
|
|
msg, |
|
|
main_process_only=main_process_only, |
|
|
local_main_process_only=local_main_process_only, |
|
|
in_order=in_order, |
|
|
**kwargs, |
|
|
) |
|
|
self._log_freq_counter[name] += 1 |
|
|
|
|
|
|
|
|
def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]: |
|
|
global _logger |
|
|
return _logger |
|
|
|
|
|
|
|
|
def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter: |
|
|
_logger.parallel_backend = parallel_backend |
|
|
|
|
|
|
|
|
_logger = logging.getLogger("finetrainers") |
|
|
_logger.setLevel(FINETRAINERS_LOG_LEVEL) |
|
|
_console_handler = logging.StreamHandler() |
|
|
_console_handler.setLevel(FINETRAINERS_LOG_LEVEL) |
|
|
_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
|
_console_handler.setFormatter(_formatter) |
|
|
_logger.addHandler(_console_handler) |
|
|
_logger = FinetrainersLoggerAdapter(_logger) |
|
|
|