Spaces:
Sleeping
Sleeping
| from lpm_kernel.api.services.user_llm_config_service import UserLLMConfigService | |
| from lpm_kernel.configs.config import Config | |
| from typing import List, Union | |
| import requests | |
| import numpy as np | |
| from lpm_kernel.configs.logging import get_train_process_logger | |
| logger = get_train_process_logger() | |
| import lpm_kernel.common.strategy.classification as classification | |
| from sentence_transformers import SentenceTransformer | |
| import json | |
| class EmbeddingError(Exception): | |
| """Custom exception class for embedding-related errors""" | |
| def __init__(self, message, original_error=None): | |
| super().__init__(message) | |
| self.original_error = original_error | |
| class LLMClient: | |
| """LLM client utility class""" | |
| def __init__(self): | |
| self.config = Config.from_env() | |
| self.user_llm_config_service = UserLLMConfigService() | |
| self.embedding_max_text_length = int(self.config.get('EMBEDDING_MAX_TEXT_LENGTH', 8000)) | |
| # self.user_llm_config = self.user_llm_config_service.get_available_llm() | |
| # self.chat_api_key = self.user_llm_config.chat_api_key | |
| # self.chat_base_url = self.user_llm_config.chat_endpoint | |
| # self.chat_model = self.user_llm_config.chat_model_name | |
| # self.embedding_api_key = self.user_llm_config.embedding_api_key | |
| # self.embedding_base_url = self.user_llm_config.embedding_endpoint | |
| # self.embedding_model = self.user_llm_config.embedding_model_name | |
| def get_embedding(self, texts: Union[str, List[str]]) -> np.ndarray: | |
| """Calculate text embedding | |
| Args: | |
| texts (str or list): Input text or list of texts | |
| Returns: | |
| numpy.ndarray: Embedding vector of the text | |
| """ | |
| # Ensure texts is in list format | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| # Split long texts into chunks using configured max length | |
| chunked_texts = [] | |
| text_chunk_counts = [] # Keep track of how many chunks each text was split into | |
| for text in texts: | |
| if len(text) > self.embedding_max_text_length: | |
| # Split into chunks of embedding_max_text_length | |
| chunks = [text[i:i + self.embedding_max_text_length] | |
| for i in range(0, len(text), self.embedding_max_text_length)] | |
| chunked_texts.extend(chunks) | |
| text_chunk_counts.append(len(chunks)) | |
| else: | |
| chunked_texts.append(text) | |
| text_chunk_counts.append(1) | |
| user_llm_config = self.user_llm_config_service.get_available_llm() | |
| if not user_llm_config: | |
| raise EmbeddingError("No LLM configuration found") | |
| try: | |
| # Send request to embedding endpoint | |
| embeddings_array = classification.strategy_classification(user_llm_config, chunked_texts) | |
| # If we split any texts, we need to merge their embeddings back | |
| if sum(text_chunk_counts) > len(texts): | |
| final_embeddings = [] | |
| start_idx = 0 | |
| for chunk_count in text_chunk_counts: | |
| if chunk_count > 1: | |
| # Average embeddings for split text | |
| chunk_embeddings = embeddings_array[start_idx:start_idx + chunk_count] | |
| avg_embedding = np.mean(chunk_embeddings, axis=0) | |
| final_embeddings.append(avg_embedding) | |
| else: | |
| final_embeddings.append(embeddings_array[start_idx]) | |
| start_idx += chunk_count | |
| return np.array(final_embeddings) | |
| return embeddings_array | |
| except requests.exceptions.RequestException as e: | |
| # Handle request errors | |
| error_msg = f"Request error getting embeddings: {str(e)}" | |
| logger.error(error_msg) | |
| raise EmbeddingError(error_msg, e) | |
| except json.JSONDecodeError as e: | |
| # Handle JSON parsing errors | |
| error_msg = f"Invalid JSON response from embedding API: {str(e)}" | |
| logger.error(error_msg) | |
| raise EmbeddingError(error_msg, e) | |
| except (KeyError, IndexError, ValueError) as e: | |
| # Handle response structure errors | |
| error_msg = f"Invalid response structure from embedding API: {str(e)}" | |
| logger.error(error_msg) | |
| raise EmbeddingError(error_msg, e) | |
| except Exception as e: | |
| # Fallback for any other errors | |
| error_msg = f"Unexpected error getting embeddings: {str(e)}" | |
| logger.error(error_msg, exc_info=True) | |
| raise EmbeddingError(error_msg, e) | |
| def chat_credentials(self): | |
| """Get LLM authentication information""" | |
| return {"api_key": self.chat_api_key, "base_url": self.chat_base_url} | |