Spaces:
Running
Running
David Pomerenke
commited on
Commit
·
031925d
1
Parent(s):
4106f13
Analyze MMLU datasets
Browse files- evals/datasets_/mmlu.py +90 -0
evals/datasets_/mmlu.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from joblib.memory import Memory
|
| 2 |
+
from datasets import load_dataset, get_dataset_config_names
|
| 3 |
+
from rich import print
|
| 4 |
+
from langcodes import standardize_tag, Language
|
| 5 |
+
from collections import defaultdict, Counter
|
| 6 |
+
cache = Memory(location=".cache", verbose=0).cache
|
| 7 |
+
|
| 8 |
+
@cache
|
| 9 |
+
def _get_dataset_config_names(dataset):
|
| 10 |
+
return get_dataset_config_names(dataset)
|
| 11 |
+
|
| 12 |
+
@cache
|
| 13 |
+
def _load_dataset(dataset, subset, **kwargs):
|
| 14 |
+
return load_dataset(dataset, subset, **kwargs)
|
| 15 |
+
|
| 16 |
+
def print_counts(slug,subjects_dev, subjects_test):
|
| 17 |
+
print(f"{slug:<25} {len(list(set(subjects_test))):>3} test categories, {len(subjects_test):>6} samples, {len(list(set(subjects_dev))):>3} dev categories, {len(subjects_dev):>6} dev samples")
|
| 18 |
+
|
| 19 |
+
def print_datasets_analysis():
|
| 20 |
+
print("Category counts and sample counts per dataset:")
|
| 21 |
+
slug1 = "masakhane/afrimmlu"
|
| 22 |
+
ds1 = _load_dataset(slug1, "eng")
|
| 23 |
+
print_counts(slug1, ds1["dev"]["subject"], ds1["test"]["subject"])
|
| 24 |
+
langs1 = _get_dataset_config_names(slug1)
|
| 25 |
+
langs1 = [standardize_tag(a, macro=True) for a in langs1]
|
| 26 |
+
|
| 27 |
+
slug2 = "openai/MMMLU" # does not have dev set! – but: these languages are all also present in Global-MMLU
|
| 28 |
+
ds2 = _load_dataset(slug2, "FR_FR")
|
| 29 |
+
print_counts(slug2, [], ds2["test"]["Subject"])
|
| 30 |
+
langs2 = _get_dataset_config_names(slug2)
|
| 31 |
+
langs2 = [a.split("_")[0].lower() for a in langs2]
|
| 32 |
+
langs2.remove("default")
|
| 33 |
+
|
| 34 |
+
slug3 = "CohereForAI/Global-MMLU"
|
| 35 |
+
ds3 = _load_dataset(slug3, "en")
|
| 36 |
+
print_counts(slug3, ds3["dev"]["subject"], ds3["test"]["subject"])
|
| 37 |
+
langs3 = _get_dataset_config_names(slug3)
|
| 38 |
+
langs3 = [standardize_tag(a, macro=True) for a in langs3]
|
| 39 |
+
|
| 40 |
+
slug4 = "lighteval/okapi_mmlu"
|
| 41 |
+
ds4 = _load_dataset(slug4, "ar", trust_remote_code=True)
|
| 42 |
+
print_counts(slug4, [a.split("/")[0] for a in ds4["dev"]["id"]], [a.split("/")[0] for a in ds4["test"]["id"]])
|
| 43 |
+
langs4 = _get_dataset_config_names(slug4)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
slug5 = "Eurolingua/mmlux"
|
| 47 |
+
subsets = _get_dataset_config_names(slug5)
|
| 48 |
+
subjects = set(a.rsplit("_", 1)[0] for a in subsets)
|
| 49 |
+
rows_test = [_load_dataset(slug5, subset)["test"]["id"] for subset in subsets if "_DA" in subset]
|
| 50 |
+
rows_test = [a.split("/")[0] for l in rows_test for a in l]
|
| 51 |
+
rows_dev = [_load_dataset(slug5, subset)["dev"]["id"] for subset in subsets if "_DA" in subset]
|
| 52 |
+
rows_dev = [a.split("/")[0] for l in rows_dev for a in l]
|
| 53 |
+
print_counts(slug5, rows_dev, rows_test)
|
| 54 |
+
langs5 = list(set(a.rsplit("_", 1)[1].split("-")[0].lower() for a in subsets))
|
| 55 |
+
|
| 56 |
+
langs = langs1 + langs2 + langs3 + langs4 + langs5
|
| 57 |
+
lang_datasets = defaultdict(list)
|
| 58 |
+
for slug, langs_list in [
|
| 59 |
+
(slug1, langs1),
|
| 60 |
+
(slug2, langs2),
|
| 61 |
+
(slug3, langs3),
|
| 62 |
+
(slug4, langs4),
|
| 63 |
+
(slug5, langs5),
|
| 64 |
+
]:
|
| 65 |
+
for lang in langs_list:
|
| 66 |
+
lname = Language.get(lang).display_name()
|
| 67 |
+
lang_datasets[lname].append(slug)
|
| 68 |
+
print("Datasets per language:")
|
| 69 |
+
print(sorted(lang_datasets.items()))
|
| 70 |
+
print(len(set(langs)))
|
| 71 |
+
|
| 72 |
+
print("Datasets per language for languages that are not in Global-MMLU:")
|
| 73 |
+
print(sorted((lang, datasets) for lang, datasets in lang_datasets.items() if slug3 not in datasets))
|
| 74 |
+
print(Counter(dataset for ds_list in lang_datasets.values() for dataset in ds_list if slug3 not in ds_list))
|
| 75 |
+
print(list(set(ds1["test"]["subject"])))
|
| 76 |
+
|
| 77 |
+
# based on this analysis:
|
| 78 |
+
# - we drop the OpenAI dataset, since it does not have a dev set, and since every language that it has is also present in Global-MMLU
|
| 79 |
+
# - we stick to the 5 categories of the AfriMMLU dataset, since this is the most restricted dataset, and these 5 categories are present in all datasets, so this is good for comparability
|
| 80 |
+
|
| 81 |
+
# AfriMMLU is human-translated, but has only 5 task categories
|
| 82 |
+
# Global-MMLU is partially human-translated, specifically those 15 languages are that are also present in Global-MMLU-Lite, which are mostly from MMMLU; otherwise translated using Google Translate
|
| 83 |
+
# Okapi-MMLU is translated using ChatGPT (version unclear)
|
| 84 |
+
# MMLUX is translated using DeepL
|
| 85 |
+
# Therefore, the priority is: AfriMMLU, Global-MMLU, Okapi-MMLU, MMLUX
|
| 86 |
+
|
| 87 |
+
print_datasets_analysis()
|
| 88 |
+
|
| 89 |
+
def load_mmlu(language_bcp_47):
|
| 90 |
+
pass
|