Spaces:
Running
Running
| """ | |
| Given a directory of results, plot the benchmarks for each task as a bar chart and line chart. | |
| """ | |
| import argparse | |
| import os | |
| from typing import Optional | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import seaborn as sns | |
| from dgeb import TaskResult, get_all_tasks, get_output_folder, get_tasks_by_name | |
| ALL_TASKS = [task.metadata.id for task in get_all_tasks()] | |
| def plot_benchmarks( | |
| results_dir, | |
| task_ids: Optional[list[str]] = None, | |
| output="benchmarks.png", | |
| model_substring=None, | |
| ): | |
| models = os.listdir(results_dir) | |
| all_results = [] | |
| tasks = get_all_tasks() if task_ids is None else get_tasks_by_name(task_ids) | |
| for model_name in models: | |
| if model_substring is not None and all( | |
| substr not in model_name for substr in model_substring | |
| ): | |
| continue | |
| for task in tasks: | |
| if task.metadata.display_name == "NoOp Task": | |
| continue | |
| filepath = get_output_folder(model_name, task, results_dir, create=False) | |
| # if the file does not exist, skip | |
| if not os.path.exists(filepath): | |
| continue | |
| with open(filepath) as f: | |
| task_result = TaskResult.model_validate_json(f.read()) | |
| num_params = task_result.model["num_params"] | |
| primary_metric_id = task_result.task.primary_metric_id | |
| main_scores = [ | |
| metric.value | |
| for layer_result in task_result.results | |
| for metric in layer_result.metrics | |
| if metric.id == primary_metric_id | |
| ] | |
| best_score = max(main_scores) | |
| all_results.append( | |
| { | |
| "task": task.metadata.display_name, | |
| "model": model_name, | |
| "num_params": num_params, | |
| "score": best_score, | |
| } | |
| ) | |
| results_df = pd.DataFrame(all_results) | |
| # order the models by ascending number of parameters | |
| results_df["num_params"] = results_df["num_params"].astype(int) | |
| results_df = results_df.sort_values(by="num_params") | |
| # number of tasks | |
| n_tasks = len(set(results_df["task"])) | |
| _, ax = plt.subplots(2, n_tasks, figsize=(5 * n_tasks, 10)) | |
| for i, task in enumerate(set(results_df["task"])): | |
| if n_tasks > 1: | |
| sns.barplot( | |
| x="model", | |
| y="score", | |
| data=results_df[results_df["task"] == task], | |
| ax=ax[0][i], | |
| ) | |
| ax[0][i].set_title(task) | |
| # rotate the x axis labels | |
| for tick in ax[0][i].get_xticklabels(): | |
| tick.set_rotation(90) | |
| else: | |
| sns.barplot( | |
| x="model", | |
| y="score", | |
| data=results_df[results_df["task"] == task], | |
| ax=ax[0], | |
| ) | |
| ax[0].set_title(task) | |
| # rotate the x axis labels | |
| for tick in ax[0].get_xticklabels(): | |
| tick.set_rotation(90) | |
| # make a line graph with number of parameters on x axis for each task in the second row of figures | |
| for i, task in enumerate(set(results_df["task"])): | |
| if n_tasks > 1: | |
| sns.lineplot( | |
| x="num_params", | |
| y="score", | |
| data=results_df[results_df["task"] == task], | |
| ax=ax[1][i], | |
| ) | |
| ax[1][i].set_title(task) | |
| ax[1][i].set_xlabel("Number of parameters") | |
| else: | |
| sns.lineplot( | |
| x="num_params", | |
| y="score", | |
| data=results_df[results_df["task"] == task], | |
| ax=ax[1], | |
| ) | |
| ax[1].set_title(task) | |
| ax[1].set_xlabel("Number of parameters") | |
| plt.tight_layout() | |
| plt.savefig(output) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-d", | |
| "--results_dir", | |
| type=str, | |
| default="results", | |
| help="Directory containing the results of the benchmarking", | |
| ) | |
| parser.add_argument( | |
| "-t", | |
| "--tasks", | |
| type=lambda s: [item for item in s.split(",")], | |
| default=None, | |
| help=f"Comma separated list of tasks to plot. Choose from {ALL_TASKS} or do not specify to plot all tasks. ", | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--output", | |
| type=str, | |
| default="benchmarks.png", | |
| help="Output file for the plot", | |
| ) | |
| parser.add_argument( | |
| "--model_substring", | |
| type=lambda s: [item for item in s.split(",")], | |
| default=None, | |
| help="Comma separated list of model substrings. Only plot results for models containing this substring", | |
| ) | |
| args = parser.parse_args() | |
| plot_benchmarks(args.results_dir, args.tasks, args.output, args.model_substring) | |