feat: add Metharme prompt strategy (#446)
Browse files* Add Metharme tokenizing strategy
This strategy accounts for how the Metharme JSONLs are formatted as well as adds duplicated EOS tokens which can help trim model output length.
I haven't gotten the chance to test this yet, and probably won't have the chance for quite a bit, so I'm committing this now.
* Redo Metharme tokenizing strategy
lol
* fix: oops
* Rearrange a conditional
* chore: reformat code in accordance with linter
* chore: Make lint not freak out
* chore: fix lint
---------
Co-authored-by: NanoCode012 <[email protected]>
- README.md +4 -0
- src/axolotl/prompt_strategies/metharme.py +76 -0
README.md
CHANGED
|
@@ -257,6 +257,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|
| 257 |
```json
|
| 258 |
{"conversations": [{"role": "...", "value": "..."}]}
|
| 259 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
| 261 |
```json
|
| 262 |
{"conversations": [{"role": "...", "value": "..."}]}
|
|
|
|
| 257 |
```json
|
| 258 |
{"conversations": [{"role": "...", "value": "..."}]}
|
| 259 |
```
|
| 260 |
+
- `metharme`: instruction, adds additional eos tokens
|
| 261 |
+
```json
|
| 262 |
+
{"prompt": "...", "generation": "..."}
|
| 263 |
+
```
|
| 264 |
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
| 265 |
```json
|
| 266 |
{"conversations": [{"role": "...", "value": "..."}]}
|
src/axolotl/prompt_strategies/metharme.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
| 7 |
+
from axolotl.prompters import AlpacaPrompter
|
| 8 |
+
|
| 9 |
+
LOG = logging.getLogger("axolotl")
|
| 10 |
+
|
| 11 |
+
IGNORE_TOKEN_ID = -100
|
| 12 |
+
|
| 13 |
+
# pylint: disable=duplicate-code
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
| 17 |
+
"""
|
| 18 |
+
Tokenizing strategy for the Metharme models
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
| 22 |
+
return (prompt["prompt"], "", prompt["generation"])
|
| 23 |
+
|
| 24 |
+
def _tokenize(
|
| 25 |
+
self,
|
| 26 |
+
prompt: str,
|
| 27 |
+
add_eos_token: bool = True,
|
| 28 |
+
strip_bos_token: bool = False,
|
| 29 |
+
num_eos_tokens: int = 3,
|
| 30 |
+
):
|
| 31 |
+
result = self.tokenizer(
|
| 32 |
+
prompt,
|
| 33 |
+
truncation=True,
|
| 34 |
+
max_length=self.sequence_len,
|
| 35 |
+
padding=False,
|
| 36 |
+
return_tensors=None,
|
| 37 |
+
)
|
| 38 |
+
if len(result["input_ids"]) == 0:
|
| 39 |
+
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
| 40 |
+
# If there's already an EOS token there, subtract from the number added
|
| 41 |
+
if result["input_ids"][-1] == self.tokenizer.eos_token_id:
|
| 42 |
+
num_eos_tokens -= 1
|
| 43 |
+
|
| 44 |
+
if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
|
| 45 |
+
for _ in range(num_eos_tokens):
|
| 46 |
+
if len(result["input_ids"]) < self.sequence_len:
|
| 47 |
+
result["input_ids"].append(self.tokenizer.eos_token_id)
|
| 48 |
+
result["attention_mask"].append(1)
|
| 49 |
+
|
| 50 |
+
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
| 51 |
+
result["input_ids"] = result["input_ids"][1:]
|
| 52 |
+
result["attention_mask"] = result["attention_mask"][1:]
|
| 53 |
+
|
| 54 |
+
result["labels"] = result["input_ids"].copy()
|
| 55 |
+
return result
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MetharmePrompter(AlpacaPrompter):
|
| 59 |
+
"""
|
| 60 |
+
Prompter for the Metharme models.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
system_prompt = ""
|
| 64 |
+
system_no_input_prompt = ""
|
| 65 |
+
system_format = ""
|
| 66 |
+
turn_format = "{instruction}"
|
| 67 |
+
turn_no_input_format = "{instruction}"
|
| 68 |
+
|
| 69 |
+
def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load(tokenizer, cfg):
|
| 74 |
+
return MetharmePromptTokenizingStrategy(
|
| 75 |
+
MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
| 76 |
+
)
|