Feat: Add warmup_ratio (#893)
Browse files* Feat: Add warmup_ratio
* fix: update readme with more details on conflict
- README.md +2 -1
- src/axolotl/core/trainer_builder.py +8 -5
- src/axolotl/utils/config.py +3 -0
- tests/test_validation.py +30 -0
README.md
CHANGED
|
@@ -675,7 +675,8 @@ gradient_accumulation_steps: 1
|
|
| 675 |
micro_batch_size: 2
|
| 676 |
eval_batch_size:
|
| 677 |
num_epochs: 4
|
| 678 |
-
warmup_steps: 100
|
|
|
|
| 679 |
learning_rate: 0.00003
|
| 680 |
lr_quadratic_warmup:
|
| 681 |
logging_steps:
|
|
|
|
| 675 |
micro_batch_size: 2
|
| 676 |
eval_batch_size:
|
| 677 |
num_epochs: 4
|
| 678 |
+
warmup_steps: 100 # cannot use with warmup_ratio
|
| 679 |
+
warmup_ratio: 0.05 # cannot use with warmup_steps
|
| 680 |
learning_rate: 0.00003
|
| 681 |
lr_quadratic_warmup:
|
| 682 |
logging_steps:
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -461,11 +461,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 461 |
return AxolotlTrainer
|
| 462 |
|
| 463 |
def build(self, total_num_steps):
|
| 464 |
-
warmup_steps =
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
|
|
|
|
|
|
|
|
|
| 469 |
logging_steps = (
|
| 470 |
self.cfg.logging_steps
|
| 471 |
if self.cfg.logging_steps is not None
|
|
|
|
| 461 |
return AxolotlTrainer
|
| 462 |
|
| 463 |
def build(self, total_num_steps):
|
| 464 |
+
warmup_steps = None
|
| 465 |
+
if self.cfg.warmup_steps is not None:
|
| 466 |
+
warmup_steps = self.cfg.warmup_steps
|
| 467 |
+
elif self.cfg.warmup_ratio is not None:
|
| 468 |
+
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
| 469 |
+
else:
|
| 470 |
+
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
| 471 |
+
|
| 472 |
logging_steps = (
|
| 473 |
self.cfg.logging_steps
|
| 474 |
if self.cfg.logging_steps is not None
|
src/axolotl/utils/config.py
CHANGED
|
@@ -372,6 +372,9 @@ def validate_config(cfg):
|
|
| 372 |
if cfg.rope_scaling:
|
| 373 |
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
| 374 |
|
|
|
|
|
|
|
|
|
|
| 375 |
# TODO
|
| 376 |
# MPT 7b
|
| 377 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 372 |
if cfg.rope_scaling:
|
| 373 |
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
| 374 |
|
| 375 |
+
if cfg.warmup_steps and cfg.warmup_ratio:
|
| 376 |
+
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
| 377 |
+
|
| 378 |
# TODO
|
| 379 |
# MPT 7b
|
| 380 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
tests/test_validation.py
CHANGED
|
@@ -649,3 +649,33 @@ class ValidationTest(unittest.TestCase):
|
|
| 649 |
)
|
| 650 |
|
| 651 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
)
|
| 650 |
|
| 651 |
validate_config(cfg)
|
| 652 |
+
|
| 653 |
+
def test_warmup_step_no_conflict(self):
|
| 654 |
+
cfg = DictDefault(
|
| 655 |
+
{
|
| 656 |
+
"warmup_steps": 10,
|
| 657 |
+
"warmup_ratio": 0.1,
|
| 658 |
+
}
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
with pytest.raises(
|
| 662 |
+
ValueError,
|
| 663 |
+
match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
|
| 664 |
+
):
|
| 665 |
+
validate_config(cfg)
|
| 666 |
+
|
| 667 |
+
cfg = DictDefault(
|
| 668 |
+
{
|
| 669 |
+
"warmup_steps": 10,
|
| 670 |
+
}
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
validate_config(cfg)
|
| 674 |
+
|
| 675 |
+
cfg = DictDefault(
|
| 676 |
+
{
|
| 677 |
+
"warmup_ratio": 0.1,
|
| 678 |
+
}
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
validate_config(cfg)
|