add eval code
Browse files- custom_evaluation_tasks.py +650 -0
- custom_evaluation_utils.py +158 -0
- lighteval_eval_config.yaml +45 -0
- run_evals.py +442 -0
- run_train.py +2 -2
    	
        custom_evaluation_tasks.py
    ADDED
    
    | @@ -0,0 +1,650 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # ruff: noqa: F405, F403, F401
         | 
| 2 | 
            +
            """
         | 
| 3 | 
            +
            Custom evaluation tasks for lighteval
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            This file generally create just a TASKS_TABLE and TASKS_GROUPS which are then imported by LightEval.
         | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            import re
         | 
| 8 | 
            +
            from dataclasses import asdict
         | 
| 9 | 
            +
            from typing import Dict, List, Tuple
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from custom_evaluation_utils import *
         | 
| 12 | 
            +
            from lighteval.tasks.requests import Doc
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # fmt: off
         | 
| 15 | 
            +
            LETTER_INDICES = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]
         | 
| 16 | 
            +
            # fmt: on
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            _TASKS_STRINGS: List[Tuple[CustomEvaluationTask, str]] = []
         | 
| 19 | 
            +
            _TASKS: List[CustomEvaluationTask] = []
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            ## COMMON_SENSE_REASONING_TASKS ##
         | 
| 22 | 
            +
            COMMON_SENSE_REASONING_TASKS = [
         | 
| 23 | 
            +
                CustomEvaluationTask(
         | 
| 24 | 
            +
                    name="hellaswag",
         | 
| 25 | 
            +
                    prompt_function="hellaswag_prompt",
         | 
| 26 | 
            +
                    hf_repo="hellaswag",
         | 
| 27 | 
            +
                    hf_subset="default",
         | 
| 28 | 
            +
                    metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
         | 
| 29 | 
            +
                ),
         | 
| 30 | 
            +
                CustomEvaluationTask(
         | 
| 31 | 
            +
                    name="winogrande",
         | 
| 32 | 
            +
                    prompt_function="winogrande",
         | 
| 33 | 
            +
                    hf_repo="winogrande",
         | 
| 34 | 
            +
                    hf_subset="winogrande_xl",
         | 
| 35 | 
            +
                    metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
         | 
| 36 | 
            +
                ),
         | 
| 37 | 
            +
                CustomEvaluationTask(
         | 
| 38 | 
            +
                    name="piqa",
         | 
| 39 | 
            +
                    prompt_function="piqa_harness",
         | 
| 40 | 
            +
                    hf_repo="piqa",
         | 
| 41 | 
            +
                    hf_subset="plain_text",
         | 
| 42 | 
            +
                    metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
         | 
| 43 | 
            +
                ),
         | 
| 44 | 
            +
                CustomEvaluationTask(
         | 
| 45 | 
            +
                    name="siqa",
         | 
| 46 | 
            +
                    prompt_function="siqa_prompt",
         | 
| 47 | 
            +
                    hf_repo="lighteval/siqa",
         | 
| 48 | 
            +
                    hf_subset="default",
         | 
| 49 | 
            +
                    hf_avail_splits=["train", "validation"],
         | 
| 50 | 
            +
                    metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
         | 
| 51 | 
            +
                ),
         | 
| 52 | 
            +
                CustomEvaluationTask(
         | 
| 53 | 
            +
                    name="openbookqa",
         | 
| 54 | 
            +
                    prompt_function="openbookqa",
         | 
| 55 | 
            +
                    hf_repo="openbookqa",
         | 
| 56 | 
            +
                    hf_subset="main",
         | 
| 57 | 
            +
                    metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
         | 
| 58 | 
            +
                ),
         | 
| 59 | 
            +
                CustomEvaluationTask(
         | 
| 60 | 
            +
                    name="arc:easy",
         | 
| 61 | 
            +
                    prompt_function="arc",
         | 
| 62 | 
            +
                    hf_repo="ai2_arc",
         | 
| 63 | 
            +
                    hf_subset="ARC-Easy",
         | 
| 64 | 
            +
                    evaluation_splits=["test"],
         | 
| 65 | 
            +
                    generation_size=1,
         | 
| 66 | 
            +
                    metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
         | 
| 67 | 
            +
                ),
         | 
| 68 | 
            +
                CustomEvaluationTask(
         | 
| 69 | 
            +
                    name="arc:challenge",
         | 
| 70 | 
            +
                    prompt_function="arc",
         | 
| 71 | 
            +
                    hf_repo="ai2_arc",
         | 
| 72 | 
            +
                    hf_subset="ARC-Challenge",
         | 
| 73 | 
            +
                    evaluation_splits=["test"],
         | 
| 74 | 
            +
                    generation_size=1,
         | 
| 75 | 
            +
                    metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
         | 
| 76 | 
            +
                ),
         | 
| 77 | 
            +
                CustomEvaluationTask(
         | 
| 78 | 
            +
                    name="commonsense_qa",
         | 
| 79 | 
            +
                    prompt_function="commonsense_qa_prompt",
         | 
| 80 | 
            +
                    hf_repo="commonsense_qa",
         | 
| 81 | 
            +
                    hf_subset="default",
         | 
| 82 | 
            +
                    metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
         | 
| 83 | 
            +
                ),
         | 
| 84 | 
            +
            ]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def commonsense_qa_prompt(line, task_name: str = None):
         | 
| 88 | 
            +
                return Doc(
         | 
| 89 | 
            +
                    task_name=task_name,
         | 
| 90 | 
            +
                    query=line["question"],
         | 
| 91 | 
            +
                    choices=[f" {c}" for c in line["choices"]["text"]],
         | 
| 92 | 
            +
                    gold_index=LETTER_INDICES.index(line["answerKey"].strip()),
         | 
| 93 | 
            +
                    instruction="",
         | 
| 94 | 
            +
                )
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            def siqa_prompt(line, task_name: str = None):
         | 
| 98 | 
            +
                return Doc(
         | 
| 99 | 
            +
                    task_name=task_name,
         | 
| 100 | 
            +
                    query=line["context"] + " " + line["question"],
         | 
| 101 | 
            +
                    choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]],
         | 
| 102 | 
            +
                    gold_index=int(line["label"]) - 1,
         | 
| 103 | 
            +
                    instruction="",
         | 
| 104 | 
            +
                )
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            def hellaswag_prompt(line, task_name: str = None):
         | 
| 108 | 
            +
                def preprocess(text):
         | 
| 109 | 
            +
                    """Comes from AiHarness"""
         | 
| 110 | 
            +
                    # text = text.strip()
         | 
| 111 | 
            +
                    # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
         | 
| 112 | 
            +
                    text = text.replace(" [title]", ". ")
         | 
| 113 | 
            +
                    text = re.sub("\\[.*?\\]", "", text)
         | 
| 114 | 
            +
                    text = text.replace("  ", " ")
         | 
| 115 | 
            +
                    return text
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} "
         | 
| 118 | 
            +
                return Doc(
         | 
| 119 | 
            +
                    task_name=task_name,
         | 
| 120 | 
            +
                    query=preprocess(line["activity_label"] + ": " + ctx),
         | 
| 121 | 
            +
                    choices=[" " + preprocess(ending) for ending in line["endings"]],
         | 
| 122 | 
            +
                    gold_index=int(line["label"]) if line["label"] != "" else -1,  # -1 for test
         | 
| 123 | 
            +
                    # "metric": "choices_loglikelihood",
         | 
| 124 | 
            +
                )
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            # 0 short for common sense
         | 
| 128 | 
            +
            COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS]
         | 
| 129 | 
            +
            _TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING)
         | 
| 130 | 
            +
            _TASKS += COMMON_SENSE_REASONING_TASKS
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            ## WORLD_KNOWLEDGE_TASKS ##
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            WORLD_KNOWLEDGE_TASKS = [
         | 
| 135 | 
            +
                CustomEvaluationTask(
         | 
| 136 | 
            +
                    name="trivia_qa",
         | 
| 137 | 
            +
                    prompt_function="triviaqa",
         | 
| 138 | 
            +
                    hf_repo="trivia_qa",
         | 
| 139 | 
            +
                    hf_subset="rc.nocontext",
         | 
| 140 | 
            +
                    metric=[Metrics.quasi_exact_match2],
         | 
| 141 | 
            +
                    generation_size=20,
         | 
| 142 | 
            +
                    stop_sequence=["\n", ".", ","],
         | 
| 143 | 
            +
                ),
         | 
| 144 | 
            +
                CustomEvaluationTask(
         | 
| 145 | 
            +
                    name="natural_questions",
         | 
| 146 | 
            +
                    prompt_function="natural_questions_prompt",
         | 
| 147 | 
            +
                    hf_repo="lighteval/natural_questions_clean",
         | 
| 148 | 
            +
                    hf_subset="default",
         | 
| 149 | 
            +
                    metric=[Metrics.quasi_exact_match2],
         | 
| 150 | 
            +
                    generation_size=20,
         | 
| 151 | 
            +
                    stop_sequence=["\n", ".", ","],
         | 
| 152 | 
            +
                ),
         | 
| 153 | 
            +
            ]
         | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
            def natural_questions_prompt(line, task_name: str = None):
         | 
| 157 | 
            +
                return Doc(
         | 
| 158 | 
            +
                    task_name=task_name,
         | 
| 159 | 
            +
                    query=line["question"] + "?\nAnswer: ",
         | 
| 160 | 
            +
                    choices=[line["short_answers"]],
         | 
| 161 | 
            +
                    gold_index=0,
         | 
| 162 | 
            +
                    instruction="",
         | 
| 163 | 
            +
                )
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            WORLD_KNOWLEDGE_STRING = [(t, f"custom|{t.name}|5|1") for t in WORLD_KNOWLEDGE_TASKS]
         | 
| 167 | 
            +
            # WORLD_KNOWLEDGE_STRING = {t: f'custom|{t.name}|0|1' for t in WORLD_KNOWLEDGE_TASKS}
         | 
| 168 | 
            +
            _TASKS_STRINGS.extend(WORLD_KNOWLEDGE_STRING)
         | 
| 169 | 
            +
            _TASKS += WORLD_KNOWLEDGE_TASKS
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            ## Reading comprehension ##
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            READING_COMP_TASKS = [
         | 
| 174 | 
            +
                CustomEvaluationTask(
         | 
| 175 | 
            +
                    name="super_glue:boolq",
         | 
| 176 | 
            +
                    prompt_function="boolq_prompt",
         | 
| 177 | 
            +
                    hf_repo="super_glue",
         | 
| 178 | 
            +
                    hf_subset="boolq",
         | 
| 179 | 
            +
                    metric=[Metrics.target_perplexity],
         | 
| 180 | 
            +
                ),
         | 
| 181 | 
            +
                CustomEvaluationTask(
         | 
| 182 | 
            +
                    name="quac",
         | 
| 183 | 
            +
                    prompt_function="quac",
         | 
| 184 | 
            +
                    hf_repo="lighteval/quac_helm",
         | 
| 185 | 
            +
                    hf_subset="default",
         | 
| 186 | 
            +
                    metric=[Metrics.quasi_exact_match2],
         | 
| 187 | 
            +
                    generation_size=20,
         | 
| 188 | 
            +
                    stop_sequence=["\n", ".", ","],
         | 
| 189 | 
            +
                ),
         | 
| 190 | 
            +
            ]
         | 
| 191 | 
            +
             | 
| 192 | 
            +
             | 
| 193 | 
            +
            def boolq_prompt(line, task_name: str = None):
         | 
| 194 | 
            +
                return Doc(
         | 
| 195 | 
            +
                    task_name=task_name,
         | 
| 196 | 
            +
                    query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:",
         | 
| 197 | 
            +
                    choices=[" No", " Yes"],  # Only gold
         | 
| 198 | 
            +
                    gold_index=int(line["label"]),
         | 
| 199 | 
            +
                )
         | 
| 200 | 
            +
             | 
| 201 | 
            +
             | 
| 202 | 
            +
            READING_COMP_STRING = [(t, f"custom|{t.name}|0|1") for t in READING_COMP_TASKS]
         | 
| 203 | 
            +
            _TASKS_STRINGS.extend(READING_COMP_STRING)
         | 
| 204 | 
            +
            _TASKS += READING_COMP_TASKS
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            ## MATH ##
         | 
| 208 | 
            +
            class CustomMathEvaluationTask(CustomEvaluationTask):
         | 
| 209 | 
            +
                """Custom class for math tasks with all the defaults set"""
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def __init__(
         | 
| 212 | 
            +
                    self,
         | 
| 213 | 
            +
                    name,
         | 
| 214 | 
            +
                    prompt_function="math",
         | 
| 215 | 
            +
                    hf_repo="lighteval/MATH",
         | 
| 216 | 
            +
                    hf_subset=None,
         | 
| 217 | 
            +
                    metric=[Metrics.math_quasi_exact_match],
         | 
| 218 | 
            +
                    hf_avail_splits=None,
         | 
| 219 | 
            +
                    evaluation_splits=["test"],
         | 
| 220 | 
            +
                    few_shots_split=None,
         | 
| 221 | 
            +
                    few_shots_select=None,
         | 
| 222 | 
            +
                    suite=["custom"],
         | 
| 223 | 
            +
                    generation_size=40,
         | 
| 224 | 
            +
                    stop_sequence=None,
         | 
| 225 | 
            +
                    output_regex=None,
         | 
| 226 | 
            +
                    frozen=False,
         | 
| 227 | 
            +
                ):
         | 
| 228 | 
            +
                    super().__init__(
         | 
| 229 | 
            +
                        name=name,
         | 
| 230 | 
            +
                        prompt_function=prompt_function,
         | 
| 231 | 
            +
                        hf_repo=hf_repo,
         | 
| 232 | 
            +
                        hf_subset=hf_subset,
         | 
| 233 | 
            +
                        metric=metric,
         | 
| 234 | 
            +
                        hf_avail_splits=hf_avail_splits,
         | 
| 235 | 
            +
                        evaluation_splits=evaluation_splits,
         | 
| 236 | 
            +
                        few_shots_split=few_shots_split,
         | 
| 237 | 
            +
                        few_shots_select=few_shots_select,
         | 
| 238 | 
            +
                        suite=suite,
         | 
| 239 | 
            +
                        generation_size=generation_size,
         | 
| 240 | 
            +
                        stop_sequence=stop_sequence,
         | 
| 241 | 
            +
                        output_regex=output_regex,
         | 
| 242 | 
            +
                        frozen=frozen,
         | 
| 243 | 
            +
                    )
         | 
| 244 | 
            +
             | 
| 245 | 
            +
             | 
| 246 | 
            +
            MATH_TASKS = [
         | 
| 247 | 
            +
                CustomMathEvaluationTask(name="math:algebra", hf_subset="algebra"),
         | 
| 248 | 
            +
                CustomMathEvaluationTask(name="math:counting_and_probability", hf_subset="counting_and_probability"),
         | 
| 249 | 
            +
                CustomMathEvaluationTask(name="math:geometry", hf_subset="geometry"),
         | 
| 250 | 
            +
                CustomMathEvaluationTask(name="math:intermediate_algebra", hf_subset="intermediate_algebra"),
         | 
| 251 | 
            +
                CustomMathEvaluationTask(name="math:number_theory", hf_subset="number_theory"),
         | 
| 252 | 
            +
                CustomMathEvaluationTask(name="math:prealgebra", hf_subset="prealgebra"),
         | 
| 253 | 
            +
                CustomMathEvaluationTask(name="math:precalculus", hf_subset="precalculus"),
         | 
| 254 | 
            +
            ]
         | 
| 255 | 
            +
            GSM8K = CustomEvaluationTask(
         | 
| 256 | 
            +
                name="gsm8k",
         | 
| 257 | 
            +
                prompt_function="gsm8k",
         | 
| 258 | 
            +
                hf_repo="gsm8k",
         | 
| 259 | 
            +
                hf_subset="main",
         | 
| 260 | 
            +
                hf_avail_splits=["train", "test"],
         | 
| 261 | 
            +
                evaluation_splits=["test"],
         | 
| 262 | 
            +
                metric=[Metrics.perfect_exact_match],
         | 
| 263 | 
            +
                generation_size=10,
         | 
| 264 | 
            +
                stop_sequence=["\n"],
         | 
| 265 | 
            +
            )
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            MATH_STRING = [(t, f"custom|{t.name}|4|1") for t in MATH_TASKS]
         | 
| 269 | 
            +
            GSM8K_STRING = [(GSM8K, f"custom|{GSM8K.name}|8|1")]
         | 
| 270 | 
            +
            _TASKS_STRINGS.extend(MATH_STRING)
         | 
| 271 | 
            +
            _TASKS_STRINGS.extend(GSM8K_STRING)
         | 
| 272 | 
            +
            _TASKS += MATH_TASKS + [GSM8K]
         | 
| 273 | 
            +
             | 
| 274 | 
            +
             | 
| 275 | 
            +
            ## MMLU ##
         | 
| 276 | 
            +
            class CustomMMLUEvaluationTask(CustomEvaluationTask):
         | 
| 277 | 
            +
                def __init__(
         | 
| 278 | 
            +
                    self,
         | 
| 279 | 
            +
                    name,
         | 
| 280 | 
            +
                    prompt_function="mmlu_prompt",
         | 
| 281 | 
            +
                    hf_repo="lighteval/mmlu",
         | 
| 282 | 
            +
                    hf_subset=None,
         | 
| 283 | 
            +
                    #  metric=[Metrics.loglikelihood_acc_single_token],
         | 
| 284 | 
            +
                    metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
         | 
| 285 | 
            +
                    hf_avail_splits=None,
         | 
| 286 | 
            +
                    evaluation_splits=["test"],
         | 
| 287 | 
            +
                    few_shots_split="dev",
         | 
| 288 | 
            +
                    few_shots_select=None,
         | 
| 289 | 
            +
                    suite=None,
         | 
| 290 | 
            +
                    generation_size=-1,
         | 
| 291 | 
            +
                    stop_sequence=None,
         | 
| 292 | 
            +
                    output_regex=None,
         | 
| 293 | 
            +
                    frozen=False,
         | 
| 294 | 
            +
                ):
         | 
| 295 | 
            +
                    super().__init__(
         | 
| 296 | 
            +
                        name=name,
         | 
| 297 | 
            +
                        prompt_function=prompt_function,
         | 
| 298 | 
            +
                        hf_repo=hf_repo,
         | 
| 299 | 
            +
                        hf_subset=hf_subset,
         | 
| 300 | 
            +
                        metric=metric,
         | 
| 301 | 
            +
                        hf_avail_splits=hf_avail_splits,
         | 
| 302 | 
            +
                        evaluation_splits=evaluation_splits,
         | 
| 303 | 
            +
                        few_shots_split=few_shots_split,
         | 
| 304 | 
            +
                        few_shots_select=few_shots_select,
         | 
| 305 | 
            +
                        suite=suite,
         | 
| 306 | 
            +
                        generation_size=generation_size,
         | 
| 307 | 
            +
                        stop_sequence=stop_sequence,
         | 
| 308 | 
            +
                        output_regex=output_regex,
         | 
| 309 | 
            +
                        frozen=frozen,
         | 
| 310 | 
            +
                    )
         | 
| 311 | 
            +
             | 
| 312 | 
            +
             | 
| 313 | 
            +
            MMLU_TASKS = [
         | 
| 314 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:abstract_algebra", hf_subset="abstract_algebra"),
         | 
| 315 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:anatomy", hf_subset="anatomy"),
         | 
| 316 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:astronomy", hf_subset="astronomy"),
         | 
| 317 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:business_ethics", hf_subset="business_ethics"),
         | 
| 318 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:clinical_knowledge", hf_subset="clinical_knowledge"),
         | 
| 319 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:college_biology", hf_subset="college_biology"),
         | 
| 320 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:college_chemistry", hf_subset="college_chemistry"),
         | 
| 321 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:college_computer_science", hf_subset="college_computer_science"),
         | 
| 322 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:college_mathematics", hf_subset="college_mathematics"),
         | 
| 323 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:college_medicine", hf_subset="college_medicine"),
         | 
| 324 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:college_physics", hf_subset="college_physics"),
         | 
| 325 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:computer_security", hf_subset="computer_security"),
         | 
| 326 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:conceptual_physics", hf_subset="conceptual_physics"),
         | 
| 327 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:econometrics", hf_subset="econometrics"),
         | 
| 328 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:electrical_engineering", hf_subset="electrical_engineering"),
         | 
| 329 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:elementary_mathematics", hf_subset="elementary_mathematics"),
         | 
| 330 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:formal_logic", hf_subset="formal_logic"),
         | 
| 331 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:global_facts", hf_subset="global_facts"),
         | 
| 332 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_biology", hf_subset="high_school_biology"),
         | 
| 333 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_chemistry", hf_subset="high_school_chemistry"),
         | 
| 334 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_computer_science", hf_subset="high_school_computer_science"),
         | 
| 335 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_european_history", hf_subset="high_school_european_history"),
         | 
| 336 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_geography", hf_subset="high_school_geography"),
         | 
| 337 | 
            +
                CustomMMLUEvaluationTask(
         | 
| 338 | 
            +
                    name="mmlu:high_school_government_and_politics", hf_subset="high_school_government_and_politics"
         | 
| 339 | 
            +
                ),
         | 
| 340 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_macroeconomics", hf_subset="high_school_macroeconomics"),
         | 
| 341 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_mathematics", hf_subset="high_school_mathematics"),
         | 
| 342 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_microeconomics", hf_subset="high_school_microeconomics"),
         | 
| 343 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_physics", hf_subset="high_school_physics"),
         | 
| 344 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_psychology", hf_subset="high_school_psychology"),
         | 
| 345 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_statistics", hf_subset="high_school_statistics"),
         | 
| 346 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_us_history", hf_subset="high_school_us_history"),
         | 
| 347 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:high_school_world_history", hf_subset="high_school_world_history"),
         | 
| 348 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:human_aging", hf_subset="human_aging"),
         | 
| 349 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:human_sexuality", hf_subset="human_sexuality"),
         | 
| 350 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:international_law", hf_subset="international_law"),
         | 
| 351 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:jurisprudence", hf_subset="jurisprudence"),
         | 
| 352 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:logical_fallacies", hf_subset="logical_fallacies"),
         | 
| 353 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:machine_learning", hf_subset="machine_learning"),
         | 
| 354 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:management", hf_subset="management"),
         | 
| 355 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:marketing", hf_subset="marketing"),
         | 
| 356 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:medical_genetics", hf_subset="medical_genetics"),
         | 
| 357 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:miscellaneous", hf_subset="miscellaneous"),
         | 
| 358 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:moral_disputes", hf_subset="moral_disputes"),
         | 
| 359 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:moral_scenarios", hf_subset="moral_scenarios"),
         | 
| 360 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:nutrition", hf_subset="nutrition"),
         | 
| 361 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:philosophy", hf_subset="philosophy"),
         | 
| 362 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:prehistory", hf_subset="prehistory"),
         | 
| 363 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:professional_accounting", hf_subset="professional_accounting"),
         | 
| 364 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:professional_law", hf_subset="professional_law"),
         | 
| 365 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:professional_medicine", hf_subset="professional_medicine"),
         | 
| 366 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:professional_psychology", hf_subset="professional_psychology"),
         | 
| 367 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:public_relations", hf_subset="public_relations"),
         | 
| 368 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:security_studies", hf_subset="security_studies"),
         | 
| 369 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:sociology", hf_subset="sociology"),
         | 
| 370 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:us_foreign_policy", hf_subset="us_foreign_policy"),
         | 
| 371 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:virology", hf_subset="virology"),
         | 
| 372 | 
            +
                CustomMMLUEvaluationTask(name="mmlu:world_religions", hf_subset="world_religions"),
         | 
| 373 | 
            +
            ]
         | 
| 374 | 
            +
             | 
| 375 | 
            +
             | 
| 376 | 
            +
            def mmlu_harness(line, task_name: str = None):
         | 
| 377 | 
            +
                topic = line["subject"]
         | 
| 378 | 
            +
                prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n"
         | 
| 379 | 
            +
                prompt += line["question"] + "\n"
         | 
| 380 | 
            +
                prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])])
         | 
| 381 | 
            +
                prompt += "Answer:"
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
         | 
| 384 | 
            +
                "__few_shots" in line and line["__few_shots"] is True  # We are adding few shots
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                return Doc(
         | 
| 387 | 
            +
                    task_name=task_name,
         | 
| 388 | 
            +
                    query=prompt,
         | 
| 389 | 
            +
                    choices=[" A", " B", " C", " D"],
         | 
| 390 | 
            +
                    target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix],
         | 
| 391 | 
            +
                    gold_index=gold_ix,
         | 
| 392 | 
            +
                    instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n",
         | 
| 393 | 
            +
                )
         | 
| 394 | 
            +
             | 
| 395 | 
            +
             | 
| 396 | 
            +
            def mmlu_prompt(line, task_name: str = None):
         | 
| 397 | 
            +
                """MMLU prompt without letters"""
         | 
| 398 | 
            +
                topic = line["subject"]
         | 
| 399 | 
            +
                prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: "
         | 
| 400 | 
            +
                prompt += line["question"] + "\nAnswer:"
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                return Doc(
         | 
| 403 | 
            +
                    task_name=task_name,
         | 
| 404 | 
            +
                    query=prompt,
         | 
| 405 | 
            +
                    choices=[f" {c}" for c in line["choices"]],
         | 
| 406 | 
            +
                    gold_index=line["answer"],
         | 
| 407 | 
            +
                    instruction=f"The following are questions about {topic.replace('_', ' ')}.\n",
         | 
| 408 | 
            +
                )
         | 
| 409 | 
            +
             | 
| 410 | 
            +
             | 
| 411 | 
            +
            # MMLU_STRING = {t: f'custom|{t.name}|5|1' for t in MMLU_TASKS}
         | 
| 412 | 
            +
            MMLU_STRING = [(t, f"custom|{t.name}|0|1") for t in MMLU_TASKS]
         | 
| 413 | 
            +
            _TASKS_STRINGS.extend(MMLU_STRING)
         | 
| 414 | 
            +
            _TASKS += MMLU_TASKS
         | 
| 415 | 
            +
             | 
| 416 | 
            +
            ## BBH ##
         | 
| 417 | 
            +
             | 
| 418 | 
            +
             | 
| 419 | 
            +
            class CustomBBHEvaluationTask(CustomEvaluationTask):
         | 
| 420 | 
            +
                def __init__(
         | 
| 421 | 
            +
                    self,
         | 
| 422 | 
            +
                    name,
         | 
| 423 | 
            +
                    prompt_function="bbh_prompt",
         | 
| 424 | 
            +
                    hf_repo="lighteval/big_bench_hard",
         | 
| 425 | 
            +
                    hf_subset=None,
         | 
| 426 | 
            +
                    metric=[Metrics.exact_match],
         | 
| 427 | 
            +
                    hf_avail_splits=["train"],
         | 
| 428 | 
            +
                    evaluation_splits=["train"],
         | 
| 429 | 
            +
                    few_shots_split="train",
         | 
| 430 | 
            +
                    few_shots_select=None,
         | 
| 431 | 
            +
                    suite=None,
         | 
| 432 | 
            +
                    generation_size=4,
         | 
| 433 | 
            +
                    stop_sequence=None,
         | 
| 434 | 
            +
                    output_regex=None,
         | 
| 435 | 
            +
                    frozen=False,
         | 
| 436 | 
            +
                ):
         | 
| 437 | 
            +
                    super().__init__(
         | 
| 438 | 
            +
                        name=name,
         | 
| 439 | 
            +
                        prompt_function=prompt_function,
         | 
| 440 | 
            +
                        hf_repo=hf_repo,
         | 
| 441 | 
            +
                        hf_subset=hf_subset,
         | 
| 442 | 
            +
                        metric=metric,
         | 
| 443 | 
            +
                        hf_avail_splits=hf_avail_splits,
         | 
| 444 | 
            +
                        evaluation_splits=evaluation_splits,
         | 
| 445 | 
            +
                        few_shots_split=few_shots_split,
         | 
| 446 | 
            +
                        few_shots_select=few_shots_select,
         | 
| 447 | 
            +
                        suite=suite,
         | 
| 448 | 
            +
                        generation_size=generation_size,
         | 
| 449 | 
            +
                        stop_sequence=stop_sequence,
         | 
| 450 | 
            +
                        output_regex=output_regex,
         | 
| 451 | 
            +
                        frozen=frozen,
         | 
| 452 | 
            +
                    )
         | 
| 453 | 
            +
             | 
| 454 | 
            +
             | 
| 455 | 
            +
            BBH_TASKS = [
         | 
| 456 | 
            +
                CustomBBHEvaluationTask(name="bbh:boolean_expressions", hf_subset="boolean_expressions"),
         | 
| 457 | 
            +
                CustomBBHEvaluationTask(name="bbh:causal_judgement", hf_subset="causal_judgement"),
         | 
| 458 | 
            +
                CustomBBHEvaluationTask(name="bbh:date_understanding", hf_subset="date_understanding"),
         | 
| 459 | 
            +
                CustomBBHEvaluationTask(name="bbh:disambiguation_qa", hf_subset="disambiguation_qa"),
         | 
| 460 | 
            +
                CustomBBHEvaluationTask(name="bbh:dyck_languages", hf_subset="dyck_languages"),
         | 
| 461 | 
            +
                CustomBBHEvaluationTask(name="bbh:formal_fallacies", hf_subset="formal_fallacies"),
         | 
| 462 | 
            +
                CustomBBHEvaluationTask(name="bbh:geometric_shapes", hf_subset="geometric_shapes"),
         | 
| 463 | 
            +
                CustomBBHEvaluationTask(name="bbh:hyperbaton", hf_subset="hyperbaton"),
         | 
| 464 | 
            +
                CustomBBHEvaluationTask(name="bbh:logical_deduction_five_objects", hf_subset="logical_deduction_five_objects"),
         | 
| 465 | 
            +
                CustomBBHEvaluationTask(name="bbh:logical_deduction_seven_objects", hf_subset="logical_deduction_seven_objects"),
         | 
| 466 | 
            +
                CustomBBHEvaluationTask(name="bbh:logical_deduction_three_objects", hf_subset="logical_deduction_three_objects"),
         | 
| 467 | 
            +
                CustomBBHEvaluationTask(name="bbh:movie_recommendation", hf_subset="movie_recommendation"),
         | 
| 468 | 
            +
                CustomBBHEvaluationTask(name="bbh:multistep_arithmetic_two", hf_subset="multistep_arithmetic_two"),
         | 
| 469 | 
            +
                CustomBBHEvaluationTask(name="bbh:navigate", hf_subset="navigate"),
         | 
| 470 | 
            +
                CustomBBHEvaluationTask(name="bbh:object_counting", hf_subset="object_counting"),
         | 
| 471 | 
            +
                CustomBBHEvaluationTask(name="bbh:penguins_in_a_table", hf_subset="penguins_in_a_table"),
         | 
| 472 | 
            +
                CustomBBHEvaluationTask(name="bbh:reasoning_about_colored_objects", hf_subset="reasoning_about_colored_objects"),
         | 
| 473 | 
            +
                CustomBBHEvaluationTask(name="bbh:ruin_names", hf_subset="ruin_names"),
         | 
| 474 | 
            +
                CustomBBHEvaluationTask(
         | 
| 475 | 
            +
                    name="bbh:salient_translation_error_detection", hf_subset="salient_translation_error_detection"
         | 
| 476 | 
            +
                ),
         | 
| 477 | 
            +
                CustomBBHEvaluationTask(name="bbh:snarks", hf_subset="snarks"),
         | 
| 478 | 
            +
                CustomBBHEvaluationTask(name="bbh:sports_understanding", hf_subset="sports_understanding"),
         | 
| 479 | 
            +
                CustomBBHEvaluationTask(name="bbh:temporal_sequences", hf_subset="temporal_sequences"),
         | 
| 480 | 
            +
                CustomBBHEvaluationTask(
         | 
| 481 | 
            +
                    name="bbh:tracking_shuffled_objects_five_objects", hf_subset="tracking_shuffled_objects_five_objects"
         | 
| 482 | 
            +
                ),
         | 
| 483 | 
            +
                CustomBBHEvaluationTask(
         | 
| 484 | 
            +
                    name="bbh:tracking_shuffled_objects_seven_objects", hf_subset="tracking_shuffled_objects_seven_objects"
         | 
| 485 | 
            +
                ),
         | 
| 486 | 
            +
                CustomBBHEvaluationTask(
         | 
| 487 | 
            +
                    name="bbh:tracking_shuffled_objects_three_objects", hf_subset="tracking_shuffled_objects_three_objects"
         | 
| 488 | 
            +
                ),
         | 
| 489 | 
            +
                CustomBBHEvaluationTask(name="bbh:web_of_lies", hf_subset="web_of_lies"),
         | 
| 490 | 
            +
                CustomBBHEvaluationTask(name="bbh:word_sorting", hf_subset="word_sorting"),
         | 
| 491 | 
            +
            ]
         | 
| 492 | 
            +
             | 
| 493 | 
            +
             | 
| 494 | 
            +
            def bbh_prompt(line, task_name: str = None):
         | 
| 495 | 
            +
                return Doc(
         | 
| 496 | 
            +
                    task_name=task_name,
         | 
| 497 | 
            +
                    query=line["input"] + "\nAnswer: ",
         | 
| 498 | 
            +
                    choices=[line["target"]],
         | 
| 499 | 
            +
                    gold_index=0,
         | 
| 500 | 
            +
                )
         | 
| 501 | 
            +
             | 
| 502 | 
            +
             | 
| 503 | 
            +
            # BBH_STRING = {t: f'custom|{t.name}|3|1' for t in BBH_TASKS}
         | 
| 504 | 
            +
            BBH_STRING = [(t, f"custom|{t.name}|0|1") for t in BBH_TASKS]
         | 
| 505 | 
            +
            _TASKS_STRINGS.extend(BBH_STRING)
         | 
| 506 | 
            +
            _TASKS += BBH_TASKS
         | 
| 507 | 
            +
             | 
| 508 | 
            +
             | 
| 509 | 
            +
            ## AGI eval ##
         | 
| 510 | 
            +
            class CustomAGIEvalEvaluationTask(CustomEvaluationTask):
         | 
| 511 | 
            +
                def __init__(
         | 
| 512 | 
            +
                    self,
         | 
| 513 | 
            +
                    name,
         | 
| 514 | 
            +
                    prompt_function="agi_eval_prompt_no_letters",
         | 
| 515 | 
            +
                    hf_repo="lighteval/agi_eval_en",
         | 
| 516 | 
            +
                    hf_subset=None,
         | 
| 517 | 
            +
                    #  metric=[Metrics.loglikelihood_acc_single_token],
         | 
| 518 | 
            +
                    metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
         | 
| 519 | 
            +
                    hf_avail_splits=["train", "validation"],
         | 
| 520 | 
            +
                    evaluation_splits=["train"],
         | 
| 521 | 
            +
                    few_shots_split="validation",
         | 
| 522 | 
            +
                    few_shots_select=None,
         | 
| 523 | 
            +
                    suite=None,
         | 
| 524 | 
            +
                    generation_size=-1,
         | 
| 525 | 
            +
                    stop_sequence=None,
         | 
| 526 | 
            +
                    output_regex=None,
         | 
| 527 | 
            +
                    frozen=False,
         | 
| 528 | 
            +
                ):
         | 
| 529 | 
            +
                    super().__init__(
         | 
| 530 | 
            +
                        name=name,
         | 
| 531 | 
            +
                        prompt_function=prompt_function,
         | 
| 532 | 
            +
                        hf_repo=hf_repo,
         | 
| 533 | 
            +
                        hf_subset=hf_subset,
         | 
| 534 | 
            +
                        metric=metric,
         | 
| 535 | 
            +
                        hf_avail_splits=hf_avail_splits,
         | 
| 536 | 
            +
                        evaluation_splits=evaluation_splits,
         | 
| 537 | 
            +
                        few_shots_split=few_shots_split,
         | 
| 538 | 
            +
                        few_shots_select=few_shots_select,
         | 
| 539 | 
            +
                        suite=suite,
         | 
| 540 | 
            +
                        generation_size=generation_size,
         | 
| 541 | 
            +
                        stop_sequence=stop_sequence,
         | 
| 542 | 
            +
                        output_regex=output_regex,
         | 
| 543 | 
            +
                        frozen=frozen,
         | 
| 544 | 
            +
                    )
         | 
| 545 | 
            +
             | 
| 546 | 
            +
             | 
| 547 | 
            +
            AGIEVAL_TASKS = [
         | 
| 548 | 
            +
                CustomAGIEvalEvaluationTask(name="agi_eval:aqua_rat", hf_subset="aqua_rat"),
         | 
| 549 | 
            +
                CustomAGIEvalEvaluationTask(name="agi_eval:logiqa-en", hf_subset="logiqa-en"),
         | 
| 550 | 
            +
                CustomAGIEvalEvaluationTask(name="agi_eval:lsat-ar", hf_subset="lsat-ar"),
         | 
| 551 | 
            +
                CustomAGIEvalEvaluationTask(name="agi_eval:lsat-lr", hf_subset="lsat-lr"),
         | 
| 552 | 
            +
                CustomAGIEvalEvaluationTask(name="agi_eval:lsat-rc", hf_subset="lsat-rc"),
         | 
| 553 | 
            +
                CustomAGIEvalEvaluationTask(
         | 
| 554 | 
            +
                    name="agi_eval:math",
         | 
| 555 | 
            +
                    hf_subset="math",
         | 
| 556 | 
            +
                    prompt_function="agi_eval_math_prompt",
         | 
| 557 | 
            +
                    metric=[Metrics.exact_match, Metrics.quasi_exact_match2],
         | 
| 558 | 
            +
                    generation_size=40,
         | 
| 559 | 
            +
                ),
         | 
| 560 | 
            +
                CustomAGIEvalEvaluationTask(name="agi_eval:sat-en", hf_subset="sat-en"),
         | 
| 561 | 
            +
                CustomAGIEvalEvaluationTask(name="agi_eval:sat-math", hf_subset="sat-math"),
         | 
| 562 | 
            +
            ]
         | 
| 563 | 
            +
             | 
| 564 | 
            +
             | 
| 565 | 
            +
            def agi_eval_math_prompt(line, task_name: str = None):
         | 
| 566 | 
            +
                return Doc(
         | 
| 567 | 
            +
                    task_name=task_name,
         | 
| 568 | 
            +
                    query=line["question"],
         | 
| 569 | 
            +
                    choices=[line["answer"]],
         | 
| 570 | 
            +
                    gold_index=0,
         | 
| 571 | 
            +
                    instruction="",
         | 
| 572 | 
            +
                )
         | 
| 573 | 
            +
             | 
| 574 | 
            +
             | 
| 575 | 
            +
            def agi_eval_prompt(line, task_name: str = None):
         | 
| 576 | 
            +
                cleaned_options = [o.replace("(", "").replace(")", " ") for o in line["options"]]
         | 
| 577 | 
            +
                prompt = "The following are multiple choice questions (with answers).\n\n"
         | 
| 578 | 
            +
                prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n"
         | 
| 579 | 
            +
                prompt += "Answer: "
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                choices = LETTER_INDICES[: len(line["options"])]
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                output = Doc(
         | 
| 584 | 
            +
                    query=prompt,
         | 
| 585 | 
            +
                    instruction="The following are multiple choice questions (with answers).\n\n",
         | 
| 586 | 
            +
                )
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                if line["label"]:
         | 
| 589 | 
            +
                    output.choices = choices
         | 
| 590 | 
            +
                    output.gold_index = LETTER_INDICES.index(line["label"].strip())
         | 
| 591 | 
            +
                else:
         | 
| 592 | 
            +
                    output.choices = [line["answer"]]
         | 
| 593 | 
            +
                    output.gold_index = 0
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                return output
         | 
| 596 | 
            +
             | 
| 597 | 
            +
             | 
| 598 | 
            +
            def agi_eval_prompt_no_letters(line, task_name: str = None):
         | 
| 599 | 
            +
                cleaned_options = [
         | 
| 600 | 
            +
                    " " + o.replace("(A)", "").replace("(B)", "").replace("(C)", "").replace("(D)", "").replace("(E)", "")
         | 
| 601 | 
            +
                    for o in line["options"]
         | 
| 602 | 
            +
                ]
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                output = Doc(
         | 
| 605 | 
            +
                    query=line["question"],
         | 
| 606 | 
            +
                    choices=cleaned_options,
         | 
| 607 | 
            +
                    gold_index=LETTER_INDICES.index(line["label"].strip()),
         | 
| 608 | 
            +
                    instruction="",
         | 
| 609 | 
            +
                )
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                return output
         | 
| 612 | 
            +
             | 
| 613 | 
            +
             | 
| 614 | 
            +
            # AGIEVAL_STRING = {t: f'custom|{t.name}|5|1' for t in AGIEVAL_TASKS}
         | 
| 615 | 
            +
            AGIEVAL_STRING = [(t, f"custom|{t.name}|0|1") for t in AGIEVAL_TASKS]
         | 
| 616 | 
            +
            _TASKS_STRINGS.extend(AGIEVAL_STRING)
         | 
| 617 | 
            +
            _TASKS += AGIEVAL_TASKS
         | 
| 618 | 
            +
             | 
| 619 | 
            +
             | 
| 620 | 
            +
            ## HUMAN EVAL ##
         | 
| 621 | 
            +
            # human_eval = CustomEvaluationTask(
         | 
| 622 | 
            +
            #         name="human_eval",
         | 
| 623 | 
            +
            #         prompt_function="human_eval",
         | 
| 624 | 
            +
            #         hf_repo="lighteval/human_eval",
         | 
| 625 | 
            +
            #         metric=["human_eval_pass_at_1"],
         | 
| 626 | 
            +
            #     ),
         | 
| 627 | 
            +
             | 
| 628 | 
            +
             | 
| 629 | 
            +
            def has_generative_metrics(task: CustomEvaluationTask) -> bool:
         | 
| 630 | 
            +
                for metric in task.metric:
         | 
| 631 | 
            +
                    if metric in NEEDS_GENERATION_ONLY:
         | 
| 632 | 
            +
                        return True
         | 
| 633 | 
            +
                return False
         | 
| 634 | 
            +
             | 
| 635 | 
            +
             | 
| 636 | 
            +
            EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING])
         | 
| 637 | 
            +
             | 
| 638 | 
            +
            # Convert to dict for lighteval
         | 
| 639 | 
            +
            TASKS_TABLE = [asdict(task) for task in _TASKS]
         | 
| 640 | 
            +
            # You can have a few pre-organised groups of tasks
         | 
| 641 | 
            +
            TASKS_GROUPS = {
         | 
| 642 | 
            +
                "all": ",".join(t[1] for t in _TASKS_STRINGS),
         | 
| 643 | 
            +
                "early-signal": EARLY_SIGNAL_TASKS,
         | 
| 644 | 
            +
                "non-generatives": ",".join(t for k, t in _TASKS_STRINGS if not has_generative_metrics(k)),
         | 
| 645 | 
            +
                "generatives": ",".join(t for k, t in _TASKS_STRINGS if has_generative_metrics(k)),
         | 
| 646 | 
            +
            }
         | 
| 647 | 
            +
             | 
| 648 | 
            +
            if __name__ == "__main__":
         | 
| 649 | 
            +
                print(t["name"] for t in TASKS_TABLE)
         | 
| 650 | 
            +
                print(len(TASKS_TABLE))
         | 
    	
        custom_evaluation_utils.py
    ADDED
    
    | @@ -0,0 +1,158 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Custom evaluation tasks for lighteval
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            from dataclasses import dataclass
         | 
| 5 | 
            +
            from enum import Enum, auto
         | 
| 6 | 
            +
            from typing import Optional, Tuple, Union
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class Metrics(Enum):
         | 
| 10 | 
            +
                any_target_loglikelihood_acc = auto()
         | 
| 11 | 
            +
                bert_score = auto()
         | 
| 12 | 
            +
                bias = auto()
         | 
| 13 | 
            +
                bits_per_byte = auto()
         | 
| 14 | 
            +
                bleu = auto()
         | 
| 15 | 
            +
                bleu_1 = auto()
         | 
| 16 | 
            +
                bleu_4 = auto()
         | 
| 17 | 
            +
                byte_perplexity = auto()
         | 
| 18 | 
            +
                chrf = auto()
         | 
| 19 | 
            +
                code_eval_APPS = auto()
         | 
| 20 | 
            +
                code_eval_HE = auto()
         | 
| 21 | 
            +
                copyright = auto()
         | 
| 22 | 
            +
                disinformation = auto()
         | 
| 23 | 
            +
                exact_match = auto()
         | 
| 24 | 
            +
                exact_set_match = auto()
         | 
| 25 | 
            +
                extractiveness = auto()
         | 
| 26 | 
            +
                f1_from_bags = auto()
         | 
| 27 | 
            +
                f1_quasi = auto()
         | 
| 28 | 
            +
                f1_sequence = auto()
         | 
| 29 | 
            +
                f1_set_match = auto()
         | 
| 30 | 
            +
                faithfulness = auto()
         | 
| 31 | 
            +
                iou_set_match = auto()
         | 
| 32 | 
            +
                log_prob = auto()
         | 
| 33 | 
            +
                loglikelihood_acc = auto()
         | 
| 34 | 
            +
                loglikelihood_acc_norm = auto()
         | 
| 35 | 
            +
                loglikelihood_acc_norm_nospace = auto()
         | 
| 36 | 
            +
                loglikelihood_acc_norm_single_token = auto()
         | 
| 37 | 
            +
                loglikelihood_acc_single_token = auto()
         | 
| 38 | 
            +
                loglikelihood_f1 = auto()
         | 
| 39 | 
            +
                loglikelihood_f1_single_token = auto()
         | 
| 40 | 
            +
                math_quasi_exact_match = auto()
         | 
| 41 | 
            +
                mc_taco = auto()
         | 
| 42 | 
            +
                mcc = auto()
         | 
| 43 | 
            +
                mcc_single_token = auto()
         | 
| 44 | 
            +
                mrr = auto()
         | 
| 45 | 
            +
                mrr_single_token = auto()
         | 
| 46 | 
            +
                multi_fi_numeric = auto()
         | 
| 47 | 
            +
                one_choice_loglikelihood_acc = auto()
         | 
| 48 | 
            +
                perfect_exact_match = auto()
         | 
| 49 | 
            +
                prediction_perplexity = auto()
         | 
| 50 | 
            +
                prefix_exact_match = auto()
         | 
| 51 | 
            +
                prefix_quasi_exact_match = auto()
         | 
| 52 | 
            +
                quasi_exact_match = auto()
         | 
| 53 | 
            +
                ranking = auto()
         | 
| 54 | 
            +
                recall_at_1_single_token = auto()
         | 
| 55 | 
            +
                recall_at_2_single_token = auto()
         | 
| 56 | 
            +
                recall_at_1 = auto()
         | 
| 57 | 
            +
                recall_at_2 = auto()
         | 
| 58 | 
            +
                rouge = auto()
         | 
| 59 | 
            +
                rouge_1 = auto()
         | 
| 60 | 
            +
                rouge_2 = auto()
         | 
| 61 | 
            +
                rouge_l = auto()
         | 
| 62 | 
            +
                target_perplexity = auto()
         | 
| 63 | 
            +
                ter = auto()
         | 
| 64 | 
            +
                toxicity = auto()
         | 
| 65 | 
            +
                truthfulqa_mc_metrics = auto()
         | 
| 66 | 
            +
                word_perplexity = auto()
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def __str__(self):
         | 
| 69 | 
            +
                    return self.name.replace("_at_", "@")
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            NEEDS_GENERATION_ONLY = [
         | 
| 73 | 
            +
                "perfect_exact_match",
         | 
| 74 | 
            +
                "exact_match",
         | 
| 75 | 
            +
                "quasi_exact_match",
         | 
| 76 | 
            +
                "quasi_exact_match2",
         | 
| 77 | 
            +
                "prefix_exact_match",
         | 
| 78 | 
            +
                "prefix_quasi_exact_match",
         | 
| 79 | 
            +
                "math_quasi_exact_match",
         | 
| 80 | 
            +
                "iou_set_match",
         | 
| 81 | 
            +
                "exact_set_match",
         | 
| 82 | 
            +
                "f1_sequence",
         | 
| 83 | 
            +
                "f1_quasi",
         | 
| 84 | 
            +
                "f1_set_match",
         | 
| 85 | 
            +
                "f1_from_bags",
         | 
| 86 | 
            +
                "chrf",
         | 
| 87 | 
            +
                "ter",
         | 
| 88 | 
            +
                "rouge",
         | 
| 89 | 
            +
                "rouge_1",
         | 
| 90 | 
            +
                "rouge_2",
         | 
| 91 | 
            +
                "rouge_l",
         | 
| 92 | 
            +
                "faithfulness",
         | 
| 93 | 
            +
                "extractiveness",
         | 
| 94 | 
            +
                "bert_score",
         | 
| 95 | 
            +
                "bleu",
         | 
| 96 | 
            +
                "bleu_1",
         | 
| 97 | 
            +
                "bleu_4",
         | 
| 98 | 
            +
                "bias",
         | 
| 99 | 
            +
                "toxicity",
         | 
| 100 | 
            +
                "code_eval_HE",
         | 
| 101 | 
            +
                "code_eval_APPS",
         | 
| 102 | 
            +
                "copyright",
         | 
| 103 | 
            +
            ]
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            @dataclass(unsafe_hash=True)
         | 
| 107 | 
            +
            class CustomEvaluationTask:
         | 
| 108 | 
            +
                name: str
         | 
| 109 | 
            +
                prompt_function: str
         | 
| 110 | 
            +
                hf_repo: str
         | 
| 111 | 
            +
                hf_subset: str
         | 
| 112 | 
            +
                metric: Tuple[Union[str, Metrics]]
         | 
| 113 | 
            +
                hf_avail_splits: Optional[Tuple[str]] = None
         | 
| 114 | 
            +
                evaluation_splits: Optional[Tuple[str]] = None
         | 
| 115 | 
            +
                few_shots_split: Optional[str] = None
         | 
| 116 | 
            +
                few_shots_select: Optional[str] = None
         | 
| 117 | 
            +
                generation_size: int = -1
         | 
| 118 | 
            +
                stop_sequence: Optional[Tuple[str]] = None
         | 
| 119 | 
            +
                output_regex: Optional[str] = None
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                frozen: bool = False
         | 
| 122 | 
            +
                suite: Optional[Tuple[str]] = None  # we use this to know if we should use a custom lighteval or bigcode task
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def __post_init__(self):
         | 
| 125 | 
            +
                    self.metric = [str(m) for m in self.metric]
         | 
| 126 | 
            +
                    if self.suite is None:
         | 
| 127 | 
            +
                        self.suite = ["custom"]
         | 
| 128 | 
            +
                    if self.hf_avail_splits is None:
         | 
| 129 | 
            +
                        self.hf_avail_splits = ["train", "validation", "test"]
         | 
| 130 | 
            +
                    if self.evaluation_splits is None:
         | 
| 131 | 
            +
                        self.evaluation_splits = ["validation"]
         | 
| 132 | 
            +
                    if self.stop_sequence is None:
         | 
| 133 | 
            +
                        self.stop_sequence = ["\n"]
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # Convert list to tuple for hashing
         | 
| 136 | 
            +
                    self.metric = tuple(self.metric)
         | 
| 137 | 
            +
                    self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits else None
         | 
| 138 | 
            +
                    self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits else None
         | 
| 139 | 
            +
                    self.suite = tuple(self.suite) if self.suite else None
         | 
| 140 | 
            +
                    self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence else None
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            @dataclass(unsafe_hash=True)
         | 
| 144 | 
            +
            class BigCodeEvaluationTask:
         | 
| 145 | 
            +
                name: str
         | 
| 146 | 
            +
                bigcode_task: str
         | 
| 147 | 
            +
                bigcode_task_kwargs: Optional[dict] = None
         | 
| 148 | 
            +
                n_samples: int = 1
         | 
| 149 | 
            +
                prefix: Optional[str] = None
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                suite: Tuple[str] = None
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                def __post_init__(self):
         | 
| 154 | 
            +
                    if self.suite is None:
         | 
| 155 | 
            +
                        self.suite = ("bigcode",)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    # Convert list to tuple for hashing
         | 
| 158 | 
            +
                    self.suite = tuple(self.suite)
         | 
    	
        lighteval_eval_config.yaml
    ADDED
    
    | @@ -0,0 +1,45 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            checkpoints: null
         | 
| 2 | 
            +
            data: null
         | 
| 3 | 
            +
            experiment_logger: null
         | 
| 4 | 
            +
            general: null
         | 
| 5 | 
            +
            kill_switch_path: null
         | 
| 6 | 
            +
            lighteval:
         | 
| 7 | 
            +
              batch_size: 24
         | 
| 8 | 
            +
              checkpoints_path: null
         | 
| 9 | 
            +
              generation: null
         | 
| 10 | 
            +
              logging:
         | 
| 11 | 
            +
                hub_repo_details: null
         | 
| 12 | 
            +
                hub_repo_results: null
         | 
| 13 | 
            +
                hub_repo_tensorboard: HuggingFaceBR4/thomwolf-nanotron-mistral-7b
         | 
| 14 | 
            +
                local_output_path: /scratch/thomwolf/lighteval/nanotron-mistral-7b
         | 
| 15 | 
            +
                push_details_to_hub: false
         | 
| 16 | 
            +
                push_results_to_hub: false
         | 
| 17 | 
            +
                push_results_to_tensorboard: true
         | 
| 18 | 
            +
                tensorboard_metric_prefix: e
         | 
| 19 | 
            +
              parallelism:
         | 
| 20 | 
            +
                dp: 4
         | 
| 21 | 
            +
                pp: 1
         | 
| 22 | 
            +
                pp_engine: 1f1b
         | 
| 23 | 
            +
                recompute_granularity: null
         | 
| 24 | 
            +
                tp: 2
         | 
| 25 | 
            +
                tp_linear_async_communication: false
         | 
| 26 | 
            +
                tp_mode: ALL_REDUCE
         | 
| 27 | 
            +
              slurm: null
         | 
| 28 | 
            +
              slurm_script_dir: null
         | 
| 29 | 
            +
              slurm_template: null
         | 
| 30 | 
            +
              tasks:
         | 
| 31 | 
            +
                custom_tasks_file: ./custom_evaluation_tasks.py
         | 
| 32 | 
            +
                dataset_loading_processes: 8
         | 
| 33 | 
            +
                max_samples: 1000
         | 
| 34 | 
            +
                multichoice_continuations_start_space: null
         | 
| 35 | 
            +
                no_multichoice_continuations_start_space: null
         | 
| 36 | 
            +
                num_fewshot_seeds: null
         | 
| 37 | 
            +
                tasks: early-signal
         | 
| 38 | 
            +
            logging: null
         | 
| 39 | 
            +
            model: null
         | 
| 40 | 
            +
            optimizer: null
         | 
| 41 | 
            +
            parallelism: null
         | 
| 42 | 
            +
            profiler: null
         | 
| 43 | 
            +
            s3_upload: null
         | 
| 44 | 
            +
            tokenizer: null
         | 
| 45 | 
            +
            tokens: null
         | 
    	
        run_evals.py
    ADDED
    
    | @@ -0,0 +1,442 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Nanotron Inference Script
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            Usage:
         | 
| 5 | 
            +
            ```
         | 
| 6 | 
            +
            export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
         | 
| 7 | 
            +
            torchrun --nproc_per_node=8 run_evals.py --checkpoint-config-path ./pretrained/Mistral-7B-v0.1/config.yaml \
         | 
| 8 | 
            +
                --lighteval-override ./lighteval_eval_config.yaml
         | 
| 9 | 
            +
            ```
         | 
| 10 | 
            +
            """
         | 
