use context manager to run things on rank0 before others (#397)
Browse files- scripts/finetune.py +2 -9
- src/axolotl/utils/data.py +2 -12
- src/axolotl/utils/distributed.py +14 -0
scripts/finetune.py
CHANGED
|
@@ -21,7 +21,7 @@ from axolotl.logging_config import configure_logging
|
|
| 21 |
from axolotl.utils.config import normalize_config, validate_config
|
| 22 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
| 23 |
from axolotl.utils.dict import DictDefault
|
| 24 |
-
from axolotl.utils.distributed import
|
| 25 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 26 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 27 |
from axolotl.utils.trainer import (
|
|
@@ -198,17 +198,10 @@ def train(
|
|
| 198 |
train_dataset = train_dataset.with_format("torch")
|
| 199 |
eval_dataset = None
|
| 200 |
|
| 201 |
-
|
| 202 |
-
# process on rank 0 first so it gets cached so other ranks load from cache
|
| 203 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
| 204 |
cfg, train_dataset, eval_dataset
|
| 205 |
)
|
| 206 |
-
barrier()
|
| 207 |
-
if not is_main_process():
|
| 208 |
-
train_dataset, eval_dataset = process_datasets_for_packing(
|
| 209 |
-
cfg, train_dataset, eval_dataset
|
| 210 |
-
)
|
| 211 |
-
barrier()
|
| 212 |
if cfg.max_steps:
|
| 213 |
total_num_steps = min(
|
| 214 |
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
|
|
|
| 21 |
from axolotl.utils.config import normalize_config, validate_config
|
| 22 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
| 23 |
from axolotl.utils.dict import DictDefault
|
| 24 |
+
from axolotl.utils.distributed import is_main_process, zero_first
|
| 25 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 26 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 27 |
from axolotl.utils.trainer import (
|
|
|
|
| 198 |
train_dataset = train_dataset.with_format("torch")
|
| 199 |
eval_dataset = None
|
| 200 |
|
| 201 |
+
with zero_first(is_main_process()):
|
|
|
|
| 202 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
| 203 |
cfg, train_dataset, eval_dataset
|
| 204 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
if cfg.max_steps:
|
| 206 |
total_num_steps = min(
|
| 207 |
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
src/axolotl/utils/data.py
CHANGED
|
@@ -41,7 +41,7 @@ from axolotl.prompters import (
|
|
| 41 |
ShareGPTPrompter,
|
| 42 |
SummarizeTLDRPrompter,
|
| 43 |
)
|
| 44 |
-
from axolotl.utils.distributed import
|
| 45 |
|
| 46 |
LOG = logging.getLogger("axolotl")
|
| 47 |
|
|
@@ -440,7 +440,7 @@ def load_prepare_datasets(
|
|
| 440 |
to_hash_test.encode(), usedforsecurity=False
|
| 441 |
).hexdigest()
|
| 442 |
|
| 443 |
-
|
| 444 |
dataset = dataset.train_test_split(
|
| 445 |
test_size=cfg.val_set_size,
|
| 446 |
shuffle=False,
|
|
@@ -448,16 +448,6 @@ def load_prepare_datasets(
|
|
| 448 |
train_new_fingerprint=train_fingerprint,
|
| 449 |
test_new_fingerprint=test_fingerprint,
|
| 450 |
)
|
| 451 |
-
barrier()
|
| 452 |
-
if not is_main_process():
|
| 453 |
-
dataset = dataset.train_test_split(
|
| 454 |
-
test_size=cfg.val_set_size,
|
| 455 |
-
shuffle=False,
|
| 456 |
-
seed=cfg.seed or 42,
|
| 457 |
-
train_new_fingerprint=train_fingerprint,
|
| 458 |
-
test_new_fingerprint=test_fingerprint,
|
| 459 |
-
)
|
| 460 |
-
barrier()
|
| 461 |
|
| 462 |
train_dataset = dataset["train"]
|
| 463 |
eval_dataset = dataset["test"]
|
|
|
|
| 41 |
ShareGPTPrompter,
|
| 42 |
SummarizeTLDRPrompter,
|
| 43 |
)
|
| 44 |
+
from axolotl.utils.distributed import is_main_process, zero_first
|
| 45 |
|
| 46 |
LOG = logging.getLogger("axolotl")
|
| 47 |
|
|
|
|
| 440 |
to_hash_test.encode(), usedforsecurity=False
|
| 441 |
).hexdigest()
|
| 442 |
|
| 443 |
+
with zero_first(is_main_process()):
|
| 444 |
dataset = dataset.train_test_split(
|
| 445 |
test_size=cfg.val_set_size,
|
| 446 |
shuffle=False,
|
|
|
|
| 448 |
train_new_fingerprint=train_fingerprint,
|
| 449 |
test_new_fingerprint=test_fingerprint,
|
| 450 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
train_dataset = dataset["train"]
|
| 453 |
eval_dataset = dataset["test"]
|
src/axolotl/utils/distributed.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
utility helpers for distributed checks
|
| 3 |
"""
|
|
|
|
|
|
|
| 4 |
import torch.distributed as dist
|
| 5 |
from accelerate import Accelerator
|
| 6 |
|
|
@@ -39,3 +41,15 @@ def is_main_process():
|
|
| 39 |
if not is_distributed():
|
| 40 |
return True
|
| 41 |
return dist.get_rank() == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
utility helpers for distributed checks
|
| 3 |
"""
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
|
| 6 |
import torch.distributed as dist
|
| 7 |
from accelerate import Accelerator
|
| 8 |
|
|
|
|
| 41 |
if not is_distributed():
|
| 42 |
return True
|
| 43 |
return dist.get_rank() == 0
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@contextmanager
|
| 47 |
+
def zero_first(is_main):
|
| 48 |
+
"""
|
| 49 |
+
runs the wrapped context so that rank 0 runs first before other ranks
|
| 50 |
+
"""
|
| 51 |
+
if not is_main: # other ranks wait first
|
| 52 |
+
barrier()
|
| 53 |
+
yield
|
| 54 |
+
if is_main: # then rank 0 waits after it has run the context
|
| 55 |
+
barrier()
|