update table for rwkv4 support, fix process count for dataset (#822)
Browse files- README.md +1 -0
- src/axolotl/datasets.py +8 -2
- src/axolotl/utils/data.py +30 -10
README.md
CHANGED
|
@@ -74,6 +74,7 @@ Features:
|
|
| 74 |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
| 75 |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
| 76 |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
## Quickstart ⚡
|
|
|
|
| 74 |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
| 75 |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
| 76 |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| 77 |
+
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
| 78 |
|
| 79 |
|
| 80 |
## Quickstart ⚡
|
src/axolotl/datasets.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
import os
|
| 5 |
-
from typing import List
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from datasets import Dataset, IterableDataset
|
|
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
|
|
| 30 |
self,
|
| 31 |
prompt_tokenizer: PromptTokenizingStrategy,
|
| 32 |
dataset: IterableDataset,
|
|
|
|
| 33 |
**kwargs,
|
| 34 |
):
|
| 35 |
self.prompt_tokenizer = prompt_tokenizer
|
|
|
|
| 36 |
super().__init__(self.process(dataset).data, **kwargs)
|
| 37 |
|
| 38 |
def process(self, dataset):
|
| 39 |
features = dataset.features.keys()
|
| 40 |
-
num_proc =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
map_kwargs = {}
|
| 42 |
if self.prompt_tokenizer.supports_batched:
|
| 43 |
map_kwargs["batched"] = True
|
|
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
import os
|
| 5 |
+
from typing import List, Optional
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from datasets import Dataset, IterableDataset
|
|
|
|
| 30 |
self,
|
| 31 |
prompt_tokenizer: PromptTokenizingStrategy,
|
| 32 |
dataset: IterableDataset,
|
| 33 |
+
process_count: Optional[int] = None,
|
| 34 |
**kwargs,
|
| 35 |
):
|
| 36 |
self.prompt_tokenizer = prompt_tokenizer
|
| 37 |
+
self.process_count = process_count
|
| 38 |
super().__init__(self.process(dataset).data, **kwargs)
|
| 39 |
|
| 40 |
def process(self, dataset):
|
| 41 |
features = dataset.features.keys()
|
| 42 |
+
num_proc = (
|
| 43 |
+
min(64, self.process_count)
|
| 44 |
+
if self.process_count
|
| 45 |
+
else min(64, os.cpu_count())
|
| 46 |
+
)
|
| 47 |
map_kwargs = {}
|
| 48 |
if self.prompt_tokenizer.supports_batched:
|
| 49 |
map_kwargs["batched"] = True
|
src/axolotl/utils/data.py
CHANGED
|
@@ -482,10 +482,14 @@ def get_dataset_wrapper(
|
|
| 482 |
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
| 483 |
)
|
| 484 |
dataset_prompter = UnsupportedPrompter()
|
| 485 |
-
dataset_wrapper = TokenizedPromptDataset(
|
|
|
|
|
|
|
| 486 |
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
| 487 |
dataset_prompter = UnsupportedPrompter()
|
| 488 |
-
dataset_wrapper = TokenizedPromptDataset(
|
|
|
|
|
|
|
| 489 |
elif d_base_type == "alpaca":
|
| 490 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
| 491 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
|
@@ -494,7 +498,9 @@ def get_dataset_wrapper(
|
|
| 494 |
cfg.train_on_inputs,
|
| 495 |
cfg.sequence_len,
|
| 496 |
)
|
| 497 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
|
|
| 498 |
dataset_wrapper = ds_wrapper
|
| 499 |
elif d_base_type == "explainchoice":
|
| 500 |
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
|
@@ -504,7 +510,9 @@ def get_dataset_wrapper(
|
|
| 504 |
cfg.train_on_inputs,
|
| 505 |
cfg.sequence_len,
|
| 506 |
)
|
| 507 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
|
|
| 508 |
dataset_wrapper = ds_wrapper
|
| 509 |
elif d_base_type == "concisechoice":
|
| 510 |
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
|
@@ -514,7 +522,9 @@ def get_dataset_wrapper(
|
|
| 514 |
cfg.train_on_inputs,
|
| 515 |
cfg.sequence_len,
|
| 516 |
)
|
| 517 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
|
|
| 518 |
dataset_wrapper = ds_wrapper
|
| 519 |
elif d_base_type == "summarizetldr":
|
| 520 |
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
|
@@ -524,7 +534,9 @@ def get_dataset_wrapper(
|
|
| 524 |
cfg.train_on_inputs,
|
| 525 |
cfg.sequence_len,
|
| 526 |
)
|
| 527 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
|
|
| 528 |
dataset_wrapper = ds_wrapper
|
| 529 |
elif d_base_type == "jeopardy":
|
| 530 |
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
|
@@ -534,7 +546,9 @@ def get_dataset_wrapper(
|
|
| 534 |
cfg.train_on_inputs,
|
| 535 |
cfg.sequence_len,
|
| 536 |
)
|
| 537 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
|
|
| 538 |
dataset_wrapper = ds_wrapper
|
| 539 |
elif d_base_type == "oasst":
|
| 540 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
|
@@ -544,7 +558,9 @@ def get_dataset_wrapper(
|
|
| 544 |
cfg.train_on_inputs,
|
| 545 |
cfg.sequence_len,
|
| 546 |
)
|
| 547 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
|
|
| 548 |
dataset_wrapper = ds_wrapper
|
| 549 |
elif d_base_type == "gpteacher":
|
| 550 |
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
|
@@ -554,7 +570,9 @@ def get_dataset_wrapper(
|
|
| 554 |
cfg.train_on_inputs,
|
| 555 |
cfg.sequence_len,
|
| 556 |
)
|
| 557 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
|
|
| 558 |
dataset_wrapper = ds_wrapper
|
| 559 |
elif d_base_type == "reflection":
|
| 560 |
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
|
@@ -564,7 +582,9 @@ def get_dataset_wrapper(
|
|
| 564 |
cfg.train_on_inputs,
|
| 565 |
cfg.sequence_len,
|
| 566 |
)
|
| 567 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
|
|
| 568 |
dataset_wrapper = ds_wrapper
|
| 569 |
else:
|
| 570 |
suffix = ""
|
|
|
|
| 482 |
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
| 483 |
)
|
| 484 |
dataset_prompter = UnsupportedPrompter()
|
| 485 |
+
dataset_wrapper = TokenizedPromptDataset(
|
| 486 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
| 487 |
+
)
|
| 488 |
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
| 489 |
dataset_prompter = UnsupportedPrompter()
|
| 490 |
+
dataset_wrapper = TokenizedPromptDataset(
|
| 491 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
| 492 |
+
)
|
| 493 |
elif d_base_type == "alpaca":
|
| 494 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
| 495 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
|
|
|
| 498 |
cfg.train_on_inputs,
|
| 499 |
cfg.sequence_len,
|
| 500 |
)
|
| 501 |
+
ds_wrapper = TokenizedPromptDataset(
|
| 502 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
| 503 |
+
)
|
| 504 |
dataset_wrapper = ds_wrapper
|
| 505 |
elif d_base_type == "explainchoice":
|
| 506 |
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
|
|
|
| 510 |
cfg.train_on_inputs,
|
| 511 |
cfg.sequence_len,
|
| 512 |
)
|
| 513 |
+
ds_wrapper = TokenizedPromptDataset(
|
| 514 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
| 515 |
+
)
|
| 516 |
dataset_wrapper = ds_wrapper
|
| 517 |
elif d_base_type == "concisechoice":
|
| 518 |
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
|
|
|
| 522 |
cfg.train_on_inputs,
|
| 523 |
cfg.sequence_len,
|
| 524 |
)
|
| 525 |
+
ds_wrapper = TokenizedPromptDataset(
|
| 526 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
| 527 |
+
)
|
| 528 |
dataset_wrapper = ds_wrapper
|
| 529 |
elif d_base_type == "summarizetldr":
|
| 530 |
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
|
|
|
| 534 |
cfg.train_on_inputs,
|
| 535 |
cfg.sequence_len,
|
| 536 |
)
|
| 537 |
+
ds_wrapper = TokenizedPromptDataset(
|
| 538 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
| 539 |
+
)
|
| 540 |
dataset_wrapper = ds_wrapper
|
| 541 |
elif d_base_type == "jeopardy":
|
| 542 |
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
|
|
|
| 546 |
cfg.train_on_inputs,
|
| 547 |
cfg.sequence_len,
|
| 548 |
)
|
| 549 |
+
ds_wrapper = TokenizedPromptDataset(
|
| 550 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
| 551 |
+
)
|
| 552 |
dataset_wrapper = ds_wrapper
|
| 553 |
elif d_base_type == "oasst":
|
| 554 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
|
|
|
| 558 |
cfg.train_on_inputs,
|
| 559 |
cfg.sequence_len,
|
| 560 |
)
|
| 561 |
+
ds_wrapper = TokenizedPromptDataset(
|
| 562 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
| 563 |
+
)
|
| 564 |
dataset_wrapper = ds_wrapper
|
| 565 |
elif d_base_type == "gpteacher":
|
| 566 |
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
|
|
|
| 570 |
cfg.train_on_inputs,
|
| 571 |
cfg.sequence_len,
|
| 572 |
)
|
| 573 |
+
ds_wrapper = TokenizedPromptDataset(
|
| 574 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
| 575 |
+
)
|
| 576 |
dataset_wrapper = ds_wrapper
|
| 577 |
elif d_base_type == "reflection":
|
| 578 |
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
|
|
|
| 582 |
cfg.train_on_inputs,
|
| 583 |
cfg.sequence_len,
|
| 584 |
)
|
| 585 |
+
ds_wrapper = TokenizedPromptDataset(
|
| 586 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
| 587 |
+
)
|
| 588 |
dataset_wrapper = ds_wrapper
|
| 589 |
else:
|
| 590 |
suffix = ""
|