Spaces:
Running
Running
fix _temp_run can't be pickled; pass indices to allow evaluation on subset
Browse files- apps_metric.py +2 -2
- tests.py +20 -10
- utils.py +17 -10
apps_metric.py
CHANGED
|
@@ -76,7 +76,7 @@ class apps_metric(evaluate.EvaluationModule):
|
|
| 76 |
|
| 77 |
|
| 78 |
|
| 79 |
-
def _compute(self, predictions, k_list=[1, 10, 100], count_errors=True, level="all", debug=False):
|
| 80 |
"""Returns the scores"""
|
| 81 |
-
metrics = compute_metrics(predictions, k_list=k_list, count_errors=count_errors, level=level, debug=debug)
|
| 82 |
return metrics
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
|
| 79 |
+
def _compute(self, predictions, indices=None, k_list=[1, 10, 100], count_errors=True, level="all", debug=False):
|
| 80 |
"""Returns the scores"""
|
| 81 |
+
metrics = compute_metrics(predictions, indices=indices, k_list=k_list, count_errors=count_errors, level=level, debug=debug)
|
| 82 |
return metrics
|
tests.py
CHANGED
|
@@ -1,14 +1,24 @@
|
|
| 1 |
import json
|
| 2 |
-
from
|
| 3 |
|
| 4 |
-
|
| 5 |
-
solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
|
| 6 |
-
single_solutions = [solution_sample1[:1], solution_sample2[:1]]
|
| 7 |
-
multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]
|
| 8 |
|
| 9 |
-
metric = load("codeparrot/apps_metric")
|
| 10 |
-
result_1 = metric.compute(predictions=single_solutions, level="all")
|
| 11 |
-
result_2 = metric.compute(predictions=multiple_solutions, level="all", k_list=[1, 2, 3])
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
+
from multiprocessing import freeze_support
|
| 3 |
|
| 4 |
+
from apps_metric import apps_metric
|
|
|
|
|
|
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
if __name__ == '__main__':
|
| 8 |
+
"""
|
| 9 |
+
Verify by checking if reference solutions pass all test cases (with strict accuracy == 1).
|
| 10 |
+
Note that some reference solutions may not pass all test cases. So only throw a warning.
|
| 11 |
+
"""
|
| 12 |
+
freeze_support()
|
| 13 |
+
|
| 14 |
+
solution_sample1 = json.load(open("test_examples/solutions_problem_1.json", "r"))
|
| 15 |
+
solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
|
| 16 |
+
single_solutions = [solution_sample1[:1], solution_sample2[:1]]
|
| 17 |
+
multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]
|
| 18 |
+
|
| 19 |
+
metric = apps_metric()
|
| 20 |
+
result_1 = metric.compute(predictions=single_solutions, level="all")
|
| 21 |
+
result_2 = metric.compute(predictions=multiple_solutions, level="all", k_list=[1, 2, 3])
|
| 22 |
+
|
| 23 |
+
assert result_1 == {'avg_accuracy': 1.0, 'strict_accuracy': 1.0, 'pass_at_k': None}
|
| 24 |
+
assert result_2 == {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
|
utils.py
CHANGED
|
@@ -9,13 +9,14 @@ from .testing_util import run_test
|
|
| 9 |
DATASET = "codeparrot/apps"
|
| 10 |
TIMEOUT = 10
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def check_correctness(sample, generation, timeout, debug=True):
|
| 13 |
"""Check correctness of code generation with a global timeout.
|
| 14 |
The global timeout is to catch some extreme/rare cases not handled by the timeouts
|
| 15 |
inside `run_test`"""
|
| 16 |
-
def _temp_run(sample, generation, debug, result):
|
| 17 |
-
result.append(run_test(sample, test=generation, debug=debug))
|
| 18 |
-
|
| 19 |
manager = multiprocessing.Manager()
|
| 20 |
result = manager.list()
|
| 21 |
p = multiprocessing.Process(target=_temp_run, args=(sample, generation, debug, result))
|
|
@@ -32,12 +33,13 @@ def check_correctness(sample, generation, timeout, debug=True):
|
|
| 32 |
return result[0]
|
| 33 |
|
| 34 |
|
| 35 |
-
def evaluate_generations(generations: list, level: str = "all", debug: bool = False):
|
| 36 |
"""We take the list of code generations and try to compile them
|
| 37 |
and the run their corresponding unit tests which are retrieved from the APPS dataset.
|
| 38 |
|
| 39 |
Args:
|
| 40 |
generations: list of code generations (same order as samples in APPS dataset)
|
|
|
|
| 41 |
level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition"
|
| 42 |
|
| 43 |
Returns:
|
|
@@ -47,10 +49,14 @@ def evaluate_generations(generations: list, level: str = "all", debug: bool = Fa
|
|
| 47 |
|
| 48 |
# generations are code generations in the same order of the dataset
|
| 49 |
apps_eval = load_dataset(DATASET, split="test", difficulties=[level])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
results = {}
|
| 51 |
-
for index in
|
| 52 |
# code generations for problem (index)
|
| 53 |
-
problem_generations =
|
| 54 |
# get corresponding samples from APPS dataset
|
| 55 |
sample = apps_eval[index]
|
| 56 |
res = []
|
|
@@ -74,7 +80,7 @@ def evaluate_generations(generations: list, level: str = "all", debug: bool = Fa
|
|
| 74 |
print(f"Results were not True for all test cases")
|
| 75 |
except Exception as e:
|
| 76 |
if debug:
|
| 77 |
-
print(f"Compilation failed, test framework exception = {repr(e)}
|
| 78 |
break
|
| 79 |
finally:
|
| 80 |
assert isinstance(curr_res, list)
|
|
@@ -125,7 +131,7 @@ def get_results(results: Dict[int, list], count_errors: bool = False, k_list: li
|
|
| 125 |
|
| 126 |
metrics = {"avg_accuracy": None, "strict_accuracy": None, "pass_at_k": None}
|
| 127 |
|
| 128 |
-
if len(results[0]) == 1:
|
| 129 |
# for single generations we compute average accuracy and stric accuracy: original APPS metrics
|
| 130 |
print("Computing accuracy metrics...")
|
| 131 |
res = []
|
|
@@ -173,10 +179,11 @@ def get_results(results: Dict[int, list], count_errors: bool = False, k_list: li
|
|
| 173 |
metrics["pass_at_k"] = pass_at_k
|
| 174 |
return metrics
|
| 175 |
|
| 176 |
-
def compute_metrics(generations, level="all", k_list=[1, 10, 100], count_errors=True, debug=False):
|
| 177 |
"""Return metrics for the given generations.
|
| 178 |
Args:
|
| 179 |
generations: list of code generations for each problem (each generation is a list of generations)
|
|
|
|
| 180 |
k_list: list of k values to compute pass@k when using multiple generations
|
| 181 |
count_errors: whether to count compilation and runtime errors when using single generations
|
| 182 |
level: difficulty level in APPS dataset that was used for the given generations (from: "all", "introductory", "interview", "competition")
|
|
@@ -204,7 +211,7 @@ def compute_metrics(generations, level="all", k_list=[1, 10, 100], count_errors=
|
|
| 204 |
{'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}
|
| 205 |
{'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
|
| 206 |
"""
|
| 207 |
-
results = evaluate_generations(generations, level=level, debug=debug)
|
| 208 |
metrics = get_results(results, count_errors=count_errors, k_list=k_list)
|
| 209 |
return metrics
|
| 210 |
|
|
|
|
| 9 |
DATASET = "codeparrot/apps"
|
| 10 |
TIMEOUT = 10
|
| 11 |
|
| 12 |
+
|
| 13 |
+
def _temp_run(sample, generation, debug, result):
|
| 14 |
+
result.append(run_test(sample, test=generation, debug=debug))
|
| 15 |
+
|
| 16 |
def check_correctness(sample, generation, timeout, debug=True):
|
| 17 |
"""Check correctness of code generation with a global timeout.
|
| 18 |
The global timeout is to catch some extreme/rare cases not handled by the timeouts
|
| 19 |
inside `run_test`"""
|
|
|
|
|
|
|
|
|
|
| 20 |
manager = multiprocessing.Manager()
|
| 21 |
result = manager.list()
|
| 22 |
p = multiprocessing.Process(target=_temp_run, args=(sample, generation, debug, result))
|
|
|
|
| 33 |
return result[0]
|
| 34 |
|
| 35 |
|
| 36 |
+
def evaluate_generations(generations: list, indices: list = [], level: str = "all", debug: bool = False):
|
| 37 |
"""We take the list of code generations and try to compile them
|
| 38 |
and the run their corresponding unit tests which are retrieved from the APPS dataset.
|
| 39 |
|
| 40 |
Args:
|
| 41 |
generations: list of code generations (same order as samples in APPS dataset)
|
| 42 |
+
indices: list of indicies of problems to evaluate, if empty, evaluate all problems
|
| 43 |
level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition"
|
| 44 |
|
| 45 |
Returns:
|
|
|
|
| 49 |
|
| 50 |
# generations are code generations in the same order of the dataset
|
| 51 |
apps_eval = load_dataset(DATASET, split="test", difficulties=[level])
|
| 52 |
+
|
| 53 |
+
if indices is None:
|
| 54 |
+
indices = range(len(generations))
|
| 55 |
+
|
| 56 |
results = {}
|
| 57 |
+
for index, generation in zip(indices, generations):
|
| 58 |
# code generations for problem (index)
|
| 59 |
+
problem_generations = generation
|
| 60 |
# get corresponding samples from APPS dataset
|
| 61 |
sample = apps_eval[index]
|
| 62 |
res = []
|
|
|
|
| 80 |
print(f"Results were not True for all test cases")
|
| 81 |
except Exception as e:
|
| 82 |
if debug:
|
| 83 |
+
print(f"Compilation failed, test framework exception = {repr(e)}\n")
|
| 84 |
break
|
| 85 |
finally:
|
| 86 |
assert isinstance(curr_res, list)
|
|
|
|
| 131 |
|
| 132 |
metrics = {"avg_accuracy": None, "strict_accuracy": None, "pass_at_k": None}
|
| 133 |
|
| 134 |
+
if len(list(results.values())[0]) == 1:
|
| 135 |
# for single generations we compute average accuracy and stric accuracy: original APPS metrics
|
| 136 |
print("Computing accuracy metrics...")
|
| 137 |
res = []
|
|
|
|
| 179 |
metrics["pass_at_k"] = pass_at_k
|
| 180 |
return metrics
|
| 181 |
|
| 182 |
+
def compute_metrics(generations, indices=None, level="all", k_list=[1, 10, 100], count_errors=True, debug=False):
|
| 183 |
"""Return metrics for the given generations.
|
| 184 |
Args:
|
| 185 |
generations: list of code generations for each problem (each generation is a list of generations)
|
| 186 |
+
indices: list of indices of problems (if None, generations are all problems)
|
| 187 |
k_list: list of k values to compute pass@k when using multiple generations
|
| 188 |
count_errors: whether to count compilation and runtime errors when using single generations
|
| 189 |
level: difficulty level in APPS dataset that was used for the given generations (from: "all", "introductory", "interview", "competition")
|
|
|
|
| 211 |
{'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}
|
| 212 |
{'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
|
| 213 |
"""
|
| 214 |
+
results = evaluate_generations(generations, indices=indices, level=level, debug=debug)
|
| 215 |
metrics = get_results(results, count_errors=count_errors, k_list=k_list)
|
| 216 |
return metrics
|
| 217 |
|