allow the optimizer prune ratio for ReLoRA to be configurable (#1287)
Browse files* allow the optimizer prune ration for relora to be configurable
* update docs for relora
* prevent circular imports
- README.md +2 -0
- src/axolotl/core/trainer_builder.py +18 -3
- src/axolotl/monkeypatch/relora.py +3 -1
README.md
CHANGED
|
@@ -734,6 +734,8 @@ peft:
|
|
| 734 |
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
| 735 |
relora_steps: # Number of steps per ReLoRA restart
|
| 736 |
relora_warmup_steps: # Number of per-restart warmup steps
|
|
|
|
|
|
|
| 737 |
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
| 738 |
|
| 739 |
# wandb configuration if you're using it
|
|
|
|
| 734 |
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
| 735 |
relora_steps: # Number of steps per ReLoRA restart
|
| 736 |
relora_warmup_steps: # Number of per-restart warmup steps
|
| 737 |
+
relora_anneal_steps: # Number of anneal steps for each relora cycle
|
| 738 |
+
relora_prune_ratio: # threshold for optimizer magnitude when pruning
|
| 739 |
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
| 740 |
|
| 741 |
# wandb configuration if you're using it
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -131,6 +131,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
| 131 |
default=None,
|
| 132 |
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
| 133 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
bench_split: Optional[str] = field(
|
| 135 |
default="eval", metadata={"help": "The benchmark split to run on"}
|
| 136 |
)
|
|
@@ -900,9 +904,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 900 |
training_arguments_kwargs[
|
| 901 |
"sample_packing_seq_len_multiplier"
|
| 902 |
] = self.cfg.micro_batch_size
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 906 |
training_arguments_kwargs = self.hook_pre_create_training_args(
|
| 907 |
training_arguments_kwargs
|
| 908 |
)
|
|
|
|
| 131 |
default=None,
|
| 132 |
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
| 133 |
)
|
| 134 |
+
relora_prune_ratio: Optional[float] = field(
|
| 135 |
+
default=0.9,
|
| 136 |
+
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
| 137 |
+
)
|
| 138 |
bench_split: Optional[str] = field(
|
| 139 |
default="eval", metadata={"help": "The benchmark split to run on"}
|
| 140 |
)
|
|
|
|
| 904 |
training_arguments_kwargs[
|
| 905 |
"sample_packing_seq_len_multiplier"
|
| 906 |
] = self.cfg.micro_batch_size
|
| 907 |
+
if self.cfg.relora_steps:
|
| 908 |
+
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
| 909 |
+
training_arguments_kwargs[
|
| 910 |
+
"relora_warmup_steps"
|
| 911 |
+
] = self.cfg.relora_warmup_steps
|
| 912 |
+
if self.cfg.relora_anneal_steps:
|
| 913 |
+
training_arguments_kwargs[
|
| 914 |
+
"relora_anneal_steps"
|
| 915 |
+
] = self.cfg.relora_anneal_steps
|
| 916 |
+
if self.cfg.relora_prune_ratio:
|
| 917 |
+
training_arguments_kwargs[
|
| 918 |
+
"relora_prune_ratio"
|
| 919 |
+
] = self.cfg.relora_prune_ratio
|
| 920 |
+
|
| 921 |
training_arguments_kwargs = self.hook_pre_create_training_args(
|
| 922 |
training_arguments_kwargs
|
| 923 |
)
|
src/axolotl/monkeypatch/relora.py
CHANGED
|
@@ -46,8 +46,9 @@ def reset_optimizer(
|
|
| 46 |
*,
|
| 47 |
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
| 48 |
optimizer_state_keys: list[str],
|
|
|
|
| 49 |
):
|
| 50 |
-
pruning_fn = partial(magnitude_pruning_, prune_ratio=
|
| 51 |
n_zeros = 0
|
| 52 |
n_total = 0
|
| 53 |
|
|
@@ -159,6 +160,7 @@ class ReLoRACallback(TrainerCallback):
|
|
| 159 |
optimizer,
|
| 160 |
reset_params=lora_params,
|
| 161 |
optimizer_state_keys=optimizer_state_keys,
|
|
|
|
| 162 |
)
|
| 163 |
|
| 164 |
if self.quantized:
|
|
|
|
| 46 |
*,
|
| 47 |
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
| 48 |
optimizer_state_keys: list[str],
|
| 49 |
+
prune_ratio: float = 0.9,
|
| 50 |
):
|
| 51 |
+
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
| 52 |
n_zeros = 0
|
| 53 |
n_total = 0
|
| 54 |
|
|
|
|
| 160 |
optimizer,
|
| 161 |
reset_params=lora_params,
|
| 162 |
optimizer_state_keys=optimizer_state_keys,
|
| 163 |
+
prune_ratio=args.relora_prune_ratio,
|
| 164 |
)
|
| 165 |
|
| 166 |
if self.quantized:
|