Spaces:
Sleeping
Sleeping
| from typing import Dict, Any | |
| from .openai_api import OpenAIAPI | |
| from .anthropic_api import AnthropicAPI | |
| from .grok_api import GrokAPI | |
| from .base_api import BaseAPI | |
| class APIFactory: | |
| """Factory class to create API instances based on model name""" | |
| # Model to provider mapping | |
| MODEL_PROVIDERS = { | |
| # OpenAI models | |
| 'gpt-4o': 'openai', | |
| 'gpt-4-turbo': 'openai', | |
| 'gpt-3.5-turbo': 'openai', | |
| # Anthropic models | |
| 'claude-3-5-sonnet-20241022': 'anthropic', | |
| 'claude-3-opus-20240229': 'anthropic', | |
| 'claude-3-haiku-20240307': 'anthropic', | |
| # Grok models | |
| 'grok-4-0709': 'grok', | |
| 'grok-beta': 'grok', | |
| 'grok-2-latest': 'grok', | |
| 'grok-vision-beta': 'grok', | |
| } | |
| # Provider to API class mapping | |
| PROVIDER_APIS = { | |
| 'openai': OpenAIAPI, | |
| 'anthropic': AnthropicAPI, | |
| 'grok': GrokAPI, | |
| } | |
| def create_api(cls, model_name: str, config: Dict[str, Any]) -> BaseAPI: | |
| """Create an API instance for the given model""" | |
| # Determine provider | |
| provider = cls.MODEL_PROVIDERS.get(model_name) | |
| if not provider: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| # Get provider config | |
| provider_config = config['models'].get(provider) | |
| if not provider_config: | |
| raise ValueError(f"No configuration found for provider: {provider}") | |
| # Get API key | |
| api_key = provider_config.get('api_key') | |
| if not api_key: | |
| raise ValueError(f"No API key found for provider: {provider}") | |
| # Get API class | |
| api_class = cls.PROVIDER_APIS.get(provider) | |
| if not api_class: | |
| raise ValueError(f"No API implementation for provider: {provider}") | |
| # Create API instance with provider-specific kwargs | |
| kwargs = { | |
| 'rate_limit_delay': config['evaluation'].get('rate_limit_delay', 1.0), | |
| 'max_retries': config['evaluation'].get('max_retries', 3), | |
| 'timeout': config['evaluation'].get('timeout', 30), | |
| } | |
| # Add provider-specific config | |
| if provider == 'grok': | |
| kwargs['base_url'] = provider_config.get('base_url', 'https://api.x.ai/v1') | |
| return api_class(api_key, model_name, **kwargs) |