Spaces:
Running
Running
add suite
Browse files
tlem.py
CHANGED
|
@@ -6,6 +6,8 @@ except Exception as e:
|
|
| 6 |
import logging
|
| 7 |
|
| 8 |
from typing import Any, Optional, Protocol, Iterable, Callable
|
|
|
|
|
|
|
| 9 |
|
| 10 |
# %%
|
| 11 |
|
|
@@ -33,14 +35,18 @@ TextGenerationPipeline = Callable[[Iterable[str]], list[str]]
|
|
| 33 |
from evaluate import load
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
@dataclass
|
| 37 |
class Task:
|
| 38 |
-
dataset_name: str = "gsm8k"
|
| 39 |
-
|
| 40 |
# metrics: list[str] = field(default_factory=list)
|
| 41 |
-
metric_name: str | tuple[str, str] = "gsm8k"
|
| 42 |
input_column: str = "question"
|
| 43 |
-
label_column: str = "
|
| 44 |
prompt: Optional[Callable | str] = None
|
| 45 |
|
| 46 |
@cached_property
|
|
@@ -49,7 +55,12 @@ class Task:
|
|
| 49 |
|
| 50 |
@cached_property
|
| 51 |
def dataset(self):
|
| 52 |
-
ds = load_dataset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
if self.prompt is not None:
|
| 54 |
ds = ds.map(
|
| 55 |
lambda example: {
|
|
@@ -72,9 +83,11 @@ class Task:
|
|
| 72 |
)
|
| 73 |
return metric
|
| 74 |
|
| 75 |
-
def run(self, pipeline: TextGenerationPipeline):
|
| 76 |
outputs = pipeline(self.samples)
|
| 77 |
-
return self.metric.compute(
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
class Metrics:
|
|
@@ -224,7 +237,41 @@ class ReasoningMetric(evaluate.Metric):
|
|
| 224 |
|
| 225 |
return results
|
| 226 |
|
| 227 |
-
# %%
|
| 228 |
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import logging
|
| 7 |
|
| 8 |
from typing import Any, Optional, Protocol, Iterable, Callable
|
| 9 |
+
from tqdm.auto import tqdm
|
| 10 |
+
from evaluate.evaluation_suite import EvaluationSuite
|
| 11 |
|
| 12 |
# %%
|
| 13 |
|
|
|
|
| 35 |
from evaluate import load
|
| 36 |
|
| 37 |
|
| 38 |
+
def fake_pipeline(prompts: Iterable[str]) -> list[str]:
|
| 39 |
+
return [prompt for prompt in tqdm(prompts)]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
@dataclass
|
| 43 |
class Task:
|
| 44 |
+
dataset_name: str | tuple[str, str] = ("gsm8k", "main")
|
| 45 |
+
split: str = "test"
|
| 46 |
# metrics: list[str] = field(default_factory=list)
|
| 47 |
+
metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k")
|
| 48 |
input_column: str = "question"
|
| 49 |
+
label_column: str = "answer"
|
| 50 |
prompt: Optional[Callable | str] = None
|
| 51 |
|
| 52 |
@cached_property
|
|
|
|
| 55 |
|
| 56 |
@cached_property
|
| 57 |
def dataset(self):
|
| 58 |
+
ds = load_dataset(
|
| 59 |
+
*self.dataset_name
|
| 60 |
+
if isinstance(self.dataset_name, tuple)
|
| 61 |
+
else self.dataset_name,
|
| 62 |
+
split=self.split
|
| 63 |
+
)
|
| 64 |
if self.prompt is not None:
|
| 65 |
ds = ds.map(
|
| 66 |
lambda example: {
|
|
|
|
| 83 |
)
|
| 84 |
return metric
|
| 85 |
|
| 86 |
+
def run(self, pipeline: TextGenerationPipeline = fake_pipeline):
|
| 87 |
outputs = pipeline(self.samples)
|
| 88 |
+
return self.metric.compute(
|
| 89 |
+
responses=outputs, references=self.dataset[self.label_column]
|
| 90 |
+
)
|
| 91 |
|
| 92 |
|
| 93 |
class Metrics:
|
|
|
|
| 237 |
|
| 238 |
return results
|
| 239 |
|
|
|
|
| 240 |
|
| 241 |
+
class Suite(EvaluationSuite):
|
| 242 |
+
def run(
|
| 243 |
+
self, model_or_pipeline: Any, prompt: str = "{instruction}"
|
| 244 |
+
) -> dict[str, float]:
|
| 245 |
+
self.assert_suite_nonempty()
|
| 246 |
+
|
| 247 |
+
results_all = {}
|
| 248 |
+
for task in tqdm(self.suite, desc="Running tasks"):
|
| 249 |
+
task_name = task.name
|
| 250 |
+
results = task.run(model_or_pipeline)
|
| 251 |
+
results_all[task_name] = results
|
| 252 |
+
return results_all
|
| 253 |
+
|
| 254 |
+
def __init__(self, name):
|
| 255 |
+
super().__init__(name)
|
| 256 |
+
|
| 257 |
+
self.suite = [
|
| 258 |
+
Task(
|
| 259 |
+
dataset_name=("gsm8k", "main"),
|
| 260 |
+
metric_name=("sustech/tlem", "gsm8k"),
|
| 261 |
+
input_column="question",
|
| 262 |
+
label_column="answer",
|
| 263 |
+
)
|
| 264 |
+
# TASK_REGISTRY["gsm8k"],
|
| 265 |
+
# TASK_REGISTRY["competition_math"],
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
|
| 269 |
+
# %%
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
# metric = load("sustech/tlem", "gsm8k")
|
| 273 |
+
# output = metric.compute(responses=["answer is 2", "1+2"], references=["2", "3"])
|
| 274 |
+
# logging.info(output)
|
| 275 |
+
suite = EvaluationSuite.load("sustech/tlem")
|
| 276 |
+
suite.run(fake_pipeline)
|
| 277 |
+
# %%
|