Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Model provider implementations for GAIA agent. | |
| Contains specific model provider classes and utilities. | |
| """ | |
| import os | |
| import time | |
| import litellm | |
| from typing import List, Dict, Any, Optional | |
| from ..utils.exceptions import ModelError, ModelAuthenticationError | |
| class LiteLLMModel: | |
| """Custom model adapter to use LiteLLM with smolagents""" | |
| def __init__(self, model_name: str, api_key: str, api_base: str = None): | |
| if not api_key: | |
| raise ValueError(f"No API key provided for {model_name}") | |
| self.model_name = model_name | |
| self.api_key = api_key | |
| self.api_base = api_base | |
| # Configure LiteLLM based on provider | |
| self._configure_environment() | |
| self._test_authentication() | |
| def _configure_environment(self) -> None: | |
| """Configure environment variables for the model.""" | |
| try: | |
| if "gemini" in self.model_name.lower(): | |
| os.environ["GEMINI_API_KEY"] = self.api_key | |
| elif self.api_base: | |
| # For custom API endpoints like Kluster.ai | |
| os.environ["OPENAI_API_KEY"] = self.api_key | |
| os.environ["OPENAI_API_BASE"] = self.api_base | |
| litellm.set_verbose = False # Reduce verbose logging | |
| except Exception as e: | |
| raise ModelError(f"Failed to configure environment for {self.model_name}: {e}") | |
| def _test_authentication(self) -> None: | |
| """Test authentication with a minimal request.""" | |
| try: | |
| if "gemini" in self.model_name.lower(): | |
| # Test Gemini authentication | |
| test_response = litellm.completion( | |
| model=self.model_name, | |
| messages=[{"role": "user", "content": "test"}], | |
| max_tokens=1 | |
| ) | |
| print(f"✅ Initialized LiteLLM with {self.model_name}" + | |
| (f" via {self.api_base}" if self.api_base else "")) | |
| except Exception as e: | |
| error_msg = f"Authentication failed for {self.model_name}: {str(e)}" | |
| print(f"❌ {error_msg}") | |
| raise ModelAuthenticationError(error_msg, model_name=self.model_name) | |
| class ChatMessage: | |
| """Enhanced ChatMessage class for smolagents + LiteLLM compatibility""" | |
| def __init__(self, content: str, role: str = "assistant"): | |
| self.content = content | |
| self.role = role | |
| self.tool_calls = [] | |
| # Token usage attributes - covering different naming conventions | |
| self.token_usage = { | |
| "prompt_tokens": 0, | |
| "completion_tokens": 0, | |
| "total_tokens": 0 | |
| } | |
| # Additional attributes for broader compatibility | |
| self.input_tokens = 0 # Alternative naming for prompt_tokens | |
| self.output_tokens = 0 # Alternative naming for completion_tokens | |
| self.usage = self.token_usage # Alternative attribute name | |
| # Optional metadata attributes | |
| self.finish_reason = "stop" | |
| self.model = None | |
| self.created = None | |
| def __str__(self): | |
| return self.content | |
| def __repr__(self): | |
| return f"ChatMessage(role='{self.role}', content='{self.content[:50]}...')" | |
| def __getitem__(self, key): | |
| """Make the object dict-like for backward compatibility""" | |
| if key == 'input_tokens': | |
| return self.input_tokens | |
| elif key == 'output_tokens': | |
| return self.output_tokens | |
| elif key == 'content': | |
| return self.content | |
| elif key == 'role': | |
| return self.role | |
| else: | |
| raise KeyError(f"Key '{key}' not found") | |
| def get(self, key, default=None): | |
| """Dict-like get method""" | |
| try: | |
| return self[key] | |
| except KeyError: | |
| return default | |
| def __call__(self, messages: List[Dict], **kwargs): | |
| """Make the model callable for smolagents compatibility""" | |
| try: | |
| # Format messages for LiteLLM | |
| formatted_messages = self._format_messages(messages) | |
| # Execute with retry logic | |
| return self._execute_with_retry(formatted_messages, **kwargs) | |
| except Exception as e: | |
| print(f"❌ LiteLLM error: {e}") | |
| print(f"Error type: {type(e)}") | |
| if "content" in str(e): | |
| print("This looks like a response parsing error - returning error as ChatMessage") | |
| return self.ChatMessage(f"Error in model response: {str(e)}") | |
| print(f"Debug - Input messages: {messages}") | |
| # Return error as ChatMessage instead of raising to maintain compatibility | |
| return self.ChatMessage(f"Error: {str(e)}") | |
| def _format_messages(self, messages: List[Dict]) -> List[Dict]: | |
| """Format messages for LiteLLM consumption.""" | |
| formatted_messages = [] | |
| for msg in messages: | |
| if isinstance(msg, dict): | |
| if 'content' in msg: | |
| content = msg['content'] | |
| role = msg.get('role', 'user') | |
| # Handle complex content structures | |
| if isinstance(content, list): | |
| text_content = self._extract_text_from_content_list(content) | |
| formatted_messages.append({"role": role, "content": text_content}) | |
| elif isinstance(content, str): | |
| formatted_messages.append({"role": role, "content": content}) | |
| else: | |
| formatted_messages.append({"role": role, "content": str(content)}) | |
| else: | |
| # Fallback for messages without explicit content | |
| formatted_messages.append({"role": "user", "content": str(msg)}) | |
| else: | |
| # Handle string messages | |
| formatted_messages.append({"role": "user", "content": str(msg)}) | |
| # Ensure we have at least one message | |
| if not formatted_messages: | |
| formatted_messages = [{"role": "user", "content": "Hello"}] | |
| return formatted_messages | |
| def _extract_text_from_content_list(self, content_list: List) -> str: | |
| """Extract text content from complex content structures.""" | |
| text_content = "" | |
| for item in content_list: | |
| if isinstance(item, dict): | |
| if 'content' in item and isinstance(item['content'], list): | |
| # Nested content structure | |
| for subitem in item['content']: | |
| if isinstance(subitem, dict) and subitem.get('type') == 'text': | |
| text_content += subitem.get('text', '') + "\n" | |
| elif item.get('type') == 'text': | |
| text_content += item.get('text', '') + "\n" | |
| else: | |
| text_content += str(item) + "\n" | |
| return text_content.strip() | |
| def _execute_with_retry(self, formatted_messages: List[Dict], **kwargs): | |
| """Execute LiteLLM call with retry logic.""" | |
| max_retries = 3 | |
| base_delay = 2 | |
| for attempt in range(max_retries): | |
| try: | |
| # Prepare completion arguments | |
| completion_kwargs = { | |
| "model": self.model_name, | |
| "messages": formatted_messages, | |
| "temperature": kwargs.get('temperature', 0.7), | |
| "max_tokens": kwargs.get('max_tokens', 4000) | |
| } | |
| # Add API base for custom endpoints | |
| if self.api_base: | |
| completion_kwargs["api_base"] = self.api_base | |
| # Make the API call | |
| response = litellm.completion(**completion_kwargs) | |
| # Process and return response | |
| return self._process_response(response) | |
| except Exception as retry_error: | |
| if self._is_retryable_error(retry_error) and attempt < max_retries - 1: | |
| delay = base_delay * (2 ** attempt) | |
| print(f"⏳ Model overloaded (attempt {attempt + 1}/{max_retries}), retrying in {delay}s...") | |
| time.sleep(delay) | |
| continue | |
| else: | |
| # For non-retryable errors or final attempt, raise | |
| raise retry_error | |
| def _is_retryable_error(self, error: Exception) -> bool: | |
| """Check if error is retryable (overload/503 errors).""" | |
| error_str = str(error).lower() | |
| return "overloaded" in error_str or "503" in error_str | |
| def _process_response(self, response) -> 'ChatMessage': | |
| """Process LiteLLM response and return ChatMessage.""" | |
| content = None | |
| if hasattr(response, 'choices') and len(response.choices) > 0: | |
| choice = response.choices[0] | |
| if hasattr(choice, 'message') and hasattr(choice.message, 'content'): | |
| content = choice.message.content | |
| elif hasattr(choice, 'text'): | |
| content = choice.text | |
| else: | |
| print(f"Warning: Unexpected choice structure: {choice}") | |
| content = str(choice) | |
| elif isinstance(response, str): | |
| content = response | |
| else: | |
| print(f"Warning: Unexpected response format: {type(response)}") | |
| content = str(response) | |
| # Create ChatMessage with token usage | |
| if content: | |
| chat_msg = self.ChatMessage(content) | |
| self._extract_token_usage(response, chat_msg) | |
| return chat_msg | |
| else: | |
| return self.ChatMessage("Error: No content in response") | |
| def _extract_token_usage(self, response, chat_msg: 'ChatMessage') -> None: | |
| """Extract token usage from response.""" | |
| if hasattr(response, 'usage'): | |
| usage = response.usage | |
| if hasattr(usage, 'prompt_tokens'): | |
| chat_msg.input_tokens = usage.prompt_tokens | |
| chat_msg.token_usage['prompt_tokens'] = usage.prompt_tokens | |
| if hasattr(usage, 'completion_tokens'): | |
| chat_msg.output_tokens = usage.completion_tokens | |
| chat_msg.token_usage['completion_tokens'] = usage.completion_tokens | |
| if hasattr(usage, 'total_tokens'): | |
| chat_msg.token_usage['total_tokens'] = usage.total_tokens | |
| def generate(self, prompt: str, **kwargs): | |
| """Generate response for a single prompt""" | |
| messages = [{"role": "user", "content": prompt}] | |
| result = self(messages, **kwargs) | |
| # Ensure we always return a ChatMessage object | |
| if not isinstance(result, self.ChatMessage): | |
| return self.ChatMessage(str(result)) | |
| return result | |
| class GeminiProvider: | |
| """Specialized provider for Gemini models.""" | |
| def __init__(self, api_key: str): | |
| self.api_key = api_key | |
| self.model_name = "gemini/gemini-2.0-flash" | |
| def create_model(self) -> LiteLLMModel: | |
| """Create Gemini model instance.""" | |
| return LiteLLMModel(self.model_name, self.api_key) | |
| class KlusterProvider: | |
| """Specialized provider for Kluster.ai models.""" | |
| MODELS = { | |
| "gemma3-27b": "openai/google/gemma-3-27b-it", | |
| "qwen3-235b": "openai/Qwen/Qwen3-235B-A22B-FP8", | |
| "qwen2.5-72b": "openai/Qwen/Qwen2.5-72B-Instruct", | |
| "llama3.1-405b": "openai/meta-llama/Meta-Llama-3.1-405B-Instruct" | |
| } | |
| def __init__(self, api_key: str, model_key: str = "qwen3-235b"): | |
| self.api_key = api_key | |
| self.model_key = model_key | |
| self.api_base = "https://api.kluster.ai/v1" | |
| if model_key not in self.MODELS: | |
| raise ValueError(f"Model '{model_key}' not found. Available: {list(self.MODELS.keys())}") | |
| self.model_name = self.MODELS[model_key] | |
| def create_model(self) -> LiteLLMModel: | |
| """Create Kluster.ai model instance.""" | |
| return LiteLLMModel(self.model_name, self.api_key, self.api_base) |