Spaces:
Running
Running
| import os | |
| import torch | |
| import transformers | |
| import os | |
| from starvector.util import checkpoint_key | |
| import glob | |
| import shutil | |
| import builtins | |
| import datetime | |
| from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig | |
| from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy | |
| from torch.distributed.fsdp import ( | |
| MixedPrecision, | |
| ShardingStrategy, | |
| ) | |
| import functools | |
| from accelerate import FullyShardedDataParallelPlugin | |
| from accelerate.utils import PrecisionType | |
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
| checkpoint_wrapper, | |
| CheckpointImpl, | |
| apply_activation_checkpointing, | |
| ) | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForCausalLM | |
| ) | |
| from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM | |
| from starvector.train.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict | |
| def is_deepspeed(checkpoint_dir): | |
| # Check zero_to_fp32.py file (generated only in deepspeed training) | |
| return os.path.exists(os.path.join(checkpoint_dir, 'zero_to_fp32.py')) | |
| def consolidate_deepspeed_checkpoint(checkpoint_dir): | |
| path_state_dict = os.path.join(checkpoint_dir, 'weights.pt') | |
| if not os.path.exists(path_state_dict): | |
| convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, path_state_dict) | |
| def load_checkpoint(model, checkpoint_dir): | |
| candidate_files = ['weights.pt', 'pytorch_model.bin', 'model.safetensors'] | |
| # Determine the correct file to load | |
| for candidate in candidate_files: | |
| path_state_dict = os.path.join(checkpoint_dir, candidate) | |
| if os.path.exists(path_state_dict): | |
| break | |
| else: | |
| raise FileNotFoundError(f"No checkpoint file found in {checkpoint_dir}") | |
| # Load the state dict based on file type | |
| if path_state_dict.endswith('.safetensors'): | |
| import safetensors.torch | |
| state_dict = safetensors.torch.load_file(path_state_dict) | |
| else: | |
| state_dict = torch.load(path_state_dict) | |
| # Handle FSDP or module prefix | |
| if list(model.state_dict().keys())[0].startswith('module'): | |
| new_state_dict = {'module.' + key: val for key, val in state_dict.items()} | |
| else: | |
| new_state_dict = state_dict | |
| # Handle Tied Weights | |
| if hasattr(model, 'tie_weights'): | |
| # Remove the lm_head.weight key if it exists and tie_weights will handle it | |
| new_state_dict.pop("model.svg_transformer.transformer.lm_head.weight", None) | |
| # Load the state dict into the model with strict=False to ignore missing keys | |
| model.load_state_dict(new_state_dict, strict=False) # Allow missing keys | |
| # Ensure weights are tied after loading | |
| model.tie_weights() # This method should tie the weights internally | |
| return model | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForCausalLM | |
| ) | |
| from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM | |
| def push_model_to_hub(model, new_model_name, tokenizer, processor): | |
| # Register the model for HF | |
| AutoConfig.register("starvector", StarVectorConfig) | |
| AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM) | |
| StarVectorConfig.register_for_auto_class() | |
| StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM") | |
| model.push_to_hub(new_model_name, commit_message=new_model_name, private=True) | |
| tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True) | |
| processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True) | |
| # push_model_to_hub(self.model, new_model_name, self.tokenizer, self.processor) | |
| def save_checkpoint(accelerator, model, global_step, logging_dir, checkpoint_limit): | |
| print("Saving checkpoint! Global Step: " + str(global_step)) | |
| save_checkpoint_dir = os.path.join(logging_dir, f"checkpoint-{global_step}") | |
| os.makedirs(save_checkpoint_dir, exist_ok=True) | |
| accelerator.wait_for_everyone() | |
| accelerator.save_state(save_checkpoint_dir) | |
| chkp_dirs = sorted(glob.glob(os.path.join(logging_dir, "checkpoint-*")), key = checkpoint_key) | |
| chkp_to_remove = chkp_dirs[:-checkpoint_limit] | |
| for chkp in chkp_to_remove: | |
| if accelerator.is_main_process: | |
| try: | |
| shutil.rmtree(chkp) | |
| except: | |
| print("could not remove checkpoint") | |
| print(f"Saved state to {save_checkpoint_dir}") | |
| def push_model_to_hub(model, new_model_name, hf_token=None): | |
| tokenizer = model.model.svg_transformer.tokenizer | |
| # Register the model for HF | |
| AutoConfig.register("starvector", StarVectorConfig) | |
| AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM) | |
| StarVectorConfig.register_for_auto_class() | |
| StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM") | |
| model.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token) | |
| tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token) | |
| processor = model.model.image_encoder.processor | |
| from starvector.data.base import ImageTrainProcessor | |
| if not isinstance(processor, ImageTrainProcessor): | |
| processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token) | |
| def get_optimizer(config, model): | |
| optimizer = config.training.optimizer | |
| if optimizer == "adamw": | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=config.training.lr, | |
| betas=(config.training.adam_beta1, config.training.adam_beta2), | |
| weight_decay=config.training.adam_weight_decay, | |
| eps=config.training.adam_epsilon, | |
| ) | |
| elif optimizer == "adafactor": | |
| optimizer = transformers.Adafactor( | |
| model.parameters(), | |
| lr=config.training.lr, | |
| relative_step=False, | |
| scale_parameter=False, | |
| ) | |
| else: | |
| raise ValueError(f"Optimizer {optimizer} not supported") | |
| return optimizer | |
| def init_distributed_mode(config): | |
| """ | |
| Initializes torch distributed | |
| """ | |
| assert all(var in os.environ for var in ['WORLD_SIZE', 'LOCAL_RANK', 'RANK']) | |
| world_size = int(os.environ['WORLD_SIZE']) | |
| rank = int(os.environ["RANK"]) | |
| local_rank = int(os.environ['LOCAL_RANK']) | |
| dist_url = 'env://' | |
| torch.cuda.set_device(local_rank) | |
| dist_backend = 'nccl' | |
| print('| distributed init (rank {}): {}, gpu {}'.format( | |
| rank, dist_url, local_rank), flush=True) | |
| torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url, | |
| world_size=world_size, rank=rank) | |
| torch.distributed.barrier() | |
| print_only_on_master(rank == 0) | |
| def print_only_on_master(is_master): | |
| """ | |
| This function disables printing when not in master process | |
| """ | |
| builtin_print = builtins.print | |
| def print(*args, **kwargs): | |
| force = kwargs.pop('force', False) | |
| kwargs['flush'] = True | |
| if is_master or force: | |
| now = datetime.datetime.now().time() | |
| builtin_print('[{}] '.format(now), end='') # print with time stamp | |
| builtin_print(*args, **kwargs) | |
| builtins.print = print | |
| def setup_train_env_variables(config): | |
| """ | |
| Set environment variables needed by FSDP and accelerate | |
| """ | |
| mixed_precision = config.training.model_precision.lower() | |
| try: | |
| mixed_precision = PrecisionType(mixed_precision) | |
| except ValueError: | |
| raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.") | |
| os.environ["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) | |
| if config.fsdp.enable: | |
| # We have to manually set some of the FSDP arguments as environment variables as these are not exposed by the FSDP Plugin API | |
| os.environ['ACCELERATE_USE_FSDP']="true" | |
| os.environ['FSDP_USE_ORIG_PARAMS']=str(config.fsdp.use_orig_params).lower() | |
| os.environ['FSDP_FORWARD_PREFETCH']=str(config.fsdp.forward_prefetch).lower() | |
| if config.fsdp.cpu_ram_efficient_loading and not config.fsdp.sync_module_states: | |
| raise ValueError("When using `fsdp.cpu_ram_efficient_loading` set `fsdp.sync_module_states` to `True`") | |
| os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING']=str(config.fsdp.cpu_ram_efficient_loading).lower() | |
| os.environ['FSDP_SYNC_MODULE_STATES']=str(config.fsdp.sync_module_states).lower() | |
| def load_fsdp_plugin(config, model): | |
| if config.fsdp.enable: | |
| # get mixed precsion dtype | |
| mixed_precision_dtype = { | |
| "fp16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| "tf32": torch.float32, | |
| }[config.training.model_precision] | |
| fsdp_plugin = FullyShardedDataParallelPlugin( | |
| state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), | |
| optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), | |
| auto_wrap_policy=model.model.get_fsdp_wrapping_policy(), | |
| mixed_precision_policy=MixedPrecision( | |
| param_dtype=mixed_precision_dtype, | |
| reduce_dtype=mixed_precision_dtype, | |
| buffer_dtype=mixed_precision_dtype, | |
| ), | |
| sharding_strategy={ | |
| "sdp": ShardingStrategy.SHARD_GRAD_OP, | |
| "ddp": ShardingStrategy.NO_SHARD, | |
| "fsdp": ShardingStrategy.FULL_SHARD, | |
| "hsdp": ShardingStrategy.HYBRID_SHARD, | |
| }[config.fsdp.sharding_strategy], | |
| backward_prefetch=config.fsdp.backward_prefetch, | |
| cpu_offload=config.fsdp.cpu_offload, | |
| ) | |
| else: | |
| fsdp_plugin = None | |
| return fsdp_plugin | |
| def apply_gradient_checkpointing(model): | |
| """ Apply gradient checkpointing to Transformer cls of the LLM """ | |
| def check_fn(submodule): | |
| return isinstance(submodule, model.model.svg_transformer.transformer_layer_cls) | |
| apply_activation_checkpointing( | |
| model, | |
| checkpoint_wrapper_fn=functools.partial( | |
| checkpoint_wrapper, | |
| checkpoint_impl=CheckpointImpl.NO_REENTRANT, | |
| ), | |
| check_fn=check_fn, | |
| ) | |
| # Wait for all processes to finish | |
| torch.distributed.barrier() | |
| return model | |
| def get_module_class_from_name(module, name): | |
| """ | |
| Gets a class from a module by its name. | |
| Args: | |
| module (`torch.nn.Module`): The module to get the class from. | |
| name (`str`): The name of the class. | |
| """ | |
| modules_children = list(module.children()) | |
| if module.__class__.__name__ == name: | |
| return module.__class__ | |
| elif len(modules_children) == 0: | |
| return | |
| else: | |
| for child_module in modules_children: | |
| module_class = get_module_class_from_name(child_module, name) | |
| if module_class is not None: | |
| return module_class |