Spaces:
Running
Running
File size: 4,244 Bytes
f3a09a2 2cdada4 f3a09a2 0cbac6c 2cdada4 f3a09a2 549360a 0cbac6c 549360a 0cbac6c 549360a 0cbac6c 549360a f3a09a2 0cbac6c f3a09a2 549360a f3a09a2 2cdada4 549360a 2cdada4 549360a 2cdada4 c790fdb 2cdada4 56adaa2 2cdada4 c790fdb 2cdada4 549360a 2cdada4 f3a09a2 2cdada4 f3a09a2 2cdada4 f3a09a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import asyncio
import os
import random
from datasets import Dataset, load_dataset
from datasets_.util import _get_dataset_config_names, _load_dataset, cache, standardize_bcp47
from langcodes import Language
from models import get_google_supported_languages, translate_google
from rich import print
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
slug_mgsm = "juletxara/mgsm"
tags_mgsm = {
standardize_bcp47(a): a for a in _get_dataset_config_names(slug_mgsm)
}
slug_afrimgsm = "masakhane/afrimgsm"
tags_afrimgsm = {
standardize_bcp47(a): a for a in _get_dataset_config_names(slug_afrimgsm)
}
slug_gsm8kx = "Eurolingua/gsm8kx"
tags_gsm8kx = {
standardize_bcp47(a): a
for a in _get_dataset_config_names(slug_gsm8kx, trust_remote_code=True)
}
slug_gsm_autotranslated = "fair-forward/gsm-autotranslated"
tags_gsm_autotranslated = {
standardize_bcp47(a): a
for a in _get_dataset_config_names(slug_gsm_autotranslated)
}
def parse_number(i):
if isinstance(i, int):
return i
try:
return int(i.replace(",", "").replace(".", ""))
except ValueError:
return None
@cache
def _get_mgsm_item(dataset_slug, subset_tag, nr, trust_remote_code=False):
"""Cache individual MGSM items efficiently"""
try:
ds = _load_dataset(
dataset_slug,
subset=subset_tag,
split="test",
trust_remote_code=trust_remote_code,
)
if nr >= len(ds):
return None
row = ds[nr]
# Post-process based on dataset type
if dataset_slug == slug_gsm8kx:
row["answer_number"] = row["answer"].split("####")[1].strip()
return row
except Exception:
# Dataset doesn't exist or doesn't have test split
return None
def load_mgsm(language_bcp_47, nr):
if language_bcp_47 in tags_mgsm.keys():
item = _get_mgsm_item(slug_mgsm, tags_mgsm[language_bcp_47], nr)
return slug_mgsm, item, "human" if item else (None, None, None)
elif language_bcp_47 in tags_afrimgsm.keys():
item = _get_mgsm_item(slug_afrimgsm, tags_afrimgsm[language_bcp_47], nr)
return slug_afrimgsm, item, "human" if item else (None, None, None)
elif language_bcp_47 in tags_gsm8kx.keys():
item = _get_mgsm_item(
slug_gsm8kx, tags_gsm8kx[language_bcp_47], nr, trust_remote_code=True
)
return slug_gsm8kx, item, "machine" if item else (None, None, None)
elif language_bcp_47 in tags_gsm_autotranslated.keys():
item = _get_mgsm_item(
slug_gsm_autotranslated, tags_gsm_autotranslated[language_bcp_47], nr
)
return slug_gsm_autotranslated, item, "machine" if item else (None, None, None)
else:
return None, None, None
def translate_mgsm(languages):
human_translated = [*tags_mgsm.keys(), *tags_afrimgsm.keys()]
untranslated = [
lang
for lang in languages["bcp_47"].values
if lang not in human_translated and lang in get_google_supported_languages()
]
en = _load_dataset(slug_mgsm, subset=tags_mgsm["en"], split="test")
slug = "fair-forward/gsm-autotranslated"
for lang in tqdm(untranslated):
# check if already exists on hub
try:
ds_lang = load_dataset(slug, lang, split="test")
except ValueError:
print(f"Translating {lang}...")
questions_tr = [translate_google(q, "en", lang) for q in en["question"]]
questions_tr = asyncio.run(tqdm_asyncio.gather(*questions_tr))
ds_lang = Dataset.from_dict(
{
"question": questions_tr,
"answer": en["answer"],
"answer_number": en["answer_number"],
"equation_solution": en["equation_solution"],
}
)
ds_lang.push_to_hub(
slug,
split="test",
config_name=lang,
token=os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
)
ds_lang.to_json(
f"data/translations/mgsm/{lang}.json",
lines=False,
force_ascii=False,
indent=2,
)
|