Feat: Add save_safetensors
Browse files- README.md +3 -0
- src/axolotl/utils/trainer.py +3 -0
README.md
CHANGED
|
@@ -411,6 +411,9 @@ logging_steps:
|
|
| 411 |
save_steps:
|
| 412 |
eval_steps:
|
| 413 |
|
|
|
|
|
|
|
|
|
|
| 414 |
# whether to mask out or include the human's prompt from the training labels
|
| 415 |
train_on_inputs: false
|
| 416 |
# don't use this, leads to wonky training (according to someone on the internet)
|
|
|
|
| 411 |
save_steps:
|
| 412 |
eval_steps:
|
| 413 |
|
| 414 |
+
# save model as safetensors (require safetensors package)
|
| 415 |
+
save_safetensors:
|
| 416 |
+
|
| 417 |
# whether to mask out or include the human's prompt from the training labels
|
| 418 |
train_on_inputs: false
|
| 419 |
# don't use this, leads to wonky training (according to someone on the internet)
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -182,6 +182,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 182 |
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
|
| 183 |
training_arguments_kwargs["push_to_hub"] = True
|
| 184 |
|
|
|
|
|
|
|
|
|
|
| 185 |
training_args = AxolotlTrainingArguments(
|
| 186 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 187 |
per_device_eval_batch_size=cfg.eval_batch_size
|
|
|
|
| 182 |
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
|
| 183 |
training_arguments_kwargs["push_to_hub"] = True
|
| 184 |
|
| 185 |
+
if cfg.save_safetensors:
|
| 186 |
+
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
| 187 |
+
|
| 188 |
training_args = AxolotlTrainingArguments(
|
| 189 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 190 |
per_device_eval_batch_size=cfg.eval_batch_size
|