Spaces:
Paused
Paused
| import asyncio | |
| import contextvars | |
| from functools import partial | |
| from typing import Any, Coroutine, Dict, List, Literal, Optional, Union | |
| import litellm | |
| from litellm._logging import verbose_logger | |
| from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
| from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig | |
| from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler | |
| from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler | |
| from litellm.llms.together_ai.rerank.handler import TogetherAIRerank | |
| from litellm.rerank_api.rerank_utils import get_optional_rerank_params | |
| from litellm.secret_managers.main import get_secret, get_secret_str | |
| from litellm.types.rerank import OptionalRerankParams, RerankResponse | |
| from litellm.types.router import * | |
| from litellm.utils import ProviderConfigManager, client, exception_type | |
| ####### ENVIRONMENT VARIABLES ################### | |
| # Initialize any necessary instances or variables here | |
| together_rerank = TogetherAIRerank() | |
| bedrock_rerank = BedrockRerankHandler() | |
| base_llm_http_handler = BaseLLMHTTPHandler() | |
| ################################################# | |
| async def arerank( | |
| model: str, | |
| query: str, | |
| documents: List[Union[str, Dict[str, Any]]], | |
| custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None, | |
| top_n: Optional[int] = None, | |
| rank_fields: Optional[List[str]] = None, | |
| return_documents: Optional[bool] = None, | |
| max_chunks_per_doc: Optional[int] = None, | |
| **kwargs, | |
| ) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: | |
| """ | |
| Async: Reranks a list of documents based on their relevance to the query | |
| """ | |
| try: | |
| loop = asyncio.get_event_loop() | |
| kwargs["arerank"] = True | |
| func = partial( | |
| rerank, | |
| model, | |
| query, | |
| documents, | |
| custom_llm_provider, | |
| top_n, | |
| rank_fields, | |
| return_documents, | |
| max_chunks_per_doc, | |
| **kwargs, | |
| ) | |
| ctx = contextvars.copy_context() | |
| func_with_context = partial(ctx.run, func) | |
| init_response = await loop.run_in_executor(None, func_with_context) | |
| if asyncio.iscoroutine(init_response): | |
| response = await init_response | |
| else: | |
| response = init_response | |
| return response | |
| except Exception as e: | |
| raise e | |
| def rerank( # noqa: PLR0915 | |
| model: str, | |
| query: str, | |
| documents: List[Union[str, Dict[str, Any]]], | |
| custom_llm_provider: Optional[ | |
| Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"] | |
| ] = None, | |
| top_n: Optional[int] = None, | |
| rank_fields: Optional[List[str]] = None, | |
| return_documents: Optional[bool] = True, | |
| max_chunks_per_doc: Optional[int] = None, | |
| max_tokens_per_doc: Optional[int] = None, | |
| **kwargs, | |
| ) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: | |
| """ | |
| Reranks a list of documents based on their relevance to the query | |
| """ | |
| headers: Optional[dict] = kwargs.get("headers") # type: ignore | |
| litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore | |
| litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None) | |
| proxy_server_request = kwargs.get("proxy_server_request", None) | |
| model_info = kwargs.get("model_info", None) | |
| metadata = kwargs.get("metadata", {}) | |
| user = kwargs.get("user", None) | |
| client = kwargs.get("client", None) | |
| try: | |
| _is_async = kwargs.pop("arerank", False) is True | |
| optional_params = GenericLiteLLMParams(**kwargs) | |
| # Params that are unique to specific versions of the client for the rerank call | |
| unique_version_params = { | |
| "max_chunks_per_doc": max_chunks_per_doc, | |
| "max_tokens_per_doc": max_tokens_per_doc, | |
| } | |
| present_version_params = [ | |
| k for k, v in unique_version_params.items() if v is not None | |
| ] | |
| ( | |
| model, | |
| _custom_llm_provider, | |
| dynamic_api_key, | |
| dynamic_api_base, | |
| ) = litellm.get_llm_provider( | |
| model=model, | |
| custom_llm_provider=custom_llm_provider, | |
| api_base=optional_params.api_base, | |
| api_key=optional_params.api_key, | |
| ) | |
| rerank_provider_config: BaseRerankConfig = ( | |
| ProviderConfigManager.get_provider_rerank_config( | |
| model=model, | |
| provider=litellm.LlmProviders(_custom_llm_provider), | |
| api_base=optional_params.api_base, | |
| present_version_params=present_version_params, | |
| ) | |
| ) | |
| optional_rerank_params: OptionalRerankParams = get_optional_rerank_params( | |
| rerank_provider_config=rerank_provider_config, | |
| model=model, | |
| drop_params=kwargs.get("drop_params") or litellm.drop_params or False, | |
| query=query, | |
| documents=documents, | |
| custom_llm_provider=_custom_llm_provider, | |
| top_n=top_n, | |
| rank_fields=rank_fields, | |
| return_documents=return_documents, | |
| max_chunks_per_doc=max_chunks_per_doc, | |
| max_tokens_per_doc=max_tokens_per_doc, | |
| non_default_params=kwargs, | |
| ) | |
| if isinstance(optional_params.timeout, str): | |
| optional_params.timeout = float(optional_params.timeout) | |
| model_response = RerankResponse() | |
| litellm_logging_obj.update_environment_variables( | |
| model=model, | |
| user=user, | |
| optional_params=dict(optional_rerank_params), | |
| litellm_params={ | |
| "litellm_call_id": litellm_call_id, | |
| "proxy_server_request": proxy_server_request, | |
| "model_info": model_info, | |
| "metadata": metadata, | |
| "preset_cache_key": None, | |
| "stream_response": {}, | |
| **optional_params.model_dump(exclude_unset=True), | |
| }, | |
| custom_llm_provider=_custom_llm_provider, | |
| ) | |
| # Implement rerank logic here based on the custom_llm_provider | |
| if _custom_llm_provider == "cohere" or _custom_llm_provider == "litellm_proxy": | |
| # Implement Cohere rerank logic | |
| api_key: Optional[str] = ( | |
| dynamic_api_key or optional_params.api_key or litellm.api_key | |
| ) | |
| api_base: Optional[str] = ( | |
| dynamic_api_base | |
| or optional_params.api_base | |
| or litellm.api_base | |
| or get_secret("COHERE_API_BASE") # type: ignore | |
| or "https://api.cohere.com" | |
| ) | |
| if api_base is None: | |
| raise Exception( | |
| "Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var." | |
| ) | |
| response = base_llm_http_handler.rerank( | |
| model=model, | |
| custom_llm_provider=_custom_llm_provider, | |
| provider_config=rerank_provider_config, | |
| optional_rerank_params=optional_rerank_params, | |
| logging_obj=litellm_logging_obj, | |
| timeout=optional_params.timeout, | |
| api_key=api_key, | |
| api_base=api_base, | |
| _is_async=_is_async, | |
| headers=headers or litellm.headers or {}, | |
| client=client, | |
| model_response=model_response, | |
| ) | |
| elif _custom_llm_provider == "azure_ai": | |
| api_base = ( | |
| dynamic_api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there | |
| or optional_params.api_base | |
| or litellm.api_base | |
| or get_secret("AZURE_AI_API_BASE") # type: ignore | |
| ) | |
| response = base_llm_http_handler.rerank( | |
| model=model, | |
| custom_llm_provider=_custom_llm_provider, | |
| optional_rerank_params=optional_rerank_params, | |
| provider_config=rerank_provider_config, | |
| logging_obj=litellm_logging_obj, | |
| timeout=optional_params.timeout, | |
| api_key=dynamic_api_key or optional_params.api_key, | |
| api_base=api_base, | |
| _is_async=_is_async, | |
| headers=headers or litellm.headers or {}, | |
| client=client, | |
| model_response=model_response, | |
| ) | |
| elif _custom_llm_provider == "infinity": | |
| # Implement Infinity rerank logic | |
| api_key = dynamic_api_key or optional_params.api_key or litellm.api_key | |
| api_base = ( | |
| dynamic_api_base | |
| or optional_params.api_base | |
| or litellm.api_base | |
| or get_secret_str("INFINITY_API_BASE") | |
| ) | |
| if api_base is None: | |
| raise Exception( | |
| "Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var." | |
| ) | |
| response = base_llm_http_handler.rerank( | |
| model=model, | |
| custom_llm_provider=_custom_llm_provider, | |
| provider_config=rerank_provider_config, | |
| optional_rerank_params=optional_rerank_params, | |
| logging_obj=litellm_logging_obj, | |
| timeout=optional_params.timeout, | |
| api_key=dynamic_api_key or optional_params.api_key, | |
| api_base=api_base, | |
| _is_async=_is_async, | |
| headers=headers or litellm.headers or {}, | |
| client=client, | |
| model_response=model_response, | |
| ) | |
| elif _custom_llm_provider == "together_ai": | |
| # Implement Together AI rerank logic | |
| api_key = ( | |
| dynamic_api_key | |
| or optional_params.api_key | |
| or litellm.togetherai_api_key | |
| or get_secret("TOGETHERAI_API_KEY") # type: ignore | |
| or litellm.api_key | |
| ) | |
| if api_key is None: | |
| raise ValueError( | |
| "TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment" | |
| ) | |
| response = together_rerank.rerank( | |
| model=model, | |
| query=query, | |
| documents=documents, | |
| top_n=top_n, | |
| rank_fields=rank_fields, | |
| return_documents=return_documents, | |
| max_chunks_per_doc=max_chunks_per_doc, | |
| api_key=api_key, | |
| _is_async=_is_async, | |
| ) | |
| elif _custom_llm_provider == "jina_ai": | |
| if dynamic_api_key is None: | |
| raise ValueError( | |
| "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment" | |
| ) | |
| api_base = ( | |
| dynamic_api_base | |
| or optional_params.api_base | |
| or litellm.api_base | |
| or get_secret("BEDROCK_API_BASE") # type: ignore | |
| ) | |
| response = base_llm_http_handler.rerank( | |
| model=model, | |
| custom_llm_provider=_custom_llm_provider, | |
| optional_rerank_params=optional_rerank_params, | |
| logging_obj=litellm_logging_obj, | |
| provider_config=rerank_provider_config, | |
| timeout=optional_params.timeout, | |
| api_key=dynamic_api_key or optional_params.api_key, | |
| api_base=api_base, | |
| _is_async=_is_async, | |
| headers=headers or litellm.headers or {}, | |
| client=client, | |
| model_response=model_response, | |
| ) | |
| elif _custom_llm_provider == "bedrock": | |
| api_base = ( | |
| dynamic_api_base | |
| or optional_params.api_base | |
| or litellm.api_base | |
| or get_secret("BEDROCK_API_BASE") # type: ignore | |
| ) | |
| response = bedrock_rerank.rerank( | |
| model=model, | |
| query=query, | |
| documents=documents, | |
| top_n=top_n, | |
| rank_fields=rank_fields, | |
| return_documents=return_documents, | |
| max_chunks_per_doc=max_chunks_per_doc, | |
| _is_async=_is_async, | |
| optional_params=optional_params.model_dump(exclude_unset=True), | |
| api_base=api_base, | |
| logging_obj=litellm_logging_obj, | |
| client=client, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported provider: {_custom_llm_provider}") | |
| # Placeholder return | |
| return response | |
| except Exception as e: | |
| verbose_logger.error(f"Error in rerank: {str(e)}") | |
| raise exception_type( | |
| model=model, custom_llm_provider=custom_llm_provider, original_exception=e | |
| ) | |