| 11 | 
            +
            # flake8: noqa: C901
         | 
| 12 | 
            +
            import argparse
         | 
| 13 | 
            +
            import os
         | 
| 14 | 
            +
            import random
         | 
| 15 | 
            +
            import time
         | 
| 16 | 
            +
            from dataclasses import asdict
         | 
| 17 | 
            +
            from pathlib import Path
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import numpy as np
         | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
            from huggingface_hub import HFSummaryWriter
         | 
| 22 | 
            +
            from lighteval.evaluator import evaluate, make_results_table
         | 
| 23 | 
            +
            from lighteval.logging.evaluation_tracker import EvaluationTracker
         | 
| 24 | 
            +
            from lighteval.logging.hierarchical_logger import hlog, htrack, htrack_block
         | 
| 25 | 
            +
            from lighteval.logging.info_loggers import (
         | 
| 26 | 
            +
                DetailsLogger,
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
            from lighteval.models.model_loader import ModelInfo
         | 
| 29 | 
            +
            from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
         | 
| 30 | 
            +
            from lighteval.tasks.registry import Registry, get_custom_tasks, taskinfo_selector
         | 
| 31 | 
            +
            from nanotron import distributed as dist
         | 
| 32 | 
            +
            from nanotron import logging
         | 
| 33 | 
            +
            from nanotron.config import get_config_from_file
         | 
| 34 | 
            +
            from nanotron.logging import get_logger, log_rank
         | 
| 35 | 
            +
            from nanotron.parallel.context import ParallelContext
         | 
| 36 | 
            +
            from nanotron.utils import local_ranks_zero_first
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            from brrr.config import BrrrConfig
         | 
| 39 | 
            +
            from brrr.experiment_loggers import flatten_dict, obj_to_markdown
         | 
| 40 | 
            +
            from brrr.s3_checkpoints import fs_copy
         | 
| 41 | 
            +
            from brrr.utils import check_env
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            from lighteval.models.brrr_models import BRRRModel
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            from modeling_mistral import MistralForTraining
         | 
| 46 | 
            +
            from config_mistral import MistralConfig
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            logger = get_logger(__name__)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            TOKEN = os.getenv("HF_TOKEN")
         | 
| 51 | 
            +
            CACHE_DIR = os.getenv("HF_HOME", "/scratch")
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def get_parser():
         | 
| 55 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 56 | 
            +
                parser.add_argument(
         | 
| 57 | 
            +
                    "--checkpoint-config-path",
         | 
| 58 | 
            +
                    type=str,
         | 
| 59 | 
            +
                    required=True,
         | 
| 60 | 
            +
                    help="Path to the brr checkpoint YAML or python config file, potentially on S3",
         | 
| 61 | 
            +
                )
         | 
| 62 | 
            +
                parser.add_argument(
         | 
| 63 | 
            +
                    "--lighteval-override",
         | 
| 64 | 
            +
                    type=str,
         | 
| 65 | 
            +
                    help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config",
         | 
| 66 | 
            +
                )
         | 
| 67 | 
            +
                parser.add_argument(
         | 
| 68 | 
            +
                    "--tokenizer",
         | 
| 69 | 
            +
                    type=str,
         | 
| 70 | 
            +
                    help="Local or hub path of an optional tokenizer (if not indicated in the checkpoint)",
         | 
| 71 | 
            +
                )
         | 
| 72 | 
            +
                parser.add_argument(
         | 
| 73 | 
            +
                    "--s5cmd-path",
         | 
| 74 | 
            +
                    type=str,
         | 
| 75 | 
            +
                    default="/admin/home/thomwolf/miniconda3/envs/b4r/bin/s5cmd",
         | 
| 76 | 
            +
                    help="Path to s5cmd install",
         | 
| 77 | 
            +
                )
         | 
| 78 | 
            +
                parser.add_argument(
         | 
| 79 | 
            +
                    "--s5cmd-numworkers",
         | 
| 80 | 
            +
                    type=int,
         | 
| 81 | 
            +
                    default=64,
         | 
| 82 | 
            +
                    help="s5cmd num workers (optional)",
         | 
| 83 | 
            +
                )
         | 
| 84 | 
            +
                parser.add_argument(
         | 
| 85 | 
            +
                    "--s5cmd-concurrency",
         | 
| 86 | 
            +
                    type=int,
         | 
| 87 | 
            +
                    default=10,
         | 
| 88 | 
            +
                    help="s5cmd concurrency (optional)",
         | 
| 89 | 
            +
                )
         | 
| 90 | 
            +
                parser.add_argument(
         | 
| 91 | 
            +
                    "--cache-dir",
         | 
| 92 | 
            +
                    type=str,
         | 
| 93 | 
            +
                    default="",
         | 
| 94 | 
            +
                    help="Cache directory",
         | 
| 95 | 
            +
                )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                return parser
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            def push_results_to_wandb(  # noqa: C901
         | 
| 101 | 
            +
                config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
         | 
| 102 | 
            +
            ):
         | 
| 103 | 
            +
                # config: BrrrConfig = get_config_from_dict(config, config_class=BrrrConfig)
         | 
| 104 | 
            +
                lighteval_config = config.lighteval
         | 
| 105 | 
            +
                try:
         | 
| 106 | 
            +
                    global_step = config.general.step
         | 
| 107 | 
            +
                except ValueError:
         | 
| 108 | 
            +
                    global_step = 0
         | 
| 109 | 
            +
                if config.lighteval.logging.tensorboard_metric_prefix is not None:
         | 
| 110 | 
            +
                    prefix = config.lighteval.logging.tensorboard_metric_prefix
         | 
| 111 | 
            +
                else:
         | 
| 112 | 
            +
                    prefix = "eval"
         | 
| 113 | 
            +
                output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix)
         | 
