Spaces:
Paused
Paused
| # Copyright 2025 Bytedance Ltd. and/or its affiliates. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import functools | |
| import os | |
| import wandb | |
| import yaml | |
| from copy import deepcopy | |
| from dataclasses import dataclass, field | |
| from time import time | |
| import torch | |
| import torch.distributed as dist | |
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
| CheckpointImpl, | |
| apply_activation_checkpointing, | |
| checkpoint_wrapper, | |
| ) | |
| from torch.utils.data import DataLoader | |
| from transformers import HfArgumentParser, set_seed | |
| from transformers.optimization import ( | |
| get_constant_schedule_with_warmup, | |
| get_cosine_with_min_lr_schedule_with_warmup, | |
| ) | |
| from data.dataset_base import DataConfig, PackedDataset, collate_wrapper | |
| from data.data_utils import add_special_tokens | |
| from modeling.autoencoder import load_ae | |
| from modeling.bagel import ( | |
| BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel | |
| ) | |
| from modeling.qwen2 import Qwen2Tokenizer | |
| from train.train_utils import create_logger, get_latest_ckpt | |
| from train.fsdp_utils import ( | |
| FSDPCheckpoint, FSDPConfig, grad_checkpoint_check_fn, fsdp_wrapper, | |
| fsdp_ema_setup, fsdp_ema_update, | |
| ) | |
| class ModelArguments: | |
| llm_path: str = field( | |
| default="hf/Qwen2.5-0.5B-Instruct/", | |
| metadata={"help": "Path or HuggingFace repo ID of the pretrained Qwen2-style language model."} | |
| ) | |
| llm_qk_norm: bool = field( | |
| default=True, | |
| metadata={"help": "Enable QK LayerNorm (qk_norm) inside the attention blocks."} | |
| ) | |
| tie_word_embeddings: bool = field( | |
| default=False, | |
| metadata={"help": "Share input and output word embeddings (tied embeddings)."} | |
| ) | |
| layer_module: str = field( | |
| default="Qwen2DecoderLayer", | |
| metadata={"help": "Python class name of the decoder layer to instantiate."} | |
| ) | |
| vae_path: str = field( | |
| default="flux/vae/ae.safetensors", | |
| metadata={"help": "Path to the pretrained VAE checkpoint for latent-space image generation."} | |
| ) | |
| vit_path: str = field( | |
| default="hf/siglip-so400m-14-980-flash-attn2-navit/", | |
| metadata={"help": "Path or repo ID of the SigLIP Vision Transformer used for image understanding."} | |
| ) | |
| max_latent_size: int = field( | |
| default=32, | |
| metadata={"help": "Maximum latent grid size (patches per side) for the VAE latent tensor."} | |
| ) | |
| latent_patch_size: int = field( | |
| default=2, | |
| metadata={"help": "Spatial size (in VAE pixels) covered by each latent patch."} | |
| ) | |
| vit_patch_size: int = field( | |
| default=14, | |
| metadata={"help": "Patch size (pixels) for the Vision Transformer encoder."} | |
| ) | |
| vit_max_num_patch_per_side: int = field( | |
| default=70, | |
| metadata={"help": "Maximum number of ViT patches along one image side after cropping / resize."} | |
| ) | |
| connector_act: str = field( | |
| default="gelu_pytorch_tanh", | |
| metadata={"help": "Activation function used in the latent-to-text connector MLP."} | |
| ) | |
| interpolate_pos: bool = field( | |
| default=False, | |
| metadata={"help": "Interpolate positional embeddings when image resolution differs from pre-training."} | |
| ) | |
| vit_select_layer: int = field( | |
| default=-2, | |
| metadata={"help": "Which hidden layer of the ViT to take as the visual feature (negative = from the end)."} | |
| ) | |
| vit_rope: bool = field( | |
| default=False, | |
| metadata={"help": "Replace ViT positional encodings with RoPE."} | |
| ) | |
| text_cond_dropout_prob: float = field( | |
| default=0.1, | |
| metadata={"help": "Probability of dropping text embeddings during training."} | |
| ) | |
| vae_cond_dropout_prob: float = field( | |
| default=0.3, | |
| metadata={"help": "Probability of dropping VAE latent inputs during training."} | |
| ) | |
| vit_cond_dropout_prob: float = field( | |
| default=0.3, | |
| metadata={"help": "Probability of dropping ViT visual features during training."} | |
| ) | |
| class DataArguments: | |
| dataset_config_file: str = field( | |
| default="data/configs/example.yaml", | |
| metadata={"help": "YAML file specifying dataset groups, weights, and preprocessing rules."} | |
| ) | |
| prefetch_factor: int = field( | |
| default=2, | |
| metadata={"help": "How many batches each DataLoader worker pre-loads in advance."} | |
| ) | |
| num_workers: int = field( | |
| default=4, | |
| metadata={"help": "Number of background workers for the PyTorch DataLoader."} | |
| ) | |
| max_num_tokens_per_sample: int = field( | |
| default=16384, | |
| metadata={"help": "Maximum tokens allowed in one raw sample; longer samples are skipped."} | |
| ) | |
| max_num_tokens: int = field( | |
| default=36864, | |
| metadata={"help": "Hard limit on tokens in a packed batch; flush if adding a sample would exceed it."} | |
| ) | |
| prefer_buffer_before: int = field( | |
| default=16384, | |
| metadata={"help": "While batch length is below this, pop from the overflow buffer before new sampling."} | |
| ) | |
| max_buffer_size: int = field( | |
| default=50, | |
| metadata={"help": "Maximum number of oversized samples kept in the overflow buffer."} | |
| ) | |
| data_seed: int = field( | |
| default=42, | |
| metadata={"help": "Seed used when shuffling / sampling data shards to ensure reproducibility."} | |
| ) | |
| class TrainingArguments: | |
| # --- modality switches --- | |
| visual_gen: bool = field( | |
| default=True, | |
| metadata={"help": "Train image generation branch."} | |
| ) | |
| visual_und: bool = field( | |
| default=True, | |
| metadata={"help": "Train image understanding branch."} | |
| ) | |
| # --- bookkeeping & logging --- | |
| results_dir: str = field( | |
| default="results", | |
| metadata={"help": "Root directory for logs."} | |
| ) | |
| checkpoint_dir: str = field( | |
| default="results/checkpoints", | |
| metadata={"help": "Root directory for model checkpoints."} | |
| ) | |
| wandb_project: str = field( | |
| default="bagel", | |
| metadata={"help": "Weights & Biases project name."} | |
| ) | |
| wandb_name: str = field( | |
| default="run", | |
| metadata={"help": "Name shown in the Weights & Biases UI for this run."} | |
| ) | |
| wandb_runid: str = field( | |
| default="0", | |
| metadata={"help": "Unique identifier to resume a previous W&B run, if desired."} | |
| ) | |
| wandb_resume: str = field( | |
| default="allow", | |
| metadata={"help": "W&B resume mode: 'allow', 'must', or 'never'."} | |
| ) | |
| wandb_offline: bool = field( | |
| default=False, | |
| metadata={"help": "Run W&B in offline mode (logs locally, sync later)."} | |
| ) | |
| # --- reproducibility & resume --- | |
| global_seed: int = field( | |
| default=4396, | |
| metadata={"help": "Base random seed; actual seed is offset by rank for DDP."} | |
| ) | |
| auto_resume: bool = field( | |
| default=False, | |
| metadata={"help": "Automatically pick up the latest checkpoint found in checkpoint_dir."} | |
| ) | |
| resume_from: str = field( | |
| default=None, | |
| metadata={"help": "Explicit checkpoint path to resume from (overrides auto_resume)." } | |
| ) | |
| resume_model_only: bool = field( | |
| default=False, | |
| metadata={"help": "Load only model weights, ignoring optimizer/scheduler states."} | |
| ) | |
| finetune_from_ema: bool = field( | |
| default=False, | |
| metadata={"help": "When resume_model_only=True, load the EMA (exponential moving average) weights instead of raw weights."} | |
| ) | |
| # --- reporting frequency --- | |
| log_every: int = field( | |
| default=10, | |
| metadata={"help": "Print / log every N training steps."} | |
| ) | |
| save_every: int = field( | |
| default=2000, | |
| metadata={"help": "Save a checkpoint every N training steps."} | |
| ) | |
| total_steps: int = field( | |
| default=500_000, | |
| metadata={"help": "Total number of optimizer steps to train for."} | |
| ) | |
| # --- optimization & scheduler --- | |
| warmup_steps: int = field( | |
| default=2000, | |
| metadata={"help": "Linear warm-up steps before applying the main LR schedule."} | |
| ) | |
| lr_scheduler: str = field( | |
| default="constant", | |
| metadata={"help": "Type of LR schedule: 'constant' or 'cosine'."} | |
| ) | |
| lr: float = field( | |
| default=1e-4, | |
| metadata={"help": "Peak learning rate after warm-up."} | |
| ) | |
| min_lr: float = field( | |
| default=1e-7, | |
| metadata={"help": "Minimum learning rate for cosine schedule (ignored for constant)."} | |
| ) | |
| beta1: float = field( | |
| default=0.9, | |
| metadata={"help": "AdamW β₁ coefficient."} | |
| ) | |
| beta2: float = field( | |
| default=0.95, | |
| metadata={"help": "AdamW β₂ coefficient."} | |
| ) | |
| eps: float = field( | |
| default=1e-15, | |
| metadata={"help": "AdamW ε for numerical stability."} | |
| ) | |
| ema: float = field( | |
| default=0.9999, | |
| metadata={"help": "Decay rate for the exponential moving average of model weights."} | |
| ) | |
| max_grad_norm: int = field( | |
| default=1.0, | |
| metadata={"help": "Gradient clipping threshold (L2 norm)."} | |
| ) | |
| timestep_shift: float = field( | |
| default=1.0, | |
| metadata={"help": "Shift applied to diffusion timestep indices (for latent prediction)."} | |
| ) | |
| mse_weight: float = field( | |
| default=1.0, | |
| metadata={"help": "Scaling factor for the image-reconstruction MSE loss term."} | |
| ) | |
| ce_weight: float = field( | |
| default=1.0, | |
| metadata={"help": "Scaling factor for the language cross-entropy loss term."} | |
| ) | |
| ce_loss_reweighting: bool = field( | |
| default=False, | |
| metadata={"help": "Reweight CE loss by token importance (provided via ce_loss_weights)."} | |
| ) | |
| expected_num_tokens: int = field( | |
| default=32768, | |
| metadata={"help": "Soft target token count; yield the batch once it reaches or exceeds this size."} | |
| ) | |
| # --- distributed training / FSDP --- | |
| num_replicate: int = field( | |
| default=1, | |
| metadata={"help": "Number of model replicas per GPU rank for tensor parallelism."} | |
| ) | |
| num_shard: int = field( | |
| default=8, | |
| metadata={"help": "Number of parameter shards when using FSDP HYBRID_SHARD."} | |
| ) | |
| sharding_strategy: str = field( | |
| default="HYBRID_SHARD", | |
| metadata={"help": "FSDP sharding strategy: FULL_SHARD, SHARD_GRAD_OP, HYBRID_SHARD, etc."} | |
| ) | |
| backward_prefetch: str = field( | |
| default="BACKWARD_PRE", | |
| metadata={"help": "FSDP backward prefetch strategy (BACKWARD_PRE or NO_PREFETCH)."} | |
| ) | |
| cpu_offload: bool = field( | |
| default=False, | |
| metadata={"help": "Enable FSDP parameter offload to CPU."} | |
| ) | |
| # --- module freezing --- | |
| freeze_llm: bool = field( | |
| default=False, | |
| metadata={"help": "Keep language-model weights fixed (no gradient updates)."} | |
| ) | |
| freeze_vit: bool = field( | |
| default=False, | |
| metadata={"help": "Keep ViT weights fixed during training."} | |
| ) | |
| freeze_vae: bool = field( | |
| default=True, | |
| metadata={"help": "Keep VAE weights fixed; only predict latents, don’t fine-tune encoder/decoder."} | |
| ) | |
| freeze_und: bool = field( | |
| default=False, | |
| metadata={"help": "Freeze the visual understanding connector layers."} | |
| ) | |
| copy_init_moe: bool = field( | |
| default=True, | |
| metadata={"help": "Duplicate initial MoE experts so each has identical initialisation."} | |
| ) | |
| use_flex: bool = field( | |
| default=False, | |
| metadata={"help": "Enable FLEX (flash-ext friendly) packing algorithm for sequence data."} | |
| ) | |
| def main(): | |
| assert torch.cuda.is_available() | |
| dist.init_process_group("nccl") | |
| device = dist.get_rank() % torch.cuda.device_count() | |
| torch.cuda.set_device(device) | |
| parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| # Setup logging: | |
| if dist.get_rank() == 0: | |
| os.makedirs(training_args.results_dir, exist_ok=True) | |
| os.makedirs(training_args.checkpoint_dir, exist_ok=True) | |
| logger = create_logger(training_args.results_dir, dist.get_rank()) | |
| wandb.init( | |
| project=training_args.wandb_project, | |
| id=f"{training_args.wandb_name}-run{training_args.wandb_runid}", | |
| name=training_args.wandb_name, | |
| resume=training_args.wandb_resume, | |
| mode="offline" if training_args.wandb_offline else "online" | |
| ) | |
| wandb.config.update(training_args) | |
| wandb.config.update(model_args) | |
| wandb.config.update(data_args) | |
| else: | |
| logger = create_logger(None, dist.get_rank()) | |
| dist.barrier() | |
| logger.info(f'Training arguments {training_args}') | |
| logger.info(f'Model arguments {model_args}') | |
| logger.info(f'Data arguments {data_args}') | |
| # prepare auto resume logic: | |
| if training_args.auto_resume: | |
| resume_from = get_latest_ckpt(training_args.checkpoint_dir) | |
| if resume_from is None: | |
| resume_from = training_args.resume_from | |
| resume_model_only = training_args.resume_model_only | |
| if resume_model_only: | |
| finetune_from_ema = training_args.finetune_from_ema | |
| else: | |
| finetune_from_ema = False | |
| else: | |
| resume_model_only = False | |
| finetune_from_ema = False | |
| else: | |
| resume_from = training_args.resume_from | |
| resume_model_only = training_args.resume_model_only | |
| if resume_model_only: | |
| finetune_from_ema = training_args.finetune_from_ema | |
| else: | |
| finetune_from_ema = False | |
| # Set seed: | |
| seed = training_args.global_seed * dist.get_world_size() + dist.get_rank() | |
| set_seed(seed) | |
| # Setup model: | |
| llm_config = Qwen2Config.from_pretrained(model_args.llm_path) | |
| llm_config.layer_module = model_args.layer_module | |
| llm_config.qk_norm = model_args.llm_qk_norm | |
| llm_config.tie_word_embeddings = model_args.tie_word_embeddings | |
| llm_config.freeze_und = training_args.freeze_und | |
| language_model = Qwen2ForCausalLM.from_pretrained(model_args.llm_path, config=llm_config) | |
| if training_args.copy_init_moe: | |
| language_model.init_moe() | |
| if training_args.visual_und: | |
| vit_config = SiglipVisionConfig.from_pretrained(model_args.vit_path) | |
| vit_config.num_hidden_layers = vit_config.num_hidden_layers + 1 + model_args.vit_select_layer | |
| vit_config.rope = model_args.vit_rope | |
| vit_model = SiglipVisionModel.from_pretrained(model_args.vit_path, config=vit_config) | |
| if training_args.visual_gen: | |
| vae_model, vae_config = load_ae(local_path=model_args.vae_path) | |
| config = BagelConfig( | |
| visual_gen=training_args.visual_gen, | |
| visual_und=training_args.visual_und, | |
| llm_config=llm_config, | |
| vit_config=vit_config if training_args.visual_und else None, | |
| vae_config=vae_config if training_args.visual_gen else None, | |
| latent_patch_size=model_args.latent_patch_size, | |
| max_latent_size=model_args.max_latent_size, | |
| vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side, | |
| connector_act=model_args.connector_act, | |
| interpolate_pos=model_args.interpolate_pos, | |
| timestep_shift=training_args.timestep_shift, | |
| ) | |
| model = Bagel( | |
| language_model, | |
| vit_model if training_args.visual_und else None, | |
| config | |
| ) | |
| if training_args.visual_und: | |
| model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config) | |
| # Setup tokenizer for model: | |
| tokenizer = Qwen2Tokenizer.from_pretrained(model_args.llm_path) | |
| tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer) | |
| if num_new_tokens > 0: | |
| model.language_model.resize_token_embeddings(len(tokenizer)) | |
| model.config.llm_config.vocab_size = len(tokenizer) | |
| model.language_model.config.vocab_size = len(tokenizer) | |
| # maybe freeze something: | |
| if training_args.freeze_vae and training_args.visual_gen: | |
| for param in vae_model.parameters(): | |
| param.requires_grad = False | |
| if training_args.freeze_llm: | |
| model.language_model.eval() | |
| for param in model.language_model.parameters(): | |
| param.requires_grad = False | |
| if training_args.freeze_vit and training_args.visual_und: | |
| model.vit_model.eval() | |
| for param in model.vit_model.parameters(): | |
| param.requires_grad = False | |
| # Setup FSDP and load pretrained model: | |
| fsdp_config = FSDPConfig( | |
| sharding_strategy=training_args.sharding_strategy, | |
| backward_prefetch=training_args.backward_prefetch, | |
| cpu_offload=training_args.cpu_offload, | |
| num_replicate=training_args.num_replicate, | |
| num_shard=training_args.num_shard, | |
| ) | |
| ema_model = deepcopy(model) | |
| model, ema_model = FSDPCheckpoint.try_load_ckpt( | |
| resume_from, logger, model, ema_model, resume_from_ema=finetune_from_ema | |
| ) | |
| ema_model = fsdp_ema_setup(ema_model, fsdp_config) | |
| fsdp_model = fsdp_wrapper(model, fsdp_config) | |
| apply_activation_checkpointing( | |
| fsdp_model, | |
| checkpoint_wrapper_fn=functools.partial( | |
| checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT | |
| ), | |
| check_fn=grad_checkpoint_check_fn | |
| ) | |
| if dist.get_rank() == 0: | |
| print(fsdp_model) | |
| for name, param in model.named_parameters(): | |
| print(name, param.requires_grad) | |
| # Setup optimizer and scheduler | |
| optimizer = torch.optim.AdamW( | |
| fsdp_model.parameters(), | |
| lr=training_args.lr, | |
| betas=(training_args.beta1, training_args.beta2), | |
| eps=training_args.eps, | |
| weight_decay=0 | |
| ) | |
| if training_args.lr_scheduler == 'cosine': | |
| scheduler = get_cosine_with_min_lr_schedule_with_warmup( | |
| optimizer=optimizer, | |
| num_warmup_steps=training_args.warmup_steps, | |
| num_training_steps=training_args.total_steps, | |
| min_lr=training_args.min_lr, | |
| ) | |
| elif training_args.lr_scheduler == 'constant': | |
| scheduler = get_constant_schedule_with_warmup( | |
| optimizer=optimizer, num_warmup_steps=training_args.warmup_steps | |
| ) | |
| else: | |
| raise ValueError | |
| # maybe resume optimizer, scheduler, and train_steps | |
| if resume_model_only: | |
| train_step = 0 | |
| data_status = None | |
| else: | |
| optimizer, scheduler, train_step, data_status = FSDPCheckpoint.try_load_train_state( | |
| resume_from, optimizer, scheduler, fsdp_config, | |
| ) | |
| # Setup packed dataloader | |
| with open(data_args.dataset_config_file, "r") as stream: | |
| dataset_meta = yaml.safe_load(stream) | |
| dataset_config = DataConfig(grouped_datasets=dataset_meta) | |
| if training_args.visual_und: | |
| dataset_config.vit_patch_size = model_args.vit_patch_size | |
| dataset_config.max_num_patch_per_side = model_args.vit_max_num_patch_per_side | |
| if training_args.visual_gen: | |
| vae_image_downsample = model_args.latent_patch_size * vae_config.downsample | |
| dataset_config.vae_image_downsample = vae_image_downsample | |
| dataset_config.max_latent_size = model_args.max_latent_size | |
| dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob | |
| dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob | |
| dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob | |
| train_dataset = PackedDataset( | |
| dataset_config, | |
| tokenizer=tokenizer, | |
| special_tokens=new_token_ids, | |
| local_rank=dist.get_rank(), | |
| world_size=dist.get_world_size(), | |
| num_workers=data_args.num_workers, | |
| expected_num_tokens=training_args.expected_num_tokens, | |
| max_num_tokens_per_sample=data_args.max_num_tokens_per_sample, | |
| max_num_tokens=data_args.max_num_tokens, | |
| max_buffer_size=data_args.max_buffer_size, | |
| prefer_buffer_before=data_args.prefer_buffer_before, | |
| interpolate_pos=model_args.interpolate_pos, | |
| use_flex=training_args.use_flex, | |
| data_status=data_status, | |
| ) | |
| train_dataset.set_epoch(data_args.data_seed) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=1, # batch size is 1 packed dataset | |
| num_workers=data_args.num_workers, | |
| pin_memory=True, | |
| collate_fn=collate_wrapper(), | |
| drop_last=True, | |
| prefetch_factor=data_args.prefetch_factor, | |
| ) | |
| # Prepare models for training: | |
| if training_args.visual_gen: | |
| vae_model.to(device).eval() | |
| fsdp_model.train() | |
| ema_model.eval() | |
| # train loop | |
| start_time = time() | |
| logger.info(f"Training for {training_args.total_steps} steps, starting at {train_step}...") | |
| for curr_step, data in enumerate(train_loader, start=train_step): | |
| data = data.cuda(device).to_dict() | |
| data_indexes = data.pop('batch_data_indexes', None) | |
| ce_loss_weights = data.pop('ce_loss_weights', None) | |
| with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): | |
| if training_args.visual_gen: | |
| with torch.no_grad(): | |
| data['padded_latent'] = vae_model.encode(data.pop('padded_images')) | |
| loss_dict = fsdp_model(**data) | |
| loss = 0 | |
| ce = loss_dict["ce"] | |
| if ce is not None: | |
| total_ce_tokens = torch.tensor(len(data['ce_loss_indexes']), device=device) | |
| dist.all_reduce(total_ce_tokens, op=dist.ReduceOp.SUM) | |
| if training_args.ce_loss_reweighting: | |
| ce = ce * ce_loss_weights | |
| total_ce_loss_weights = ce_loss_weights.sum() | |
| dist.all_reduce(total_ce_loss_weights, op=dist.ReduceOp.SUM) | |
| ce = ce.sum() * dist.get_world_size() / total_ce_loss_weights | |
| else: | |
| ce = ce.sum() * dist.get_world_size() / total_ce_tokens | |
| loss_dict["ce"] = ce.detach() | |
| loss = loss + ce * training_args.ce_weight | |
| else: | |
| assert not training_args.visual_und | |
| loss_dict["ce"] = torch.tensor(0, device=device) | |
| total_ce_tokens = torch.tensor(0, device=device) | |
| if training_args.visual_gen: | |
| mse = loss_dict["mse"] | |
| total_mse_tokens = torch.tensor(len(data['mse_loss_indexes']), device=device) | |
| dist.all_reduce(total_mse_tokens, op=dist.ReduceOp.SUM) | |
| mse = mse.mean(dim=-1).sum() * dist.get_world_size() / total_mse_tokens | |
| loss_dict["mse"] = mse.detach() | |
| loss = loss + mse * training_args.mse_weight | |
| else: | |
| assert not training_args.visual_gen | |
| loss_dict["mse"] = torch.tensor(0, device=device) | |
| total_mse_tokens = torch.tensor(0, device=device) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| total_norm = fsdp_model.clip_grad_norm_(training_args.max_grad_norm) | |
| optimizer.step() | |
| scheduler.step() | |
| fsdp_ema_update(ema_model, fsdp_model, decay=training_args.ema) | |
| # Log loss values: | |
| if curr_step % training_args.log_every == 0: | |
| total_samples = torch.tensor(len(data['sample_lens']), device=device) | |
| dist.all_reduce(total_samples, op=dist.ReduceOp.SUM) | |
| # Measure training speed: | |
| torch.cuda.synchronize() | |
| end_time = time() | |
| steps_per_sec = training_args.log_every / (end_time - start_time) | |
| message = f"(step={curr_step:07d}) " | |
| wandb_log = {} | |
| for key, value in loss_dict.items(): | |
| # Reduce loss history over all processes: | |
| avg_loss = torch.tensor(value.item(), device=device) | |
| dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) | |
| avg_loss = avg_loss.item() / dist.get_world_size() | |
| message += f"Train Loss {key}: {avg_loss:.4f}, " | |
| wandb_log[key] = avg_loss | |
| message += f"Train Steps/Sec: {steps_per_sec:.2f}, " | |
| logger.info(message) | |
| wandb_log['lr'] = optimizer.param_groups[0]['lr'] | |
| wandb_log['total_mse_tokens'] = total_mse_tokens.item() | |
| wandb_log['total_ce_tokens'] = total_ce_tokens.item() | |
| wandb_log['total_norm'] = total_norm.item() | |
| wandb_log['total_samples'] = total_samples.item() | |
| mem_allocated = torch.tensor(torch.cuda.max_memory_allocated() / 1024**2, device=device) | |
| dist.all_reduce(mem_allocated, op=dist.ReduceOp.MAX) | |
| wandb_log['mem_allocated'] = mem_allocated | |
| mem_cache = torch.tensor(torch.cuda.max_memory_reserved() / 1024**2, device=device) | |
| dist.all_reduce(mem_cache, op=dist.ReduceOp.MAX) | |
| wandb_log['mem_cache'] = mem_cache | |
| if dist.get_rank() == 0: | |
| wandb.log(wandb_log, step=curr_step) | |
| start_time = time() | |
| if data_status is None: | |
| data_status = {} | |
| for item in data_indexes: | |
| if item['dataset_name'] not in data_status.keys(): | |
| data_status[item['dataset_name']] = {} | |
| data_status[item['dataset_name']][item['worker_id']] = item['data_indexes'] | |
| if curr_step > 0 and curr_step % training_args.save_every == 0: | |
| if dist.get_rank() == 0: | |
| gather_list = [None] * dist.get_world_size() | |
| else: | |
| gather_list = None | |
| dist.gather_object(data_status, gather_list, dst=0) | |
| FSDPCheckpoint.fsdp_save_ckpt( | |
| ckpt_dir=training_args.checkpoint_dir, | |
| train_steps=curr_step, | |
| model=fsdp_model, | |
| ema_model=ema_model, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| logger=logger, | |
| fsdp_config=fsdp_config, | |
| data_status=gather_list | |
| ) | |
| logger.info("Done!") | |
| if dist.get_rank() == 0: | |
| wandb.finish() | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() | |