Spaces:
Paused
Paused
| import asyncio | |
| import logging | |
| from collections.abc import AsyncIterator | |
| from contextlib import asynccontextmanager | |
| from typing import Any, Literal, Self | |
| from omegaconf import OmegaConf | |
| from pydantic import BaseModel | |
| from proxy_lite.environments import ( | |
| Action, | |
| BaseEnvironment, | |
| EnvironmentConfigTypes, | |
| Environments, | |
| EventType, | |
| Observation, | |
| ) | |
| from proxy_lite.logger import create_logger | |
| from proxy_lite.recorder import DataRecorder, Run | |
| from proxy_lite.solvers import ( | |
| BaseSolver, | |
| SolverConfigTypes, | |
| Solvers, | |
| ) | |
| async def async_timeout(timeout: float, task_name: str = "timeout"): | |
| try: | |
| async with asyncio.TaskGroup() as tg: | |
| async def timeout_task(): | |
| await asyncio.sleep(timeout) | |
| raise TimeoutError( | |
| f"Operation {task_name} timed out after {timeout} seconds", | |
| ) | |
| # Create the timeout task | |
| timeout_handle = tg.create_task(timeout_task()) | |
| try: | |
| yield | |
| finally: | |
| timeout_handle.cancel() | |
| except* asyncio.TimeoutError as eg: | |
| for e in eg.exceptions: | |
| raise e | |
| except* Exception as eg: | |
| for e in eg.exceptions: | |
| raise e | |
| class RunnerConfig(BaseModel): | |
| environment: EnvironmentConfigTypes | |
| solver: SolverConfigTypes | |
| save_every_step: bool = True | |
| max_steps: int = 50 | |
| action_timeout: float = 600.0 | |
| environment_timeout: float = 300.0 | |
| task_timeout: float = 1800.0 | |
| logger_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO" | |
| detailed_logger_name: bool = False | |
| def from_dict(cls, config_dict: dict) -> Self: | |
| conf = OmegaConf.create(config_dict) | |
| config_dict = OmegaConf.to_container(conf, resolve=True) | |
| return cls(**config_dict) | |
| def from_yaml(cls, yaml_path: str) -> Self: | |
| conf = OmegaConf.load(yaml_path) | |
| config_dict = OmegaConf.to_container(conf, resolve=True) | |
| return cls(**config_dict) | |
| class Runner(BaseModel): | |
| config: RunnerConfig | |
| recorder: DataRecorder | None = None | |
| environment: type[BaseEnvironment] | None = None | |
| solver: type[BaseSolver] | None = None | |
| logger: logging.Logger | None = None | |
| _run: Run | None = None | |
| class Config: | |
| arbitrary_types_allowed = True | |
| def model_post_init(self, __context: Any) -> None: | |
| super().model_post_init(__context) | |
| self.environment = Environments.get(self.config.environment.name) | |
| self.solver = Solvers.get(self.config.solver.name) | |
| self.recorder = DataRecorder() | |
| self.logger = create_logger( | |
| name=f"([bold purple]{self.config.solver.name}[/]-[bold blue]{self.config.environment.name}[/])", | |
| level=self.config.logger_level, | |
| detailed_name=self.config.detailed_logger_name, | |
| ) | |
| async def run_generator(self, task: str) -> AsyncIterator[Run]: | |
| async with ( | |
| async_timeout(self.config.task_timeout, "Task"), | |
| ): | |
| if self.config.logger_level is not None: | |
| self.logger.setLevel(self.config.logger_level) | |
| run = self.recorder.initialise_run(task) | |
| run.environment = self.config.environment | |
| run.solver = self.config.solver | |
| self.logger.debug(f"Run intialised: {run.run_id}") | |
| event_queue = asyncio.Queue() | |
| async with ( | |
| self.environment( | |
| config=self.config.environment, | |
| logger=self.logger, | |
| ) as environment, | |
| self.solver(config=self.config.solver, logger=self.logger) as solver, | |
| ): | |
| run.env_info = await environment.get_info() | |
| await solver.initialise( | |
| task, | |
| environment.tools, | |
| environment.info_for_user, | |
| ) | |
| self.logger.debug("Solver initialised.") | |
| run.solver_history = solver.history | |
| observation: Observation = await environment.initialise() | |
| await event_queue.put(observation) | |
| self.logger.debug("Environment initialised.") | |
| step_count = 0 | |
| while step_count < self.config.max_steps: | |
| event = await event_queue.get() | |
| self.logger.debug(f"π€ [bold purple]Processing event:[/] {event.type}") | |
| match event.type: | |
| case EventType.OBSERVATION: | |
| observation: Observation = event | |
| run.record( | |
| observation=observation, | |
| solver_history=solver.history, | |
| ) | |
| async with async_timeout( | |
| self.config.action_timeout, | |
| "Action decision", | |
| ): | |
| action: Action = await solver.act(observation) | |
| await event_queue.put(action) | |
| case EventType.ACTION: | |
| action: Action = event | |
| self.logger.debug(f"Tool calls: {action.tool_calls}") | |
| run.record(action=action, solver_history=solver.history) | |
| run.complete = await solver.is_complete(observation) | |
| if self.config.save_every_step: | |
| await self.recorder.save(run) | |
| if run.complete: | |
| run.result = action.text | |
| self.logger.info(f"π€ [bold purple]Task complete.[/] β¨ \n{run.result}") | |
| break | |
| self.logger.debug(f"DEBUG: Using environment_timeout: {self.config.environment_timeout} seconds") | |
| async with async_timeout( | |
| self.config.environment_timeout, | |
| "Environment response", | |
| ): | |
| observation: Observation = await environment.execute_action(action) | |
| step_count += 1 | |
| await event_queue.put(observation) | |
| yield run | |
| if not run.complete: | |
| self.logger.warning("π€ [bold purple]Ran out of steps!") | |
| await self.recorder.terminate(run, save=True) | |
| yield run | |
| async def run(self, task: str) -> Run: | |
| async for run in self.run_generator(task): | |
| self._run = run | |
| return run | |
| def run_concurrent(self, tasks: list[str]) -> list[Run]: | |
| async def gather_runs(): | |
| return await asyncio.gather( | |
| *[self.run(task) for task in tasks], | |
| return_exceptions=True, | |
| ) | |
| return asyncio.run(gather_runs()) | |
| def complete(self) -> bool: | |
| if self._run is None: | |
| raise RuntimeError("Run not initialised") | |
| return self._run.complete | |
| def run_id(self) -> str: | |
| if self._run is None: | |
| raise RuntimeError("Run not initialised") | |
| return self._run.run_id | |
| def run_result(self) -> str: | |
| if self._run is None: | |
| raise RuntimeError("Run not initialised") | |
| return self._run.result | |
| if __name__ == "__main__": | |
| from proxy_lite.logger import logger | |
| config = RunnerConfig.from_dict( | |
| { | |
| "environment": { | |
| "name": "webbrowser", | |
| "homepage": "https://www.google.com", | |
| "viewport_width": 1280, | |
| "viewport_height": 1920, | |
| "screenshot_delay": 1, | |
| "headless": False, | |
| }, | |
| "solver": { | |
| "name": "simple", | |
| "agent": { | |
| "name": "proxy_lite", | |
| "client": { | |
| "name": "convergence", | |
| "model_id": "convergence-ai/proxy-lite", | |
| "api_base": "https://convergence-ai-demo-api.hf.space/v1", | |
| }, | |
| }, | |
| }, | |
| "max_steps": 150, | |
| "action_timeout": 1800, | |
| "environment_timeout": 1800, | |
| "task_timeout": 18000, | |
| "logger_level": "DEBUG", | |
| }, | |
| ) | |
| logger.info(f"π€ [bold purple]Config:[/] {config}") | |
| runner = Runner(config=config) | |
| result = asyncio.run(runner.run("Tell me the tesla stock price.")) | |
| print(runner.run_result) | |
| print(runner.complete) | |