| 114 | 
            +
                output_dir_tb.mkdir(parents=True, exist_ok=True)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                os.environ["WANDB_DISABLE_SERVICE"] = "True"
         | 
| 117 | 
            +
                import wandb
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                wandb.tensorboard.patch(root_logdir=config.lighteval.logging.local_output_path)
         | 
| 120 | 
            +
                hlog("Starting wandb with WANDB_DISABLE_SERVICE=True")
         | 
| 121 | 
            +
                wandb.init(
         | 
| 122 | 
            +
                    project=config.lighteval.wandb.wandb_project,
         | 
| 123 | 
            +
                    entity=config.lighteval.wandb.wandb_entity,
         | 
| 124 | 
            +
                    name=config.lighteval.wandb.wandb_run_name,
         | 
| 125 | 
            +
                    config=config.as_dict(),
         | 
| 126 | 
            +
                    # sync_tensorboard=True,
         | 
| 127 | 
            +
                    resume=True,
         | 
| 128 | 
            +
                )
         | 
| 129 | 
            +
                wb_dict = {}
         | 
| 130 | 
            +
                bench_averages = {}
         | 
| 131 | 
            +
                for name, values in results.items():
         | 
| 132 | 
            +
                    splited_name = name.split("|")
         | 
| 133 | 
            +
                    if len(splited_name) == 3:
         | 
