optionally configure sample packing for evals (#589)
Browse files- src/axolotl/utils/trainer.py +11 -2
src/axolotl/utils/trainer.py
CHANGED
|
@@ -117,6 +117,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
| 117 |
default=False,
|
| 118 |
metadata={"help": "Use sample packing for efficient training."},
|
| 119 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
sample_packing_efficiency: float = field(
|
| 121 |
default=1.0,
|
| 122 |
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
|
@@ -212,7 +216,11 @@ class AxolotlTrainer(Trainer):
|
|
| 212 |
def _get_eval_sampler(
|
| 213 |
self, eval_dataset: Dataset
|
| 214 |
) -> Optional[torch.utils.data.Sampler]:
|
| 215 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
return SequentialDistributedSampler(
|
| 217 |
eval_dataset,
|
| 218 |
num_replicas=self.args.world_size,
|
|
@@ -241,7 +249,7 @@ class AxolotlTrainer(Trainer):
|
|
| 241 |
def get_eval_dataloader(
|
| 242 |
self, eval_dataset: Optional[Dataset] = None
|
| 243 |
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
| 244 |
-
if self.args.sample_packing:
|
| 245 |
eval_dataset = (
|
| 246 |
eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 247 |
)
|
|
@@ -659,6 +667,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 659 |
else "cosine",
|
| 660 |
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
| 661 |
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
|
|
|
| 662 |
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
| 663 |
relora_steps=cfg.relora_steps,
|
| 664 |
relora_warmup_steps=cfg.relora_warmup_steps,
|
|
|
|
| 117 |
default=False,
|
| 118 |
metadata={"help": "Use sample packing for efficient training."},
|
| 119 |
)
|
| 120 |
+
eval_sample_packing: Optional[bool] = field(
|
| 121 |
+
default=None,
|
| 122 |
+
metadata={"help": "Use sample packing for efficient evals."},
|
| 123 |
+
)
|
| 124 |
sample_packing_efficiency: float = field(
|
| 125 |
default=1.0,
|
| 126 |
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
|
|
|
| 216 |
def _get_eval_sampler(
|
| 217 |
self, eval_dataset: Dataset
|
| 218 |
) -> Optional[torch.utils.data.Sampler]:
|
| 219 |
+
if (
|
| 220 |
+
self.args.world_size > 1
|
| 221 |
+
and self.args.sample_packing
|
| 222 |
+
and self.args.eval_sample_packing is not False
|
| 223 |
+
):
|
| 224 |
return SequentialDistributedSampler(
|
| 225 |
eval_dataset,
|
| 226 |
num_replicas=self.args.world_size,
|
|
|
|
| 249 |
def get_eval_dataloader(
|
| 250 |
self, eval_dataset: Optional[Dataset] = None
|
| 251 |
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
| 252 |
+
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
| 253 |
eval_dataset = (
|
| 254 |
eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 255 |
)
|
|
|
|
| 667 |
else "cosine",
|
| 668 |
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
| 669 |
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
| 670 |
+
eval_sample_packing=cfg.eval_sample_packing,
|
| 671 |
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
| 672 |
relora_steps=cfg.relora_steps,
|
| 673 |
relora_warmup_steps=cfg.relora_warmup_steps,
|