Add seq2seq eval benchmark callback (#1274)
Browse files* Add CausalLMBenchEvalCallback for measuring seq2seq performance
* Fix code for pre-commit
* Fix typing and improve logging
* eval_sample_packing must be false with CausalLMBenchEvalCallback
- README.md +2 -1
- examples/llama-2/loftq.yml +1 -1
- examples/llama-2/lora.yml +1 -1
- examples/mamba/config.yml +1 -1
- examples/mistral/Mistral-7b-example/config.yml +1 -1
- examples/mistral/config.yml +1 -1
- examples/mistral/mixtral.yml +1 -1
- examples/mistral/qlora.yml +1 -1
- examples/qwen/lora.yml +1 -1
- examples/qwen/qlora.yml +1 -1
- examples/yi-34B-chat/qlora.yml +1 -1
- requirements.txt +1 -1
- src/axolotl/core/trainer_builder.py +11 -0
- src/axolotl/utils/callbacks.py +182 -1
- src/axolotl/utils/config.py +22 -1
README.md
CHANGED
|
@@ -784,7 +784,8 @@ save_total_limit: # Checkpoints saved at a time
|
|
| 784 |
max_steps:
|
| 785 |
|
| 786 |
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
| 787 |
-
|
|
|
|
| 788 |
|
| 789 |
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
| 790 |
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
|
|
|
| 784 |
max_steps:
|
| 785 |
|
| 786 |
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
| 787 |
+
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
| 788 |
+
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
|
| 789 |
|
| 790 |
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
| 791 |
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
examples/llama-2/loftq.yml
CHANGED
|
@@ -60,7 +60,7 @@ s2_attention:
|
|
| 60 |
warmup_steps: 10
|
| 61 |
evals_per_epoch: 4
|
| 62 |
eval_table_size:
|
| 63 |
-
|
| 64 |
saves_per_epoch: 1
|
| 65 |
debug:
|
| 66 |
deepspeed:
|
|
|
|
| 60 |
warmup_steps: 10
|
| 61 |
evals_per_epoch: 4
|
| 62 |
eval_table_size:
|
| 63 |
+
eval_max_new_tokens: 128
|
| 64 |
saves_per_epoch: 1
|
| 65 |
debug:
|
| 66 |
deepspeed:
|
examples/llama-2/lora.yml
CHANGED
|
@@ -57,7 +57,7 @@ s2_attention:
|
|
| 57 |
warmup_steps: 10
|
| 58 |
evals_per_epoch: 4
|
| 59 |
eval_table_size:
|
| 60 |
-
|
| 61 |
saves_per_epoch: 1
|
| 62 |
debug:
|
| 63 |
deepspeed:
|
|
|
|
| 57 |
warmup_steps: 10
|
| 58 |
evals_per_epoch: 4
|
| 59 |
eval_table_size:
|
| 60 |
+
eval_max_new_tokens: 128
|
| 61 |
saves_per_epoch: 1
|
| 62 |
debug:
|
| 63 |
deepspeed:
|
examples/mamba/config.yml
CHANGED
|
@@ -49,7 +49,7 @@ flash_attention:
|
|
| 49 |
warmup_steps: 10
|
| 50 |
evals_per_epoch: 4
|
| 51 |
eval_table_size:
|
| 52 |
-
|
| 53 |
saves_per_epoch: 1
|
| 54 |
debug:
|
| 55 |
deepspeed:
|
|
|
|
| 49 |
warmup_steps: 10
|
| 50 |
evals_per_epoch: 4
|
| 51 |
eval_table_size:
|
| 52 |
+
eval_max_new_tokens: 128
|
| 53 |
saves_per_epoch: 1
|
| 54 |
debug:
|
| 55 |
deepspeed:
|
examples/mistral/Mistral-7b-example/config.yml
CHANGED
|
@@ -61,7 +61,7 @@ flash_attention: true
|
|
| 61 |
warmup_steps: 10
|
| 62 |
evals_per_epoch: 4
|
| 63 |
eval_table_size:
|
| 64 |
-
|
| 65 |
saves_per_epoch: 1
|
| 66 |
debug:
|
| 67 |
#default deepspeed, can use more aggresive if needed like zero2, zero3
|
|
|
|
| 61 |
warmup_steps: 10
|
| 62 |
evals_per_epoch: 4
|
| 63 |
eval_table_size:
|
| 64 |
+
eval_max_new_tokens: 128
|
| 65 |
saves_per_epoch: 1
|
| 66 |
debug:
|
| 67 |
#default deepspeed, can use more aggresive if needed like zero2, zero3
|
examples/mistral/config.yml
CHANGED
|
@@ -49,7 +49,7 @@ flash_attention: true
|
|
| 49 |
warmup_steps: 10
|
| 50 |
evals_per_epoch: 4
|
| 51 |
eval_table_size:
|
| 52 |
-
|
| 53 |
saves_per_epoch: 1
|
| 54 |
debug:
|
| 55 |
deepspeed:
|
|
|
|
| 49 |
warmup_steps: 10
|
| 50 |
evals_per_epoch: 4
|
| 51 |
eval_table_size:
|
| 52 |
+
eval_max_new_tokens: 128
|
| 53 |
saves_per_epoch: 1
|
| 54 |
debug:
|
| 55 |
deepspeed:
|
examples/mistral/mixtral.yml
CHANGED
|
@@ -81,7 +81,7 @@ loss_watchdog_patience: 3
|
|
| 81 |
warmup_steps: 10
|
| 82 |
evals_per_epoch: 4
|
| 83 |
eval_table_size:
|
| 84 |
-
|
| 85 |
saves_per_epoch: 1
|
| 86 |
debug:
|
| 87 |
deepspeed: deepspeed_configs/zero2.json
|
|
|
|
| 81 |
warmup_steps: 10
|
| 82 |
evals_per_epoch: 4
|
| 83 |
eval_table_size:
|
| 84 |
+
eval_max_new_tokens: 128
|
| 85 |
saves_per_epoch: 1
|
| 86 |
debug:
|
| 87 |
deepspeed: deepspeed_configs/zero2.json
|
examples/mistral/qlora.yml
CHANGED
|
@@ -68,7 +68,7 @@ loss_watchdog_patience: 3
|
|
| 68 |
warmup_steps: 10
|
| 69 |
evals_per_epoch: 4
|
| 70 |
eval_table_size:
|
| 71 |
-
|
| 72 |
saves_per_epoch: 1
|
| 73 |
debug:
|
| 74 |
deepspeed:
|
|
|
|
| 68 |
warmup_steps: 10
|
| 69 |
evals_per_epoch: 4
|
| 70 |
eval_table_size:
|
| 71 |
+
eval_max_new_tokens: 128
|
| 72 |
saves_per_epoch: 1
|
| 73 |
debug:
|
| 74 |
deepspeed:
|
examples/qwen/lora.yml
CHANGED
|
@@ -58,7 +58,7 @@ flash_attention:
|
|
| 58 |
warmup_steps: 10
|
| 59 |
evals_per_epoch: 4
|
| 60 |
eval_table_size:
|
| 61 |
-
|
| 62 |
saves_per_epoch: 1
|
| 63 |
debug:
|
| 64 |
deepspeed:
|
|
|
|
| 58 |
warmup_steps: 10
|
| 59 |
evals_per_epoch: 4
|
| 60 |
eval_table_size:
|
| 61 |
+
eval_max_new_tokens: 128
|
| 62 |
saves_per_epoch: 1
|
| 63 |
debug:
|
| 64 |
deepspeed:
|
examples/qwen/qlora.yml
CHANGED
|
@@ -58,7 +58,7 @@ flash_attention:
|
|
| 58 |
warmup_steps: 10
|
| 59 |
evals_per_epoch: 4
|
| 60 |
eval_table_size:
|
| 61 |
-
|
| 62 |
saves_per_epoch: 1
|
| 63 |
debug:
|
| 64 |
deepspeed:
|
|
|
|
| 58 |
warmup_steps: 10
|
| 59 |
evals_per_epoch: 4
|
| 60 |
eval_table_size:
|
| 61 |
+
eval_max_new_tokens: 128
|
| 62 |
saves_per_epoch: 1
|
| 63 |
debug:
|
| 64 |
deepspeed:
|
examples/yi-34B-chat/qlora.yml
CHANGED
|
@@ -29,7 +29,7 @@ num_epochs: 1
|
|
| 29 |
val_set_size: 0.1
|
| 30 |
evals_per_epoch: 5
|
| 31 |
eval_table_size:
|
| 32 |
-
|
| 33 |
eval_sample_packing: false
|
| 34 |
eval_batch_size: 1
|
| 35 |
|
|
|
|
| 29 |
val_set_size: 0.1
|
| 30 |
evals_per_epoch: 5
|
| 31 |
eval_table_size:
|
| 32 |
+
eval_max_new_tokens: 128
|
| 33 |
eval_sample_packing: false
|
| 34 |
eval_batch_size: 1
|
| 35 |
|
requirements.txt
CHANGED
|
@@ -23,7 +23,7 @@ numba
|
|
| 23 |
numpy>=1.24.4
|
| 24 |
mlflow
|
| 25 |
# qlora things
|
| 26 |
-
evaluate==0.4.
|
| 27 |
scipy
|
| 28 |
scikit-learn==1.2.2
|
| 29 |
pynvml
|
|
|
|
| 23 |
numpy>=1.24.4
|
| 24 |
mlflow
|
| 25 |
# qlora things
|
| 26 |
+
evaluate==0.4.1
|
| 27 |
scipy
|
| 28 |
scikit-learn==1.2.2
|
| 29 |
pynvml
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -38,6 +38,7 @@ from axolotl.utils.callbacks import (
|
|
| 38 |
SaveAxolotlConfigtoWandBCallback,
|
| 39 |
SaveBetterTransformerModelCallback,
|
| 40 |
bench_eval_callback_factory,
|
|
|
|
| 41 |
log_prediction_callback_factory,
|
| 42 |
)
|
| 43 |
from axolotl.utils.collators import (
|
|
@@ -148,6 +149,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
| 148 |
do_bench_eval: Optional[bool] = field(
|
| 149 |
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
| 150 |
)
|
|
|
|
|
|
|
|
|
|
| 151 |
max_bench_samples: Optional[int] = field(
|
| 152 |
default=None,
|
| 153 |
metadata={
|
|
@@ -664,6 +668,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 664 |
|
| 665 |
if self.cfg.do_bench_eval:
|
| 666 |
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 667 |
|
| 668 |
if self.cfg.early_stopping_patience:
|
| 669 |
early_stop_cb = EarlyStoppingCallback(
|
|
@@ -812,6 +821,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 812 |
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
| 813 |
if self.cfg.bench_dataset:
|
| 814 |
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
|
|
|
|
|
|
| 815 |
if self.cfg.metric_for_best_model:
|
| 816 |
training_arguments_kwargs[
|
| 817 |
"metric_for_best_model"
|
|
|
|
| 38 |
SaveAxolotlConfigtoWandBCallback,
|
| 39 |
SaveBetterTransformerModelCallback,
|
| 40 |
bench_eval_callback_factory,
|
| 41 |
+
causal_lm_bench_eval_callback_factory,
|
| 42 |
log_prediction_callback_factory,
|
| 43 |
)
|
| 44 |
from axolotl.utils.collators import (
|
|
|
|
| 149 |
do_bench_eval: Optional[bool] = field(
|
| 150 |
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
| 151 |
)
|
| 152 |
+
do_causal_lm_eval: Optional[bool] = field(
|
| 153 |
+
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
| 154 |
+
)
|
| 155 |
max_bench_samples: Optional[int] = field(
|
| 156 |
default=None,
|
| 157 |
metadata={
|
|
|
|
| 668 |
|
| 669 |
if self.cfg.do_bench_eval:
|
| 670 |
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
| 671 |
+
if self.cfg.do_causal_lm_eval:
|
| 672 |
+
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
|
| 673 |
+
trainer, self.tokenizer
|
| 674 |
+
)
|
| 675 |
+
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
|
| 676 |
|
| 677 |
if self.cfg.early_stopping_patience:
|
| 678 |
early_stop_cb = EarlyStoppingCallback(
|
|
|
|
| 821 |
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
| 822 |
if self.cfg.bench_dataset:
|
| 823 |
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
| 824 |
+
if self.cfg.do_causal_lm_eval:
|
| 825 |
+
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
| 826 |
if self.cfg.metric_for_best_model:
|
| 827 |
training_arguments_kwargs[
|
| 828 |
"metric_for_best_model"
|
src/axolotl/utils/callbacks.py
CHANGED
|
@@ -361,6 +361,187 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|
| 361 |
return BenchEvalCallback
|
| 362 |
|
| 363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
| 365 |
class LogPredictionCallback(TrainerCallback):
|
| 366 |
"""Callback to log prediction values during each evaluation"""
|
|
@@ -388,7 +569,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|
| 388 |
|
| 389 |
# pylint: disable=duplicate-code
|
| 390 |
generation_config = GenerationConfig(
|
| 391 |
-
max_new_tokens=self.cfg.
|
| 392 |
bos_token_id=tokenizer.bos_token_id,
|
| 393 |
eos_token_id=tokenizer.eos_token_id,
|
| 394 |
pad_token_id=tokenizer.pad_token_id,
|
|
|
|
| 361 |
return BenchEvalCallback
|
| 362 |
|
| 363 |
|
| 364 |
+
def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
| 365 |
+
class CausalLMBenchEvalCallback(TrainerCallback):
|
| 366 |
+
"""Callback to log prediction values during each evaluation"""
|
| 367 |
+
|
| 368 |
+
def __init__(self, cfg):
|
| 369 |
+
self.cfg = cfg
|
| 370 |
+
self.logged = False
|
| 371 |
+
self.metrics = self.__maybe_load_metrics()
|
| 372 |
+
|
| 373 |
+
def __maybe_load_metrics(self):
|
| 374 |
+
metrics = {}
|
| 375 |
+
for metric in self.cfg.eval_causal_lm_metrics:
|
| 376 |
+
try:
|
| 377 |
+
metrics[metric] = evaluate.load(metric)
|
| 378 |
+
except Exception as exc: # pylint: disable=broad-exception-caught
|
| 379 |
+
LOG.warning(f"{metric}: {exc.args}")
|
| 380 |
+
return metrics
|
| 381 |
+
|
| 382 |
+
def on_evaluate(
|
| 383 |
+
self,
|
| 384 |
+
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
| 385 |
+
state: TrainerState,
|
| 386 |
+
control: TrainerControl,
|
| 387 |
+
train_dataloader, # pylint: disable=unused-argument
|
| 388 |
+
eval_dataloader,
|
| 389 |
+
**kwargs, # pylint: disable=unused-argument
|
| 390 |
+
):
|
| 391 |
+
trainer.model.eval()
|
| 392 |
+
device = torch.device(self.cfg.device)
|
| 393 |
+
|
| 394 |
+
# pylint: disable=duplicate-code
|
| 395 |
+
generation_config = GenerationConfig(
|
| 396 |
+
max_new_tokens=self.cfg.eval_max_new_tokens,
|
| 397 |
+
bos_token_id=tokenizer.bos_token_id,
|
| 398 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 399 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 400 |
+
do_sample=False,
|
| 401 |
+
use_cache=True,
|
| 402 |
+
return_dict_in_generate=True,
|
| 403 |
+
output_attentions=False,
|
| 404 |
+
output_hidden_states=False,
|
| 405 |
+
output_scores=False,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
def find_ranges(lst):
|
| 409 |
+
ranges = []
|
| 410 |
+
start = 0
|
| 411 |
+
for i in range(1, len(lst)):
|
| 412 |
+
if lst[i] == 0:
|
| 413 |
+
ranges.append((start, i - 1))
|
| 414 |
+
start = i
|
| 415 |
+
end = len(lst) - 1
|
| 416 |
+
ranges.append((start, end))
|
| 417 |
+
return ranges
|
| 418 |
+
|
| 419 |
+
def compute(metric: evaluate.Metric, **kwargs):
|
| 420 |
+
# safely compute a metric and return the score if the format is correct
|
| 421 |
+
metric_score = None
|
| 422 |
+
try:
|
| 423 |
+
metric_score = metric.compute(**kwargs)
|
| 424 |
+
return (
|
| 425 |
+
metric_score["score"]
|
| 426 |
+
if "score" in metric_score
|
| 427 |
+
else metric_score["mean_score"]
|
| 428 |
+
)
|
| 429 |
+
except Exception: # pylint: disable=broad-exception-caught
|
| 430 |
+
LOG.debug(
|
| 431 |
+
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
|
| 432 |
+
)
|
| 433 |
+
return metric_score
|
| 434 |
+
|
| 435 |
+
def evaluate_preds(sources, predictions, references):
|
| 436 |
+
scores = {}
|
| 437 |
+
|
| 438 |
+
for metric_name, metric in self.metrics.items():
|
| 439 |
+
score = compute(
|
| 440 |
+
metric,
|
| 441 |
+
references=references,
|
| 442 |
+
predictions=predictions,
|
| 443 |
+
sources=sources,
|
| 444 |
+
)
|
| 445 |
+
score = score or compute(
|
| 446 |
+
metric,
|
| 447 |
+
references=[[r] for r in references],
|
| 448 |
+
predictions=predictions,
|
| 449 |
+
)
|
| 450 |
+
scores[metric_name] = score
|
| 451 |
+
return scores
|
| 452 |
+
|
| 453 |
+
def predict_with_generate():
|
| 454 |
+
eval_src, eval_pred, eval_ref = [], [], []
|
| 455 |
+
|
| 456 |
+
for batch in tqdm(eval_dataloader):
|
| 457 |
+
batch_labels = batch["labels"].to(device)
|
| 458 |
+
batch_input_ids = batch["input_ids"].to(device)
|
| 459 |
+
|
| 460 |
+
if "position_ids" in batch:
|
| 461 |
+
batch_pos_ids = batch["position_ids"].tolist()
|
| 462 |
+
else:
|
| 463 |
+
batch_pos_ids = [None] * len(batch["input_ids"])
|
| 464 |
+
|
| 465 |
+
prompt_token_ids_list = []
|
| 466 |
+
completion_token_ids_list = []
|
| 467 |
+
|
| 468 |
+
for input_ids_all, labels_all, pos_ids in zip(
|
| 469 |
+
batch_input_ids,
|
| 470 |
+
batch_labels,
|
| 471 |
+
batch_pos_ids,
|
| 472 |
+
):
|
| 473 |
+
if pos_ids is None:
|
| 474 |
+
pos_ranges = [(0, len(input_ids_all) - 1)]
|
| 475 |
+
else:
|
| 476 |
+
pos_ranges = find_ranges(pos_ids)
|
| 477 |
+
|
| 478 |
+
for pos_range in pos_ranges:
|
| 479 |
+
start, end = pos_range
|
| 480 |
+
if start == end:
|
| 481 |
+
continue
|
| 482 |
+
|
| 483 |
+
input_ids = input_ids_all[start : end + 1]
|
| 484 |
+
labels = labels_all[start : end + 1]
|
| 485 |
+
|
| 486 |
+
tokens_without_loss = labels == IGNORE_INDEX
|
| 487 |
+
tokens_with_loss = labels != IGNORE_INDEX
|
| 488 |
+
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
|
| 489 |
+
prompt_token_includes = (
|
| 490 |
+
tokens_without_loss & tokens_exclude_padding
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
prompt_token_ids = input_ids[prompt_token_includes]
|
| 494 |
+
prompt_token_ids_list.append(prompt_token_ids)
|
| 495 |
+
|
| 496 |
+
completion_token_ids = input_ids[tokens_with_loss]
|
| 497 |
+
completion_token_ids_list.append(completion_token_ids)
|
| 498 |
+
|
| 499 |
+
prompt_texts = tokenizer.batch_decode(
|
| 500 |
+
prompt_token_ids_list, skip_special_tokens=True
|
| 501 |
+
)
|
| 502 |
+
completion_texts = tokenizer.batch_decode(
|
| 503 |
+
completion_token_ids_list, skip_special_tokens=True
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
with torch.no_grad():
|
| 507 |
+
prompt_encoding = tokenizer(
|
| 508 |
+
prompt_texts, padding=True, return_tensors="pt"
|
| 509 |
+
).to(self.cfg.device)
|
| 510 |
+
predictions = trainer.model.generate(
|
| 511 |
+
**prompt_encoding, generation_config=generation_config
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
| 515 |
+
prediction_without_prompt_tokens_list = []
|
| 516 |
+
for prompt_token_ids, prediction_tokens in zip(
|
| 517 |
+
prompt_token_ids_list, prediction_all_tokens
|
| 518 |
+
):
|
| 519 |
+
prediction_without_prompt_tokens = prediction_tokens[
|
| 520 |
+
len(prompt_token_ids) :
|
| 521 |
+
]
|
| 522 |
+
prediction_without_prompt_tokens_list.append(
|
| 523 |
+
prediction_without_prompt_tokens
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
predicted_texts = tokenizer.batch_decode(
|
| 527 |
+
prediction_without_prompt_tokens_list, skip_special_tokens=True
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
eval_src.extend(prompt_texts)
|
| 531 |
+
eval_pred.extend(predicted_texts)
|
| 532 |
+
eval_ref.extend(completion_texts)
|
| 533 |
+
|
| 534 |
+
return eval_src, eval_pred, eval_ref
|
| 535 |
+
|
| 536 |
+
if is_main_process():
|
| 537 |
+
eval_preds = predict_with_generate()
|
| 538 |
+
trainer.log(evaluate_preds(*eval_preds))
|
| 539 |
+
|
| 540 |
+
return control
|
| 541 |
+
|
| 542 |
+
return CausalLMBenchEvalCallback
|
| 543 |
+
|
| 544 |
+
|
| 545 |
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
| 546 |
class LogPredictionCallback(TrainerCallback):
|
| 547 |
"""Callback to log prediction values during each evaluation"""
|
|
|
|
| 569 |
|
| 570 |
# pylint: disable=duplicate-code
|
| 571 |
generation_config = GenerationConfig(
|
| 572 |
+
max_new_tokens=self.cfg.eval_max_new_tokens,
|
| 573 |
bos_token_id=tokenizer.bos_token_id,
|
| 574 |
eos_token_id=tokenizer.eos_token_id,
|
| 575 |
pad_token_id=tokenizer.pad_token_id,
|
src/axolotl/utils/config.py
CHANGED
|
@@ -56,7 +56,13 @@ def normalize_config(cfg):
|
|
| 56 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 57 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 58 |
cfg.eval_table_size = cfg.eval_table_size or 0
|
| 59 |
-
cfg.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
choose_device(cfg)
|
| 61 |
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
| 62 |
if cfg.ddp:
|
|
@@ -550,6 +556,21 @@ def validate_config(cfg):
|
|
| 550 |
if cfg.fsdp and "bnb" in cfg.optimizer:
|
| 551 |
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
| 552 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
# TODO
|
| 554 |
# MPT 7b
|
| 555 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 56 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 57 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 58 |
cfg.eval_table_size = cfg.eval_table_size or 0
|
| 59 |
+
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
|
| 60 |
+
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
|
| 61 |
+
"sacrebleu",
|
| 62 |
+
"comet",
|
| 63 |
+
"ter",
|
| 64 |
+
"chrf",
|
| 65 |
+
]
|
| 66 |
choose_device(cfg)
|
| 67 |
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
| 68 |
if cfg.ddp:
|
|
|
|
| 556 |
if cfg.fsdp and "bnb" in cfg.optimizer:
|
| 557 |
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
| 558 |
|
| 559 |
+
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
|
| 560 |
+
raise ValueError(
|
| 561 |
+
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
if cfg.eval_causal_lm_metrics:
|
| 565 |
+
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
|
| 566 |
+
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
| 567 |
+
raise ValueError("eval_causal_lm_metrics must be a list")
|
| 568 |
+
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
| 569 |
+
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
|
| 570 |
+
raise ValueError(
|
| 571 |
+
f"eval_causal_lm_metrics must be one of {supported_metrics}"
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
# TODO
|
| 575 |
# MPT 7b
|
| 576 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|