|
|
import os |
|
|
from typing import Tuple |
|
|
|
|
|
from accelerate.logging import get_logger |
|
|
|
|
|
from ..constants import FINETRAINERS_LOG_LEVEL |
|
|
from ..utils.file_utils import delete_files, find_files |
|
|
|
|
|
|
|
|
logger = get_logger("finetrainers") |
|
|
logger.setLevel(FINETRAINERS_LOG_LEVEL) |
|
|
|
|
|
|
|
|
def get_latest_ckpt_path_to_resume_from( |
|
|
resume_from_checkpoint: str, num_update_steps_per_epoch: int, output_dir: str |
|
|
) -> Tuple[str, int, int, int]: |
|
|
if not resume_from_checkpoint: |
|
|
initial_global_step = 0 |
|
|
global_step = 0 |
|
|
first_epoch = 0 |
|
|
resume_from_checkpoint_path = None |
|
|
else: |
|
|
if resume_from_checkpoint != "latest": |
|
|
path = os.path.basename(resume_from_checkpoint) |
|
|
else: |
|
|
|
|
|
dirs = os.listdir(output_dir) |
|
|
dirs = [d for d in dirs if d.startswith("checkpoint")] |
|
|
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) |
|
|
path = dirs[-1] if len(dirs) > 0 else None |
|
|
|
|
|
if path is None: |
|
|
logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.") |
|
|
resume_from_checkpoint = None |
|
|
initial_global_step = 0 |
|
|
global_step = 0 |
|
|
first_epoch = 0 |
|
|
resume_from_checkpoint_path = None |
|
|
else: |
|
|
logger.info(f"Resuming from checkpoint {path}") |
|
|
resume_from_checkpoint_path = os.path.join(output_dir, path) |
|
|
global_step = int(path.split("-")[1]) |
|
|
|
|
|
initial_global_step = global_step |
|
|
first_epoch = global_step // num_update_steps_per_epoch |
|
|
|
|
|
return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch |
|
|
|
|
|
|
|
|
def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str: |
|
|
|
|
|
if checkpointing_limit is not None: |
|
|
checkpoints = find_files(output_dir, prefix="checkpoint") |
|
|
|
|
|
|
|
|
if len(checkpoints) >= checkpointing_limit: |
|
|
num_to_remove = len(checkpoints) - checkpointing_limit + 1 |
|
|
checkpoints_to_remove = [os.path.join(output_dir, x) for x in checkpoints[0:num_to_remove]] |
|
|
delete_files(checkpoints_to_remove) |
|
|
|
|
|
logger.info(f"Checkpointing at step {step}") |
|
|
save_path = os.path.join(output_dir, f"checkpoint-{step}") |
|
|
logger.info(f"Saving state to {save_path}") |
|
|
return save_path |
|
|
|