Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Model management system for GAIA agent. | |
| Handles model initialization, fallback chains, and lifecycle management. | |
| """ | |
| import os | |
| import time | |
| import random | |
| from typing import Optional, List, Dict, Any, Union | |
| from abc import ABC, abstractmethod | |
| from enum import Enum | |
| from ..config.settings import Config, ModelType, config | |
| from ..utils.exceptions import ( | |
| ModelError, ModelNotAvailableError, ModelAuthenticationError, | |
| ModelOverloadedError, create_error | |
| ) | |
| class ModelStatus(Enum): | |
| """Model status states.""" | |
| AVAILABLE = "available" | |
| UNAVAILABLE = "unavailable" | |
| OVERLOADED = "overloaded" | |
| AUTHENTICATING = "authenticating" | |
| ERROR = "error" | |
| class ModelProvider(ABC): | |
| """Abstract base class for model providers.""" | |
| def __init__(self, name: str, model_type: ModelType): | |
| self.name = name | |
| self.model_type = model_type | |
| self.status = ModelStatus.UNAVAILABLE | |
| self.last_error: Optional[str] = None | |
| self.retry_count = 0 | |
| self.last_used = None | |
| def initialize(self) -> bool: | |
| """Initialize the model provider. Returns True if successful.""" | |
| pass | |
| def is_available(self) -> bool: | |
| """Check if the model is available for use.""" | |
| pass | |
| def create_model(self, **kwargs): | |
| """Create model instance.""" | |
| pass | |
| def reset_error_state(self) -> None: | |
| """Reset error state for retry attempts.""" | |
| self.retry_count = 0 | |
| self.last_error = None | |
| self.status = ModelStatus.UNAVAILABLE | |
| def record_usage(self) -> None: | |
| """Record model usage timestamp.""" | |
| self.last_used = time.time() | |
| def handle_error(self, error: Exception) -> None: | |
| """Handle and categorize model errors.""" | |
| error_str = str(error).lower() | |
| if "overloaded" in error_str or "503" in error_str: | |
| self.status = ModelStatus.OVERLOADED | |
| self.last_error = "Model overloaded" | |
| elif "authentication" in error_str or "401" in error_str or "403" in error_str: | |
| self.status = ModelStatus.ERROR | |
| self.last_error = "Authentication failed" | |
| else: | |
| self.status = ModelStatus.ERROR | |
| self.last_error = str(error) | |
| self.retry_count += 1 | |
| class LiteLLMProvider(ModelProvider): | |
| """Provider for LiteLLM-based models (Gemini, Kluster.ai).""" | |
| def __init__(self, model_name: str, api_key: str, api_base: Optional[str] = None): | |
| self.model_name = model_name | |
| self.api_key = api_key | |
| self.api_base = api_base | |
| self._model_instance = None | |
| model_type = self._determine_model_type(model_name) | |
| super().__init__(model_name, model_type) | |
| def _determine_model_type(self, model_name: str) -> ModelType: | |
| """Determine model type from name.""" | |
| if "gemini" in model_name.lower(): | |
| return ModelType.GEMINI | |
| elif hasattr(self, 'api_base') and self.api_base and "kluster" in str(self.api_base).lower(): | |
| return ModelType.KLUSTER | |
| else: | |
| return ModelType.QWEN | |
| def initialize(self) -> bool: | |
| """Initialize LiteLLM model.""" | |
| try: | |
| # Import the class from the same module | |
| from .providers import LiteLLMModel | |
| self.status = ModelStatus.AUTHENTICATING | |
| # Configure environment | |
| if self.model_type == ModelType.GEMINI: | |
| os.environ["GEMINI_API_KEY"] = self.api_key | |
| elif self.api_base: | |
| os.environ["OPENAI_API_KEY"] = self.api_key | |
| os.environ["OPENAI_API_BASE"] = self.api_base | |
| # Create model instance | |
| self._model_instance = LiteLLMModel( | |
| model_name=self.model_name, | |
| api_key=self.api_key, | |
| api_base=self.api_base | |
| ) | |
| self.status = ModelStatus.AVAILABLE | |
| return True | |
| except Exception as e: | |
| self.handle_error(e) | |
| return False | |
| def is_available(self) -> bool: | |
| """Check if model is available.""" | |
| return self.status == ModelStatus.AVAILABLE and self._model_instance is not None | |
| def create_model(self, **kwargs): | |
| """Create model instance.""" | |
| if not self.is_available(): | |
| raise ModelNotAvailableError(f"Model {self.name} is not available") | |
| self.record_usage() | |
| return self._model_instance | |
| class HuggingFaceProvider(ModelProvider): | |
| """Provider for HuggingFace models.""" | |
| def __init__(self, model_name: str, api_key: str): | |
| super().__init__(model_name, ModelType.QWEN) | |
| self.model_name = model_name | |
| self.api_key = api_key | |
| self._model_instance = None | |
| def initialize(self) -> bool: | |
| """Initialize HuggingFace model.""" | |
| try: | |
| from smolagents import InferenceClientModel | |
| self.status = ModelStatus.AUTHENTICATING | |
| self._model_instance = InferenceClientModel( | |
| model_id=self.model_name, | |
| token=self.api_key | |
| ) | |
| self.status = ModelStatus.AVAILABLE | |
| return True | |
| except Exception as e: | |
| self.handle_error(e) | |
| return False | |
| def is_available(self) -> bool: | |
| """Check if model is available.""" | |
| return self.status == ModelStatus.AVAILABLE and self._model_instance is not None | |
| def create_model(self, **kwargs): | |
| """Create model instance.""" | |
| if not self.is_available(): | |
| raise ModelNotAvailableError(f"Model {self.name} is not available") | |
| self.record_usage() | |
| return self._model_instance | |
| class ModelManager: | |
| """Manages model providers and fallback chains.""" | |
| def __init__(self, config_instance: Optional[Config] = None): | |
| self.config = config_instance or config | |
| self.providers: Dict[str, ModelProvider] = {} | |
| self.fallback_chain: List[str] = [] | |
| self.current_provider: Optional[str] = None | |
| self._initialize_providers() | |
| def _initialize_providers(self) -> None: | |
| """Initialize all available model providers.""" | |
| # Kluster.ai models | |
| if self.config.has_api_key("kluster"): | |
| kluster_key = self.config.get_api_key("kluster") | |
| for model_key, model_name in self.config.model.KLUSTER_MODELS.items(): | |
| provider_name = f"kluster_{model_key}" | |
| provider = LiteLLMProvider( | |
| model_name=model_name, | |
| api_key=kluster_key, | |
| api_base=self.config.model.KLUSTER_API_BASE | |
| ) | |
| self.providers[provider_name] = provider | |
| # Gemini models | |
| if self.config.has_api_key("gemini"): | |
| gemini_key = self.config.get_api_key("gemini") | |
| provider = LiteLLMProvider( | |
| model_name=self.config.model.GEMINI_MODEL, | |
| api_key=gemini_key | |
| ) | |
| self.providers["gemini"] = provider | |
| # HuggingFace models | |
| if self.config.has_api_key("huggingface"): | |
| hf_key = self.config.get_api_key("huggingface") | |
| provider = HuggingFaceProvider( | |
| model_name=self.config.model.QWEN_MODEL, | |
| api_key=hf_key | |
| ) | |
| self.providers["qwen"] = provider | |
| # Set up fallback chain | |
| self._setup_fallback_chain() | |
| def _setup_fallback_chain(self) -> None: | |
| """Set up model fallback chain based on availability and preference.""" | |
| # Priority order: Kluster.ai (highest tier) -> Gemini -> Qwen | |
| priority_providers = [] | |
| # Add Kluster.ai models (prefer qwen3-235b) | |
| if "kluster_qwen3-235b" in self.providers: | |
| priority_providers.append("kluster_qwen3-235b") | |
| elif "kluster_gemma3-27b" in self.providers: | |
| priority_providers.append("kluster_gemma3-27b") | |
| # Add other available providers | |
| if "gemini" in self.providers: | |
| priority_providers.append("gemini") | |
| if "qwen" in self.providers: | |
| priority_providers.append("qwen") | |
| self.fallback_chain = priority_providers | |
| if not self.fallback_chain: | |
| raise ModelNotAvailableError("No model providers available") | |
| def initialize_all(self) -> Dict[str, bool]: | |
| """Initialize all model providers.""" | |
| results = {} | |
| for name, provider in self.providers.items(): | |
| try: | |
| success = provider.initialize() | |
| results[name] = success | |
| if success and self.current_provider is None: | |
| self.current_provider = name | |
| except Exception as e: | |
| results[name] = False | |
| provider.handle_error(e) | |
| return results | |
| def get_current_model(self, **kwargs): | |
| """Get current active model.""" | |
| if self.current_provider is None: | |
| self._select_best_provider() | |
| if self.current_provider is None: | |
| raise ModelNotAvailableError("No models available") | |
| provider = self.providers[self.current_provider] | |
| try: | |
| return provider.create_model(**kwargs) | |
| except Exception as e: | |
| provider.handle_error(e) | |
| # Try to switch to fallback | |
| if self._switch_to_fallback(): | |
| return self.get_current_model(**kwargs) | |
| else: | |
| raise ModelError(f"All models failed: {str(e)}") | |
| def _select_best_provider(self) -> None: | |
| """Select the best available provider from fallback chain.""" | |
| for provider_name in self.fallback_chain: | |
| provider = self.providers.get(provider_name) | |
| if provider and provider.is_available(): | |
| self.current_provider = provider_name | |
| return | |
| elif provider and provider.status == ModelStatus.UNAVAILABLE: | |
| # Try to initialize | |
| if provider.initialize(): | |
| self.current_provider = provider_name | |
| return | |
| self.current_provider = None | |
| def _switch_to_fallback(self) -> bool: | |
| """Switch to next available model in fallback chain.""" | |
| if self.current_provider is None: | |
| return False | |
| try: | |
| current_index = self.fallback_chain.index(self.current_provider) | |
| # Try next providers in chain | |
| for i in range(current_index + 1, len(self.fallback_chain)): | |
| provider_name = self.fallback_chain[i] | |
| provider = self.providers[provider_name] | |
| if provider.is_available() or provider.initialize(): | |
| self.current_provider = provider_name | |
| return True | |
| except ValueError: | |
| pass | |
| # No fallback available | |
| self.current_provider = None | |
| return False | |
| def retry_current_model(self, max_retries: int = 3) -> bool: | |
| """Retry current model with exponential backoff.""" | |
| if self.current_provider is None: | |
| return False | |
| provider = self.providers[self.current_provider] | |
| for attempt in range(max_retries): | |
| if provider.status == ModelStatus.OVERLOADED: | |
| wait_time = (2 ** attempt) + random.random() | |
| time.sleep(wait_time) | |
| # Reset error state and try to reinitialize | |
| provider.reset_error_state() | |
| if provider.initialize(): | |
| return True | |
| return False | |
| def get_model_status(self) -> Dict[str, Dict[str, Any]]: | |
| """Get status of all model providers.""" | |
| status = {} | |
| for name, provider in self.providers.items(): | |
| status[name] = { | |
| "status": provider.status.value, | |
| "model_type": provider.model_type.value, | |
| "last_error": provider.last_error, | |
| "retry_count": provider.retry_count, | |
| "last_used": provider.last_used, | |
| "is_current": name == self.current_provider | |
| } | |
| return status | |
| def switch_to_provider(self, provider_name: str) -> bool: | |
| """Manually switch to specific provider.""" | |
| if provider_name not in self.providers: | |
| raise ModelNotAvailableError(f"Provider {provider_name} not found") | |
| provider = self.providers[provider_name] | |
| if provider.is_available() or provider.initialize(): | |
| self.current_provider = provider_name | |
| return True | |
| return False | |
| def get_available_providers(self) -> List[str]: | |
| """Get list of available providers.""" | |
| available = [] | |
| for name, provider in self.providers.items(): | |
| if provider.is_available(): | |
| available.append(name) | |
| return available | |
| def reset_all_providers(self) -> None: | |
| """Reset all providers to allow retry.""" | |
| for provider in self.providers.values(): | |
| provider.reset_error_state() | |
| self.current_provider = None | |
| self._select_best_provider() | |
| # Monkey patch for smolagents compatibility | |
| def monkey_patch_smolagents(): | |
| """Apply compatibility patches for smolagents.""" | |
| try: | |
| import smolagents.monitoring | |
| from smolagents.monitoring import TokenUsage | |
| # Store original update_metrics function | |
| original_update_metrics = smolagents.monitoring.Monitor.update_metrics | |
| def patched_update_metrics(self, step_log): | |
| """Patched version that handles dict token_usage""" | |
| try: | |
| # If token_usage is a dict, convert it to TokenUsage object | |
| if hasattr(step_log, 'token_usage') and isinstance(step_log.token_usage, dict): | |
| token_dict = step_log.token_usage | |
| # Create TokenUsage object from dict | |
| step_log.token_usage = TokenUsage( | |
| input_tokens=token_dict.get('prompt_tokens', 0), | |
| output_tokens=token_dict.get('completion_tokens', 0) | |
| ) | |
| # Call original function | |
| return original_update_metrics(self, step_log) | |
| except Exception as e: | |
| # If patching fails, try to handle gracefully | |
| print(f"Token usage patch warning: {e}") | |
| return original_update_metrics(self, step_log) | |
| # Apply the patch | |
| smolagents.monitoring.Monitor.update_metrics = patched_update_metrics | |
| print("✅ Applied smolagents token usage compatibility patch") | |
| except ImportError: | |
| print("⚠️ smolagents not available, skipping compatibility patch") | |
| except Exception as e: | |
| print(f"⚠️ Failed to apply smolagents patch: {e}") | |
| # Apply monkey patch on import | |
| monkey_patch_smolagents() |