Spaces:
Running
Running
| import logging | |
| import os | |
| import traceback | |
| from itertools import chain | |
| from typing import Any, List | |
| from rich.console import Console | |
| from .eval_utils import set_all_seeds | |
| from .modality import Modality | |
| from .models import BioSeqTransformer | |
| from .tasks.tasks import Task | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class DGEB: | |
| """GEB class to run the evaluation pipeline.""" | |
| def __init__(self, tasks: List[type[Task]], seed: int = 42): | |
| self.tasks = tasks | |
| set_all_seeds(seed) | |
| def print_selected_tasks(self): | |
| """Print the selected tasks.""" | |
| console = Console() | |
| console.rule("[bold]Selected Tasks\n", style="grey15") | |
| for task in self.tasks: | |
| prefix = " - " | |
| name = f"{task.metadata.display_name}" | |
| category = f", [italic grey39]{task.metadata.type}[/]" | |
| console.print(f"{prefix}{name}{category}") | |
| console.print("\n") | |
| def run( | |
| self, | |
| model, # type encoder | |
| output_folder: str = "results", | |
| ): | |
| """Run the evaluation pipeline on the selected tasks. | |
| Args: | |
| model: Model to be used for evaluation | |
| output_folder: Folder where the results will be saved. Default to 'results'. Where it will save the results in the format: | |
| `{output_folder}/{model_name}/{model_revision}/{task_name}.json`. | |
| Returns: | |
| A list of MTEBResults objects, one for each task evaluated. | |
| """ | |
| # Run selected tasks | |
| self.print_selected_tasks() | |
| results = [] | |
| for task in self.tasks: | |
| logger.info( | |
| f"\n\n********************** Evaluating {task.metadata.display_name} **********************" | |
| ) | |
| try: | |
| result = task().run(model) | |
| except Exception as e: | |
| logger.error(e) | |
| logger.error(traceback.format_exc()) | |
| logger.error(f"Error running task {task}") | |
| continue | |
| results.append(result) | |
| save_path = get_output_folder(model.hf_name, task, output_folder) | |
| with open(save_path, "w") as f_out: | |
| f_out.write(result.model_dump_json(indent=2)) | |
| return results | |
| def get_model(model_name: str, **kwargs: Any) -> type[BioSeqTransformer]: | |
| all_names = get_all_model_names() | |
| for cls in BioSeqTransformer.__subclasses__(): | |
| if model_name in cls.MODEL_NAMES: | |
| return cls(model_name, **kwargs) | |
| raise ValueError(f"Model {model_name} not found in {all_names}.") | |
| def get_all_model_names() -> List[str]: | |
| return list( | |
| chain.from_iterable( | |
| cls.MODEL_NAMES for cls in BioSeqTransformer.__subclasses__() | |
| ) | |
| ) | |
| def get_all_task_names() -> List[str]: | |
| return [task.metadata.id for task in get_all_tasks()] | |
| def get_tasks_by_name(tasks: List[str]) -> List[type[Task]]: | |
| return [_get_task(task) for task in tasks] | |
| def get_tasks_by_modality(modality: Modality) -> List[type[Task]]: | |
| return [task for task in get_all_tasks() if task.metadata.modality == modality] | |
| def get_all_tasks() -> List[type[Task]]: | |
| return Task.__subclasses__() | |
| def _get_task(task_name: str) -> type[Task]: | |
| logger.info(f"Getting task {task_name}") | |
| for task in get_all_tasks(): | |
| if task.metadata.id == task_name: | |
| return task | |
| raise ValueError( | |
| f"Task {task_name} not found, available tasks are: {[task.metadata.id for task in get_all_tasks()]}" | |
| ) | |
| def get_output_folder( | |
| model_hf_name: str, task: type[Task], output_folder: str, create: bool = True | |
| ): | |
| output_folder = os.path.join(output_folder, os.path.basename(model_hf_name)) | |
| # create output folder if it does not exist | |
| if create and not os.path.exists(output_folder): | |
| os.makedirs(output_folder) | |
| return os.path.join( | |
| output_folder, | |
| f"{task.metadata.id}.json", | |
| ) | |