File size: 4,244 Bytes
f3a09a2
 
2cdada4
f3a09a2
 
0cbac6c
 
2cdada4
 
f3a09a2
 
549360a
 
 
0cbac6c
549360a
 
 
0cbac6c
549360a
 
 
0cbac6c
549360a
 
f3a09a2
 
0cbac6c
f3a09a2
 
 
549360a
 
 
 
 
 
 
 
 
f3a09a2
2cdada4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549360a
 
2cdada4
 
549360a
2cdada4
 
 
 
 
c790fdb
2cdada4
56adaa2
2cdada4
 
c790fdb
2cdada4
549360a
2cdada4
f3a09a2
 
 
 
 
 
2cdada4
 
f3a09a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdada4
 
 
 
f3a09a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import asyncio
import os
import random

from datasets import Dataset, load_dataset
from datasets_.util import _get_dataset_config_names, _load_dataset, cache, standardize_bcp47
from langcodes import Language
from models import get_google_supported_languages, translate_google
from rich import print
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

slug_mgsm = "juletxara/mgsm"
tags_mgsm = {
    standardize_bcp47(a): a for a in _get_dataset_config_names(slug_mgsm)
}
slug_afrimgsm = "masakhane/afrimgsm"
tags_afrimgsm = {
    standardize_bcp47(a): a for a in _get_dataset_config_names(slug_afrimgsm)
}
slug_gsm8kx = "Eurolingua/gsm8kx"
tags_gsm8kx = {
    standardize_bcp47(a): a
    for a in _get_dataset_config_names(slug_gsm8kx, trust_remote_code=True)
}
slug_gsm_autotranslated = "fair-forward/gsm-autotranslated"
tags_gsm_autotranslated = {
    standardize_bcp47(a): a
    for a in _get_dataset_config_names(slug_gsm_autotranslated)
}


def parse_number(i):
    if isinstance(i, int):
        return i
    try:
        return int(i.replace(",", "").replace(".", ""))
    except ValueError:
        return None


@cache
def _get_mgsm_item(dataset_slug, subset_tag, nr, trust_remote_code=False):
    """Cache individual MGSM items efficiently"""
    try:
        ds = _load_dataset(
            dataset_slug,
            subset=subset_tag,
            split="test",
            trust_remote_code=trust_remote_code,
        )
        if nr >= len(ds):
            return None

        row = ds[nr]

        # Post-process based on dataset type
        if dataset_slug == slug_gsm8kx:
            row["answer_number"] = row["answer"].split("####")[1].strip()

        return row
    except Exception:
        # Dataset doesn't exist or doesn't have test split
        return None


def load_mgsm(language_bcp_47, nr):
    if language_bcp_47 in tags_mgsm.keys():
        item = _get_mgsm_item(slug_mgsm, tags_mgsm[language_bcp_47], nr)
        return slug_mgsm, item, "human" if item else (None, None, None)
    elif language_bcp_47 in tags_afrimgsm.keys():
        item = _get_mgsm_item(slug_afrimgsm, tags_afrimgsm[language_bcp_47], nr)
        return slug_afrimgsm, item, "human" if item else (None, None, None)
    elif language_bcp_47 in tags_gsm8kx.keys():
        item = _get_mgsm_item(
            slug_gsm8kx, tags_gsm8kx[language_bcp_47], nr, trust_remote_code=True
        )
        return slug_gsm8kx, item, "machine" if item else (None, None, None)
    elif language_bcp_47 in tags_gsm_autotranslated.keys():
        item = _get_mgsm_item(
            slug_gsm_autotranslated, tags_gsm_autotranslated[language_bcp_47], nr
        )
        return slug_gsm_autotranslated, item, "machine" if item else (None, None, None)
    else:
        return None, None, None


def translate_mgsm(languages):
    human_translated = [*tags_mgsm.keys(), *tags_afrimgsm.keys()]
    untranslated = [
        lang
        for lang in languages["bcp_47"].values
        if lang not in human_translated and lang in get_google_supported_languages()
    ]
    en = _load_dataset(slug_mgsm, subset=tags_mgsm["en"], split="test")
    slug = "fair-forward/gsm-autotranslated"
    for lang in tqdm(untranslated):
        # check if already exists on hub
        try:
            ds_lang = load_dataset(slug, lang, split="test")
        except ValueError:
            print(f"Translating {lang}...")
            questions_tr = [translate_google(q, "en", lang) for q in en["question"]]
            questions_tr = asyncio.run(tqdm_asyncio.gather(*questions_tr))
            ds_lang = Dataset.from_dict(
                {
                    "question": questions_tr,
                    "answer": en["answer"],
                    "answer_number": en["answer_number"],
                    "equation_solution": en["equation_solution"],
                }
            )
            ds_lang.push_to_hub(
                slug,
                split="test",
                config_name=lang,
                token=os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
            )
            ds_lang.to_json(
                f"data/translations/mgsm/{lang}.json",
                lines=False,
                force_ascii=False,
                indent=2,
            )