Spaces:
Paused
Paused
| import logging | |
| from abc import ABC, abstractmethod | |
| from functools import cached_property | |
| from typing import Optional, Self, Type, cast | |
| from pydantic import BaseModel, Field | |
| from proxy_lite.environments.environment_base import Action, Observation | |
| from proxy_lite.tools import Tool | |
| class BaseSolverConfig(BaseModel): | |
| pass | |
| class BaseSolver(BaseModel, ABC): | |
| task: Optional[str] = None | |
| env_tools: list[Tool] = Field(default_factory=list) | |
| config: BaseSolverConfig | |
| logger: logging.Logger | None = None | |
| class Config: | |
| arbitrary_types_allowed = True | |
| async def __aenter__(self) -> Self: | |
| return self | |
| async def __aexit__(self, exc_type, exc_value, traceback) -> None: | |
| pass | |
| def tools(self) -> list[Tool]: ... | |
| async def initialise( | |
| self, | |
| task: str, | |
| env_tools: list[Tool], | |
| env_info: str, | |
| ) -> None: | |
| """ | |
| Initialise the solution with the given task. | |
| """ | |
| ... | |
| async def act(self, observation: Observation) -> Action: | |
| """ | |
| Return an action for interacting with the environment. | |
| """ | |
| ... | |
| async def is_complete(self, observation: Observation) -> bool: | |
| """ | |
| Return a boolean indicating if the task is complete. | |
| """ | |
| return observation.terminated | |
| class Solvers: | |
| _solver_registry: dict[str, type[BaseSolver]] = {} | |
| _solver_config_registry: dict[str, type[BaseSolverConfig]] = {} | |
| def register_solver(cls, name: str): | |
| """ | |
| Decorator to register a Solver class under a given name. | |
| Example: | |
| @Solvers.register_solver("my_solver") | |
| class MySolver(BaseSolver): | |
| ... | |
| """ | |
| def decorator(solver_cls: type[BaseSolver]) -> type[BaseSolver]: | |
| cls._solver_registry[name] = solver_cls | |
| return solver_cls | |
| return decorator | |
| def register_solver_config(cls, name: str): | |
| """ | |
| Decorator to register a Solver configuration class under a given name. | |
| Example: | |
| @Solvers.register_solver_config("my_solver") | |
| class MySolverConfig(BaseSolverConfig): | |
| ... | |
| """ | |
| def decorator(config_cls: type[BaseSolverConfig]) -> type[BaseSolverConfig]: | |
| cls._solver_config_registry[name] = config_cls | |
| return config_cls | |
| return decorator | |
| def get(cls, name: str) -> type[BaseSolver]: | |
| """ | |
| Retrieve a registered Solver class by its name. | |
| Raises: | |
| ValueError: If no such solver is found. | |
| """ | |
| try: | |
| return cast(Type[BaseSolver], cls._solver_registry[name]) | |
| except KeyError: | |
| raise ValueError(f"Solver '{name}' not found.") | |
| def get_config(cls, name: str) -> type[BaseSolverConfig]: | |
| """ | |
| Retrieve a registered Solver configuration class by its name. | |
| Raises: | |
| ValueError: If no such config is found. | |
| """ | |
| try: | |
| return cast(Type[BaseSolverConfig], cls._solver_config_registry[name]) | |
| except KeyError: | |
| raise ValueError(f"Solver config for '{name}' not found.") | |