Spaces:
Paused
Paused
| import json | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from contextlib import AsyncExitStack | |
| from functools import cached_property | |
| from typing import Any, Optional, Type, cast | |
| from pydantic import BaseModel, Field | |
| from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential | |
| from proxy_lite.client import BaseClient, ClientConfigTypes, OpenAIClientConfig | |
| from proxy_lite.history import ( | |
| AssistantMessage, | |
| MessageHistory, | |
| MessageLabel, | |
| SystemMessage, | |
| Text, | |
| ToolCall, | |
| ToolMessage, | |
| UserMessage, | |
| ) | |
| from proxy_lite.logger import logger | |
| from proxy_lite.tools import Tool | |
| # if TYPE_CHECKING: | |
| # from proxy_lite.tools import Tool | |
| class BaseAgentConfig(BaseModel): | |
| client: ClientConfigTypes = Field(default_factory=OpenAIClientConfig) | |
| history_messages_limit: dict[MessageLabel, int] = Field(default_factory=lambda: dict()) | |
| history_messages_include: Optional[dict[MessageLabel, int]] = Field( | |
| default=None, | |
| description="If set, overrides history_messages_limit by setting all message types to 0 except those specified", | |
| ) | |
| def model_post_init(self, __context: Any) -> None: | |
| if self.history_messages_include is not None: | |
| self.history_messages_limit = {label: 0 for label in MessageLabel} | |
| self.history_messages_limit.update(self.history_messages_include) | |
| class BaseAgent(BaseModel, ABC): | |
| config: BaseAgentConfig | |
| temperature: float = Field(default=0.7, ge=0, le=2) | |
| history: MessageHistory = Field(default_factory=MessageHistory) | |
| client: Optional[BaseClient] = None | |
| env_tools: list[Tool] = Field(default_factory=list) | |
| task: Optional[str] = Field(default=None) | |
| seed: Optional[int] = Field(default=None) | |
| class Config: | |
| arbitrary_types_allowed = True | |
| def __init__(self, **data) -> None: | |
| super().__init__(**data) | |
| self._exit_stack = AsyncExitStack() | |
| self._tools_init_task = None | |
| def model_post_init(self, __context: Any) -> None: | |
| super().model_post_init(__context) | |
| self.client = BaseClient.create(self.config.client) | |
| def system_prompt(self) -> str: ... | |
| def tools(self) -> list[Tool]: ... | |
| def tool_descriptions(self) -> str: | |
| tool_descriptions = [] | |
| for tool in self.tools: | |
| func_descriptions = "\n".join("- {name}: {description}".format(**schema) for schema in tool.schema) | |
| tool_title = f"{tool.__class__.__name__}:\n" if len(self.tools) > 1 else "" | |
| tool_descriptions.append(f"{tool_title}{func_descriptions}") | |
| return "\n\n".join(tool_descriptions) | |
| async def get_history_view(self) -> MessageHistory: | |
| return MessageHistory( | |
| messages=[SystemMessage(content=[Text(text=self.system_prompt)])], | |
| ) + self.history.history_view( | |
| limits=self.config.history_messages_limit, | |
| ) | |
| async def generate_output( | |
| self, | |
| use_tool: bool = False, | |
| response_format: Optional[type[BaseModel]] = None, | |
| append_assistant_message: bool = True, | |
| ) -> AssistantMessage: | |
| messages: MessageHistory = await self.get_history_view() | |
| response_content = ( | |
| await self.client.create_completion( | |
| messages=messages, | |
| temperature=self.temperature, | |
| seed=self.seed, | |
| response_format=response_format, | |
| tools=self.tools if use_tool else None, | |
| ) | |
| ).model_dump() | |
| response_content = response_content["choices"][0]["message"] | |
| assistant_message = AssistantMessage( | |
| role=response_content["role"], | |
| content=[Text(text=response_content["content"])] if response_content["content"] else [], | |
| tool_calls=response_content["tool_calls"], | |
| ) | |
| if append_assistant_message: | |
| self.history.append(message=assistant_message, label=self.message_label) | |
| return assistant_message | |
| def receive_user_message( | |
| self, | |
| text: Optional[str] = None, | |
| image: list[bytes] = None, | |
| label: MessageLabel = None, | |
| is_base64: bool = False, | |
| ) -> None: | |
| message = UserMessage.from_media( | |
| text=text, | |
| image=image, | |
| is_base64=is_base64, | |
| ) | |
| self.history.append(message=message, label=label) | |
| def receive_system_message( | |
| self, | |
| text: Optional[str] = None, | |
| label: MessageLabel = None, | |
| ) -> None: | |
| message = SystemMessage.from_media(text=text) | |
| self.history.append(message=message, label=label) | |
| def receive_assistant_message( | |
| self, | |
| content: Optional[str] = None, | |
| tool_calls: Optional[list[ToolCall]] = None, | |
| label: MessageLabel = None, | |
| ) -> None: | |
| message = AssistantMessage( | |
| content=[Text(text=content)] if content else [], | |
| tool_calls=tool_calls, | |
| ) | |
| self.history.append(message=message, label=label) | |
| async def use_tool(self, tool_call: ToolCall): | |
| function = tool_call.function | |
| for tool in self.tools: | |
| if hasattr(tool, function["name"]): | |
| return await getattr(tool, function["name"])( | |
| **json.loads(function["arguments"]), | |
| ) | |
| msg = f'No tool function with name "{function["name"]}"' | |
| raise ValueError(msg) | |
| async def receive_tool_message( | |
| self, | |
| text: str, | |
| tool_id: str, | |
| label: MessageLabel = None, | |
| ) -> None: | |
| self.history.append( | |
| message=ToolMessage(content=[Text(text=text)], tool_call_id=tool_id), | |
| label=label, | |
| ) | |
| class Agents: | |
| _agent_registry: dict[str, type[BaseAgent]] = {} | |
| _agent_config_registry: dict[str, type[BaseAgentConfig]] = {} | |
| def register_agent(cls, name: str): | |
| """ | |
| Decorator to register an Agent class under a given name. | |
| Example: | |
| @Agents.register_agent("browser") | |
| class BrowserAgent(BaseAgent): | |
| ... | |
| """ | |
| def decorator(agent_cls: type[BaseAgent]) -> type[BaseAgent]: | |
| cls._agent_registry[name] = agent_cls | |
| return agent_cls | |
| return decorator | |
| def register_agent_config(cls, name: str): | |
| """ | |
| Decorator to register a configuration class under a given name. | |
| Example: | |
| @Agents.register_agent_config("browser") | |
| class BrowserAgentConfig(BaseAgentConfig): | |
| ... | |
| """ | |
| def decorator(config_cls: type[BaseAgentConfig]) -> type[BaseAgentConfig]: | |
| cls._agent_config_registry[name] = config_cls | |
| return config_cls | |
| return decorator | |
| def get(cls, name: str) -> type[BaseAgent]: | |
| """ | |
| Retrieve a registered Agent class by its name. | |
| Raises: | |
| ValueError: If no such agent is found. | |
| """ | |
| try: | |
| return cast(Type[BaseAgent], cls._agent_registry[name]) | |
| except KeyError: | |
| raise ValueError(f"Agent '{name}' not found.") | |
| def get_config(cls, name: str) -> type[BaseAgentConfig]: | |
| """ | |
| Retrieve a registered Agent configuration class by its name. | |
| Raises: | |
| ValueError: If no such config is found. | |
| """ | |
| try: | |
| return cast(type[BaseAgentConfig], cls._agent_config_registry[name]) | |
| except KeyError: | |
| raise ValueError(f"Agent config for '{name}' not found.") | |