| 134 | 
            +
                        _, task_name, _ = splited_name
         | 
| 135 | 
            +
                    else:
         | 
| 136 | 
            +
                        task_name = name
         | 
| 137 | 
            +
                    bench_suite = None
         | 
| 138 | 
            +
                    if ":" in task_name:
         | 
| 139 | 
            +
                        bench_suite = task_name.split(":")[0]  # e.g. MMLU
         | 
| 140 | 
            +
                        hlog(f"bench_suite {bench_suite} in {task_name}")
         | 
| 141 | 
            +
                        for metric, value in values.items():
         | 
| 142 | 
            +
                            if "stderr" in metric:
         | 
| 143 | 
            +
                                continue
         | 
| 144 | 
            +
                            if bench_suite not in bench_averages:
         | 
| 145 | 
            +
                                bench_averages[bench_suite] = {}
         | 
| 146 | 
            +
                            bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)]
         | 
| 147 | 
            +
                    hlog(f"Pushing {task_name} {values} to tensorboard")
         | 
| 148 | 
            +
                    for metric, value in values.items():
         | 
| 149 | 
            +
                        if "stderr" in metric:
         | 
| 150 | 
            +
                            wb_dict[f"stderr_{metric}/{task_name}"] = value
         | 
| 151 | 
            +
                        elif bench_suite is not None:
         | 
