address PR feedback
Browse files- examples/pythia-12b/README.md +1 -1
- examples/pythia-12b/config.yml +2 -2
- scripts/finetune.py +4 -1
- src/axolotl/utils/data.py +2 -2
- src/axolotl/utils/trainer.py +0 -2
examples/pythia-12b/README.md
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
#
|
| 2 |
|
| 3 |
- Single-GPU A100 only (?)
|
| 4 |
|
|
|
|
| 1 |
+
# Pythia 12B
|
| 2 |
|
| 3 |
- Single-GPU A100 only (?)
|
| 4 |
|
examples/pythia-12b/config.yml
CHANGED
|
@@ -22,7 +22,7 @@ lora_dropout: 0.0
|
|
| 22 |
lora_target_modules:
|
| 23 |
lora_target_linear: true
|
| 24 |
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
| 25 |
-
wandb_project:
|
| 26 |
wandb_watch:
|
| 27 |
wandb_run_id:
|
| 28 |
wandb_log_model:
|
|
@@ -45,5 +45,5 @@ resume_from_checkpoint:
|
|
| 45 |
local_rank:
|
| 46 |
gradient_checkpointing: true
|
| 47 |
fsdp:
|
| 48 |
-
|
| 49 |
collator_pad_to_longest: true
|
|
|
|
| 22 |
lora_target_modules:
|
| 23 |
lora_target_linear: true
|
| 24 |
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
| 25 |
+
wandb_project:
|
| 26 |
wandb_watch:
|
| 27 |
wandb_run_id:
|
| 28 |
wandb_log_model:
|
|
|
|
| 45 |
local_rank:
|
| 46 |
gradient_checkpointing: true
|
| 47 |
fsdp:
|
| 48 |
+
fsdp_config:
|
| 49 |
collator_pad_to_longest: true
|
scripts/finetune.py
CHANGED
|
@@ -208,7 +208,10 @@ def train(
|
|
| 208 |
)
|
| 209 |
else:
|
| 210 |
train_dataset = load_pretraining_dataset(
|
| 211 |
-
cfg.pretraining_dataset,
|
|
|
|
|
|
|
|
|
|
| 212 |
)
|
| 213 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
| 214 |
train_dataset = train_dataset.with_format("torch")
|
|
|
|
| 208 |
)
|
| 209 |
else:
|
| 210 |
train_dataset = load_pretraining_dataset(
|
| 211 |
+
cfg.pretraining_dataset,
|
| 212 |
+
tokenizer,
|
| 213 |
+
max_tokens=cfg.sequence_len,
|
| 214 |
+
seed=cfg.seed,
|
| 215 |
)
|
| 216 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
| 217 |
train_dataset = train_dataset.with_format("torch")
|
src/axolotl/utils/data.py
CHANGED
|
@@ -505,10 +505,10 @@ def encode_pretraining(tokenizer, max_tokens, examples):
|
|
| 505 |
return ret
|
| 506 |
|
| 507 |
|
| 508 |
-
def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
|
| 509 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
| 510 |
dataset = load_dataset(path, streaming=True, split="train")
|
| 511 |
-
dataset = dataset.shuffle(seed=
|
| 512 |
# TODO dynamically figure out which columns/features to remove
|
| 513 |
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
| 514 |
return dataset
|
|
|
|
| 505 |
return ret
|
| 506 |
|
| 507 |
|
| 508 |
+
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
| 509 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
| 510 |
dataset = load_dataset(path, streaming=True, split="train")
|
| 511 |
+
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
| 512 |
# TODO dynamically figure out which columns/features to remove
|
| 513 |
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
| 514 |
return dataset
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
"""Module containing the Trainer class and related functions"""
|
| 2 |
|
| 3 |
import importlib
|
| 4 |
-
import logging
|
| 5 |
import math
|
| 6 |
import os
|
| 7 |
import sys
|
|
@@ -232,7 +231,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 232 |
callbacks.append(SavePeftModelCallback)
|
| 233 |
|
| 234 |
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
| 235 |
-
logging.info("Setting up SaveBetterTransformerModelCallback.")
|
| 236 |
callbacks.append(SaveBetterTransformerModelCallback)
|
| 237 |
|
| 238 |
data_collator_kwargs = {
|
|
|
|
| 1 |
"""Module containing the Trainer class and related functions"""
|
| 2 |
|
| 3 |
import importlib
|
|
|
|
| 4 |
import math
|
| 5 |
import os
|
| 6 |
import sys
|
|
|
|
| 231 |
callbacks.append(SavePeftModelCallback)
|
| 232 |
|
| 233 |
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
|
|
|
| 234 |
callbacks.append(SaveBetterTransformerModelCallback)
|
| 235 |
|
| 236 |
data_collator_kwargs = {
|