Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Base classes and interfaces for GAIA tools. | |
| """ | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, Optional, Union, List | |
| from enum import Enum | |
| import time | |
| import functools | |
| from ..utils.exceptions import ToolError, ToolValidationError, ToolExecutionError, ToolTimeoutError | |
| class ToolStatus(Enum): | |
| """Tool execution status.""" | |
| SUCCESS = "success" | |
| ERROR = "error" | |
| TIMEOUT = "timeout" | |
| VALIDATION_FAILED = "validation_failed" | |
| class ToolResult: | |
| """Standardized tool result format.""" | |
| status: ToolStatus | |
| output: Any | |
| error_message: Optional[str] = None | |
| execution_time: Optional[float] = None | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| def is_success(self) -> bool: | |
| """Check if tool execution was successful.""" | |
| return self.status == ToolStatus.SUCCESS | |
| def is_error(self) -> bool: | |
| """Check if tool execution failed.""" | |
| return self.status in [ToolStatus.ERROR, ToolStatus.TIMEOUT, ToolStatus.VALIDATION_FAILED] | |
| def get_output_or_error(self) -> str: | |
| """Get output if successful, otherwise error message.""" | |
| if self.is_success: | |
| return str(self.output) | |
| return self.error_message or "Unknown error" | |
| class GAIATool(ABC): | |
| """Abstract base class for all GAIA tools.""" | |
| def __init__(self, name: str, description: str, timeout: int = 60): | |
| self.name = name | |
| self.description = description | |
| self.timeout = timeout | |
| self._execution_count = 0 | |
| self._total_execution_time = 0.0 | |
| def _execute(self, **kwargs) -> Any: | |
| """Execute the tool logic. Must be implemented by subclasses.""" | |
| pass | |
| def _validate_input(self, **kwargs) -> None: | |
| """Validate input parameters. Must be implemented by subclasses.""" | |
| pass | |
| def execute(self, **kwargs) -> ToolResult: | |
| """Execute tool with standardized error handling and timing.""" | |
| start_time = time.time() | |
| try: | |
| # Input validation | |
| self._validate_input(**kwargs) | |
| # Execute with timeout | |
| result = self._execute_with_timeout(**kwargs) | |
| # Record execution | |
| execution_time = time.time() - start_time | |
| self._record_execution(execution_time) | |
| return ToolResult( | |
| status=ToolStatus.SUCCESS, | |
| output=result, | |
| execution_time=execution_time, | |
| metadata=self._get_execution_metadata() | |
| ) | |
| except ToolValidationError as e: | |
| execution_time = time.time() - start_time | |
| return ToolResult( | |
| status=ToolStatus.VALIDATION_FAILED, | |
| output=None, | |
| error_message=str(e), | |
| execution_time=execution_time | |
| ) | |
| except ToolTimeoutError as e: | |
| execution_time = time.time() - start_time | |
| return ToolResult( | |
| status=ToolStatus.TIMEOUT, | |
| output=None, | |
| error_message=str(e), | |
| execution_time=execution_time | |
| ) | |
| except Exception as e: | |
| execution_time = time.time() - start_time | |
| return ToolResult( | |
| status=ToolStatus.ERROR, | |
| output=None, | |
| error_message=f"{self.name} execution failed: {str(e)}", | |
| execution_time=execution_time | |
| ) | |
| def _execute_with_timeout(self, **kwargs) -> Any: | |
| """Execute with timeout handling.""" | |
| import signal | |
| def timeout_handler(signum, frame): | |
| raise ToolTimeoutError(f"Tool {self.name} timed out after {self.timeout} seconds") | |
| # Set timeout | |
| old_handler = signal.signal(signal.SIGALRM, timeout_handler) | |
| signal.alarm(self.timeout) | |
| try: | |
| result = self._execute(**kwargs) | |
| signal.alarm(0) # Cancel timeout | |
| return result | |
| finally: | |
| signal.signal(signal.SIGALRM, old_handler) | |
| def _record_execution(self, execution_time: float) -> None: | |
| """Record execution statistics.""" | |
| self._execution_count += 1 | |
| self._total_execution_time += execution_time | |
| def _get_execution_metadata(self) -> Dict[str, Any]: | |
| """Get execution metadata.""" | |
| return { | |
| "tool_name": self.name, | |
| "execution_count": self._execution_count, | |
| "average_execution_time": self._total_execution_time / max(1, self._execution_count) | |
| } | |
| def __call__(self, **kwargs) -> ToolResult: | |
| """Make tool callable.""" | |
| return self.execute(**kwargs) | |
| def __str__(self) -> str: | |
| return f"{self.name}: {self.description}" | |
| class AsyncGAIATool(GAIATool): | |
| """Base class for async tools.""" | |
| async def _execute_async(self, **kwargs) -> Any: | |
| """Async execute method. Must be implemented by subclasses.""" | |
| pass | |
| def _execute(self, **kwargs) -> Any: | |
| """Sync wrapper for async execution.""" | |
| import asyncio | |
| return asyncio.run(self._execute_async(**kwargs)) | |
| def tool_with_retry(max_retries: int = 3, backoff_factor: float = 2.0): | |
| """Decorator to add retry logic to tool execution.""" | |
| def decorator(tool_class): | |
| original_execute = tool_class._execute | |
| def execute_with_retry(self, **kwargs): | |
| last_exception = None | |
| for attempt in range(max_retries + 1): | |
| try: | |
| return original_execute(self, **kwargs) | |
| except Exception as e: | |
| last_exception = e | |
| if attempt < max_retries: | |
| wait_time = backoff_factor ** attempt | |
| time.sleep(wait_time) | |
| continue | |
| else: | |
| raise e | |
| if last_exception: | |
| raise last_exception | |
| tool_class._execute = execute_with_retry | |
| return tool_class | |
| return decorator | |
| def validate_required_params(*required_params): | |
| """Decorator to validate required parameters.""" | |
| def decorator(validate_method): | |
| def wrapper(self, **kwargs): | |
| # Check required parameters | |
| missing_params = [param for param in required_params if param not in kwargs] | |
| if missing_params: | |
| raise ToolValidationError( | |
| f"Missing required parameters for {self.name}: {missing_params}" | |
| ) | |
| # Check for None values | |
| none_params = [param for param in required_params if kwargs.get(param) is None] | |
| if none_params: | |
| raise ToolValidationError( | |
| f"Required parameters cannot be None for {self.name}: {none_params}" | |
| ) | |
| # Call original validation | |
| return validate_method(self, **kwargs) | |
| return wrapper | |
| return decorator | |
| class ToolCategory(Enum): | |
| """Tool categories for organization.""" | |
| MULTIMEDIA = "multimedia" | |
| RESEARCH = "research" | |
| FILE_PROCESSING = "file_processing" | |
| CHESS = "chess" | |
| MATH = "math" | |
| UTILITY = "utility" | |
| class ToolMetadata: | |
| """Metadata for tool registration and discovery.""" | |
| name: str | |
| description: str | |
| category: ToolCategory | |
| input_schema: Dict[str, Any] | |
| output_schema: Dict[str, Any] | |
| examples: List[Dict[str, Any]] = field(default_factory=list) | |
| version: str = "1.0.0" | |
| author: Optional[str] = None | |
| dependencies: List[str] = field(default_factory=list) |