| 152 | 
            +
                            wb_dict[f"{bench_suite}-{metric}/{task_name}"] = value
         | 
| 153 | 
            +
                        else:
         | 
| 154 | 
            +
                            wb_dict[f"{metric}/{task_name}"] = value
         | 
| 155 | 
            +
                # e.g. MMLU
         | 
| 156 | 
            +
                for name, values in bench_averages.items():
         | 
| 157 | 
            +
                    for metric, values in values.items():
         | 
| 158 | 
            +
                        hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard")
         | 
| 159 | 
            +
                        wb_dict[f"{metric}/{name}"] = sum(values) / len(values)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                for task_name, task_details in details.items():
         | 
| 162 | 
            +
                    if len(task_details) <= 1:
         | 
| 163 | 
            +
                        continue
         | 
| 164 | 
            +
                    columns = list(flatten_dict(asdict(task_details[0])).keys())
         | 
| 165 | 
            +
                    table = wandb.Table(columns=columns)
         | 
| 166 | 
            +
                    table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[0])).values()])
         | 
| 167 | 
            +
                    table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[1])).values()])
         | 
| 168 | 
            +
                    wandb.log({f"eval_details_{task_name}": table}, step=global_step, commit=False)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                wandb.log(dict(wb_dict.items()), step=global_step, commit=True)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                # tb_context.add_text("eval_sizes", obj_to_markdown(sizes), global_step=global_step)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                # We are doing parallel evaluations of multiple checkpoints and recording the steps not in order
         | 
