update all
Browse files- config_mistral_7b.py +2 -1
- config_mistral_7b.yaml +1 -1
- config_tiny_mistral.py +2 -1
- custom_evaluation_tasks.py +0 -650
- custom_evaluation_utils.py +0 -158
- lighteval_eval_config.yaml +6 -20
- modeling_mistral.py +1 -2
- pretrained/Mistral-7B-v0.1/config.yaml +1 -1
- run_evals.py +11 -394
- run_generate.py +2 -3
config_mistral_7b.py
CHANGED
|
@@ -66,7 +66,7 @@ PARALLELISM = ParallelismArgs(
|
|
| 66 |
)
|
| 67 |
|
| 68 |
CONFIG = Config(
|
| 69 |
-
general=GeneralArgs(project="mistralai", run="Mistral-7B-v0.1", seed=42),
|
| 70 |
checkpoints=None,
|
| 71 |
parallelism=PARALLELISM,
|
| 72 |
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=MODEL_CONFIG),
|
|
@@ -76,6 +76,7 @@ CONFIG = Config(
|
|
| 76 |
tokens=None,
|
| 77 |
data=None,
|
| 78 |
profiler=None,
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
if __name__ == "__main__":
|
|
|
|
| 66 |
)
|
| 67 |
|
| 68 |
CONFIG = Config(
|
| 69 |
+
general=GeneralArgs(project="mistralai", run="Mistral-7B-v0.1", seed=42, step=0),
|
| 70 |
checkpoints=None,
|
| 71 |
parallelism=PARALLELISM,
|
| 72 |
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=MODEL_CONFIG),
|
|
|
|
| 76 |
tokens=None,
|
| 77 |
data=None,
|
| 78 |
profiler=None,
|
| 79 |
+
lighteval=None,
|
| 80 |
)
|
| 81 |
|
| 82 |
if __name__ == "__main__":
|
config_mistral_7b.yaml
CHANGED
|
@@ -7,7 +7,7 @@ general:
|
|
| 7 |
project: mistralai
|
| 8 |
run: Mistral-7B-v0.1
|
| 9 |
seed: 42
|
| 10 |
-
step:
|
| 11 |
logging: null
|
| 12 |
model:
|
| 13 |
ddp_bucket_cap_mb: 25
|
|
|
|
| 7 |
project: mistralai
|
| 8 |
run: Mistral-7B-v0.1
|
| 9 |
seed: 42
|
| 10 |
+
step: 0
|
| 11 |
logging: null
|
| 12 |
model:
|
| 13 |
ddp_bucket_cap_mb: 25
|
config_tiny_mistral.py
CHANGED
|
@@ -92,7 +92,7 @@ checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints"
|
|
| 92 |
os.makedirs(checkpoints_path, exist_ok=True)
|
| 93 |
|
| 94 |
config = Config(
|
| 95 |
-
general=GeneralArgs(project="debug", run="tiny_mistral", seed=seed),
|
| 96 |
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10),
|
| 97 |
parallelism=parallelism,
|
| 98 |
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
|
|
@@ -102,6 +102,7 @@ config = Config(
|
|
| 102 |
tokens=tokens,
|
| 103 |
data=DataArgs(dataset=dataset, seed=seed),
|
| 104 |
profiler=None,
|
|
|
|
| 105 |
)
|
| 106 |
|
| 107 |
if __name__ == "__main__":
|
|
|
|
| 92 |
os.makedirs(checkpoints_path, exist_ok=True)
|
| 93 |
|
| 94 |
config = Config(
|
| 95 |
+
general=GeneralArgs(project="debug", run="tiny_mistral", seed=seed, step=0),
|
| 96 |
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10),
|
| 97 |
parallelism=parallelism,
|
| 98 |
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
|
|
|
|
| 102 |
tokens=tokens,
|
| 103 |
data=DataArgs(dataset=dataset, seed=seed),
|
| 104 |
profiler=None,
|
| 105 |
+
lighteval=None,
|
| 106 |
)
|
| 107 |
|
| 108 |
if __name__ == "__main__":
|
custom_evaluation_tasks.py
DELETED
|
@@ -1,650 +0,0 @@
|
|
| 1 |
-
# ruff: noqa: F405, F403, F401
|
| 2 |
-
"""
|
| 3 |
-
Custom evaluation tasks for lighteval
|
| 4 |
-
|
| 5 |
-
This file generally create just a TASKS_TABLE and TASKS_GROUPS which are then imported by LightEval.
|
| 6 |
-
"""
|
| 7 |
-
import re
|
| 8 |
-
from dataclasses import asdict
|
| 9 |
-
from typing import Dict, List, Tuple
|
| 10 |
-
|
| 11 |
-
from custom_evaluation_utils import *
|
| 12 |
-
from lighteval.tasks.requests import Doc
|
| 13 |
-
|
| 14 |
-
# fmt: off
|
| 15 |
-
LETTER_INDICES = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]
|
| 16 |
-
# fmt: on
|
| 17 |
-
|
| 18 |
-
_TASKS_STRINGS: List[Tuple[CustomEvaluationTask, str]] = []
|
| 19 |
-
_TASKS: List[CustomEvaluationTask] = []
|
| 20 |
-
|
| 21 |
-
## COMMON_SENSE_REASONING_TASKS ##
|
| 22 |
-
COMMON_SENSE_REASONING_TASKS = [
|
| 23 |
-
CustomEvaluationTask(
|
| 24 |
-
name="hellaswag",
|
| 25 |
-
prompt_function="hellaswag_prompt",
|
| 26 |
-
hf_repo="hellaswag",
|
| 27 |
-
hf_subset="default",
|
| 28 |
-
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
| 29 |
-
),
|
| 30 |
-
CustomEvaluationTask(
|
| 31 |
-
name="winogrande",
|
| 32 |
-
prompt_function="winogrande",
|
| 33 |
-
hf_repo="winogrande",
|
| 34 |
-
hf_subset="winogrande_xl",
|
| 35 |
-
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
| 36 |
-
),
|
| 37 |
-
CustomEvaluationTask(
|
| 38 |
-
name="piqa",
|
| 39 |
-
prompt_function="piqa_harness",
|
| 40 |
-
hf_repo="piqa",
|
| 41 |
-
hf_subset="plain_text",
|
| 42 |
-
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
| 43 |
-
),
|
| 44 |
-
CustomEvaluationTask(
|
| 45 |
-
name="siqa",
|
| 46 |
-
prompt_function="siqa_prompt",
|
| 47 |
-
hf_repo="lighteval/siqa",
|
| 48 |
-
hf_subset="default",
|
| 49 |
-
hf_avail_splits=["train", "validation"],
|
| 50 |
-
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
| 51 |
-
),
|
| 52 |
-
CustomEvaluationTask(
|
| 53 |
-
name="openbookqa",
|
| 54 |
-
prompt_function="openbookqa",
|
| 55 |
-
hf_repo="openbookqa",
|
| 56 |
-
hf_subset="main",
|
| 57 |
-
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
| 58 |
-
),
|
| 59 |
-
CustomEvaluationTask(
|
| 60 |
-
name="arc:easy",
|
| 61 |
-
prompt_function="arc",
|
| 62 |
-
hf_repo="ai2_arc",
|
| 63 |
-
hf_subset="ARC-Easy",
|
| 64 |
-
evaluation_splits=["test"],
|
| 65 |
-
generation_size=1,
|
| 66 |
-
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
| 67 |
-
),
|
| 68 |
-
CustomEvaluationTask(
|
| 69 |
-
name="arc:challenge",
|
| 70 |
-
prompt_function="arc",
|
| 71 |
-
hf_repo="ai2_arc",
|
| 72 |
-
hf_subset="ARC-Challenge",
|
| 73 |
-
evaluation_splits=["test"],
|
| 74 |
-
generation_size=1,
|
| 75 |
-
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
| 76 |
-
),
|
| 77 |
-
CustomEvaluationTask(
|
| 78 |
-
name="commonsense_qa",
|
| 79 |
-
prompt_function="commonsense_qa_prompt",
|
| 80 |
-
hf_repo="commonsense_qa",
|
| 81 |
-
hf_subset="default",
|
| 82 |
-
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
| 83 |
-
),
|
| 84 |
-
]
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def commonsense_qa_prompt(line, task_name: str = None):
|
| 88 |
-
return Doc(
|
| 89 |
-
task_name=task_name,
|
| 90 |
-
query=line["question"],
|
| 91 |
-
choices=[f" {c}" for c in line["choices"]["text"]],
|
| 92 |
-
gold_index=LETTER_INDICES.index(line["answerKey"].strip()),
|
| 93 |
-
instruction="",
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def siqa_prompt(line, task_name: str = None):
|
| 98 |
-
return Doc(
|
| 99 |
-
task_name=task_name,
|
| 100 |
-
query=line["context"] + " " + line["question"],
|
| 101 |
-
choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]],
|
| 102 |
-
gold_index=int(line["label"]) - 1,
|
| 103 |
-
instruction="",
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def hellaswag_prompt(line, task_name: str = None):
|
| 108 |
-
def preprocess(text):
|
| 109 |
-
"""Comes from AiHarness"""
|
| 110 |
-
# text = text.strip()
|
| 111 |
-
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
|
| 112 |
-
text = text.replace(" [title]", ". ")
|
| 113 |
-
text = re.sub("\\[.*?\\]", "", text)
|
| 114 |
-
text = text.replace(" ", " ")
|
| 115 |
-
return text
|
| 116 |
-
|
| 117 |
-
ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} "
|
| 118 |
-
return Doc(
|
| 119 |
-
task_name=task_name,
|
| 120 |
-
query=preprocess(line["activity_label"] + ": " + ctx),
|
| 121 |
-
choices=[" " + preprocess(ending) for ending in line["endings"]],
|
| 122 |
-
gold_index=int(line["label"]) if line["label"] != "" else -1, # -1 for test
|
| 123 |
-
# "metric": "choices_loglikelihood",
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
# 0 short for common sense
|
| 128 |
-
COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS]
|
| 129 |
-
_TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING)
|
| 130 |
-
_TASKS += COMMON_SENSE_REASONING_TASKS
|
| 131 |
-
|
| 132 |
-
## WORLD_KNOWLEDGE_TASKS ##
|
| 133 |
-
|
| 134 |
-
WORLD_KNOWLEDGE_TASKS = [
|
| 135 |
-
CustomEvaluationTask(
|
| 136 |
-
name="trivia_qa",
|
| 137 |
-
prompt_function="triviaqa",
|
| 138 |
-
hf_repo="trivia_qa",
|
| 139 |
-
hf_subset="rc.nocontext",
|
| 140 |
-
metric=[Metrics.quasi_exact_match2],
|
| 141 |
-
generation_size=20,
|
| 142 |
-
stop_sequence=["\n", ".", ","],
|
| 143 |
-
),
|
| 144 |
-
CustomEvaluationTask(
|
| 145 |
-
name="natural_questions",
|
| 146 |
-
prompt_function="natural_questions_prompt",
|
| 147 |
-
hf_repo="lighteval/natural_questions_clean",
|
| 148 |
-
hf_subset="default",
|
| 149 |
-
metric=[Metrics.quasi_exact_match2],
|
| 150 |
-
generation_size=20,
|
| 151 |
-
stop_sequence=["\n", ".", ","],
|
| 152 |
-
),
|
| 153 |
-
]
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def natural_questions_prompt(line, task_name: str = None):
|
| 157 |
-
return Doc(
|
| 158 |
-
task_name=task_name,
|
| 159 |
-
query=line["question"] + "?\nAnswer: ",
|
| 160 |
-
choices=[line["short_answers"]],
|
| 161 |
-
gold_index=0,
|
| 162 |
-
instruction="",
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
WORLD_KNOWLEDGE_STRING = [(t, f"custom|{t.name}|5|1") for t in WORLD_KNOWLEDGE_TASKS]
|
| 167 |
-
# WORLD_KNOWLEDGE_STRING = {t: f'custom|{t.name}|0|1' for t in WORLD_KNOWLEDGE_TASKS}
|
| 168 |
-
_TASKS_STRINGS.extend(WORLD_KNOWLEDGE_STRING)
|
| 169 |
-
_TASKS += WORLD_KNOWLEDGE_TASKS
|
| 170 |
-
|
| 171 |
-
## Reading comprehension ##
|
| 172 |
-
|
| 173 |
-
READING_COMP_TASKS = [
|
| 174 |
-
CustomEvaluationTask(
|
| 175 |
-
name="super_glue:boolq",
|
| 176 |
-
prompt_function="boolq_prompt",
|
| 177 |
-
hf_repo="super_glue",
|
| 178 |
-
hf_subset="boolq",
|
| 179 |
-
metric=[Metrics.target_perplexity],
|
| 180 |
-
),
|
| 181 |
-
CustomEvaluationTask(
|
| 182 |
-
name="quac",
|
| 183 |
-
prompt_function="quac",
|
| 184 |
-
hf_repo="lighteval/quac_helm",
|
| 185 |
-
hf_subset="default",
|
| 186 |
-
metric=[Metrics.quasi_exact_match2],
|
| 187 |
-
generation_size=20,
|
| 188 |
-
stop_sequence=["\n", ".", ","],
|
| 189 |
-
),
|
| 190 |
-
]
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
def boolq_prompt(line, task_name: str = None):
|
| 194 |
-
return Doc(
|
| 195 |
-
task_name=task_name,
|
| 196 |
-
query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:",
|
| 197 |
-
choices=[" No", " Yes"], # Only gold
|
| 198 |
-
gold_index=int(line["label"]),
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
READING_COMP_STRING = [(t, f"custom|{t.name}|0|1") for t in READING_COMP_TASKS]
|
| 203 |
-
_TASKS_STRINGS.extend(READING_COMP_STRING)
|
| 204 |
-
_TASKS += READING_COMP_TASKS
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
## MATH ##
|
| 208 |
-
class CustomMathEvaluationTask(CustomEvaluationTask):
|
| 209 |
-
"""Custom class for math tasks with all the defaults set"""
|
| 210 |
-
|
| 211 |
-
def __init__(
|
| 212 |
-
self,
|
| 213 |
-
name,
|
| 214 |
-
prompt_function="math",
|
| 215 |
-
hf_repo="lighteval/MATH",
|
| 216 |
-
hf_subset=None,
|
| 217 |
-
metric=[Metrics.math_quasi_exact_match],
|
| 218 |
-
hf_avail_splits=None,
|
| 219 |
-
evaluation_splits=["test"],
|
| 220 |
-
few_shots_split=None,
|
| 221 |
-
few_shots_select=None,
|
| 222 |
-
suite=["custom"],
|
| 223 |
-
generation_size=40,
|
| 224 |
-
stop_sequence=None,
|
| 225 |
-
output_regex=None,
|
| 226 |
-
frozen=False,
|
| 227 |
-
):
|
| 228 |
-
super().__init__(
|
| 229 |
-
name=name,
|
| 230 |
-
prompt_function=prompt_function,
|
| 231 |
-
hf_repo=hf_repo,
|
| 232 |
-
hf_subset=hf_subset,
|
| 233 |
-
metric=metric,
|
| 234 |
-
hf_avail_splits=hf_avail_splits,
|
| 235 |
-
evaluation_splits=evaluation_splits,
|
| 236 |
-
few_shots_split=few_shots_split,
|
| 237 |
-
few_shots_select=few_shots_select,
|
| 238 |
-
suite=suite,
|
| 239 |
-
generation_size=generation_size,
|
| 240 |
-
stop_sequence=stop_sequence,
|
| 241 |
-
output_regex=output_regex,
|
| 242 |
-
frozen=frozen,
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
MATH_TASKS = [
|
| 247 |
-
CustomMathEvaluationTask(name="math:algebra", hf_subset="algebra"),
|
| 248 |
-
CustomMathEvaluationTask(name="math:counting_and_probability", hf_subset="counting_and_probability"),
|
| 249 |
-
CustomMathEvaluationTask(name="math:geometry", hf_subset="geometry"),
|
| 250 |
-
CustomMathEvaluationTask(name="math:intermediate_algebra", hf_subset="intermediate_algebra"),
|
| 251 |
-
CustomMathEvaluationTask(name="math:number_theory", hf_subset="number_theory"),
|
| 252 |
-
CustomMathEvaluationTask(name="math:prealgebra", hf_subset="prealgebra"),
|
| 253 |
-
CustomMathEvaluationTask(name="math:precalculus", hf_subset="precalculus"),
|
| 254 |
-
]
|
| 255 |
-
GSM8K = CustomEvaluationTask(
|
| 256 |
-
name="gsm8k",
|
| 257 |
-
prompt_function="gsm8k",
|
| 258 |
-
hf_repo="gsm8k",
|
| 259 |
-
hf_subset="main",
|
| 260 |
-
hf_avail_splits=["train", "test"],
|
| 261 |
-
evaluation_splits=["test"],
|
| 262 |
-
metric=[Metrics.perfect_exact_match],
|
| 263 |
-
generation_size=10,
|
| 264 |
-
stop_sequence=["\n"],
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
MATH_STRING = [(t, f"custom|{t.name}|4|1") for t in MATH_TASKS]
|
| 269 |
-
GSM8K_STRING = [(GSM8K, f"custom|{GSM8K.name}|8|1")]
|
| 270 |
-
_TASKS_STRINGS.extend(MATH_STRING)
|
| 271 |
-
_TASKS_STRINGS.extend(GSM8K_STRING)
|
| 272 |
-
_TASKS += MATH_TASKS + [GSM8K]
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
## MMLU ##
|
| 276 |
-
class CustomMMLUEvaluationTask(CustomEvaluationTask):
|
| 277 |
-
def __init__(
|
| 278 |
-
self,
|
| 279 |
-
name,
|
| 280 |
-
prompt_function="mmlu_prompt",
|
| 281 |
-
hf_repo="lighteval/mmlu",
|
| 282 |
-
hf_subset=None,
|
| 283 |
-
# metric=[Metrics.loglikelihood_acc_single_token],
|
| 284 |
-
metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
|
| 285 |
-
hf_avail_splits=None,
|
| 286 |
-
evaluation_splits=["test"],
|
| 287 |
-
few_shots_split="dev",
|
| 288 |
-
few_shots_select=None,
|
| 289 |
-
suite=None,
|
| 290 |
-
generation_size=-1,
|
| 291 |
-
stop_sequence=None,
|
| 292 |
-
output_regex=None,
|
| 293 |
-
frozen=False,
|
| 294 |
-
):
|
| 295 |
-
super().__init__(
|
| 296 |
-
name=name,
|
| 297 |
-
prompt_function=prompt_function,
|
| 298 |
-
hf_repo=hf_repo,
|
| 299 |
-
hf_subset=hf_subset,
|
| 300 |
-
metric=metric,
|
| 301 |
-
hf_avail_splits=hf_avail_splits,
|
| 302 |
-
evaluation_splits=evaluation_splits,
|
| 303 |
-
few_shots_split=few_shots_split,
|
| 304 |
-
few_shots_select=few_shots_select,
|
| 305 |
-
suite=suite,
|
| 306 |
-
generation_size=generation_size,
|
| 307 |
-
stop_sequence=stop_sequence,
|
| 308 |
-
output_regex=output_regex,
|
| 309 |
-
frozen=frozen,
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
MMLU_TASKS = [
|
| 314 |
-
CustomMMLUEvaluationTask(name="mmlu:abstract_algebra", hf_subset="abstract_algebra"),
|
| 315 |
-
CustomMMLUEvaluationTask(name="mmlu:anatomy", hf_subset="anatomy"),
|
| 316 |
-
CustomMMLUEvaluationTask(name="mmlu:astronomy", hf_subset="astronomy"),
|
| 317 |
-
CustomMMLUEvaluationTask(name="mmlu:business_ethics", hf_subset="business_ethics"),
|
| 318 |
-
CustomMMLUEvaluationTask(name="mmlu:clinical_knowledge", hf_subset="clinical_knowledge"),
|
| 319 |
-
CustomMMLUEvaluationTask(name="mmlu:college_biology", hf_subset="college_biology"),
|
| 320 |
-
CustomMMLUEvaluationTask(name="mmlu:college_chemistry", hf_subset="college_chemistry"),
|
| 321 |
-
CustomMMLUEvaluationTask(name="mmlu:college_computer_science", hf_subset="college_computer_science"),
|
| 322 |
-
CustomMMLUEvaluationTask(name="mmlu:college_mathematics", hf_subset="college_mathematics"),
|
| 323 |
-
CustomMMLUEvaluationTask(name="mmlu:college_medicine", hf_subset="college_medicine"),
|
| 324 |
-
CustomMMLUEvaluationTask(name="mmlu:college_physics", hf_subset="college_physics"),
|
| 325 |
-
CustomMMLUEvaluationTask(name="mmlu:computer_security", hf_subset="computer_security"),
|
| 326 |
-
CustomMMLUEvaluationTask(name="mmlu:conceptual_physics", hf_subset="conceptual_physics"),
|
| 327 |
-
CustomMMLUEvaluationTask(name="mmlu:econometrics", hf_subset="econometrics"),
|
| 328 |
-
CustomMMLUEvaluationTask(name="mmlu:electrical_engineering", hf_subset="electrical_engineering"),
|
| 329 |
-
CustomMMLUEvaluationTask(name="mmlu:elementary_mathematics", hf_subset="elementary_mathematics"),
|
| 330 |
-
CustomMMLUEvaluationTask(name="mmlu:formal_logic", hf_subset="formal_logic"),
|
| 331 |
-
CustomMMLUEvaluationTask(name="mmlu:global_facts", hf_subset="global_facts"),
|
| 332 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_biology", hf_subset="high_school_biology"),
|
| 333 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_chemistry", hf_subset="high_school_chemistry"),
|
| 334 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_computer_science", hf_subset="high_school_computer_science"),
|
| 335 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_european_history", hf_subset="high_school_european_history"),
|
| 336 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_geography", hf_subset="high_school_geography"),
|
| 337 |
-
CustomMMLUEvaluationTask(
|
| 338 |
-
name="mmlu:high_school_government_and_politics", hf_subset="high_school_government_and_politics"
|
| 339 |
-
),
|
| 340 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_macroeconomics", hf_subset="high_school_macroeconomics"),
|
| 341 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_mathematics", hf_subset="high_school_mathematics"),
|
| 342 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_microeconomics", hf_subset="high_school_microeconomics"),
|
| 343 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_physics", hf_subset="high_school_physics"),
|
| 344 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_psychology", hf_subset="high_school_psychology"),
|
| 345 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_statistics", hf_subset="high_school_statistics"),
|
| 346 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_us_history", hf_subset="high_school_us_history"),
|
| 347 |
-
CustomMMLUEvaluationTask(name="mmlu:high_school_world_history", hf_subset="high_school_world_history"),
|
| 348 |
-
CustomMMLUEvaluationTask(name="mmlu:human_aging", hf_subset="human_aging"),
|
| 349 |
-
CustomMMLUEvaluationTask(name="mmlu:human_sexuality", hf_subset="human_sexuality"),
|
| 350 |
-
CustomMMLUEvaluationTask(name="mmlu:international_law", hf_subset="international_law"),
|
| 351 |
-
CustomMMLUEvaluationTask(name="mmlu:jurisprudence", hf_subset="jurisprudence"),
|
| 352 |
-
CustomMMLUEvaluationTask(name="mmlu:logical_fallacies", hf_subset="logical_fallacies"),
|
| 353 |
-
CustomMMLUEvaluationTask(name="mmlu:machine_learning", hf_subset="machine_learning"),
|
| 354 |
-
CustomMMLUEvaluationTask(name="mmlu:management", hf_subset="management"),
|
| 355 |
-
CustomMMLUEvaluationTask(name="mmlu:marketing", hf_subset="marketing"),
|
| 356 |
-
CustomMMLUEvaluationTask(name="mmlu:medical_genetics", hf_subset="medical_genetics"),
|
| 357 |
-
CustomMMLUEvaluationTask(name="mmlu:miscellaneous", hf_subset="miscellaneous"),
|
| 358 |
-
CustomMMLUEvaluationTask(name="mmlu:moral_disputes", hf_subset="moral_disputes"),
|
| 359 |
-
CustomMMLUEvaluationTask(name="mmlu:moral_scenarios", hf_subset="moral_scenarios"),
|
| 360 |
-
CustomMMLUEvaluationTask(name="mmlu:nutrition", hf_subset="nutrition"),
|
| 361 |
-
CustomMMLUEvaluationTask(name="mmlu:philosophy", hf_subset="philosophy"),
|
| 362 |
-
CustomMMLUEvaluationTask(name="mmlu:prehistory", hf_subset="prehistory"),
|
| 363 |
-
CustomMMLUEvaluationTask(name="mmlu:professional_accounting", hf_subset="professional_accounting"),
|
| 364 |
-
CustomMMLUEvaluationTask(name="mmlu:professional_law", hf_subset="professional_law"),
|
| 365 |
-
CustomMMLUEvaluationTask(name="mmlu:professional_medicine", hf_subset="professional_medicine"),
|
| 366 |
-
CustomMMLUEvaluationTask(name="mmlu:professional_psychology", hf_subset="professional_psychology"),
|
| 367 |
-
CustomMMLUEvaluationTask(name="mmlu:public_relations", hf_subset="public_relations"),
|
| 368 |
-
CustomMMLUEvaluationTask(name="mmlu:security_studies", hf_subset="security_studies"),
|
| 369 |
-
CustomMMLUEvaluationTask(name="mmlu:sociology", hf_subset="sociology"),
|
| 370 |
-
CustomMMLUEvaluationTask(name="mmlu:us_foreign_policy", hf_subset="us_foreign_policy"),
|
| 371 |
-
CustomMMLUEvaluationTask(name="mmlu:virology", hf_subset="virology"),
|
| 372 |
-
CustomMMLUEvaluationTask(name="mmlu:world_religions", hf_subset="world_religions"),
|
| 373 |
-
]
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
def mmlu_harness(line, task_name: str = None):
|
| 377 |
-
topic = line["subject"]
|
| 378 |
-
prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n"
|
| 379 |
-
prompt += line["question"] + "\n"
|
| 380 |
-
prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])])
|
| 381 |
-
prompt += "Answer:"
|
| 382 |
-
|
| 383 |
-
gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
|
| 384 |
-
"__few_shots" in line and line["__few_shots"] is True # We are adding few shots
|
| 385 |
-
|
| 386 |
-
return Doc(
|
| 387 |
-
task_name=task_name,
|
| 388 |
-
query=prompt,
|
| 389 |
-
choices=[" A", " B", " C", " D"],
|
| 390 |
-
target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix],
|
| 391 |
-
gold_index=gold_ix,
|
| 392 |
-
instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n",
|
| 393 |
-
)
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
def mmlu_prompt(line, task_name: str = None):
|
| 397 |
-
"""MMLU prompt without letters"""
|
| 398 |
-
topic = line["subject"]
|
| 399 |
-
prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: "
|
| 400 |
-
prompt += line["question"] + "\nAnswer:"
|
| 401 |
-
|
| 402 |
-
return Doc(
|
| 403 |
-
task_name=task_name,
|
| 404 |
-
query=prompt,
|
| 405 |
-
choices=[f" {c}" for c in line["choices"]],
|
| 406 |
-
gold_index=line["answer"],
|
| 407 |
-
instruction=f"The following are questions about {topic.replace('_', ' ')}.\n",
|
| 408 |
-
)
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
# MMLU_STRING = {t: f'custom|{t.name}|5|1' for t in MMLU_TASKS}
|
| 412 |
-
MMLU_STRING = [(t, f"custom|{t.name}|0|1") for t in MMLU_TASKS]
|
| 413 |
-
_TASKS_STRINGS.extend(MMLU_STRING)
|
| 414 |
-
_TASKS += MMLU_TASKS
|
| 415 |
-
|
| 416 |
-
## BBH ##
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
class CustomBBHEvaluationTask(CustomEvaluationTask):
|
| 420 |
-
def __init__(
|
| 421 |
-
self,
|
| 422 |
-
name,
|
| 423 |
-
prompt_function="bbh_prompt",
|
| 424 |
-
hf_repo="lighteval/big_bench_hard",
|
| 425 |
-
hf_subset=None,
|
| 426 |
-
metric=[Metrics.exact_match],
|
| 427 |
-
hf_avail_splits=["train"],
|
| 428 |
-
evaluation_splits=["train"],
|
| 429 |
-
few_shots_split="train",
|
| 430 |
-
few_shots_select=None,
|
| 431 |
-
suite=None,
|
| 432 |
-
generation_size=4,
|
| 433 |
-
stop_sequence=None,
|
| 434 |
-
output_regex=None,
|
| 435 |
-
frozen=False,
|
| 436 |
-
):
|
| 437 |
-
super().__init__(
|
| 438 |
-
name=name,
|
| 439 |
-
prompt_function=prompt_function,
|
| 440 |
-
hf_repo=hf_repo,
|
| 441 |
-
hf_subset=hf_subset,
|
| 442 |
-
metric=metric,
|
| 443 |
-
hf_avail_splits=hf_avail_splits,
|
| 444 |
-
evaluation_splits=evaluation_splits,
|
| 445 |
-
few_shots_split=few_shots_split,
|
| 446 |
-
few_shots_select=few_shots_select,
|
| 447 |
-
suite=suite,
|
| 448 |
-
generation_size=generation_size,
|
| 449 |
-
stop_sequence=stop_sequence,
|
| 450 |
-
output_regex=output_regex,
|
| 451 |
-
frozen=frozen,
|
| 452 |
-
)
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
BBH_TASKS = [
|
| 456 |
-
CustomBBHEvaluationTask(name="bbh:boolean_expressions", hf_subset="boolean_expressions"),
|
| 457 |
-
CustomBBHEvaluationTask(name="bbh:causal_judgement", hf_subset="causal_judgement"),
|
| 458 |
-
CustomBBHEvaluationTask(name="bbh:date_understanding", hf_subset="date_understanding"),
|
| 459 |
-
CustomBBHEvaluationTask(name="bbh:disambiguation_qa", hf_subset="disambiguation_qa"),
|
| 460 |
-
CustomBBHEvaluationTask(name="bbh:dyck_languages", hf_subset="dyck_languages"),
|
| 461 |
-
CustomBBHEvaluationTask(name="bbh:formal_fallacies", hf_subset="formal_fallacies"),
|
| 462 |
-
CustomBBHEvaluationTask(name="bbh:geometric_shapes", hf_subset="geometric_shapes"),
|
| 463 |
-
CustomBBHEvaluationTask(name="bbh:hyperbaton", hf_subset="hyperbaton"),
|
| 464 |
-
CustomBBHEvaluationTask(name="bbh:logical_deduction_five_objects", hf_subset="logical_deduction_five_objects"),
|
| 465 |
-
CustomBBHEvaluationTask(name="bbh:logical_deduction_seven_objects", hf_subset="logical_deduction_seven_objects"),
|
| 466 |
-
CustomBBHEvaluationTask(name="bbh:logical_deduction_three_objects", hf_subset="logical_deduction_three_objects"),
|
| 467 |
-
CustomBBHEvaluationTask(name="bbh:movie_recommendation", hf_subset="movie_recommendation"),
|
| 468 |
-
CustomBBHEvaluationTask(name="bbh:multistep_arithmetic_two", hf_subset="multistep_arithmetic_two"),
|
| 469 |
-
CustomBBHEvaluationTask(name="bbh:navigate", hf_subset="navigate"),
|
| 470 |
-
CustomBBHEvaluationTask(name="bbh:object_counting", hf_subset="object_counting"),
|
| 471 |
-
CustomBBHEvaluationTask(name="bbh:penguins_in_a_table", hf_subset="penguins_in_a_table"),
|
| 472 |
-
CustomBBHEvaluationTask(name="bbh:reasoning_about_colored_objects", hf_subset="reasoning_about_colored_objects"),
|
| 473 |
-
CustomBBHEvaluationTask(name="bbh:ruin_names", hf_subset="ruin_names"),
|
| 474 |
-
CustomBBHEvaluationTask(
|
| 475 |
-
name="bbh:salient_translation_error_detection", hf_subset="salient_translation_error_detection"
|
| 476 |
-
),
|
| 477 |
-
CustomBBHEvaluationTask(name="bbh:snarks", hf_subset="snarks"),
|
| 478 |
-
CustomBBHEvaluationTask(name="bbh:sports_understanding", hf_subset="sports_understanding"),
|
| 479 |
-
CustomBBHEvaluationTask(name="bbh:temporal_sequences", hf_subset="temporal_sequences"),
|
| 480 |
-
CustomBBHEvaluationTask(
|
| 481 |
-
name="bbh:tracking_shuffled_objects_five_objects", hf_subset="tracking_shuffled_objects_five_objects"
|
| 482 |
-
),
|
| 483 |
-
CustomBBHEvaluationTask(
|
| 484 |
-
name="bbh:tracking_shuffled_objects_seven_objects", hf_subset="tracking_shuffled_objects_seven_objects"
|
| 485 |
-
),
|
| 486 |
-
CustomBBHEvaluationTask(
|
| 487 |
-
name="bbh:tracking_shuffled_objects_three_objects", hf_subset="tracking_shuffled_objects_three_objects"
|
| 488 |
-
),
|
| 489 |
-
CustomBBHEvaluationTask(name="bbh:web_of_lies", hf_subset="web_of_lies"),
|
| 490 |
-
CustomBBHEvaluationTask(name="bbh:word_sorting", hf_subset="word_sorting"),
|
| 491 |
-
]
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
def bbh_prompt(line, task_name: str = None):
|
| 495 |
-
return Doc(
|
| 496 |
-
task_name=task_name,
|
| 497 |
-
query=line["input"] + "\nAnswer: ",
|
| 498 |
-
choices=[line["target"]],
|
| 499 |
-
gold_index=0,
|
| 500 |
-
)
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
# BBH_STRING = {t: f'custom|{t.name}|3|1' for t in BBH_TASKS}
|
| 504 |
-
BBH_STRING = [(t, f"custom|{t.name}|0|1") for t in BBH_TASKS]
|
| 505 |
-
_TASKS_STRINGS.extend(BBH_STRING)
|
| 506 |
-
_TASKS += BBH_TASKS
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
## AGI eval ##
|
| 510 |
-
class CustomAGIEvalEvaluationTask(CustomEvaluationTask):
|
| 511 |
-
def __init__(
|
| 512 |
-
self,
|
| 513 |
-
name,
|
| 514 |
-
prompt_function="agi_eval_prompt_no_letters",
|
| 515 |
-
hf_repo="lighteval/agi_eval_en",
|
| 516 |
-
hf_subset=None,
|
| 517 |
-
# metric=[Metrics.loglikelihood_acc_single_token],
|
| 518 |
-
metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
|
| 519 |
-
hf_avail_splits=["train", "validation"],
|
| 520 |
-
evaluation_splits=["train"],
|
| 521 |
-
few_shots_split="validation",
|
| 522 |
-
few_shots_select=None,
|
| 523 |
-
suite=None,
|
| 524 |
-
generation_size=-1,
|
| 525 |
-
stop_sequence=None,
|
| 526 |
-
output_regex=None,
|
| 527 |
-
frozen=False,
|
| 528 |
-
):
|
| 529 |
-
super().__init__(
|
| 530 |
-
name=name,
|
| 531 |
-
prompt_function=prompt_function,
|
| 532 |
-
hf_repo=hf_repo,
|
| 533 |
-
hf_subset=hf_subset,
|
| 534 |
-
metric=metric,
|
| 535 |
-
hf_avail_splits=hf_avail_splits,
|
| 536 |
-
evaluation_splits=evaluation_splits,
|
| 537 |
-
few_shots_split=few_shots_split,
|
| 538 |
-
few_shots_select=few_shots_select,
|
| 539 |
-
suite=suite,
|
| 540 |
-
generation_size=generation_size,
|
| 541 |
-
stop_sequence=stop_sequence,
|
| 542 |
-
output_regex=output_regex,
|
| 543 |
-
frozen=frozen,
|
| 544 |
-
)
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
AGIEVAL_TASKS = [
|
| 548 |
-
CustomAGIEvalEvaluationTask(name="agi_eval:aqua_rat", hf_subset="aqua_rat"),
|
| 549 |
-
CustomAGIEvalEvaluationTask(name="agi_eval:logiqa-en", hf_subset="logiqa-en"),
|
| 550 |
-
CustomAGIEvalEvaluationTask(name="agi_eval:lsat-ar", hf_subset="lsat-ar"),
|
| 551 |
-
CustomAGIEvalEvaluationTask(name="agi_eval:lsat-lr", hf_subset="lsat-lr"),
|
| 552 |
-
CustomAGIEvalEvaluationTask(name="agi_eval:lsat-rc", hf_subset="lsat-rc"),
|
| 553 |
-
CustomAGIEvalEvaluationTask(
|
| 554 |
-
name="agi_eval:math",
|
| 555 |
-
hf_subset="math",
|
| 556 |
-
prompt_function="agi_eval_math_prompt",
|
| 557 |
-
metric=[Metrics.exact_match, Metrics.quasi_exact_match2],
|
| 558 |
-
generation_size=40,
|
| 559 |
-
),
|
| 560 |
-
CustomAGIEvalEvaluationTask(name="agi_eval:sat-en", hf_subset="sat-en"),
|
| 561 |
-
CustomAGIEvalEvaluationTask(name="agi_eval:sat-math", hf_subset="sat-math"),
|
| 562 |
-
]
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
def agi_eval_math_prompt(line, task_name: str = None):
|
| 566 |
-
return Doc(
|
| 567 |
-
task_name=task_name,
|
| 568 |
-
query=line["question"],
|
| 569 |
-
choices=[line["answer"]],
|
| 570 |
-
gold_index=0,
|
| 571 |
-
instruction="",
|
| 572 |
-
)
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
def agi_eval_prompt(line, task_name: str = None):
|
| 576 |
-
cleaned_options = [o.replace("(", "").replace(")", " ") for o in line["options"]]
|
| 577 |
-
prompt = "The following are multiple choice questions (with answers).\n\n"
|
| 578 |
-
prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n"
|
| 579 |
-
prompt += "Answer: "
|
| 580 |
-
|
| 581 |
-
choices = LETTER_INDICES[: len(line["options"])]
|
| 582 |
-
|
| 583 |
-
output = Doc(
|
| 584 |
-
query=prompt,
|
| 585 |
-
instruction="The following are multiple choice questions (with answers).\n\n",
|
| 586 |
-
)
|
| 587 |
-
|
| 588 |
-
if line["label"]:
|
| 589 |
-
output.choices = choices
|
| 590 |
-
output.gold_index = LETTER_INDICES.index(line["label"].strip())
|
| 591 |
-
else:
|
| 592 |
-
output.choices = [line["answer"]]
|
| 593 |
-
output.gold_index = 0
|
| 594 |
-
|
| 595 |
-
return output
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
def agi_eval_prompt_no_letters(line, task_name: str = None):
|
| 599 |
-
cleaned_options = [
|
| 600 |
-
" " + o.replace("(A)", "").replace("(B)", "").replace("(C)", "").replace("(D)", "").replace("(E)", "")
|
| 601 |
-
for o in line["options"]
|
| 602 |
-
]
|
| 603 |
-
|
| 604 |
-
output = Doc(
|
| 605 |
-
query=line["question"],
|
| 606 |
-
choices=cleaned_options,
|
| 607 |
-
gold_index=LETTER_INDICES.index(line["label"].strip()),
|
| 608 |
-
instruction="",
|
| 609 |
-
)
|
| 610 |
-
|
| 611 |
-
return output
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
# AGIEVAL_STRING = {t: f'custom|{t.name}|5|1' for t in AGIEVAL_TASKS}
|
| 615 |
-
AGIEVAL_STRING = [(t, f"custom|{t.name}|0|1") for t in AGIEVAL_TASKS]
|
| 616 |
-
_TASKS_STRINGS.extend(AGIEVAL_STRING)
|
| 617 |
-
_TASKS += AGIEVAL_TASKS
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
## HUMAN EVAL ##
|
| 621 |
-
# human_eval = CustomEvaluationTask(
|
| 622 |
-
# name="human_eval",
|
| 623 |
-
# prompt_function="human_eval",
|
| 624 |
-
# hf_repo="lighteval/human_eval",
|
| 625 |
-
# metric=["human_eval_pass_at_1"],
|
| 626 |
-
# ),
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
def has_generative_metrics(task: CustomEvaluationTask) -> bool:
|
| 630 |
-
for metric in task.metric:
|
| 631 |
-
if metric in NEEDS_GENERATION_ONLY:
|
| 632 |
-
return True
|
| 633 |
-
return False
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING])
|
| 637 |
-
|
| 638 |
-
# Convert to dict for lighteval
|
| 639 |
-
TASKS_TABLE = [asdict(task) for task in _TASKS]
|
| 640 |
-
# You can have a few pre-organised groups of tasks
|
| 641 |
-
TASKS_GROUPS = {
|
| 642 |
-
"all": ",".join(t[1] for t in _TASKS_STRINGS),
|
| 643 |
-
"early-signal": EARLY_SIGNAL_TASKS,
|
| 644 |
-
"non-generatives": ",".join(t for k, t in _TASKS_STRINGS if not has_generative_metrics(k)),
|
| 645 |
-
"generatives": ",".join(t for k, t in _TASKS_STRINGS if has_generative_metrics(k)),
|
| 646 |
-
}
|
| 647 |
-
|
| 648 |
-
if __name__ == "__main__":
|
| 649 |
-
print(t["name"] for t in TASKS_TABLE)
|
| 650 |
-
print(len(TASKS_TABLE))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_evaluation_utils.py
DELETED
|
@@ -1,158 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Custom evaluation tasks for lighteval
|
| 3 |
-
"""
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
-
from enum import Enum, auto
|
| 6 |
-
from typing import Optional, Tuple, Union
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class Metrics(Enum):
|
| 10 |
-
any_target_loglikelihood_acc = auto()
|
| 11 |
-
bert_score = auto()
|
| 12 |
-
bias = auto()
|
| 13 |
-
bits_per_byte = auto()
|
| 14 |
-
bleu = auto()
|
| 15 |
-
bleu_1 = auto()
|
| 16 |
-
bleu_4 = auto()
|
| 17 |
-
byte_perplexity = auto()
|
| 18 |
-
chrf = auto()
|
| 19 |
-
code_eval_APPS = auto()
|
| 20 |
-
code_eval_HE = auto()
|
| 21 |
-
copyright = auto()
|
| 22 |
-
disinformation = auto()
|
| 23 |
-
exact_match = auto()
|
| 24 |
-
exact_set_match = auto()
|
| 25 |
-
extractiveness = auto()
|
| 26 |
-
f1_from_bags = auto()
|
| 27 |
-
f1_quasi = auto()
|
| 28 |
-
f1_sequence = auto()
|
| 29 |
-
f1_set_match = auto()
|
| 30 |
-
faithfulness = auto()
|
| 31 |
-
iou_set_match = auto()
|
| 32 |
-
log_prob = auto()
|
| 33 |
-
loglikelihood_acc = auto()
|
| 34 |
-
loglikelihood_acc_norm = auto()
|
| 35 |
-
loglikelihood_acc_norm_nospace = auto()
|
| 36 |
-
loglikelihood_acc_norm_single_token = auto()
|
| 37 |
-
loglikelihood_acc_single_token = auto()
|
| 38 |
-
loglikelihood_f1 = auto()
|
| 39 |
-
loglikelihood_f1_single_token = auto()
|
| 40 |
-
math_quasi_exact_match = auto()
|
| 41 |
-
mc_taco = auto()
|
| 42 |
-
mcc = auto()
|
| 43 |
-
mcc_single_token = auto()
|
| 44 |
-
mrr = auto()
|
| 45 |
-
mrr_single_token = auto()
|
| 46 |
-
multi_fi_numeric = auto()
|
| 47 |
-
one_choice_loglikelihood_acc = auto()
|
| 48 |
-
perfect_exact_match = auto()
|
| 49 |
-
prediction_perplexity = auto()
|
| 50 |
-
prefix_exact_match = auto()
|
| 51 |
-
prefix_quasi_exact_match = auto()
|
| 52 |
-
quasi_exact_match = auto()
|
| 53 |
-
ranking = auto()
|
| 54 |
-
recall_at_1_single_token = auto()
|
| 55 |
-
recall_at_2_single_token = auto()
|
| 56 |
-
recall_at_1 = auto()
|
| 57 |
-
recall_at_2 = auto()
|
| 58 |
-
rouge = auto()
|
| 59 |
-
rouge_1 = auto()
|
| 60 |
-
rouge_2 = auto()
|
| 61 |
-
rouge_l = auto()
|
| 62 |
-
target_perplexity = auto()
|
| 63 |
-
ter = auto()
|
| 64 |
-
toxicity = auto()
|
| 65 |
-
truthfulqa_mc_metrics = auto()
|
| 66 |
-
word_perplexity = auto()
|
| 67 |
-
|
| 68 |
-
def __str__(self):
|
| 69 |
-
return self.name.replace("_at_", "@")
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
NEEDS_GENERATION_ONLY = [
|
| 73 |
-
"perfect_exact_match",
|
| 74 |
-
"exact_match",
|
| 75 |
-
"quasi_exact_match",
|
| 76 |
-
"quasi_exact_match2",
|
| 77 |
-
"prefix_exact_match",
|
| 78 |
-
"prefix_quasi_exact_match",
|
| 79 |
-
"math_quasi_exact_match",
|
| 80 |
-
"iou_set_match",
|
| 81 |
-
"exact_set_match",
|
| 82 |
-
"f1_sequence",
|
| 83 |
-
"f1_quasi",
|
| 84 |
-
"f1_set_match",
|
| 85 |
-
"f1_from_bags",
|
| 86 |
-
"chrf",
|
| 87 |
-
"ter",
|
| 88 |
-
"rouge",
|
| 89 |
-
"rouge_1",
|
| 90 |
-
"rouge_2",
|
| 91 |
-
"rouge_l",
|
| 92 |
-
"faithfulness",
|
| 93 |
-
"extractiveness",
|
| 94 |
-
"bert_score",
|
| 95 |
-
"bleu",
|
| 96 |
-
"bleu_1",
|
| 97 |
-
"bleu_4",
|
| 98 |
-
"bias",
|
| 99 |
-
"toxicity",
|
| 100 |
-
"code_eval_HE",
|
| 101 |
-
"code_eval_APPS",
|
| 102 |
-
"copyright",
|
| 103 |
-
]
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
@dataclass(unsafe_hash=True)
|
| 107 |
-
class CustomEvaluationTask:
|
| 108 |
-
name: str
|
| 109 |
-
prompt_function: str
|
| 110 |
-
hf_repo: str
|
| 111 |
-
hf_subset: str
|
| 112 |
-
metric: Tuple[Union[str, Metrics]]
|
| 113 |
-
hf_avail_splits: Optional[Tuple[str]] = None
|
| 114 |
-
evaluation_splits: Optional[Tuple[str]] = None
|
| 115 |
-
few_shots_split: Optional[str] = None
|
| 116 |
-
few_shots_select: Optional[str] = None
|
| 117 |
-
generation_size: int = -1
|
| 118 |
-
stop_sequence: Optional[Tuple[str]] = None
|
| 119 |
-
output_regex: Optional[str] = None
|
| 120 |
-
|
| 121 |
-
frozen: bool = False
|
| 122 |
-
suite: Optional[Tuple[str]] = None # we use this to know if we should use a custom lighteval or bigcode task
|
| 123 |
-
|
| 124 |
-
def __post_init__(self):
|
| 125 |
-
self.metric = [str(m) for m in self.metric]
|
| 126 |
-
if self.suite is None:
|
| 127 |
-
self.suite = ["custom"]
|
| 128 |
-
if self.hf_avail_splits is None:
|
| 129 |
-
self.hf_avail_splits = ["train", "validation", "test"]
|
| 130 |
-
if self.evaluation_splits is None:
|
| 131 |
-
self.evaluation_splits = ["validation"]
|
| 132 |
-
if self.stop_sequence is None:
|
| 133 |
-
self.stop_sequence = ["\n"]
|
| 134 |
-
|
| 135 |
-
# Convert list to tuple for hashing
|
| 136 |
-
self.metric = tuple(self.metric)
|
| 137 |
-
self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits else None
|
| 138 |
-
self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits else None
|
| 139 |
-
self.suite = tuple(self.suite) if self.suite else None
|
| 140 |
-
self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence else None
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
@dataclass(unsafe_hash=True)
|
| 144 |
-
class BigCodeEvaluationTask:
|
| 145 |
-
name: str
|
| 146 |
-
bigcode_task: str
|
| 147 |
-
bigcode_task_kwargs: Optional[dict] = None
|
| 148 |
-
n_samples: int = 1
|
| 149 |
-
prefix: Optional[str] = None
|
| 150 |
-
|
| 151 |
-
suite: Tuple[str] = None
|
| 152 |
-
|
| 153 |
-
def __post_init__(self):
|
| 154 |
-
if self.suite is None:
|
| 155 |
-
self.suite = ("bigcode",)
|
| 156 |
-
|
| 157 |
-
# Convert list to tuple for hashing
|
| 158 |
-
self.suite = tuple(self.suite)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lighteval_eval_config.yaml
CHANGED
|
@@ -1,10 +1,5 @@
|
|
| 1 |
-
checkpoints: null
|
| 2 |
-
data: null
|
| 3 |
-
experiment_logger: null
|
| 4 |
-
general: null
|
| 5 |
-
kill_switch_path: null
|
| 6 |
lighteval:
|
| 7 |
-
batch_size:
|
| 8 |
checkpoints_path: null
|
| 9 |
generation: null
|
| 10 |
logging:
|
|
@@ -17,29 +12,20 @@ lighteval:
|
|
| 17 |
push_results_to_tensorboard: true
|
| 18 |
tensorboard_metric_prefix: e
|
| 19 |
parallelism:
|
| 20 |
-
dp:
|
| 21 |
pp: 1
|
| 22 |
pp_engine: 1f1b
|
| 23 |
recompute_granularity: null
|
| 24 |
-
tp:
|
| 25 |
tp_linear_async_communication: false
|
| 26 |
tp_mode: ALL_REDUCE
|
| 27 |
-
slurm: null
|
| 28 |
slurm_script_dir: null
|
| 29 |
slurm_template: null
|
| 30 |
tasks:
|
| 31 |
-
|
| 32 |
dataset_loading_processes: 8
|
| 33 |
-
max_samples:
|
| 34 |
multichoice_continuations_start_space: null
|
| 35 |
no_multichoice_continuations_start_space: null
|
| 36 |
num_fewshot_seeds: null
|
| 37 |
-
tasks:
|
| 38 |
-
logging: null
|
| 39 |
-
model: null
|
| 40 |
-
optimizer: null
|
| 41 |
-
parallelism: null
|
| 42 |
-
profiler: null
|
| 43 |
-
s3_upload: null
|
| 44 |
-
tokenizer: null
|
| 45 |
-
tokens: null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
lighteval:
|
| 2 |
+
batch_size: 4
|
| 3 |
checkpoints_path: null
|
| 4 |
generation: null
|
| 5 |
logging:
|
|
|
|
| 12 |
push_results_to_tensorboard: true
|
| 13 |
tensorboard_metric_prefix: e
|
| 14 |
parallelism:
|
| 15 |
+
dp: 8
|
| 16 |
pp: 1
|
| 17 |
pp_engine: 1f1b
|
| 18 |
recompute_granularity: null
|
| 19 |
+
tp: 1
|
| 20 |
tp_linear_async_communication: false
|
| 21 |
tp_mode: ALL_REDUCE
|
|
|
|
| 22 |
slurm_script_dir: null
|
| 23 |
slurm_template: null
|
| 24 |
tasks:
|
| 25 |
+
custom_tasks: brrr.lighteval.custom_tasks
|
| 26 |
dataset_loading_processes: 8
|
| 27 |
+
max_samples: 10000
|
| 28 |
multichoice_continuations_start_space: null
|
| 29 |
no_multichoice_continuations_start_space: null
|
| 30 |
num_fewshot_seeds: null
|
| 31 |
+
tasks: open-llm-leaderboard
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_mistral.py
CHANGED
|
@@ -106,7 +106,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 106 |
self.end *= 2
|
| 107 |
self._initialized_buffer = False
|
| 108 |
if self._initialized_buffer is False:
|
| 109 |
-
print(f"Initializing rotary embeddings with end={self.end}")
|
| 110 |
self.init_rotary_embeddings()
|
| 111 |
dtype = x.dtype
|
| 112 |
assert inner_dim % 2 == 0
|
|
@@ -397,7 +397,6 @@ class CausalSelfAttention(nn.Module, AttachableStore):
|
|
| 397 |
# Double check that we use store only at inference time
|
| 398 |
assert key_states.requires_grad is False
|
| 399 |
assert value_states.requires_grad is False
|
| 400 |
-
print("Using store")
|
| 401 |
if "position_offsets" in store:
|
| 402 |
old_position_offsets = store["position_offsets"]
|
| 403 |
position_ids = old_position_offsets[:, None] + sequence_mask
|
|
|
|
| 106 |
self.end *= 2
|
| 107 |
self._initialized_buffer = False
|
| 108 |
if self._initialized_buffer is False:
|
| 109 |
+
# print(f"Initializing rotary embeddings with end={self.end}")
|
| 110 |
self.init_rotary_embeddings()
|
| 111 |
dtype = x.dtype
|
| 112 |
assert inner_dim % 2 == 0
|
|
|
|
| 397 |
# Double check that we use store only at inference time
|
| 398 |
assert key_states.requires_grad is False
|
| 399 |
assert value_states.requires_grad is False
|
|
|
|
| 400 |
if "position_offsets" in store:
|
| 401 |
old_position_offsets = store["position_offsets"]
|
| 402 |
position_ids = old_position_offsets[:, None] + sequence_mask
|
pretrained/Mistral-7B-v0.1/config.yaml
CHANGED
|
@@ -7,7 +7,7 @@ general:
|
|
| 7 |
project: mistralai
|
| 8 |
run: Mistral-7B-v0.1
|
| 9 |
seed: 42
|
| 10 |
-
step:
|
| 11 |
logging: null
|
| 12 |
model:
|
| 13 |
ddp_bucket_cap_mb: 25
|
|
|
|
| 7 |
project: mistralai
|
| 8 |
run: Mistral-7B-v0.1
|
| 9 |
seed: 42
|
| 10 |
+
step: 0
|
| 11 |
logging: null
|
| 12 |
model:
|
| 13 |
ddp_bucket_cap_mb: 25
|
run_evals.py
CHANGED
|
@@ -10,46 +10,12 @@ torchrun --nproc_per_node=8 run_evals.py --checkpoint-config-path ./pretrained/M
|
|
| 10 |
"""
|
| 11 |
# flake8: noqa: C901
|
| 12 |
import argparse
|
| 13 |
-
import os
|
| 14 |
-
import random
|
| 15 |
-
import time
|
| 16 |
-
from dataclasses import asdict
|
| 17 |
-
from pathlib import Path
|
| 18 |
-
|
| 19 |
-
import numpy as np
|
| 20 |
-
import torch
|
| 21 |
-
from huggingface_hub import HFSummaryWriter
|
| 22 |
-
from lighteval.evaluator import evaluate, make_results_table
|
| 23 |
-
from lighteval.logging.evaluation_tracker import EvaluationTracker
|
| 24 |
-
from lighteval.logging.hierarchical_logger import hlog, htrack, htrack_block
|
| 25 |
-
from lighteval.logging.info_loggers import (
|
| 26 |
-
DetailsLogger,
|
| 27 |
-
)
|
| 28 |
-
from lighteval.models.model_loader import ModelInfo
|
| 29 |
-
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
|
| 30 |
-
from lighteval.tasks.registry import Registry, get_custom_tasks, taskinfo_selector
|
| 31 |
-
from nanotron import distributed as dist
|
| 32 |
-
from nanotron import logging
|
| 33 |
-
from nanotron.config import get_config_from_file
|
| 34 |
-
from nanotron.logging import get_logger, log_rank
|
| 35 |
-
from nanotron.parallel.context import ParallelContext
|
| 36 |
-
from nanotron.utils import local_ranks_zero_first
|
| 37 |
-
|
| 38 |
-
from brrr.config import BrrrConfig
|
| 39 |
-
from brrr.experiment_loggers import flatten_dict, obj_to_markdown
|
| 40 |
-
from brrr.s3_checkpoints import fs_copy
|
| 41 |
-
from brrr.utils import check_env
|
| 42 |
-
|
| 43 |
-
from lighteval.models.brrr_models import BRRRModel
|
| 44 |
|
|
|
|
| 45 |
from modeling_mistral import MistralForTraining
|
| 46 |
from config_mistral import MistralConfig
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
TOKEN = os.getenv("HF_TOKEN")
|
| 51 |
-
CACHE_DIR = os.getenv("HF_HOME", "/scratch")
|
| 52 |
-
|
| 53 |
|
| 54 |
def get_parser():
|
| 55 |
parser = argparse.ArgumentParser()
|
|
@@ -69,374 +35,25 @@ def get_parser():
|
|
| 69 |
type=str,
|
| 70 |
help="Local or hub path of an optional tokenizer (if not indicated in the checkpoint)",
|
| 71 |
)
|
| 72 |
-
parser.add_argument(
|
| 73 |
-
"--s5cmd-path",
|
| 74 |
-
type=str,
|
| 75 |
-
default="/admin/home/thomwolf/miniconda3/envs/b4r/bin/s5cmd",
|
| 76 |
-
help="Path to s5cmd install",
|
| 77 |
-
)
|
| 78 |
-
parser.add_argument(
|
| 79 |
-
"--s5cmd-numworkers",
|
| 80 |
-
type=int,
|
| 81 |
-
default=64,
|
| 82 |
-
help="s5cmd num workers (optional)",
|
| 83 |
-
)
|
| 84 |
-
parser.add_argument(
|
| 85 |
-
"--s5cmd-concurrency",
|
| 86 |
-
type=int,
|
| 87 |
-
default=10,
|
| 88 |
-
help="s5cmd concurrency (optional)",
|
| 89 |
-
)
|
| 90 |
parser.add_argument(
|
| 91 |
"--cache-dir",
|
| 92 |
type=str,
|
| 93 |
-
default=
|
| 94 |
help="Cache directory",
|
| 95 |
)
|
| 96 |
|
| 97 |
return parser
|
| 98 |
|
| 99 |
|
| 100 |
-
def push_results_to_wandb( # noqa: C901
|
| 101 |
-
config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
|
| 102 |
-
):
|
| 103 |
-
# config: BrrrConfig = get_config_from_dict(config, config_class=BrrrConfig)
|
| 104 |
-
lighteval_config = config.lighteval
|
| 105 |
-
try:
|
| 106 |
-
global_step = config.general.step
|
| 107 |
-
except ValueError:
|
| 108 |
-
global_step = 0
|
| 109 |
-
if config.lighteval.logging.tensorboard_metric_prefix is not None:
|
| 110 |
-
prefix = config.lighteval.logging.tensorboard_metric_prefix
|
| 111 |
-
else:
|
| 112 |
-
prefix = "eval"
|
| 113 |
-
output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix)
|
| 114 |
-
output_dir_tb.mkdir(parents=True, exist_ok=True)
|
| 115 |
-
|
| 116 |
-
os.environ["WANDB_DISABLE_SERVICE"] = "True"
|
| 117 |
-
import wandb
|
| 118 |
-
|
| 119 |
-
wandb.tensorboard.patch(root_logdir=config.lighteval.logging.local_output_path)
|
| 120 |
-
hlog("Starting wandb with WANDB_DISABLE_SERVICE=True")
|
| 121 |
-
wandb.init(
|
| 122 |
-
project=config.lighteval.wandb.wandb_project,
|
| 123 |
-
entity=config.lighteval.wandb.wandb_entity,
|
| 124 |
-
name=config.lighteval.wandb.wandb_run_name,
|
| 125 |
-
config=config.as_dict(),
|
| 126 |
-
# sync_tensorboard=True,
|
| 127 |
-
resume=True,
|
| 128 |
-
)
|
| 129 |
-
wb_dict = {}
|
| 130 |
-
bench_averages = {}
|
| 131 |
-
for name, values in results.items():
|
| 132 |
-
splited_name = name.split("|")
|
| 133 |
-
if len(splited_name) == 3:
|
| 134 |
-
_, task_name, _ = splited_name
|
| 135 |
-
else:
|
| 136 |
-
task_name = name
|
| 137 |
-
bench_suite = None
|
| 138 |
-
if ":" in task_name:
|
| 139 |
-
bench_suite = task_name.split(":")[0] # e.g. MMLU
|
| 140 |
-
hlog(f"bench_suite {bench_suite} in {task_name}")
|
| 141 |
-
for metric, value in values.items():
|
| 142 |
-
if "stderr" in metric:
|
| 143 |
-
continue
|
| 144 |
-
if bench_suite not in bench_averages:
|
| 145 |
-
bench_averages[bench_suite] = {}
|
| 146 |
-
bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)]
|
| 147 |
-
hlog(f"Pushing {task_name} {values} to tensorboard")
|
| 148 |
-
for metric, value in values.items():
|
| 149 |
-
if "stderr" in metric:
|
| 150 |
-
wb_dict[f"stderr_{metric}/{task_name}"] = value
|
| 151 |
-
elif bench_suite is not None:
|
| 152 |
-
wb_dict[f"{bench_suite}-{metric}/{task_name}"] = value
|
| 153 |
-
else:
|
| 154 |
-
wb_dict[f"{metric}/{task_name}"] = value
|
| 155 |
-
# e.g. MMLU
|
| 156 |
-
for name, values in bench_averages.items():
|
| 157 |
-
for metric, values in values.items():
|
| 158 |
-
hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard")
|
| 159 |
-
wb_dict[f"{metric}/{name}"] = sum(values) / len(values)
|
| 160 |
-
|
| 161 |
-
for task_name, task_details in details.items():
|
| 162 |
-
if len(task_details) <= 1:
|
| 163 |
-
continue
|
| 164 |
-
columns = list(flatten_dict(asdict(task_details[0])).keys())
|
| 165 |
-
table = wandb.Table(columns=columns)
|
| 166 |
-
table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[0])).values()])
|
| 167 |
-
table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[1])).values()])
|
| 168 |
-
wandb.log({f"eval_details_{task_name}": table}, step=global_step, commit=False)
|
| 169 |
-
|
| 170 |
-
wandb.log(dict(wb_dict.items()), step=global_step, commit=True)
|
| 171 |
-
|
| 172 |
-
# tb_context.add_text("eval_sizes", obj_to_markdown(sizes), global_step=global_step)
|
| 173 |
-
|
| 174 |
-
# We are doing parallel evaluations of multiple checkpoints and recording the steps not in order
|
| 175 |
-
# This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints
|
| 176 |
-
# See: https://github.com/tensorflow/tensorboard/issues/5958
|
| 177 |
-
# But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files
|
| 178 |
-
|
| 179 |
-
hlog(f"Pushed to wandb" f" at {output_dir_tb} and global_step {global_step}")
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
def push_results_to_tensorboard( # noqa: C901
|
| 183 |
-
config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
|
| 184 |
-
):
|
| 185 |
-
# config: BrrrConfig = get_config_from_dict(config, config_class=BrrrConfig)
|
| 186 |
-
lighteval_config = config.lighteval
|
| 187 |
-
try:
|
| 188 |
-
global_step = config.general.step
|
| 189 |
-
except ValueError:
|
| 190 |
-
global_step = 0
|
| 191 |
-
if config.lighteval.logging.tensorboard_metric_prefix is not None:
|
| 192 |
-
prefix = config.lighteval.logging.tensorboard_metric_prefix
|
| 193 |
-
else:
|
| 194 |
-
prefix = "eval"
|
| 195 |
-
output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix)
|
| 196 |
-
output_dir_tb.mkdir(parents=True, exist_ok=True)
|
| 197 |
-
tb_context = HFSummaryWriter(
|
| 198 |
-
logdir=str(output_dir_tb),
|
| 199 |
-
repo_id=lighteval_config.logging.hub_repo_tensorboard,
|
| 200 |
-
repo_private=True,
|
| 201 |
-
path_in_repo="tb",
|
| 202 |
-
commit_every=6000, # Very long time so that we can change our files names and trigger push ourselves (see below)
|
| 203 |
-
)
|
| 204 |
-
bench_averages = {}
|
| 205 |
-
for name, values in results.items():
|
| 206 |
-
splited_name = name.split("|")
|
| 207 |
-
if len(splited_name) == 3:
|
| 208 |
-
_, task_name, _ = splited_name
|
| 209 |
-
else:
|
| 210 |
-
task_name = name
|
| 211 |
-
bench_suite = None
|
| 212 |
-
if ":" in task_name:
|
| 213 |
-
bench_suite = task_name.split(":")[0] # e.g. MMLU
|
| 214 |
-
hlog(f"bench_suite {bench_suite} in {task_name}")
|
| 215 |
-
for metric, value in values.items():
|
| 216 |
-
if "stderr" in metric:
|
| 217 |
-
continue
|
| 218 |
-
if bench_suite not in bench_averages:
|
| 219 |
-
bench_averages[bench_suite] = {}
|
| 220 |
-
bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)]
|
| 221 |
-
hlog(f"Pushing {task_name} {values} to tensorboard")
|
| 222 |
-
for metric, value in values.items():
|
| 223 |
-
if "stderr" in metric:
|
| 224 |
-
tb_context.add_scalar(f"stderr_{prefix}/{task_name}/{metric}", value, global_step=global_step)
|
| 225 |
-
elif bench_suite is not None:
|
| 226 |
-
tb_context.add_scalar(f"{prefix}_{bench_suite}/{task_name}/{metric}", value, global_step=global_step)
|
| 227 |
-
else:
|
| 228 |
-
tb_context.add_scalar(f"{prefix}/{task_name}/{metric}", value, global_step=global_step)
|
| 229 |
-
# e.g. MMLU
|
| 230 |
-
for name, values in bench_averages.items():
|
| 231 |
-
for metric, values in values.items():
|
| 232 |
-
hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard")
|
| 233 |
-
tb_context.add_scalar(f"{prefix}/{name}/{metric}", sum(values) / len(values), global_step=global_step)
|
| 234 |
-
|
| 235 |
-
tb_context.add_text("eval_config", obj_to_markdown(results), global_step=global_step)
|
| 236 |
-
# tb_context.add_text("eval_sizes", obj_to_markdown(sizes), global_step=global_step)
|
| 237 |
-
|
| 238 |
-
for task_name, task_details in details.items():
|
| 239 |
-
tb_context.add_text(
|
| 240 |
-
f"eval_details_{task_name}",
|
| 241 |
-
obj_to_markdown({"0": task_details[0], "1": task_details[1] if len(task_details) > 1 else {}}),
|
| 242 |
-
global_step=global_step,
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
# We are doing parallel evaluations of multiple checkpoints and recording the steps not in order
|
| 246 |
-
# This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints
|
| 247 |
-
# See: https://github.com/tensorflow/tensorboard/issues/5958
|
| 248 |
-
# But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files
|
| 249 |
-
|
| 250 |
-
tb_context.close() # flushes the unfinished write operations
|
| 251 |
-
time.sleep(5)
|
| 252 |
-
files = os.listdir(output_dir_tb)
|
| 253 |
-
for file in files:
|
| 254 |
-
os.rename(os.path.join(output_dir_tb, file), os.path.join(output_dir_tb, f"{global_step:07d}_{file}"))
|
| 255 |
-
|
| 256 |
-
# Now we can push to the hub
|
| 257 |
-
tb_context.scheduler.trigger()
|
| 258 |
-
hlog(
|
| 259 |
-
f"Pushed to tensorboard at https://huggingface.co/tensorboard/{lighteval_config.logging.hub_repo_tensorboard}/"
|
| 260 |
-
f" at {output_dir_tb} and global_step {global_step}"
|
| 261 |
-
)
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
@htrack()
|
| 265 |
-
def main(args):
|
| 266 |
-
cache_dir = args.cache_dir or CACHE_DIR
|
| 267 |
-
check_env()
|
| 268 |
-
|
| 269 |
-
dist.initialize_torch_distributed()
|
| 270 |
-
|
| 271 |
-
with htrack_block("get config"):
|
| 272 |
-
if not args.checkpoint_config_path.endswith(".yaml"):
|
| 273 |
-
raise ValueError("The checkpoint path should point to a YAML file")
|
| 274 |
-
local_config_path = args.checkpoint_config_path
|
| 275 |
-
if args.checkpoint_config_path.startswith("s3:/"):
|
| 276 |
-
local_config_path = args.checkpoint_config_path.replace("s3:/", cache_dir)
|
| 277 |
-
with local_ranks_zero_first():
|
| 278 |
-
if os.environ.get("LOCAL_RANK", None) == "0":
|
| 279 |
-
os.makedirs(os.path.dirname(local_config_path), exist_ok=True)
|
| 280 |
-
fs_copy(args.checkpoint_config_path, local_config_path)
|
| 281 |
-
|
| 282 |
-
brrr_config: BrrrConfig = get_config_from_file(local_config_path, config_class=BrrrConfig, model_config_class=MistralConfig)
|
| 283 |
-
|
| 284 |
-
if args.lighteval_override:
|
| 285 |
-
local_override_path = args.lighteval_override.replace("s3:/", cache_dir)
|
| 286 |
-
if args.lighteval_override.startswith("s3:/"):
|
| 287 |
-
local_override_path = args.lighteval_override.replace("s3:/", cache_dir)
|
| 288 |
-
with local_ranks_zero_first():
|
| 289 |
-
if os.environ.get("LOCAL_RANK", None) == "0":
|
| 290 |
-
os.makedirs(os.path.dirname(local_override_path), exist_ok=True)
|
| 291 |
-
fs_copy(args.lighteval_override, local_override_path)
|
| 292 |
-
lighteval_brrr_config: BrrrConfig = get_config_from_file(local_override_path, config_class=BrrrConfig)
|
| 293 |
-
lighteval_config = lighteval_brrr_config.lighteval
|
| 294 |
-
brrr_config.lighteval = lighteval_config
|
| 295 |
-
else:
|
| 296 |
-
local_override_path = ""
|
| 297 |
-
lighteval_config = brrr_config.lighteval
|
| 298 |
-
|
| 299 |
-
parallel_context = ParallelContext(
|
| 300 |
-
tensor_parallel_size=lighteval_config.parallelism.tp,
|
| 301 |
-
pipeline_parallel_size=lighteval_config.parallelism.pp,
|
| 302 |
-
data_parallel_size=lighteval_config.parallelism.dp,
|
| 303 |
-
)
|
| 304 |
-
|
| 305 |
-
evaluation_tracker = EvaluationTracker(token=TOKEN)
|
| 306 |
-
evaluation_tracker.general_config_logger.log_args_info(
|
| 307 |
-
num_fewshot_seeds=1,
|
| 308 |
-
override_batch_size=None,
|
| 309 |
-
max_samples=lighteval_config.tasks.max_samples,
|
| 310 |
-
job_id=os.environ.get("SLURM_JOB_ID", None),
|
| 311 |
-
config=brrr_config.as_dict(),
|
| 312 |
-
)
|
| 313 |
-
|
| 314 |
-
with htrack_block("Test all gather"):
|
| 315 |
-
hlog("Test gather tensor")
|
| 316 |
-
# Do a first NCCL sync to warmup and try to avoid Timeout after model/data loading
|
| 317 |
-
log_rank(
|
| 318 |
-
f"[TEST] Running NCCL sync for ranks {list(range(parallel_context.world_pg.size()))}",
|
| 319 |
-
logger=logger,
|
| 320 |
-
level=logging.WARNING,
|
| 321 |
-
group=parallel_context.dp_pg,
|
| 322 |
-
rank=0,
|
| 323 |
-
)
|
| 324 |
-
test_tensor = torch.tensor([dist.get_rank(parallel_context.world_pg)], device=torch.device("cuda"))
|
| 325 |
-
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(parallel_context.world_pg.size())]
|
| 326 |
-
dist.all_gather(test_tensor_list, test_tensor, group=parallel_context.world_pg, async_op=False)
|
| 327 |
-
dist.barrier()
|
| 328 |
-
log_rank(
|
| 329 |
-
f"[TEST] NCCL sync for ranks {[t.item() for t in test_tensor_list]}",
|
| 330 |
-
logger=logger,
|
| 331 |
-
level=logging.WARNING,
|
| 332 |
-
group=parallel_context.dp_pg,
|
| 333 |
-
rank=0,
|
| 334 |
-
)
|
| 335 |
-
|
| 336 |
-
del test_tensor_list
|
| 337 |
-
del test_tensor
|
| 338 |
-
|
| 339 |
-
with htrack_block("Model loading"):
|
| 340 |
-
# We need to load the model in the main process first to avoid downloading the model multiple times
|
| 341 |
-
model = BRRRModel(
|
| 342 |
-
checkpoint_path=args.checkpoint_config_path.replace("config.yaml", ""),
|
| 343 |
-
model_args=brrr_config.model,
|
| 344 |
-
tokenizer=brrr_config.tokenizer,
|
| 345 |
-
parallel_context=parallel_context,
|
| 346 |
-
parallel_config=lighteval_config.parallelism,
|
| 347 |
-
lighteval_config=lighteval_config,
|
| 348 |
-
batch_size=lighteval_config.batch_size,
|
| 349 |
-
cache_dir=os.environ.get("HF_HOME", "/scratch"),
|
| 350 |
-
debug_one_layer_model=False,
|
| 351 |
-
s5cmd_path=args.s5cmd_path,
|
| 352 |
-
s5cmd_numworkers=args.s5cmd_numworkers,
|
| 353 |
-
s5cmd_concurrency=args.s5cmd_concurrency,
|
| 354 |
-
model_class=MistralForTraining
|
| 355 |
-
)
|
| 356 |
-
model_info = ModelInfo(model_name=f"{brrr_config.general.run}/{brrr_config.general.step}")
|
| 357 |
-
evaluation_tracker.general_config_logger.log_model_info(model_info)
|
| 358 |
-
|
| 359 |
-
with htrack_block("Tasks loading"):
|
| 360 |
-
with local_ranks_zero_first():
|
| 361 |
-
tasks_selection = lighteval_config.tasks.tasks
|
| 362 |
-
if lighteval_config.tasks.custom_tasks_file:
|
| 363 |
-
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks_file)
|
| 364 |
-
if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict:
|
| 365 |
-
tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks]
|
| 366 |
-
|
| 367 |
-
task_names_list, few_shots_dict = taskinfo_selector(tasks_selection)
|
| 368 |
-
task_dict = Registry(cache_dir=cache_dir).get_task_dict(
|
| 369 |
-
task_names_list, custom_tasks_file=lighteval_config.tasks.custom_tasks_file
|
| 370 |
-
)
|
| 371 |
-
# Loading all the dataset in a distributed manner
|
| 372 |
-
LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes)
|
| 373 |
-
|
| 374 |
-
evaluation_tracker.task_config_logger.log(task_dict)
|
| 375 |
-
|
| 376 |
-
hlog("Loading documents, and requests")
|
| 377 |
-
requests, docs = create_requests_from_tasks(
|
| 378 |
-
task_dict=task_dict,
|
| 379 |
-
fewshot_dict=few_shots_dict,
|
| 380 |
-
num_fewshot_seeds=lighteval_config.tasks.num_fewshot_seeds or 1,
|
| 381 |
-
lm=model,
|
| 382 |
-
max_samples=lighteval_config.tasks.max_samples,
|
| 383 |
-
evaluation_tracker=evaluation_tracker,
|
| 384 |
-
use_chat_template=False
|
| 385 |
-
)
|
| 386 |
-
|
| 387 |
-
with htrack_block("Setting seeds and waiting for all processes"):
|
| 388 |
-
hlog(f"setting seed to {1234} for random and numpy")
|
| 389 |
-
random.seed(1234)
|
| 390 |
-
np.random.seed(1234)
|
| 391 |
-
dist.barrier()
|
| 392 |
-
|
| 393 |
-
with htrack_block("Evaluation"):
|
| 394 |
-
hlog(f"Evaluate on {len(task_names_list)} tasks.")
|
| 395 |
-
evaluation_tracker = evaluate(
|
| 396 |
-
lm=model,
|
| 397 |
-
requests_dict=requests,
|
| 398 |
-
docs=docs,
|
| 399 |
-
task_dict=task_dict,
|
| 400 |
-
override_bs=lighteval_config.batch_size,
|
| 401 |
-
evaluation_tracker=evaluation_tracker,
|
| 402 |
-
)
|
| 403 |
-
|
| 404 |
-
if dist.get_rank(parallel_context.world_pg) == 0:
|
| 405 |
-
with htrack_block("Compiling and saving results"):
|
| 406 |
-
evaluation_tracker.general_config_logger.log_end_time()
|
| 407 |
-
evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict, bootstrap_iters=1000)
|
| 408 |
-
evaluation_tracker.details_logger.aggregate()
|
| 409 |
-
|
| 410 |
-
if lighteval_config.logging.local_output_path:
|
| 411 |
-
evaluation_tracker.save(
|
| 412 |
-
output_dir=lighteval_config.logging.local_output_path,
|
| 413 |
-
push_results_to_hub=lighteval_config.logging.push_results_to_hub,
|
| 414 |
-
push_details_to_hub=lighteval_config.logging.push_details_to_hub,
|
| 415 |
-
public=False,
|
| 416 |
-
push_results_to_tensorboard=lighteval_config.logging.push_results_to_tensorboard,
|
| 417 |
-
)
|
| 418 |
-
|
| 419 |
-
if lighteval_config.logging.push_results_to_tensorboard:
|
| 420 |
-
push_results_to_tensorboard(
|
| 421 |
-
config=brrr_config,
|
| 422 |
-
results=evaluation_tracker.metrics_logger.metric_aggregated,
|
| 423 |
-
details=evaluation_tracker.details_logger.details,
|
| 424 |
-
)
|
| 425 |
-
if lighteval_config.wandb is not None:
|
| 426 |
-
push_results_to_wandb(
|
| 427 |
-
config=brrr_config,
|
| 428 |
-
results=evaluation_tracker.metrics_logger.metric_aggregated,
|
| 429 |
-
details=evaluation_tracker.details_logger.details,
|
| 430 |
-
)
|
| 431 |
-
|
| 432 |
-
final_dict = evaluation_tracker.generate_final_dict()
|
| 433 |
-
|
| 434 |
-
hlog(make_results_table(final_dict))
|
| 435 |
-
|
| 436 |
-
return final_dict
|
| 437 |
-
|
| 438 |
|
| 439 |
if __name__ == "__main__":
|
| 440 |
parser = get_parser()
|
| 441 |
args, unknowns = parser.parse_known_args()
|
| 442 |
-
main(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
# flake8: noqa: C901
|
| 12 |
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
from nanotron.config import Config
|
| 15 |
from modeling_mistral import MistralForTraining
|
| 16 |
from config_mistral import MistralConfig
|
| 17 |
|
| 18 |
+
from lighteval.main_nanotron import main
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def get_parser():
|
| 21 |
parser = argparse.ArgumentParser()
|
|
|
|
| 35 |
type=str,
|
| 36 |
help="Local or hub path of an optional tokenizer (if not indicated in the checkpoint)",
|
| 37 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
parser.add_argument(
|
| 39 |
"--cache-dir",
|
| 40 |
type=str,
|
| 41 |
+
default=None,
|
| 42 |
help="Cache directory",
|
| 43 |
)
|
| 44 |
|
| 45 |
return parser
|
| 46 |
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
if __name__ == "__main__":
|
| 50 |
parser = get_parser()
|
| 51 |
args, unknowns = parser.parse_known_args()
|
| 52 |
+
main(
|
| 53 |
+
args.checkpoint_config_path,
|
| 54 |
+
args.lighteval_override,
|
| 55 |
+
args.cache_dir,
|
| 56 |
+
config_cls=Config,
|
| 57 |
+
model_config_cls=MistralConfig,
|
| 58 |
+
model_cls=MistralForTraining
|
| 59 |
+
)
|
run_generate.py
CHANGED
|
@@ -35,9 +35,8 @@ from nanotron.random import (
|
|
| 35 |
from nanotron.serialize import (
|
| 36 |
load_weights,
|
| 37 |
)
|
| 38 |
-
from nanotron.trainer import
|
| 39 |
|
| 40 |
-
from brrr.config import BrrrConfig
|
| 41 |
from config_mistral_7b import MistralConfig
|
| 42 |
from modeling_mistral import MistralForTraining
|
| 43 |
|
|
@@ -64,7 +63,7 @@ def main():
|
|
| 64 |
|
| 65 |
assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist"
|
| 66 |
|
| 67 |
-
config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix(),
|
| 68 |
model_config = config.model.model_config
|
| 69 |
tokenizer_path = config.tokenizer.tokenizer_name_or_path
|
| 70 |
|
|
|
|
| 35 |
from nanotron.serialize import (
|
| 36 |
load_weights,
|
| 37 |
)
|
| 38 |
+
from nanotron.trainer import mark_tied_parameters
|
| 39 |
|
|
|
|
| 40 |
from config_mistral_7b import MistralConfig
|
| 41 |
from modeling_mistral import MistralForTraining
|
| 42 |
|
|
|
|
| 63 |
|
| 64 |
assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist"
|
| 65 |
|
| 66 |
+
config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix(), model_config_class=MistralConfig, skip_unused_config_keys=True)
|
| 67 |
model_config = config.model.model_config
|
| 68 |
tokenizer_path = config.tokenizer.tokenizer_name_or_path
|
| 69 |
|