workaround for md5 variations (#533)
Browse files* workaround for md5 variations
* refactor the prepared hash too
- src/axolotl/utils/data.py +15 -13
- tests/test_data.py +64 -0
src/axolotl/utils/data.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
| 2 |
import functools
|
| 3 |
import hashlib
|
| 4 |
import logging
|
| 5 |
-
from hashlib import md5
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import Tuple, Union
|
| 8 |
|
|
@@ -52,6 +51,13 @@ LOG = logging.getLogger("axolotl")
|
|
| 52 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def prepare_dataset(cfg, tokenizer):
|
| 56 |
if not cfg.pretraining_dataset:
|
| 57 |
with zero_first(is_main_process()):
|
|
@@ -88,7 +94,7 @@ def load_tokenized_prepared_datasets(
|
|
| 88 |
) -> DatasetDict:
|
| 89 |
tokenizer_name = tokenizer.__class__.__name__
|
| 90 |
ds_hash = str(
|
| 91 |
-
md5(
|
| 92 |
(
|
| 93 |
str(cfg.sequence_len)
|
| 94 |
+ "@"
|
|
@@ -97,8 +103,8 @@ def load_tokenized_prepared_datasets(
|
|
| 97 |
)
|
| 98 |
+ "|"
|
| 99 |
+ tokenizer_name
|
| 100 |
-
)
|
| 101 |
-
)
|
| 102 |
)
|
| 103 |
prepared_ds_path = (
|
| 104 |
Path(cfg.dataset_prepared_path) / ds_hash
|
|
@@ -374,7 +380,7 @@ def load_prepare_datasets(
|
|
| 374 |
# see if we can go ahead and load the stacked dataset
|
| 375 |
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
| 376 |
ds_hash = str(
|
| 377 |
-
md5(
|
| 378 |
(
|
| 379 |
str(cfg.sequence_len)
|
| 380 |
+ "@"
|
|
@@ -385,8 +391,8 @@ def load_prepare_datasets(
|
|
| 385 |
)
|
| 386 |
+ "|"
|
| 387 |
+ tokenizer_name
|
| 388 |
-
)
|
| 389 |
-
)
|
| 390 |
)
|
| 391 |
prepared_ds_path = (
|
| 392 |
Path(cfg.dataset_prepared_path) / ds_hash
|
|
@@ -500,12 +506,8 @@ def load_prepare_datasets(
|
|
| 500 |
+ "|"
|
| 501 |
+ str(cfg.seed or 42)
|
| 502 |
)
|
| 503 |
-
train_fingerprint =
|
| 504 |
-
|
| 505 |
-
).hexdigest()
|
| 506 |
-
test_fingerprint = hashlib.md5(
|
| 507 |
-
to_hash_test.encode(), usedforsecurity=False
|
| 508 |
-
).hexdigest()
|
| 509 |
|
| 510 |
with zero_first(is_main_process()):
|
| 511 |
dataset = dataset.train_test_split(
|
|
|
|
| 2 |
import functools
|
| 3 |
import hashlib
|
| 4 |
import logging
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Tuple, Union
|
| 7 |
|
|
|
|
| 51 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
| 52 |
|
| 53 |
|
| 54 |
+
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
| 55 |
+
try:
|
| 56 |
+
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
| 57 |
+
except TypeError:
|
| 58 |
+
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
| 59 |
+
|
| 60 |
+
|
| 61 |
def prepare_dataset(cfg, tokenizer):
|
| 62 |
if not cfg.pretraining_dataset:
|
| 63 |
with zero_first(is_main_process()):
|
|
|
|
| 94 |
) -> DatasetDict:
|
| 95 |
tokenizer_name = tokenizer.__class__.__name__
|
| 96 |
ds_hash = str(
|
| 97 |
+
md5(
|
| 98 |
(
|
| 99 |
str(cfg.sequence_len)
|
| 100 |
+ "@"
|
|
|
|
| 103 |
)
|
| 104 |
+ "|"
|
| 105 |
+ tokenizer_name
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
)
|
| 109 |
prepared_ds_path = (
|
| 110 |
Path(cfg.dataset_prepared_path) / ds_hash
|
|
|
|
| 380 |
# see if we can go ahead and load the stacked dataset
|
| 381 |
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
| 382 |
ds_hash = str(
|
| 383 |
+
md5(
|
| 384 |
(
|
| 385 |
str(cfg.sequence_len)
|
| 386 |
+ "@"
|
|
|
|
| 391 |
)
|
| 392 |
+ "|"
|
| 393 |
+ tokenizer_name
|
| 394 |
+
)
|
| 395 |
+
)
|
| 396 |
)
|
| 397 |
prepared_ds_path = (
|
| 398 |
Path(cfg.dataset_prepared_path) / ds_hash
|
|
|
|
| 506 |
+ "|"
|
| 507 |
+ str(cfg.seed or 42)
|
| 508 |
)
|
| 509 |
+
train_fingerprint = md5(to_hash_train)
|
| 510 |
+
test_fingerprint = md5(to_hash_test)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
|
| 512 |
with zero_first(is_main_process()):
|
| 513 |
dataset = dataset.train_test_split(
|
tests/test_data.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
test module for the axolotl.utis.data module
|
| 3 |
+
"""
|
| 4 |
+
import unittest
|
| 5 |
+
|
| 6 |
+
from transformers import LlamaTokenizer
|
| 7 |
+
|
| 8 |
+
from axolotl.utils.data import encode_pretraining, md5
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestEncodePretraining(unittest.TestCase):
|
| 12 |
+
"""
|
| 13 |
+
test class for encode pretraining and md5 helper
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def setUp(self):
|
| 17 |
+
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
|
| 18 |
+
self.tokenizer.add_special_tokens(
|
| 19 |
+
{
|
| 20 |
+
"eos_token": "</s>",
|
| 21 |
+
"bos_token": "<s>",
|
| 22 |
+
"unk_token": "<unk>",
|
| 23 |
+
"pad_token": "<pad>",
|
| 24 |
+
}
|
| 25 |
+
)
|
| 26 |
+
self.max_tokens = 15 # set a small number for easy inspection
|
| 27 |
+
|
| 28 |
+
def test_encode_pretraining(self):
|
| 29 |
+
examples = {
|
| 30 |
+
"text": [
|
| 31 |
+
"Hello, world!",
|
| 32 |
+
"Nice to meet you.",
|
| 33 |
+
"lorem ipsum dolor sit amet.",
|
| 34 |
+
"Nice to meet you again!.",
|
| 35 |
+
"hello, hello",
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
| 39 |
+
|
| 40 |
+
self.assertEqual(len(result["input_ids"]), 3)
|
| 41 |
+
|
| 42 |
+
# Assert the length of input_ids and attention_mask is correct
|
| 43 |
+
self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
|
| 44 |
+
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
|
| 45 |
+
|
| 46 |
+
# Assert EOS and PAD tokens are correctly added
|
| 47 |
+
# hello world! is 4 tokens
|
| 48 |
+
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
|
| 49 |
+
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
|
| 50 |
+
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
|
| 51 |
+
# second part, 5 tokens
|
| 52 |
+
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
|
| 53 |
+
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
|
| 54 |
+
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
|
| 55 |
+
|
| 56 |
+
def test_md5(self):
|
| 57 |
+
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
|
| 58 |
+
self.assertEqual(
|
| 59 |
+
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
unittest.main()
|