fix fsdp training args
Browse files
src/axolotl/utils/trainer.py
CHANGED
|
@@ -34,6 +34,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 34 |
apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
|
| 35 |
else:
|
| 36 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
# deepspeed
|
|
@@ -64,8 +68,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 64 |
optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer,
|
| 65 |
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
|
| 66 |
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
| 67 |
-
fsdp=cfg.fsdp.split(" ") if cfg.fsdp else None,
|
| 68 |
-
fsdp_transformer_layer_cls_to_wrap=cfg.fsdp_transformer_layer_cls_to_wrap if cfg.fsdp_transformer_layer_cls_to_wrap else None,
|
| 69 |
**training_arguments_kwargs,
|
| 70 |
)
|
| 71 |
|
|
|
|
| 34 |
apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
|
| 35 |
else:
|
| 36 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
| 37 |
+
if cfg.fsdp:
|
| 38 |
+
training_arguments_kwargs["fsdp"] = cfg.fsdp.split(" ")
|
| 39 |
+
if cfg.fsdp_transformer_layer_cls_to_wrap:
|
| 40 |
+
training_arguments_kwargs["fsdp_transformer_layer_cls_to_wrap"] = cfg.fsdp_transformer_layer_cls_to_wrap
|
| 41 |
|
| 42 |
|
| 43 |
# deepspeed
|
|
|
|
| 68 |
optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer,
|
| 69 |
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
|
| 70 |
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
|
|
|
|
|
|
| 71 |
**training_arguments_kwargs,
|
| 72 |
)
|
| 73 |
|