Spaces:
Running
Running
| """ | |
| LLM Client Abstraction Layer | |
| Supports multiple LLM providers without hardcoding | |
| """ | |
| from abc import ABC, abstractmethod | |
| from typing import Dict, Any, Optional | |
| import os | |
| # Import configuration utility | |
| from src.utils.config import get_gemini_api_key, get_openai_api_key | |
| class BaseLLMClient(ABC): | |
| """Abstract base class for LLM clients""" | |
| def __init__(self, **kwargs): | |
| pass | |
| def generate(self, prompt: str, **kwargs) -> str: | |
| """Generate response from prompt""" | |
| pass | |
| def is_available(self) -> bool: | |
| """Check if LLM service is available""" | |
| pass | |
| class GeminiClient(BaseLLMClient): | |
| """Google Gemini client implementation""" | |
| def __init__(self, api_key: Optional[str] = None, model: str = "gemini-2.0-flash"): | |
| self.api_key = api_key or get_gemini_api_key() | |
| self.model = model | |
| if not self.api_key: | |
| raise ValueError("Gemini API key not provided") | |
| try: | |
| import google.generativeai as genai | |
| genai.configure(api_key=self.api_key) | |
| self.client = genai.GenerativeModel(self.model) | |
| print(f"✅ Gemini client initialized with model: {self.model}") | |
| except ImportError: | |
| raise ImportError("google-generativeai package not installed") | |
| def generate(self, prompt: str, **kwargs) -> str: | |
| """Generate response using Gemini""" | |
| try: | |
| # Set default temperature to 0.1 for consistency | |
| generation_config = { | |
| "temperature": kwargs.get("temperature", 0.1), | |
| "top_p": kwargs.get("top_p", 0.8), | |
| "top_k": kwargs.get("top_k", 40), | |
| "max_output_tokens": kwargs.get("max_output_tokens", 2048), | |
| } | |
| response = self.client.generate_content( | |
| prompt, generation_config=generation_config | |
| ) | |
| return response.text | |
| except Exception as e: | |
| print(f"❌ Gemini generation error: {e}") | |
| raise | |
| def is_available(self) -> bool: | |
| """Check Gemini availability""" | |
| try: | |
| test_response = self.client.generate_content("Hello") | |
| return bool(test_response.text) | |
| except: | |
| return False | |
| class OpenAIClient(BaseLLMClient): | |
| """OpenAI client implementation""" | |
| def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4"): | |
| self.api_key = api_key or get_openai_api_key() | |
| self.model = model | |
| if not self.api_key: | |
| raise ValueError("OpenAI API key not provided") | |
| try: | |
| import openai | |
| self.client = openai.OpenAI(api_key=self.api_key) | |
| print(f"✅ OpenAI client initialized with model: {self.model}") | |
| except ImportError: | |
| raise ImportError("openai package not installed") | |
| def generate(self, prompt: str, **kwargs) -> str: | |
| """Generate response using OpenAI""" | |
| try: | |
| # Set default temperature to 0.1 for consistency | |
| openai_kwargs = { | |
| "temperature": kwargs.get("temperature", 0.1), | |
| "top_p": kwargs.get("top_p", 1.0), | |
| "max_tokens": kwargs.get("max_tokens", 2048), | |
| } | |
| # Remove any Gemini-specific parameters | |
| openai_kwargs.update( | |
| { | |
| k: v | |
| for k, v in kwargs.items() | |
| if k | |
| in [ | |
| "temperature", | |
| "top_p", | |
| "max_tokens", | |
| "frequency_penalty", | |
| "presence_penalty", | |
| ] | |
| } | |
| ) | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": "user", "content": prompt}], | |
| **openai_kwargs, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| print(f"❌ OpenAI generation error: {e}") | |
| raise | |
| def is_available(self) -> bool: | |
| """Check OpenAI availability""" | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": "user", "content": "Hello"}], | |
| max_tokens=5, | |
| ) | |
| return bool(response.choices[0].message.content) | |
| except: | |
| return False | |
| class LLMClientFactory: | |
| """Factory for creating LLM clients""" | |
| SUPPORTED_PROVIDERS = { | |
| "gemini": GeminiClient, | |
| "openai": OpenAIClient, | |
| } | |
| def create_client(self, provider: str = "gemini", **kwargs) -> BaseLLMClient: | |
| """Create LLM client by provider name""" | |
| if provider not in self.SUPPORTED_PROVIDERS: | |
| raise ValueError( | |
| f"Unsupported provider: {provider}. Supported: {list(self.SUPPORTED_PROVIDERS.keys())}" | |
| ) | |
| client_class = self.SUPPORTED_PROVIDERS[provider] | |
| return client_class(**kwargs) | |
| def get_available_providers(cls) -> list: | |
| """Get list of available providers""" | |
| return list(cls.SUPPORTED_PROVIDERS.keys()) | |
| # Usage example | |
| if __name__ == "__main__": | |
| # Test Gemini client | |
| try: | |
| client = LLMClientFactory.create_client("gemini") | |
| if client.is_available(): | |
| response = client.generate("Xin chào, bạn có khỏe không?") | |
| print(f"Response: {response}") | |
| else: | |
| print("Gemini not available") | |
| except Exception as e: | |
| print(f"Error: {e}") | |