Tokenization open assistant (#1)
Browse files* refactor prompt tokenization to more easily support open assistant
* add open assisstant handling, more logging, black formatting
- scripts/finetune.py +115 -39
- src/axolotl/prompt_tokenizers.py +34 -12
scripts/finetune.py
CHANGED
|
@@ -37,6 +37,7 @@ from axolotl.prompt_tokenizers import (
|
|
| 37 |
ShareGPTPromptTokenizingStrategy,
|
| 38 |
LLAMA_DEFAULT_PAD_TOKEN,
|
| 39 |
GPTeacherPromptTokenizingStrategy,
|
|
|
|
| 40 |
)
|
| 41 |
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
| 42 |
|
|
@@ -56,7 +57,15 @@ def setup_wandb_env_vars(cfg):
|
|
| 56 |
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
| 57 |
|
| 58 |
|
| 59 |
-
def load_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# TODO refactor as a kwarg
|
| 61 |
load_in_8bit = cfg.load_in_8bit
|
| 62 |
tokenizer = None
|
|
@@ -67,13 +76,17 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|
| 67 |
if is_llama_derived_model and cfg.flash_attention:
|
| 68 |
if cfg.device not in ["mps", "cpu"] and inference is False:
|
| 69 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
|
|
|
| 70 |
logging.info("patching with flash attention")
|
| 71 |
replace_llama_attn_with_flash_attn()
|
| 72 |
|
| 73 |
-
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
|
| 74 |
try:
|
| 75 |
if cfg.load_4bit:
|
| 76 |
-
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import
|
|
|
|
|
|
|
|
|
|
| 77 |
replace_peft_model_with_int4_lora_model()
|
| 78 |
|
| 79 |
from peft import (
|
|
@@ -92,18 +105,26 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|
| 92 |
from huggingface_hub import snapshot_download
|
| 93 |
|
| 94 |
cache_model_path = Path(snapshot_download(base_model))
|
| 95 |
-
files =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
if len(files) > 0:
|
| 97 |
model_path = str(files[0])
|
| 98 |
else:
|
| 99 |
-
logging.warning(
|
|
|
|
|
|
|
| 100 |
model_path = str(cache_model_path)
|
| 101 |
model, tokenizer = load_llama_model_4bit_low_ram(
|
| 102 |
base_model_config if base_model_config else base_model,
|
| 103 |
model_path,
|
| 104 |
device_map=cfg.device_map,
|
| 105 |
groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
|
| 106 |
-
is_v1_model=cfg.gptq_model_v1
|
|
|
|
|
|
|
| 107 |
)
|
| 108 |
load_in_8bit = False
|
| 109 |
elif is_llama_derived_model:
|
|
@@ -120,7 +141,11 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|
| 120 |
torch_dtype=torch_dtype,
|
| 121 |
device_map=cfg.device_map,
|
| 122 |
)
|
| 123 |
-
except:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
model = AutoModelForCausalLM.from_pretrained(
|
| 125 |
base_model,
|
| 126 |
load_in_8bit=cfg.load_in_8bit,
|
|
@@ -145,7 +170,6 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|
| 145 |
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
| 146 |
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
| 147 |
|
| 148 |
-
|
| 149 |
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
| 150 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 151 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
@@ -165,7 +189,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|
| 165 |
)
|
| 166 |
|
| 167 |
if cfg.lora_model_dir:
|
| 168 |
-
model = PeftModel.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
else:
|
| 170 |
model = get_peft_model(model, lora_config)
|
| 171 |
|
|
@@ -174,9 +203,11 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|
| 174 |
|
| 175 |
if cfg.load_4bit:
|
| 176 |
# Scales to half
|
| 177 |
-
logging.info(
|
| 178 |
for n, m in model.named_modules():
|
| 179 |
-
if
|
|
|
|
|
|
|
| 180 |
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
| 181 |
m.zeros = m.zeros.half()
|
| 182 |
m.scales = m.scales.half()
|
|
@@ -236,37 +267,44 @@ def check_dataset_labels(dataset, tokenizer):
|
|
| 236 |
|
| 237 |
|
| 238 |
def do_inference(cfg, model, tokenizer):
|
| 239 |
-
tokenizer.add_special_tokens({
|
| 240 |
-
tokenizer.add_special_tokens({
|
| 241 |
-
tokenizer.add_special_tokens({
|
| 242 |
|
| 243 |
instruction = "Tell me a joke about dromedaries."
|
| 244 |
input = ""
|
| 245 |
-
prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(
|
|
|
|
|
|
|
| 246 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
| 247 |
|
| 248 |
model.eval()
|
| 249 |
with torch.no_grad():
|
| 250 |
# gc = GenerationConfig() # TODO swap out and use this
|
| 251 |
-
generated = model.generate(
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
|
| 265 |
def choose_config(path: Path):
|
| 266 |
yaml_files = [file for file in path.glob("*.yml")]
|
| 267 |
|
| 268 |
if not yaml_files:
|
| 269 |
-
raise ValueError(
|
|
|
|
|
|
|
| 270 |
|
| 271 |
print("Choose a YAML file:")
|
| 272 |
for idx, file in enumerate(yaml_files):
|
|
@@ -376,6 +414,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 376 |
|
| 377 |
return trainer
|
| 378 |
|
|
|
|
| 379 |
def train(
|
| 380 |
config: Path = Path("configs/"),
|
| 381 |
prepare_ds_only: bool = False,
|
|
@@ -420,7 +459,13 @@ def train(
|
|
| 420 |
# Load the model and tokenizer
|
| 421 |
logging.info("loading model, tokenizer, and lora_config...")
|
| 422 |
model, tokenizer, lora_config = load_model(
|
| 423 |
-
cfg.base_model,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
)
|
| 425 |
|
| 426 |
if "inference" in kwargs:
|
|
@@ -428,10 +473,26 @@ def train(
|
|
| 428 |
do_inference(cfg, model, tokenizer)
|
| 429 |
return
|
| 430 |
|
| 431 |
-
max_packed_sequence_len =
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
|
| 436 |
if any(prepared_ds_path.glob("*")):
|
| 437 |
logging.info("Loading prepared dataset from disk...")
|
|
@@ -464,9 +525,18 @@ def train(
|
|
| 464 |
)
|
| 465 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
| 466 |
datasets.append(ds_wrapper)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
elif d.type == "gpteacher":
|
| 468 |
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
| 469 |
-
GPTeacherPrompter(),
|
|
|
|
|
|
|
|
|
|
| 470 |
)
|
| 471 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
| 472 |
datasets.append(ds_wrapper)
|
|
@@ -476,13 +546,17 @@ def train(
|
|
| 476 |
)
|
| 477 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
| 478 |
datasets.append(ds_wrapper)
|
|
|
|
|
|
|
| 479 |
constant_len_dataset = ConstantLengthDataset(
|
| 480 |
-
tokenizer,
|
|
|
|
|
|
|
| 481 |
)
|
| 482 |
logging.info("merging, packing, shuffling, and splitting master dataset")
|
| 483 |
-
dataset = Dataset.from_list(
|
| 484 |
-
|
| 485 |
-
)
|
| 486 |
|
| 487 |
if cfg.local_rank == 0:
|
| 488 |
logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
|
|
@@ -525,7 +599,9 @@ def train(
|
|
| 525 |
|
| 526 |
if cfg.local_rank == 0:
|
| 527 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 528 |
-
logging.info(
|
|
|
|
|
|
|
| 529 |
model.save_pretrained(cfg.output_dir)
|
| 530 |
|
| 531 |
|
|
|
|
| 37 |
ShareGPTPromptTokenizingStrategy,
|
| 38 |
LLAMA_DEFAULT_PAD_TOKEN,
|
| 39 |
GPTeacherPromptTokenizingStrategy,
|
| 40 |
+
OpenAssistantPromptTokenizingStrategy,
|
| 41 |
)
|
| 42 |
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
| 43 |
|
|
|
|
| 57 |
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
| 58 |
|
| 59 |
|
| 60 |
+
def load_model(
|
| 61 |
+
base_model,
|
| 62 |
+
base_model_config,
|
| 63 |
+
model_type,
|
| 64 |
+
tokenizer_type,
|
| 65 |
+
cfg,
|
| 66 |
+
adapter="lora",
|
| 67 |
+
inference: bool = False,
|
| 68 |
+
):
|
| 69 |
# TODO refactor as a kwarg
|
| 70 |
load_in_8bit = cfg.load_in_8bit
|
| 71 |
tokenizer = None
|
|
|
|
| 76 |
if is_llama_derived_model and cfg.flash_attention:
|
| 77 |
if cfg.device not in ["mps", "cpu"] and inference is False:
|
| 78 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
| 79 |
+
|
| 80 |
logging.info("patching with flash attention")
|
| 81 |
replace_llama_attn_with_flash_attn()
|
| 82 |
|
| 83 |
+
torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
|
| 84 |
try:
|
| 85 |
if cfg.load_4bit:
|
| 86 |
+
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
| 87 |
+
replace_peft_model_with_int4_lora_model,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
replace_peft_model_with_int4_lora_model()
|
| 91 |
|
| 92 |
from peft import (
|
|
|
|
| 105 |
from huggingface_hub import snapshot_download
|
| 106 |
|
| 107 |
cache_model_path = Path(snapshot_download(base_model))
|
| 108 |
+
files = (
|
| 109 |
+
list(cache_model_path.glob("*.pt"))
|
| 110 |
+
+ list(cache_model_path.glob("*.safetensors"))
|
| 111 |
+
+ list(cache_model_path.glob("*.bin"))
|
| 112 |
+
)
|
| 113 |
if len(files) > 0:
|
| 114 |
model_path = str(files[0])
|
| 115 |
else:
|
| 116 |
+
logging.warning(
|
| 117 |
+
"unable to find a cached model file, this will likely fail..."
|
| 118 |
+
)
|
| 119 |
model_path = str(cache_model_path)
|
| 120 |
model, tokenizer = load_llama_model_4bit_low_ram(
|
| 121 |
base_model_config if base_model_config else base_model,
|
| 122 |
model_path,
|
| 123 |
device_map=cfg.device_map,
|
| 124 |
groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
|
| 125 |
+
is_v1_model=cfg.gptq_model_v1
|
| 126 |
+
if cfg.gptq_model_v1 is not None
|
| 127 |
+
else True,
|
| 128 |
)
|
| 129 |
load_in_8bit = False
|
| 130 |
elif is_llama_derived_model:
|
|
|
|
| 141 |
torch_dtype=torch_dtype,
|
| 142 |
device_map=cfg.device_map,
|
| 143 |
)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logging.error(
|
| 146 |
+
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
| 147 |
+
)
|
| 148 |
+
logging.exception(e)
|
| 149 |
model = AutoModelForCausalLM.from_pretrained(
|
| 150 |
base_model,
|
| 151 |
load_in_8bit=cfg.load_in_8bit,
|
|
|
|
| 170 |
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
| 171 |
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
| 172 |
|
|
|
|
| 173 |
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
| 174 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 175 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
| 189 |
)
|
| 190 |
|
| 191 |
if cfg.lora_model_dir:
|
| 192 |
+
model = PeftModel.from_pretrained(
|
| 193 |
+
model,
|
| 194 |
+
cfg.lora_model_dir,
|
| 195 |
+
device_map=cfg.device_map,
|
| 196 |
+
torch_dtype=torch.float16,
|
| 197 |
+
)
|
| 198 |
else:
|
| 199 |
model = get_peft_model(model, lora_config)
|
| 200 |
|
|
|
|
| 203 |
|
| 204 |
if cfg.load_4bit:
|
| 205 |
# Scales to half
|
| 206 |
+
logging.info("Fitting 4bit scales and zeros to half")
|
| 207 |
for n, m in model.named_modules():
|
| 208 |
+
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
|
| 209 |
+
type(m)
|
| 210 |
+
):
|
| 211 |
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
| 212 |
m.zeros = m.zeros.half()
|
| 213 |
m.scales = m.scales.half()
|
|
|
|
| 267 |
|
| 268 |
|
| 269 |
def do_inference(cfg, model, tokenizer):
|
| 270 |
+
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
| 271 |
+
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
| 272 |
+
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
| 273 |
|
| 274 |
instruction = "Tell me a joke about dromedaries."
|
| 275 |
input = ""
|
| 276 |
+
prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(
|
| 277 |
+
instruction=instruction, input=input
|
| 278 |
+
)
|
| 279 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
| 280 |
|
| 281 |
model.eval()
|
| 282 |
with torch.no_grad():
|
| 283 |
# gc = GenerationConfig() # TODO swap out and use this
|
| 284 |
+
generated = model.generate(
|
| 285 |
+
inputs=batch["input_ids"].to("cuda"),
|
| 286 |
+
do_sample=True,
|
| 287 |
+
use_cache=True,
|
| 288 |
+
repetition_penalty=1.1,
|
| 289 |
+
max_new_tokens=100,
|
| 290 |
+
temperature=0.9,
|
| 291 |
+
top_p=0.95,
|
| 292 |
+
top_k=40,
|
| 293 |
+
return_dict_in_generate=True,
|
| 294 |
+
output_attentions=False,
|
| 295 |
+
output_hidden_states=False,
|
| 296 |
+
output_scores=False,
|
| 297 |
+
)
|
| 298 |
+
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
| 299 |
|
| 300 |
|
| 301 |
def choose_config(path: Path):
|
| 302 |
yaml_files = [file for file in path.glob("*.yml")]
|
| 303 |
|
| 304 |
if not yaml_files:
|
| 305 |
+
raise ValueError(
|
| 306 |
+
"No YAML config files found in the specified directory. Are you using a .yml extension?"
|
| 307 |
+
)
|
| 308 |
|
| 309 |
print("Choose a YAML file:")
|
| 310 |
for idx, file in enumerate(yaml_files):
|
|
|
|
| 414 |
|
| 415 |
return trainer
|
| 416 |
|
| 417 |
+
|
| 418 |
def train(
|
| 419 |
config: Path = Path("configs/"),
|
| 420 |
prepare_ds_only: bool = False,
|
|
|
|
| 459 |
# Load the model and tokenizer
|
| 460 |
logging.info("loading model, tokenizer, and lora_config...")
|
| 461 |
model, tokenizer, lora_config = load_model(
|
| 462 |
+
cfg.base_model,
|
| 463 |
+
cfg.base_model_config,
|
| 464 |
+
cfg.model_type,
|
| 465 |
+
cfg.tokenizer_type,
|
| 466 |
+
cfg,
|
| 467 |
+
adapter=cfg.adapter,
|
| 468 |
+
inference=("inference" in kwargs),
|
| 469 |
)
|
| 470 |
|
| 471 |
if "inference" in kwargs:
|
|
|
|
| 473 |
do_inference(cfg, model, tokenizer)
|
| 474 |
return
|
| 475 |
|
| 476 |
+
max_packed_sequence_len = (
|
| 477 |
+
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
| 478 |
+
)
|
| 479 |
+
max_packed_sequence_len = min(
|
| 480 |
+
max_packed_sequence_len, cfg.sequence_len
|
| 481 |
+
) # make sure we don't accidentally set it larger than sequence_len
|
| 482 |
+
ds_hash = str(
|
| 483 |
+
md5(
|
| 484 |
+
(
|
| 485 |
+
str(max_packed_sequence_len)
|
| 486 |
+
+ "@"
|
| 487 |
+
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
| 488 |
+
).encode("utf-8")
|
| 489 |
+
).hexdigest()
|
| 490 |
+
)
|
| 491 |
+
prepared_ds_path = (
|
| 492 |
+
Path(cfg.dataset_prepared_path) / ds_hash
|
| 493 |
+
if cfg.dataset_prepared_path
|
| 494 |
+
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
| 495 |
+
)
|
| 496 |
|
| 497 |
if any(prepared_ds_path.glob("*")):
|
| 498 |
logging.info("Loading prepared dataset from disk...")
|
|
|
|
| 525 |
)
|
| 526 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
| 527 |
datasets.append(ds_wrapper)
|
| 528 |
+
elif d.type == "oasst":
|
| 529 |
+
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
| 530 |
+
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
| 531 |
+
)
|
| 532 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
| 533 |
+
datasets.append(ds_wrapper)
|
| 534 |
elif d.type == "gpteacher":
|
| 535 |
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
| 536 |
+
GPTeacherPrompter(),
|
| 537 |
+
tokenizer,
|
| 538 |
+
cfg.train_on_inputs,
|
| 539 |
+
cfg.sequence_len,
|
| 540 |
)
|
| 541 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
| 542 |
datasets.append(ds_wrapper)
|
|
|
|
| 546 |
)
|
| 547 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
| 548 |
datasets.append(ds_wrapper)
|
| 549 |
+
else:
|
| 550 |
+
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
| 551 |
constant_len_dataset = ConstantLengthDataset(
|
| 552 |
+
tokenizer,
|
| 553 |
+
datasets,
|
| 554 |
+
seq_length=max_packed_sequence_len,
|
| 555 |
)
|
| 556 |
logging.info("merging, packing, shuffling, and splitting master dataset")
|
| 557 |
+
dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
|
| 558 |
+
test_size=cfg.val_set_size, shuffle=True, seed=42
|
| 559 |
+
)
|
| 560 |
|
| 561 |
if cfg.local_rank == 0:
|
| 562 |
logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
|
|
|
|
| 599 |
|
| 600 |
if cfg.local_rank == 0:
|
| 601 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 602 |
+
logging.info(
|
| 603 |
+
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
|
| 604 |
+
)
|
| 605 |
model.save_pretrained(cfg.output_dir)
|
| 606 |
|
| 607 |
|
src/axolotl/prompt_tokenizers.py
CHANGED
|
@@ -31,14 +31,18 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
| 31 |
pass
|
| 32 |
|
| 33 |
|
| 34 |
-
class
|
|
|
|
|
|
|
|
|
|
| 35 |
def tokenize_prompt(self, prompt):
|
| 36 |
-
|
|
|
|
| 37 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
| 38 |
if not self.train_on_inputs:
|
| 39 |
user_prompt = self.prompter.build_prompt(
|
| 40 |
-
|
| 41 |
-
|
| 42 |
)
|
| 43 |
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
| 44 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
|
@@ -49,11 +53,11 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 49 |
|
| 50 |
return tokenized_full_prompt
|
| 51 |
|
| 52 |
-
def
|
| 53 |
return self.prompter.build_prompt(
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
)
|
| 58 |
|
| 59 |
def _tokenize(self, prompt, add_eos_token=True):
|
|
@@ -76,11 +80,29 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 76 |
return result
|
| 77 |
|
| 78 |
|
| 79 |
-
class
|
| 80 |
-
def
|
| 81 |
-
return
|
| 82 |
prompt["instruction"],
|
| 83 |
-
prompt["input"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
prompt["response"],
|
| 85 |
)
|
| 86 |
|
|
|
|
| 31 |
pass
|
| 32 |
|
| 33 |
|
| 34 |
+
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
| 35 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
def tokenize_prompt(self, prompt):
|
| 39 |
+
instruction, input, response = self.parse_instruction_fields(prompt)
|
| 40 |
+
full_prompt = self._build_full_prompt(instruction, input, response)
|
| 41 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
| 42 |
if not self.train_on_inputs:
|
| 43 |
user_prompt = self.prompter.build_prompt(
|
| 44 |
+
instruction,
|
| 45 |
+
input,
|
| 46 |
)
|
| 47 |
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
| 48 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
|
|
|
| 53 |
|
| 54 |
return tokenized_full_prompt
|
| 55 |
|
| 56 |
+
def _build_full_prompt(self, instruction, input, response):
|
| 57 |
return self.prompter.build_prompt(
|
| 58 |
+
instruction,
|
| 59 |
+
input,
|
| 60 |
+
response,
|
| 61 |
)
|
| 62 |
|
| 63 |
def _tokenize(self, prompt, add_eos_token=True):
|
|
|
|
| 80 |
return result
|
| 81 |
|
| 82 |
|
| 83 |
+
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
| 84 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
| 85 |
+
return (
|
| 86 |
prompt["instruction"],
|
| 87 |
+
prompt["input"] if "input" in prompt else "",
|
| 88 |
+
prompt["output"],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
| 93 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
| 94 |
+
return (
|
| 95 |
+
prompt["INSTRUCTION"],
|
| 96 |
+
"",
|
| 97 |
+
prompt["RESPONSE"],
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
| 102 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
| 103 |
+
return (
|
| 104 |
+
prompt["instruction"],
|
| 105 |
+
prompt["input"] if "input" in prompt else "",
|
| 106 |
prompt["response"],
|
| 107 |
)
|
| 108 |
|