Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Tool registry for managing and discovering GAIA tools. | |
| """ | |
| from typing import Dict, List, Optional, Type, Any | |
| from dataclasses import dataclass, field | |
| from .base import GAIATool, ToolCategory, ToolMetadata | |
| from ..utils.exceptions import ToolNotFoundError | |
| class ToolRegistry: | |
| """Registry for managing GAIA tools.""" | |
| def __init__(self): | |
| self._tools: Dict[str, Type[GAIATool]] = {} | |
| self._metadata: Dict[str, ToolMetadata] = {} | |
| self._instances: Dict[str, GAIATool] = {} | |
| def register(self, tool_class: Type[GAIATool], metadata: ToolMetadata) -> None: | |
| """Register a tool with metadata.""" | |
| self._tools[metadata.name] = tool_class | |
| self._metadata[metadata.name] = metadata | |
| def get_tool(self, name: str, **init_kwargs) -> GAIATool: | |
| """Get tool instance by name.""" | |
| if name not in self._tools: | |
| raise ToolNotFoundError(f"Tool '{name}' not found in registry") | |
| # Return cached instance or create new one | |
| cache_key = f"{name}_{hash(frozenset(init_kwargs.items()))}" | |
| if cache_key not in self._instances: | |
| tool_class = self._tools[name] | |
| self._instances[cache_key] = tool_class(**init_kwargs) | |
| return self._instances[cache_key] | |
| def get_tools_by_category(self, category: ToolCategory) -> List[str]: | |
| """Get tool names by category.""" | |
| return [ | |
| name for name, metadata in self._metadata.items() | |
| if metadata.category == category | |
| ] | |
| def get_all_tools(self) -> List[str]: | |
| """Get all registered tool names.""" | |
| return list(self._tools.keys()) | |
| def get_metadata(self, name: str) -> ToolMetadata: | |
| """Get tool metadata by name.""" | |
| if name not in self._metadata: | |
| raise ToolNotFoundError(f"Tool '{name}' not found in registry") | |
| return self._metadata[name] | |
| def search_tools(self, query: str) -> List[str]: | |
| """Search tools by name or description.""" | |
| query_lower = query.lower() | |
| matches = [] | |
| for name, metadata in self._metadata.items(): | |
| if (query_lower in name.lower() or | |
| query_lower in metadata.description.lower()): | |
| matches.append(name) | |
| return matches | |
| def validate_dependencies(self, name: str) -> bool: | |
| """Check if tool dependencies are available.""" | |
| metadata = self.get_metadata(name) | |
| # Check if dependency tools are registered | |
| for dep in metadata.dependencies: | |
| if dep not in self._tools: | |
| return False | |
| return True | |
| def get_tool_info(self, name: str) -> Dict[str, Any]: | |
| """Get comprehensive tool information.""" | |
| metadata = self.get_metadata(name) | |
| return { | |
| "name": metadata.name, | |
| "description": metadata.description, | |
| "category": metadata.category.value, | |
| "version": metadata.version, | |
| "author": metadata.author, | |
| "input_schema": metadata.input_schema, | |
| "output_schema": metadata.output_schema, | |
| "examples": metadata.examples, | |
| "dependencies": metadata.dependencies, | |
| "dependencies_satisfied": self.validate_dependencies(name) | |
| } | |
| # Global tool registry | |
| tool_registry = ToolRegistry() | |
| def register_tool(metadata: ToolMetadata): | |
| """Decorator to register a tool.""" | |
| def decorator(tool_class: Type[GAIATool]): | |
| tool_registry.register(tool_class, metadata) | |
| return tool_class | |
| return decorator |