|
|
import functools |
|
|
import pathlib |
|
|
import shutil |
|
|
import time |
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.distributed.checkpoint |
|
|
from torch.distributed.checkpoint.state_dict import ( |
|
|
StateDictOptions, |
|
|
get_model_state_dict, |
|
|
set_model_state_dict, |
|
|
) |
|
|
from torch.distributed.checkpoint.stateful import Stateful |
|
|
|
|
|
from ..logging import get_logger |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from .. import optimizer |
|
|
|
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class ModelWrapper(Stateful): |
|
|
def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None: |
|
|
self.model = [model] if isinstance(model, torch.nn.Module) else model |
|
|
|
|
|
def state_dict(self) -> Dict[str, Any]: |
|
|
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()} |
|
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
|
|
func = functools.partial( |
|
|
set_model_state_dict, |
|
|
model_state_dict=state_dict, |
|
|
options=StateDictOptions(strict=False), |
|
|
) |
|
|
list(map(func, self.model)) |
|
|
|
|
|
|
|
|
class PTDCheckpointManager: |
|
|
def __init__( |
|
|
self, |
|
|
dataloader: torch.utils.data.DataLoader, |
|
|
model_parts: List[torch.nn.Module], |
|
|
optimizers: "optimizer.OptimizerWrapper", |
|
|
schedulers: "optimizer.SchedulerWrapper", |
|
|
states: Dict[str, Any], |
|
|
checkpointing_steps: int, |
|
|
checkpointing_limit: int, |
|
|
output_dir: str, |
|
|
enable: bool = True, |
|
|
_callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, |
|
|
_prefix: str = "finetrainers_step", |
|
|
) -> None: |
|
|
self.states = states |
|
|
self.states.update( |
|
|
{ |
|
|
"model": ModelWrapper(model_parts), |
|
|
"optimizer": optimizers, |
|
|
"dataloader": dataloader, |
|
|
} |
|
|
) |
|
|
self.states.update(schedulers.get_lr_scheduler_state()) |
|
|
|
|
|
self.checkpointing_steps = checkpointing_steps |
|
|
self.checkpointing_limit = checkpointing_limit |
|
|
self.output_dir = pathlib.Path(output_dir) |
|
|
self.enable = enable |
|
|
self._callback_fn = _callback_fn |
|
|
self._prefix = _prefix |
|
|
|
|
|
logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'") |
|
|
|
|
|
def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str: |
|
|
if not self._should_checkpoint(step, force): |
|
|
return None |
|
|
|
|
|
checkpoint_dir = self._get_checkpoint_dir(step) |
|
|
begin_time = time.monotonic() |
|
|
torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix()) |
|
|
end_time = time.monotonic() |
|
|
logger.info( |
|
|
f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}" |
|
|
) |
|
|
self._purge_stale_checkpoints() |
|
|
|
|
|
state_dicts = [ |
|
|
gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process) |
|
|
for model in self.states["model"].model |
|
|
] |
|
|
if self._callback_fn is not None: |
|
|
list(map(self._callback_fn, state_dicts)) |
|
|
|
|
|
return checkpoint_dir.as_posix() |
|
|
|
|
|
def load(self, step: int = -1) -> bool: |
|
|
if not self.enable: |
|
|
return False |
|
|
if not self.output_dir.exists(): |
|
|
return False |
|
|
if step != -1 and not self._get_checkpoint_dir(step).exists(): |
|
|
return False |
|
|
|
|
|
if step == -1: |
|
|
latest_checkpoint_dir = self._find_latest_checkpoint_dir() |
|
|
if latest_checkpoint_dir is None: |
|
|
return False |
|
|
step = int(latest_checkpoint_dir.name.split("_")[-1]) |
|
|
|
|
|
checkpoint_dir = self._get_checkpoint_dir(step) |
|
|
logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}") |
|
|
|
|
|
|
|
|
states = {"model": self.states["model"]} if step == 0 else self.states |
|
|
|
|
|
|
|
|
original_stateful_states = {k: v for k, v in states.items() if isinstance(v, Stateful)} |
|
|
begin_time = time.monotonic() |
|
|
torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix()) |
|
|
end_time = time.monotonic() |
|
|
logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.") |
|
|
|
|
|
|
|
|
states.update(original_stateful_states) |
|
|
|
|
|
return True |
|
|
|
|
|
def _should_checkpoint(self, step: int, force: bool) -> bool: |
|
|
if not self.enable: |
|
|
return False |
|
|
|
|
|
if not force: |
|
|
if step % self.checkpointing_steps != 0: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def _get_checkpoint_dir(self, step: int) -> pathlib.Path: |
|
|
return self.output_dir / f"{self._prefix}_{step}" |
|
|
|
|
|
def _find_latest_checkpoint_dir(self) -> Union[pathlib.Path, None]: |
|
|
checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1])) |
|
|
return checkpoints[-1] if len(checkpoints) > 0 else None |
|
|
|
|
|
def _purge_stale_checkpoints(self) -> None: |
|
|
if self.checkpointing_limit is None or self.checkpointing_limit <= 0: |
|
|
return |
|
|
checkpoints = sorted( |
|
|
self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True |
|
|
) |
|
|
for checkpoint in checkpoints[self.checkpointing_limit :]: |
|
|
logger.info(f"Deleting stale checkpoint: {checkpoint}") |
|
|
shutil.rmtree(checkpoint, ignore_errors=True) |
|
|
|
|
|
|
|
|
def gather_state_dict_on_cpu_rank0( |
|
|
model, device: Optional[torch.device] = None, *, is_main_process: bool |
|
|
) -> Dict[str, Any]: |
|
|
cpu_state_dict = {} |
|
|
sharded_sd = model.state_dict() |
|
|
for param_name, param in sharded_sd.items(): |
|
|
if param.is_cpu: |
|
|
|
|
|
param = param.to(device) |
|
|
if hasattr(param, "_local_tensor"): |
|
|
|
|
|
param = param.full_tensor() |
|
|
if is_main_process: |
|
|
cpu_state_dict[param_name] = param.cpu() |
|
|
torch.distributed.barrier() |
|
|
return cpu_state_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|