Spaces:
Running
Running
| import os | |
| from starvector.util import ( | |
| set_env_vars, | |
| flatten_dict, | |
| get_exp_id, | |
| instantiate_from_config, | |
| generate_id_name_eval, | |
| get_last_checkpoint, | |
| model_summary_table, | |
| copy_code, | |
| ) | |
| # set_env_vars() | |
| from starvector.train.util import ( | |
| save_checkpoint, | |
| get_optimizer, | |
| init_distributed_mode, | |
| setup_train_env_variables, | |
| load_fsdp_plugin, | |
| apply_gradient_checkpointing, | |
| ) | |
| import logging | |
| import math | |
| from torch.utils.data import DataLoader | |
| from transformers import get_scheduler | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import ProjectConfiguration | |
| from tqdm.auto import tqdm | |
| from omegaconf import OmegaConf | |
| import os | |
| import time | |
| from starvector.metrics.util import AverageMeter | |
| from util import save_checkpoint, get_optimizer | |
| from starvector.util import get_output_dir | |
| from starvector.model.builder import model_builder | |
| from safetensors.torch import load_file as load_safetensors | |
| from starvector.util import get_config | |
| import torch | |
| from starvector.train.util import load_checkpoint, is_deepspeed, consolidate_deepspeed_checkpoint | |
| logger = get_logger(__name__, log_level="INFO") | |
| def validate(model, dataloader, accelerator): | |
| loss_meter = AverageMeter() | |
| model.eval() | |
| pbar = tqdm(total=len(dataloader), ncols=100, desc="Processing", disable=not accelerator.is_local_main_process) | |
| with torch.no_grad(): | |
| for i, batch in enumerate(dataloader): | |
| batch_size = len(batch["image"]) | |
| loss = model(batch) | |
| loss_meter.update(loss.detach().item(), batch_size) | |
| pbar.update(1) | |
| val_loss = ( | |
| accelerator.gather(torch.tensor(loss_meter.avg).to(accelerator.device)) | |
| .float() | |
| .mean() | |
| .item() | |
| ) | |
| accelerator.wait_for_everyone() | |
| pbar.close() | |
| return val_loss | |
| def main(config=None): | |
| print(f"Experiment config: {config}") | |
| set_env_vars() | |
| exp_id = get_exp_id(config) | |
| output_dir = get_output_dir() | |
| logging_dir = os.path.join(output_dir, config.data.train.params.dataset_name, exp_id) | |
| if os.path.exists(logging_dir) and not config.training.resume_from_checkpoint: | |
| config.training.resume_from_checkpoint = get_last_checkpoint(logging_dir) | |
| config.training.continue_training = True | |
| # Flatten config dict for logging it | |
| log_config = flatten_dict(OmegaConf.to_container(config, resolve=True)) | |
| log_config['logging_dir'] = logging_dir # Add logging dir to config | |
| if config.fsdp.enable: | |
| init_distributed_mode(config) | |
| setup_train_env_variables(config) | |
| # --------------- Datasets --------------- | |
| train_dataset = instantiate_from_config(config.data.train) | |
| test_dataset = instantiate_from_config(config.data.test) | |
| train_dataloader = DataLoader(train_dataset, batch_size=config.data.train.batch_size, shuffle=True, num_workers=config.data.num_workers, pin_memory=True) | |
| test_dataloader = DataLoader(test_dataset, batch_size=config.data.test.batch_size, shuffle=False, num_workers=config.data.num_workers, pin_memory=True) | |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.training.gradient_accumulation_steps) | |
| max_train_steps = config.training.n_epochs * num_update_steps_per_epoch | |
| global_step = 0 | |
| first_epoch = 0 | |
| model = model_builder(config) | |
| # Instantiate the model, fsdp and accelerator | |
| if config.training.resume_from_checkpoint: | |
| if not config.fsdp.enable: | |
| if is_deepspeed(config.training.resume_from_checkpoint): | |
| if accelerator.is_main_process: | |
| consolidate_deepspeed_checkpoint(config.training.resume_from_checkpoint) | |
| accelerator.wait_for_everyone() | |
| model = load_checkpoint(model, config.training.resume_from_checkpoint) | |
| else: | |
| model.load_state_dict(torch.load(os.path.join(config.training.resume_from_checkpoint, "pytorch_model_fsdp.bin")), strict=False) | |
| if config.training.continue_training: | |
| global_step = int(os.path.basename(config.training.resume_from_checkpoint).split("-")[1]) | |
| resume_global_step = global_step * config.training.gradient_accumulation_steps | |
| first_epoch = global_step // num_update_steps_per_epoch | |
| resume_step = resume_global_step % (num_update_steps_per_epoch * config.training.gradient_accumulation_steps) | |
| else: | |
| global_step = 0 | |
| first_epoch = 0 | |
| resume_step = 0 | |
| print("Loaded checkpoint but not updating global step") | |
| if config.fsdp.enable: | |
| fsdp_plugin = load_fsdp_plugin(config, model) | |
| else: | |
| fsdp_plugin = None | |
| # Define accelerator | |
| kwargs_handler = None | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=config.training.gradient_accumulation_steps, | |
| mixed_precision=config.training.model_precision, | |
| log_with="wandb" if config.project.use_wandb else None, | |
| project_dir=logging_dir, | |
| project_config=ProjectConfiguration(logging_dir=logging_dir), | |
| step_scheduler_with_optimizer=False, | |
| fsdp_plugin=fsdp_plugin, | |
| kwargs_handlers=kwargs_handler | |
| ) | |
| # --------------- Logging --------------- | |
| if accelerator.is_main_process: | |
| if config.project.use_wandb: | |
| import wandb | |
| wandb.init(name=exp_id, project=config.project.project, entity=config.project.entity, config=log_config) | |
| accelerator.init_trackers( | |
| project_name=config.project.project, | |
| ) | |
| config.project.wandb_run_id = wandb.run.id | |
| else: | |
| run = os.path.split(__file__)[-1].split(".")[0] | |
| accelerator.init_trackers(run) | |
| if logging_dir is not None: | |
| os.makedirs(logging_dir, exist_ok=True) | |
| # Copy code and dependency versions | |
| if config.project.copy_code: | |
| out_dir = os.path.join(logging_dir, "code") | |
| copy_code(os.path.join(os.path.dirname(__file__), "..", ".."), out_dir) | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger.info(accelerator.state, main_process_only=True) | |
| total_batch_size = config.data.train.batch_size * accelerator.num_processes * config.training.gradient_accumulation_steps | |
| if accelerator.is_main_process and config.project.use_wandb: | |
| wandb.log({"total_batch_size": total_batch_size}) | |
| wandb.log({"num_update_steps_per_epoch": num_update_steps_per_epoch}) | |
| wandb.log({"max_train_steps": max_train_steps}) | |
| # accelerate prepare model | |
| model = accelerator.prepare(model) | |
| # activation/gradient checkpointing | |
| if config.training.use_gradient_checkpointing: | |
| print("apply gradient checkpointing") | |
| model = apply_gradient_checkpointing(model) | |
| optimizer = get_optimizer(config, model) | |
| if accelerator.is_main_process: | |
| print("Train dataset length: ", len(train_dataset)) | |
| print("Test dataset length: ", len(test_dataset)) | |
| # --------------- Training config --------------- | |
| lr_scheduler = get_scheduler( | |
| config.training.lr_scheduler, | |
| optimizer=optimizer, | |
| num_warmup_steps=config.training.lr_warmup_steps * config.training.gradient_accumulation_steps, | |
| num_training_steps= (len(train_dataloader) * config.training.n_epochs), | |
| ) | |
| optimizer, train_dataloader, test_dataloader, lr_scheduler = accelerator.prepare( | |
| optimizer, train_dataloader, test_dataloader, lr_scheduler | |
| ) | |
| loss_meter = AverageMeter() | |
| if accelerator.is_main_process: | |
| model_summary_table(model) | |
| if not os.path.exists(os.path.join(logging_dir, 'config.yaml')): | |
| with open(os.path.join(logging_dir, 'config.yaml'), 'w') as f: | |
| OmegaConf.save(config, f) | |
| logger.info("***** Running training *****") | |
| logger.info(f" Num examples = {len(train_dataset)}") | |
| logger.info(f" Num Epochs = {config.training.n_epochs}") | |
| logger.info(f" Instantaneous batch size per device = {config.data.train.batch_size}") | |
| logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
| logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") | |
| logger.info(f" Total optimization steps = {max_train_steps}") | |
| # --------------- Generation/Validation arguments --------------- | |
| generation_args = config.generation | |
| # Need to set some experiment specific arguments | |
| generation_args.project_name = config.project.project | |
| generation_args.use_wandb = config.project.use_wandb | |
| generation_args.id = generate_id_name_eval(generation_args) | |
| generation_args.out_path = os.path.join(logging_dir, generation_args.id) | |
| generation_args.start_generation_at_step = config.generation.start_generation_at_step | |
| generation_args.metrics = config.metrics | |
| os.makedirs(generation_args.out_path, exist_ok=True) | |
| # --------------- Training loop --------------- | |
| total_steps = num_update_steps_per_epoch * config.training.n_epochs | |
| progress_bar = tqdm(total=total_steps, disable=not accelerator.is_local_main_process) | |
| progress_bar.set_description(f"Training Progress") | |
| for epoch in range(config.training.n_epochs): | |
| model.train() | |
| for step, batch in enumerate(train_dataloader): | |
| s_time = time.time() | |
| if config.training.resume_from_checkpoint and epoch == first_epoch and step < resume_step: | |
| if step % config.training.gradient_accumulation_steps == 0: | |
| progress_bar.update(1) | |
| continue | |
| with accelerator.accumulate(model): | |
| loss = model(batch) | |
| accelerator.backward(loss) | |
| loss_meter.update(loss.detach().item(), batch['image'].shape[0]) | |
| if accelerator.sync_gradients: | |
| accelerator.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| if accelerator.sync_gradients: | |
| progress_bar.update(1) | |
| global_step += 1 | |
| if global_step % config.training.checkpointing_steps == 0: | |
| accelerator.wait_for_everyone() | |
| val_loss = validate(model, test_dataloader, accelerator) | |
| accelerator.log({"val_loss": val_loss}, step=global_step) | |
| save_checkpoint(accelerator, model, global_step, logging_dir, config.training.checkpoints_total_limit) | |
| model.train() | |
| logs = { | |
| "loss": loss_meter.val, | |
| "last_lr": lr_scheduler.get_last_lr()[0], | |
| "step": global_step, | |
| "step_time": time.time() - s_time, | |
| "epoch": epoch} | |
| progress_bar.set_postfix(**logs) | |
| accelerator.log(logs, step=global_step) | |
| accelerator.end_training() | |
| if __name__ == "__main__": | |
| main(config=get_config()) |