| 175 | 
            +
                # This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints
         | 
| 176 | 
            +
                # See: https://github.com/tensorflow/tensorboard/issues/5958
         | 
| 177 | 
            +
                # But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                hlog(f"Pushed to wandb" f" at {output_dir_tb} and global_step {global_step}")
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            def push_results_to_tensorboard(  # noqa: C901
         | 
| 183 | 
            +
                config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
         | 
| 184 | 
            +
            ):
         | 
| 185 | 
            +
                # config: BrrrConfig = get_config_from_dict(config, config_class=BrrrConfig)
         | 
| 186 | 
            +
                lighteval_config = config.lighteval
         | 
| 187 | 
            +
                try:
         | 
| 188 | 
            +
                    global_step = config.general.step
         | 
| 189 | 
            +
                except ValueError:
         | 
| 190 | 
            +
                    global_step = 0
         | 
| 191 | 
            +
                if config.lighteval.logging.tensorboard_metric_prefix is not None:
         | 
| 192 | 
            +
                    prefix = config.lighteval.logging.tensorboard_metric_prefix
         | 
| 193 | 
            +
                else:
         | 
| 194 | 
            +
                    prefix = "eval"
         | 
| 195 | 
            +
                output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix)
         | 
| 196 | 
            +
                output_dir_tb.mkdir(parents=True, exist_ok=True)
         | 
