Spaces:
Paused
Paused
| import json | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast | |
| import httpx | |
| import litellm | |
| from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging | |
| from litellm.llms.custom_httpx.http_handler import ( | |
| AsyncHTTPHandler, | |
| HTTPHandler, | |
| _get_httpx_client, | |
| get_async_httpx_client, | |
| ) | |
| from litellm.types.llms.bedrock import BedrockPreparedRequest | |
| from litellm.types.rerank import RerankRequest | |
| from litellm.types.utils import RerankResponse | |
| from ..base_aws_llm import BaseAWSLLM | |
| from ..common_utils import BedrockError | |
| from .transformation import BedrockRerankConfig | |
| if TYPE_CHECKING: | |
| from botocore.awsrequest import AWSPreparedRequest | |
| else: | |
| AWSPreparedRequest = Any | |
| class BedrockRerankHandler(BaseAWSLLM): | |
| async def arerank( | |
| self, | |
| prepared_request: BedrockPreparedRequest, | |
| client: Optional[AsyncHTTPHandler] = None, | |
| ): | |
| if client is None: | |
| client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK) | |
| try: | |
| response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore | |
| response.raise_for_status() | |
| except httpx.HTTPStatusError as err: | |
| error_code = err.response.status_code | |
| raise BedrockError(status_code=error_code, message=err.response.text) | |
| except httpx.TimeoutException: | |
| raise BedrockError(status_code=408, message="Timeout error occurred.") | |
| return BedrockRerankConfig()._transform_response(response.json()) | |
| def rerank( | |
| self, | |
| model: str, | |
| query: str, | |
| documents: List[Union[str, Dict[str, Any]]], | |
| optional_params: dict, | |
| logging_obj: LitellmLogging, | |
| top_n: Optional[int] = None, | |
| rank_fields: Optional[List[str]] = None, | |
| return_documents: Optional[bool] = True, | |
| max_chunks_per_doc: Optional[int] = None, | |
| _is_async: Optional[bool] = False, | |
| api_base: Optional[str] = None, | |
| extra_headers: Optional[dict] = None, | |
| client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, | |
| ) -> RerankResponse: | |
| request_data = RerankRequest( | |
| model=model, | |
| query=query, | |
| documents=documents, | |
| top_n=top_n, | |
| rank_fields=rank_fields, | |
| return_documents=return_documents, | |
| ) | |
| data = BedrockRerankConfig()._transform_request(request_data) | |
| prepared_request = self._prepare_request( | |
| model=model, | |
| optional_params=optional_params, | |
| api_base=api_base, | |
| extra_headers=extra_headers, | |
| data=cast(dict, data), | |
| ) | |
| logging_obj.pre_call( | |
| input=data, | |
| api_key="", | |
| additional_args={ | |
| "complete_input_dict": data, | |
| "api_base": prepared_request["endpoint_url"], | |
| "headers": prepared_request["prepped"].headers, | |
| }, | |
| ) | |
| if _is_async: | |
| return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore | |
| if client is None or not isinstance(client, HTTPHandler): | |
| client = _get_httpx_client() | |
| try: | |
| response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore | |
| response.raise_for_status() | |
| except httpx.HTTPStatusError as err: | |
| error_code = err.response.status_code | |
| raise BedrockError(status_code=error_code, message=err.response.text) | |
| except httpx.TimeoutException: | |
| raise BedrockError(status_code=408, message="Timeout error occurred.") | |
| logging_obj.post_call( | |
| original_response=response.text, | |
| api_key="", | |
| ) | |
| response_json = response.json() | |
| return BedrockRerankConfig()._transform_response(response_json) | |
| def _prepare_request( | |
| self, | |
| model: str, | |
| api_base: Optional[str], | |
| extra_headers: Optional[dict], | |
| data: dict, | |
| optional_params: dict, | |
| ) -> BedrockPreparedRequest: | |
| try: | |
| from botocore.auth import SigV4Auth | |
| from botocore.awsrequest import AWSRequest | |
| except ImportError: | |
| raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") | |
| boto3_credentials_info = self._get_boto_credentials_from_optional_params( | |
| optional_params, model | |
| ) | |
| ### SET RUNTIME ENDPOINT ### | |
| _, proxy_endpoint_url = self.get_runtime_endpoint( | |
| api_base=api_base, | |
| aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint, | |
| aws_region_name=boto3_credentials_info.aws_region_name, | |
| ) | |
| proxy_endpoint_url = proxy_endpoint_url.replace( | |
| "bedrock-runtime", "bedrock-agent-runtime" | |
| ) | |
| proxy_endpoint_url = f"{proxy_endpoint_url}/rerank" | |
| sigv4 = SigV4Auth( | |
| boto3_credentials_info.credentials, | |
| "bedrock", | |
| boto3_credentials_info.aws_region_name, | |
| ) | |
| # Make POST Request | |
| body = json.dumps(data).encode("utf-8") | |
| headers = {"Content-Type": "application/json"} | |
| if extra_headers is not None: | |
| headers = {"Content-Type": "application/json", **extra_headers} | |
| request = AWSRequest( | |
| method="POST", url=proxy_endpoint_url, data=body, headers=headers | |
| ) | |
| sigv4.add_auth(request) | |
| if ( | |
| extra_headers is not None and "Authorization" in extra_headers | |
| ): # prevent sigv4 from overwriting the auth header | |
| request.headers["Authorization"] = extra_headers["Authorization"] | |
| prepped = request.prepare() | |
| return BedrockPreparedRequest( | |
| endpoint_url=proxy_endpoint_url, | |
| prepped=prepped, | |
| body=body, | |
| data=data, | |
| ) | |