Fix Deepspeed loading (#950)
Browse files* add check for zero3
* freeze parameters
* fixes for deepspeed loading
* fix model parameter check
* unfrozen parameters in example mixtral and logging when unfreezing
- deepspeed/zero3_bf16.json +39 -0
- examples/mistral/mixtral.yml +9 -0
- src/axolotl/cli/train.py +1 -1
- src/axolotl/train.py +4 -0
- src/axolotl/utils/freeze.py +38 -0
- src/axolotl/utils/models.py +4 -0
- src/axolotl/utils/trainer.py +1 -0
deepspeed/zero3_bf16.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"zero_optimization": {
|
| 3 |
+
"stage": 3,
|
| 4 |
+
"overlap_comm": true,
|
| 5 |
+
"contiguous_gradients": true,
|
| 6 |
+
"sub_group_size": 0,
|
| 7 |
+
"reduce_bucket_size": "auto",
|
| 8 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 9 |
+
"stage3_param_persistence_threshold": "auto",
|
| 10 |
+
"stage3_max_live_parameters": 0,
|
| 11 |
+
"stage3_max_reuse_distance": 0,
|
| 12 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 13 |
+
},
|
| 14 |
+
"bf16": {
|
| 15 |
+
"enabled": true
|
| 16 |
+
},
|
| 17 |
+
"fp16": {
|
| 18 |
+
"enabled": "auto",
|
| 19 |
+
"auto_cast": false,
|
| 20 |
+
"loss_scale": 0,
|
| 21 |
+
"initial_scale_power": 32,
|
| 22 |
+
"loss_scale_window": 1000,
|
| 23 |
+
"hysteresis": 2,
|
| 24 |
+
"min_loss_scale": 1
|
| 25 |
+
},
|
| 26 |
+
"optimizer": {
|
| 27 |
+
"type": "AdamW",
|
| 28 |
+
"params": {
|
| 29 |
+
"lr": "auto",
|
| 30 |
+
"betas": "auto",
|
| 31 |
+
"eps": "auto",
|
| 32 |
+
"weight_decay": "auto"
|
| 33 |
+
}
|
| 34 |
+
},
|
| 35 |
+
"gradient_accumulation_steps": "auto",
|
| 36 |
+
"train_batch_size": "auto",
|
| 37 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 38 |
+
"wall_clock_breakdown": false
|
| 39 |
+
}
|
examples/mistral/mixtral.yml
CHANGED
|
@@ -14,6 +14,15 @@ dataset_prepared_path: last_run_prepared
|
|
| 14 |
val_set_size: 0.0
|
| 15 |
output_dir: ./qlora-out
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
adapter: qlora
|
| 18 |
lora_model_dir:
|
| 19 |
|
|
|
|
| 14 |
val_set_size: 0.0
|
| 15 |
output_dir: ./qlora-out
|
| 16 |
|
| 17 |
+
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
| 18 |
+
unfrozen_parameters:
|
| 19 |
+
# - lm_head.*
|
| 20 |
+
# - model.embed_tokens.*
|
| 21 |
+
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
| 22 |
+
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
| 23 |
+
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
| 24 |
+
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
| 25 |
+
|
| 26 |
adapter: qlora
|
| 27 |
lora_model_dir:
|
| 28 |
|
src/axolotl/cli/train.py
CHANGED
|
@@ -22,8 +22,8 @@ LOG = logging.getLogger("axolotl.cli.train")
|
|
| 22 |
|
| 23 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
| 24 |
# pylint: disable=duplicate-code
|
| 25 |
-
print_axolotl_text_art()
|
| 26 |
parsed_cfg = load_cfg(config, **kwargs)
|
|
|
|
| 27 |
check_accelerate_default_config()
|
| 28 |
check_user_token()
|
| 29 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
|
|
|
| 22 |
|
| 23 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
| 24 |
# pylint: disable=duplicate-code
|
|
|
|
| 25 |
parsed_cfg = load_cfg(config, **kwargs)
|
| 26 |
+
print_axolotl_text_art()
|
| 27 |
check_accelerate_default_config()
|
| 28 |
check_user_token()
|
| 29 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
src/axolotl/train.py
CHANGED
|
@@ -18,6 +18,7 @@ from axolotl.common.cli import TrainerCliArgs
|
|
| 18 |
from axolotl.logging_config import configure_logging
|
| 19 |
from axolotl.monkeypatch import neft_embeddings
|
| 20 |
from axolotl.utils.dict import DictDefault
|
|
|
|
| 21 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 22 |
from axolotl.utils.trainer import setup_trainer
|
| 23 |
|
|
@@ -78,6 +79,9 @@ def train(
|
|
| 78 |
)
|
| 79 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
| 80 |
|
|
|
|
|
|
|
|
|
|
| 81 |
trainer = setup_trainer(
|
| 82 |
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
| 83 |
)
|
|
|
|
| 18 |
from axolotl.logging_config import configure_logging
|
| 19 |
from axolotl.monkeypatch import neft_embeddings
|
| 20 |
from axolotl.utils.dict import DictDefault
|
| 21 |
+
from axolotl.utils.freeze import freeze_parameters_except
|
| 22 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 23 |
from axolotl.utils.trainer import setup_trainer
|
| 24 |
|
|
|
|
| 79 |
)
|
| 80 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
| 81 |
|
| 82 |
+
if cfg.unfrozen_parameters:
|
| 83 |
+
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
| 84 |
+
|
| 85 |
trainer = setup_trainer(
|
| 86 |
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
| 87 |
)
|
src/axolotl/utils/freeze.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
module to freeze/unfreeze parameters by name
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
from axolotl.utils.distributed import is_main_process
|
| 8 |
+
|
| 9 |
+
LOG = logging.getLogger("axolotl.utils.freeze")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def freeze_parameters_except(model, regex_patterns):
|
| 13 |
+
"""
|
| 14 |
+
Freezes all layers of the given model except for the layers that match given regex patterns.
|
| 15 |
+
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
| 16 |
+
|
| 17 |
+
Parameters:
|
| 18 |
+
- model (nn.Module): The PyTorch model to be modified.
|
| 19 |
+
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
None; the model is modified in place.
|
| 23 |
+
"""
|
| 24 |
+
# Escape periods and compile the regex patterns
|
| 25 |
+
compiled_patterns = [
|
| 26 |
+
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
# First, freeze all parameters in the model
|
| 30 |
+
for param in model.parameters():
|
| 31 |
+
param.requires_grad = False
|
| 32 |
+
|
| 33 |
+
# Unfreeze layers that match the regex patterns
|
| 34 |
+
for name, param in model.named_parameters():
|
| 35 |
+
if any(pattern.match(name) for pattern in compiled_patterns):
|
| 36 |
+
if is_main_process():
|
| 37 |
+
LOG.debug(f"unfreezing {name}")
|
| 38 |
+
param.requires_grad = True
|
src/axolotl/utils/models.py
CHANGED
|
@@ -21,6 +21,7 @@ from transformers import ( # noqa: F401
|
|
| 21 |
PreTrainedModel,
|
| 22 |
PreTrainedTokenizerBase,
|
| 23 |
)
|
|
|
|
| 24 |
|
| 25 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
| 26 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
|
@@ -285,6 +286,9 @@ def load_model(
|
|
| 285 |
model_kwargs["max_memory"] = cfg.max_memory
|
| 286 |
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
| 287 |
|
|
|
|
|
|
|
|
|
|
| 288 |
if cfg.model_revision:
|
| 289 |
model_kwargs["revision"] = cfg.model_revision
|
| 290 |
if cfg.gptq:
|
|
|
|
| 21 |
PreTrainedModel,
|
| 22 |
PreTrainedTokenizerBase,
|
| 23 |
)
|
| 24 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
| 25 |
|
| 26 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
| 27 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
|
|
|
| 286 |
model_kwargs["max_memory"] = cfg.max_memory
|
| 287 |
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
| 288 |
|
| 289 |
+
if is_deepspeed_zero3_enabled():
|
| 290 |
+
del model_kwargs["device_map"]
|
| 291 |
+
|
| 292 |
if cfg.model_revision:
|
| 293 |
model_kwargs["revision"] = cfg.model_revision
|
| 294 |
if cfg.gptq:
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -276,6 +276,7 @@ def prepare_optim_env(cfg):
|
|
| 276 |
setup_fsdp_envs(cfg)
|
| 277 |
elif cfg.deepspeed:
|
| 278 |
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
|
|
|
| 279 |
|
| 280 |
|
| 281 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
|
|
|
| 276 |
setup_fsdp_envs(cfg)
|
| 277 |
elif cfg.deepspeed:
|
| 278 |
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
| 279 |
+
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
| 280 |
|
| 281 |
|
| 282 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|