| 197 | 
            +
                tb_context = HFSummaryWriter(
         | 
| 198 | 
            +
                    logdir=str(output_dir_tb),
         | 
| 199 | 
            +
                    repo_id=lighteval_config.logging.hub_repo_tensorboard,
         | 
| 200 | 
            +
                    repo_private=True,
         | 
| 201 | 
            +
                    path_in_repo="tb",
         | 
| 202 | 
            +
                    commit_every=6000,  # Very long time so that we can change our files names and trigger push ourselves (see below)
         | 
| 203 | 
            +
                )
         | 
| 204 | 
            +
                bench_averages = {}
         | 
| 205 | 
            +
                for name, values in results.items():
         | 
| 206 | 
            +
                    splited_name = name.split("|")
         | 
| 207 | 
            +
                    if len(splited_name) == 3:
         | 
| 208 | 
            +
                        _, task_name, _ = splited_name
         | 
| 209 | 
            +
                    else:
         | 
| 210 | 
            +
                        task_name = name
         | 
| 211 | 
            +
                    bench_suite = None
         | 
| 212 | 
            +
                    if ":" in task_name:
         | 
| 213 | 
            +
                        bench_suite = task_name.split(":")[0]  # e.g. MMLU
         | 
| 214 | 
            +
                        hlog(f"bench_suite {bench_suite} in {task_name}")
         | 
| 215 | 
            +
                        for metric, value in values.items():
         | 
| 216 | 
            +
                            if "stderr" in metric:
         | 
| 217 | 
            +
                                continue
         | 
| 218 | 
            +
                            if bench_suite not in bench_averages:
         | 
| 219 | 
            +
                                bench_averages[bench_suite] = {}
         | 
| 220 | 
            +
                            bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)]
         | 
| 221 | 
            +
                    hlog(f"Pushing {task_name} {values} to tensorboard")
         | 
| 222 | 
            +
                    for metric, value in values.items():
         | 
| 223 | 
            +
                        if "stderr" in metric:
         | 
| 224 | 
            +
                            tb_context.add_scalar(f"stderr_{prefix}/{task_name}/{metric}", value, global_step=global_step)
         | 
| 225 | 
            +
                        elif bench_suite is not None:
         | 
| 226 | 
            +
                            tb_context.add_scalar(f"{prefix}_{bench_suite}/{task_name}/{metric}", value, global_step=global_step)
         | 
| 227 | 
            +
                        else:
         | 
| 228 | 
            +
                            tb_context.add_scalar(f"{prefix}/{task_name}/{metric}", value, global_step=global_step)
         | 
| 229 | 
            +
                # e.g. MMLU
         | 
| 230 | 
            +
                for name, values in bench_averages.items():
         | 
| 231 | 
            +
                    for metric, values in values.items():
         | 
| 232 | 
            +
                        hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard")
         | 
| 233 | 
            +
                        tb_context.add_scalar(f"{prefix}/{name}/{metric}", sum(values) / len(values), global_step=global_step)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                tb_context.add_text("eval_config", obj_to_markdown(results), global_step=global_step)
         | 
| 236 | 
            +
                # tb_context.add_text("eval_sizes", obj_to_markdown(sizes), global_step=global_step)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                for task_name, task_details in details.items():
         | 
| 239 | 
            +
                    tb_context.add_text(
         | 
| 240 | 
            +
                        f"eval_details_{task_name}",
         | 
| 241 | 
            +
                        obj_to_markdown({"0": task_details[0], "1": task_details[1] if len(task_details) > 1 else {}}),
         | 
| 242 | 
            +
                        global_step=global_step,
         | 
| 243 | 
            +
                    )
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                # We are doing parallel evaluations of multiple checkpoints and recording the steps not in order
         | 
| 246 | 
            +
                # This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints
         | 
| 247 | 
            +
                # See: https://github.com/tensorflow/tensorboard/issues/5958
         | 
| 248 | 
            +
                # But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                tb_context.close()  # flushes the unfinished write operations
         | 
| 251 | 
            +
                time.sleep(5)
         | 
| 252 | 
            +
                files = os.listdir(output_dir_tb)
         | 
| 253 | 
            +
                for file in files:
         | 
| 254 | 
            +
                    os.rename(os.path.join(output_dir_tb, file), os.path.join(output_dir_tb, f"{global_step:07d}_{file}"))
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                # Now we can push to the hub
         | 
| 257 | 
            +
                tb_context.scheduler.trigger()
         | 
| 258 | 
            +
                hlog(
         | 
| 259 | 
            +
                    f"Pushed to tensorboard at https://huggingface.co/tensorboard/{lighteval_config.logging.hub_repo_tensorboard}/"
         | 
| 260 | 
            +
                    f" at {output_dir_tb} and global_step {global_step}"
         | 
| 261 | 
            +
                )
         | 
| 262 | 
            +
             | 
| 263 | 
            +
             | 
| 264 | 
            +
            @htrack()
         | 
| 265 | 
            +
            def main(args):
         | 
| 266 | 
            +
                cache_dir = args.cache_dir or CACHE_DIR
         | 
| 267 | 
            +
                check_env()
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                dist.initialize_torch_distributed()
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                with htrack_block("get config"):
         | 
| 272 | 
            +
                    if not args.checkpoint_config_path.endswith(".yaml"):
         | 
| 273 | 
            +
                        raise ValueError("The checkpoint path should point to a YAML file")
         | 
| 274 | 
            +
                    local_config_path = args.checkpoint_config_path
         | 
| 275 | 
            +
                    if args.checkpoint_config_path.startswith("s3:/"):
         | 
| 276 | 
            +
                        local_config_path = args.checkpoint_config_path.replace("s3:/", cache_dir)
         | 
| 277 | 
            +
                        with local_ranks_zero_first():
         | 
| 278 | 
            +
                            if os.environ.get("LOCAL_RANK", None) == "0":
         | 
| 279 | 
            +
                                os.makedirs(os.path.dirname(local_config_path), exist_ok=True)
         | 
| 280 | 
            +
                                fs_copy(args.checkpoint_config_path, local_config_path)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    brrr_config: BrrrConfig = get_config_from_file(local_config_path, config_class=BrrrConfig, model_config_class=MistralConfig)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    if args.lighteval_override:
         | 
| 285 | 
            +
                        local_override_path = args.lighteval_override.replace("s3:/", cache_dir)
         | 
| 286 | 
            +
                        if args.lighteval_override.startswith("s3:/"):
         | 
| 287 | 
            +
                            local_override_path = args.lighteval_override.replace("s3:/", cache_dir)
         | 
| 288 | 
            +
                            with local_ranks_zero_first():
         | 
| 289 | 
            +
                                if os.environ.get("LOCAL_RANK", None) == "0":
         | 
| 290 | 
            +
                                    os.makedirs(os.path.dirname(local_override_path), exist_ok=True)
         | 
| 291 | 
            +
                                    fs_copy(args.lighteval_override, local_override_path)
         | 
| 292 | 
            +
                        lighteval_brrr_config: BrrrConfig = get_config_from_file(local_override_path, config_class=BrrrConfig)
         | 
| 293 | 
            +
                        lighteval_config = lighteval_brrr_config.lighteval
         | 
| 294 | 
            +
                        brrr_config.lighteval = lighteval_config
         | 
| 295 | 
            +
                    else:
         | 
| 296 | 
            +
                        local_override_path = ""
         | 
