Feat(config): add max steps (#387)
Browse files- scripts/finetune.py +7 -1
- src/axolotl/utils/trainer.py +1 -1
scripts/finetune.py
CHANGED
|
@@ -209,7 +209,13 @@ def train(
|
|
| 209 |
cfg, train_dataset, eval_dataset
|
| 210 |
)
|
| 211 |
barrier()
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
if cfg.debug or "debug" in kwargs:
|
| 215 |
LOG.info("check_dataset_labels...")
|
|
|
|
| 209 |
cfg, train_dataset, eval_dataset
|
| 210 |
)
|
| 211 |
barrier()
|
| 212 |
+
if cfg.max_steps:
|
| 213 |
+
total_num_steps = min(
|
| 214 |
+
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
| 215 |
+
)
|
| 216 |
+
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
| 217 |
+
else:
|
| 218 |
+
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
| 219 |
|
| 220 |
if cfg.debug or "debug" in kwargs:
|
| 221 |
LOG.info("check_dataset_labels...")
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -461,7 +461,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 461 |
evaluation_strategy = "steps"
|
| 462 |
|
| 463 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
| 464 |
-
|
| 465 |
max_seq_length=cfg.sequence_len,
|
| 466 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 467 |
per_device_eval_batch_size=cfg.eval_batch_size
|
|
|
|
| 461 |
evaluation_strategy = "steps"
|
| 462 |
|
| 463 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
| 464 |
+
max_steps=total_num_steps if cfg.max_steps else -1,
|
| 465 |
max_seq_length=cfg.sequence_len,
|
| 466 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 467 |
per_device_eval_batch_size=cfg.eval_batch_size
|