| 297 | 
            +
                        lighteval_config = brrr_config.lighteval
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    parallel_context = ParallelContext(
         | 
| 300 | 
            +
                        tensor_parallel_size=lighteval_config.parallelism.tp,
         | 
| 301 | 
            +
                        pipeline_parallel_size=lighteval_config.parallelism.pp,
         | 
| 302 | 
            +
                        data_parallel_size=lighteval_config.parallelism.dp,
         | 
| 303 | 
            +
                    )
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    evaluation_tracker = EvaluationTracker(token=TOKEN)
         | 
| 306 | 
            +
                    evaluation_tracker.general_config_logger.log_args_info(
         | 
| 307 | 
            +
                        num_fewshot_seeds=1,
         | 
| 308 | 
            +
                        override_batch_size=None,
         | 
| 309 | 
            +
                        max_samples=lighteval_config.tasks.max_samples,
         | 
| 310 | 
            +
                        job_id=os.environ.get("SLURM_JOB_ID", None),
         | 
| 311 | 
            +
                        config=brrr_config.as_dict(),
         | 
| 312 | 
            +
                    )
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                with htrack_block("Test all gather"):
         | 
| 315 | 
            +
                    hlog("Test gather tensor")
         | 
| 316 | 
            +
                    # Do a first NCCL sync to warmup and try to avoid Timeout after model/data loading
         | 
| 317 | 
            +
                    log_rank(
         | 
| 318 | 
            +
                        f"[TEST] Running NCCL sync for ranks {list(range(parallel_context.world_pg.size()))}",
         | 
| 319 | 
            +
                        logger=logger,
         | 
| 320 | 
            +
                        level=logging.WARNING,
         | 
| 321 | 
            +
                        group=parallel_context.dp_pg,
         | 
| 322 | 
            +
                        rank=0,
         | 
| 323 | 
            +
                    )
         | 
| 324 | 
            +
                    test_tensor = torch.tensor([dist.get_rank(parallel_context.world_pg)], device=torch.device("cuda"))
         | 
| 325 | 
            +
                    test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(parallel_context.world_pg.size())]
         | 
| 326 | 
            +
                    dist.all_gather(test_tensor_list, test_tensor, group=parallel_context.world_pg, async_op=False)
         | 
| 327 | 
            +
                    dist.barrier()
         | 
| 328 | 
            +
                    log_rank(
         | 
| 329 | 
            +
                        f"[TEST] NCCL sync for ranks {[t.item() for t in test_tensor_list]}",
         | 
| 330 | 
            +
                        logger=logger,
         | 
| 331 | 
            +
                        level=logging.WARNING,
         | 
| 332 | 
            +
                        group=parallel_context.dp_pg,
         | 
| 333 | 
            +
                        rank=0,
         | 
| 334 | 
            +
                    )
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    del test_tensor_list
         | 
| 337 | 
            +
                    del test_tensor
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                with htrack_block("Model loading"):
         | 
| 340 | 
            +
                    # We need to load the model in the main process first to avoid downloading the model multiple times
         | 
| 341 | 
            +
                    model = BRRRModel(
         | 
| 342 | 
            +
                        checkpoint_path=args.checkpoint_config_path.replace("config.yaml", ""),
         | 
| 343 | 
            +
                        model_args=brrr_config.model,
         | 
| 344 | 
            +
                        tokenizer=brrr_config.tokenizer,
         | 
| 345 | 
            +
                        parallel_context=parallel_context,
         | 
| 346 | 
            +
                        parallel_config=lighteval_config.parallelism,
         | 
| 347 | 
            +
                        lighteval_config=lighteval_config,
         | 
| 348 | 
            +
                        batch_size=lighteval_config.batch_size,
         | 
| 349 | 
            +
                        cache_dir=os.environ.get("HF_HOME", "/scratch"),
         | 
| 350 | 
            +
                        debug_one_layer_model=False,
         | 
| 351 | 
            +
                        s5cmd_path=args.s5cmd_path,
         | 
| 352 | 
            +
                        s5cmd_numworkers=args.s5cmd_numworkers,
         | 
| 353 | 
            +
                        s5cmd_concurrency=args.s5cmd_concurrency,
         | 
| 354 | 
            +
                        model_class=MistralForTraining
         | 
| 355 | 
            +
                    )
         | 
| 356 | 
            +
                    model_info = ModelInfo(model_name=f"{brrr_config.general.run}/{brrr_config.general.step}")
         | 
| 357 | 
            +
                    evaluation_tracker.general_config_logger.log_model_info(model_info)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                with htrack_block("Tasks loading"):
         | 
| 360 | 
            +
                    with local_ranks_zero_first():
         | 
| 361 | 
            +
                        tasks_selection = lighteval_config.tasks.tasks
         | 
| 362 | 
            +
                        if lighteval_config.tasks.custom_tasks_file:
         | 
| 363 | 
            +
                            _, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks_file)
         | 
| 364 | 
            +
                            if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict:
         | 
| 365 | 
            +
                                tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks]
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                        task_names_list, few_shots_dict = taskinfo_selector(tasks_selection)
         | 
| 368 | 
            +
                        task_dict = Registry(cache_dir=cache_dir).get_task_dict(
         | 
| 369 | 
            +
                            task_names_list, custom_tasks_file=lighteval_config.tasks.custom_tasks_file
         | 
| 370 | 
            +
                        )
         | 
| 371 | 
            +
                        # Loading all the dataset in a distributed manner
         | 
| 372 | 
            +
                        LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                        evaluation_tracker.task_config_logger.log(task_dict)
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                        hlog("Loading documents, and requests")
         | 
| 377 | 
            +
                        requests, docs = create_requests_from_tasks(
         | 
| 378 | 
            +
                            task_dict=task_dict,
         | 
| 379 | 
            +
                            fewshot_dict=few_shots_dict,
         | 
| 380 | 
            +
                            num_fewshot_seeds=lighteval_config.tasks.num_fewshot_seeds or 1,
         | 
| 381 | 
            +
                            lm=model,
         | 
| 382 | 
            +
                            max_samples=lighteval_config.tasks.max_samples,
         | 
| 383 | 
            +
                            evaluation_tracker=evaluation_tracker,
         | 
| 384 | 
            +
                            use_chat_template=False
         | 
| 385 | 
            +
                        )
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                with htrack_block("Setting seeds and waiting for all processes"):
         | 
| 388 | 
            +
                    hlog(f"setting seed to {1234} for random and numpy")
         | 
| 389 | 
            +
                    random.seed(1234)
         | 
| 390 | 
            +
                    np.random.seed(1234)
         | 
| 391 | 
            +
                    dist.barrier()
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                with htrack_block("Evaluation"):
         | 
| 394 | 
            +
                    hlog(f"Evaluate on {len(task_names_list)} tasks.")
         | 
| 395 | 
            +
                    evaluation_tracker = evaluate(
         | 
| 396 | 
            +
                        lm=model,
         | 
| 397 | 
            +
                        requests_dict=requests,
         | 
| 398 | 
            +
                        docs=docs,
         | 
| 399 | 
            +
                        task_dict=task_dict,
         | 
| 400 | 
            +
                        override_bs=lighteval_config.batch_size,
         | 
| 401 | 
            +
                        evaluation_tracker=evaluation_tracker,
         | 
| 402 | 
            +
                    )
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                if dist.get_rank(parallel_context.world_pg) == 0:
         | 
| 405 | 
            +
                    with htrack_block("Compiling and saving results"):
         | 
| 406 | 
            +
                        evaluation_tracker.general_config_logger.log_end_time()
         | 
| 407 | 
            +
                        evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict, bootstrap_iters=1000)
         | 
| 408 | 
            +
                        evaluation_tracker.details_logger.aggregate()
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                        if lighteval_config.logging.local_output_path:
         | 
| 411 | 
            +
                            evaluation_tracker.save(
         | 
| 412 | 
            +
                                output_dir=lighteval_config.logging.local_output_path,
         | 
| 413 | 
            +
                                push_results_to_hub=lighteval_config.logging.push_results_to_hub,
         | 
| 414 | 
            +
                                push_details_to_hub=lighteval_config.logging.push_details_to_hub,
         | 
| 415 | 
            +
                                public=False,
         | 
| 416 | 
            +
                                push_results_to_tensorboard=lighteval_config.logging.push_results_to_tensorboard,
         | 
| 417 | 
            +
                            )
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                        if lighteval_config.logging.push_results_to_tensorboard:
         | 
| 420 | 
            +
                            push_results_to_tensorboard(
         | 
| 421 | 
            +
                                config=brrr_config,
         | 
| 422 | 
            +
                                results=evaluation_tracker.metrics_logger.metric_aggregated,
         | 
| 423 | 
            +
                                details=evaluation_tracker.details_logger.details,
         | 
| 424 | 
            +
                            )
         | 
| 425 | 
            +
                        if lighteval_config.wandb is not None:
         | 
| 426 | 
            +
                            push_results_to_wandb(
         | 
| 427 | 
            +
                                config=brrr_config,
         | 
| 428 | 
            +
                                results=evaluation_tracker.metrics_logger.metric_aggregated,
         | 
| 429 | 
            +
                                details=evaluation_tracker.details_logger.details,
         | 
| 430 | 
            +
                            )
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                        final_dict = evaluation_tracker.generate_final_dict()
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    hlog(make_results_table(final_dict))
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                    return final_dict
         | 
| 437 | 
            +
             | 
| 438 | 
            +
             | 
| 439 | 
            +
            if __name__ == "__main__":
         | 
| 440 | 
            +
                parser = get_parser()
         | 
| 441 | 
            +
                args, unknowns = parser.parse_known_args()
         | 
| 442 | 
            +
                main(args)
         | 
    	
        run_train.py
    CHANGED
    
    | @@ -8,11 +8,11 @@ torchrun --nproc_per_node=8 run_train.py --config-file config_tiny_mistral.yaml | |
| 8 | 
             
            ```
         | 
| 9 | 
             
            """
         | 
| 10 | 
             
            import argparse
         | 
|  | |
| 11 |  | 
| 12 | 
            -
            from config_tiny_mistral import MistralConfig
         | 
| 13 | 
             
            from dataloader import get_dataloader
         | 
| 14 | 
             
            from modeling_mistral import MistralForTraining
         | 
| 15 | 
            -
            from  | 
| 16 |  | 
| 17 |  | 
| 18 | 
             
            def get_args():
         | 
|  | |
| 8 | 
             
            ```
         | 
| 9 | 
             
            """
         | 
| 10 | 
             
            import argparse
         | 
| 11 | 
            +
            from nanotron.trainer import DistributedTrainer
         | 
| 12 |  | 
|  | |
| 13 | 
             
            from dataloader import get_dataloader
         | 
| 14 | 
             
            from modeling_mistral import MistralForTraining
         | 
| 15 | 
            +
            from config_tiny_mistral import MistralConfig
         | 
| 16 |  | 
| 17 |  | 
| 18 | 
             
            def get_args():
         | 

