Spaces:
Paused
Paused
| # litellm/proxy/guardrails/guardrail_hooks/pangea.py | |
| import os | |
| from typing import Any, Optional, Protocol | |
| from fastapi import HTTPException | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.caching.dual_cache import DualCache | |
| from litellm.integrations.custom_guardrail import ( | |
| CustomGuardrail, | |
| log_guardrail_information, | |
| ) | |
| from litellm.llms.custom_httpx.http_handler import ( | |
| get_async_httpx_client, | |
| httpxSpecialProvider, | |
| ) | |
| from litellm.proxy._types import UserAPIKeyAuth | |
| from litellm.proxy.common_utils.callback_utils import ( | |
| add_guardrail_to_applied_guardrails_header, | |
| ) | |
| from litellm.types.guardrails import GuardrailEventHooks | |
| from litellm.types.utils import LLMResponseTypes, ModelResponse, TextCompletionResponse | |
| class PangeaGuardrailMissingSecrets(Exception): | |
| """Custom exception for missing Pangea secrets.""" | |
| pass | |
| class _Transformer(Protocol): | |
| def get_messages(self) -> list[dict]: | |
| ... | |
| def update_original_body(self, prompt_messages: list[dict]) -> Any: | |
| ... | |
| class _TextCompletionRequest: | |
| def __init__(self, body): | |
| self.body = body | |
| def get_messages(self) -> list[dict]: | |
| return [{"role": "user", "content": self.body["prompt"]}] | |
| # This mutates the original dict, but we'll still return it anyways | |
| def update_original_body(self, prompt_messages: list[dict]) -> Any: | |
| assert len(prompt_messages) == 1 | |
| self.body["prompt"] = prompt_messages[0]["content"] | |
| return self.body | |
| class _TextCompletionResponse: | |
| def __init__(self, body): | |
| self.body = body | |
| def get_messages(self) -> list[dict]: | |
| messages = [] | |
| for choice in self.body["choices"]: | |
| messages.append({"role": "assistant", "content": choice["text"]}) | |
| return messages | |
| def update_original_body(self, prompt_messages: list[dict]) -> Any: | |
| assert len(prompt_messages) == len(self.body["choices"]) | |
| for choice, prompt_message in zip(self.body["choices"], prompt_messages): | |
| choice["text"] = prompt_message["content"] | |
| return self.body | |
| class _ChatCompletionRequest: | |
| def __init__(self, body): | |
| self.body = body | |
| def get_messages(self) -> list[dict]: | |
| messages = [] | |
| for message in self.body["messages"]: | |
| role = message["role"] | |
| content = message["content"] | |
| if isinstance(content, str): | |
| messages.append({"role": role, "content": content}) | |
| if isinstance(content, list): | |
| for content_part in content: | |
| if content_part["type"] == "text": | |
| messages.append({"role": role, "content": content_part["text"]}) | |
| return messages | |
| def update_original_body(self, prompt_messages: list[dict]) -> Any: | |
| count = 0 | |
| for message in self.body["messages"]: | |
| content = message["content"] | |
| if isinstance(content, str): | |
| message["content"] = prompt_messages[count]["content"] | |
| count += 1 | |
| if isinstance(content, list): | |
| for content_part in content: | |
| if content_part["type"] == "text": | |
| content_part["text"] = prompt_messages[count]["content"] | |
| count += 1 | |
| assert len(prompt_messages) == count | |
| return self.body | |
| class _ChatCompletionResponse: | |
| def __init__(self, body): | |
| self.body = body | |
| def get_messages(self) -> list[dict]: | |
| messages = [] | |
| for choice in self.body["choices"]: | |
| messages.append( | |
| { | |
| "role": choice["message"]["role"], | |
| "content": choice["message"]["content"], | |
| } | |
| ) | |
| return messages | |
| def update_original_body(self, prompt_messages: list[dict]) -> Any: | |
| assert len(prompt_messages) == len(self.body["choices"]) | |
| for choice, prompt_message in zip(self.body["choices"], prompt_messages): | |
| choice["message"]["content"] = prompt_message["content"] | |
| return self.body | |
| def _get_transformer_for_request(body, call_type) -> Optional[_Transformer]: | |
| match call_type: | |
| case "text_completion" | "atext_completion": | |
| return _TextCompletionRequest(body) | |
| case "completion" | "acompletion": | |
| return _ChatCompletionRequest(body) | |
| return None | |
| def _get_transformer_for_response(body) -> Optional[_Transformer]: | |
| match body: | |
| case TextCompletionResponse(): | |
| return _TextCompletionResponse(body) | |
| case ModelResponse(): | |
| return _ChatCompletionResponse(body) | |
| return None | |
| class PangeaHandler(CustomGuardrail): | |
| """ | |
| Pangea AI Guardrail handler to interact with the Pangea AI Guard service. | |
| This class implements the necessary hooks to call the Pangea AI Guard API | |
| for input and output scanning based on the configured recipe. | |
| """ | |
| def __init__( | |
| self, | |
| guardrail_name: str, | |
| pangea_input_recipe: Optional[str] = None, | |
| pangea_output_recipe: Optional[str] = None, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Initializes the PangeaHandler. | |
| Args: | |
| guardrail_name (str): The name of the guardrail instance. | |
| pangea_recipe (str): The Pangea recipe key to use for scanning. | |
| api_key (Optional[str]): The Pangea API key. Reads from PANGEA_API_KEY env var if None. | |
| api_base (Optional[str]): The Pangea API base URL. Reads from PANGEA_API_BASE env var or uses default if None. | |
| **kwargs: Additional arguments passed to the CustomGuardrail base class. | |
| """ | |
| self.async_handler = get_async_httpx_client( | |
| llm_provider=httpxSpecialProvider.GuardrailCallback | |
| ) | |
| self.api_key = api_key or os.environ.get("PANGEA_API_KEY") | |
| if not self.api_key: | |
| raise PangeaGuardrailMissingSecrets( | |
| "Pangea API Key not found. Set PANGEA_API_KEY environment variable or pass it in litellm_params." | |
| ) | |
| # Default Pangea base URL if not provided | |
| self.api_base = ( | |
| api_base | |
| or os.environ.get("PANGEA_API_BASE") | |
| or "https://ai-guard.aws.us.pangea.cloud" | |
| ) | |
| self.pangea_input_recipe = pangea_input_recipe | |
| self.pangea_output_recipe = pangea_output_recipe | |
| self.guardrail_endpoint = f"{self.api_base}/v1/text/guard" | |
| # Pass relevant kwargs to the parent class | |
| super().__init__(guardrail_name=guardrail_name, **kwargs) | |
| verbose_proxy_logger.info( | |
| f"Initialized Pangea Guardrail: name={guardrail_name}, recipe={pangea_input_recipe}, api_base={self.api_base}" | |
| ) | |
| async def _call_pangea_guard(self, payload: dict, hook_name: str) -> dict: | |
| """ | |
| Makes the API call to the Pangea AI Guard endpoint. | |
| The function itself will raise an error in the case that a response | |
| should be blocked, but will return a list of redacted messages that the caller | |
| should act on. | |
| Args: | |
| payload (dict): The request payload. | |
| request_data (dict): Original request data (used for logging/headers). | |
| hook_name (str): Name of the hook calling this function (for logging). | |
| Raises: | |
| HTTPException: If the Pangea API returns a 'blocked: true' response. | |
| Exception: For other API call failures. | |
| Returns: | |
| list[dict]: The original response body | |
| """ | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| try: | |
| verbose_proxy_logger.debug( | |
| f"Pangea Guardrail ({hook_name}): Calling endpoint {self.guardrail_endpoint} with payload: {payload}" | |
| ) | |
| response = await self.async_handler.post( | |
| url=self.guardrail_endpoint, json=payload, headers=headers | |
| ) | |
| response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx) | |
| result = response.json() | |
| verbose_proxy_logger.debug( | |
| f"Pangea Guardrail ({hook_name}): Received response: {result}" | |
| ) | |
| # Check if the request was blocked | |
| if result.get("result", {}).get("blocked") is True: | |
| verbose_proxy_logger.warning( | |
| f"Pangea Guardrail ({hook_name}): Request blocked. Response: {result}" | |
| ) | |
| raise HTTPException( | |
| status_code=400, # Bad Request, indicating violation | |
| detail={ | |
| "error": "Violated Pangea guardrail policy", | |
| "guardrail_name": self.guardrail_name, | |
| "pangea_response": result.get("result"), | |
| }, | |
| ) | |
| else: | |
| verbose_proxy_logger.info( | |
| f"Pangea Guardrail ({hook_name}): Request passed. Response: {result.get('result', {}).get('detectors')}" | |
| ) | |
| return result | |
| except HTTPException as e: | |
| # Re-raise HTTPException if it's the one we raised for blocking | |
| raise e | |
| except Exception as e: | |
| verbose_proxy_logger.error( | |
| f"Pangea Guardrail ({hook_name}): Error calling API: {e}. Response text: {getattr(e, 'response', None) and getattr(e.response, 'text', None)}" # type: ignore | |
| ) | |
| # Decide if you want to block by default on error, or allow through | |
| # Raising an exception here will block the request. | |
| # To allow through on error, you might just log and return. | |
| raise HTTPException( | |
| status_code=500, | |
| detail={ | |
| "error": "Error communicating with Pangea Guardrail", | |
| "guardrail_name": self.guardrail_name, | |
| "exception": str(e), | |
| }, | |
| ) from e | |
| async def async_pre_call_hook( | |
| self, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| cache: DualCache, | |
| data: dict, | |
| call_type: str, | |
| ): | |
| event_type = GuardrailEventHooks.pre_call | |
| if self.should_run_guardrail(data=data, event_type=event_type) is not True: | |
| verbose_proxy_logger.debug( | |
| f"Pangea Guardrail (async_pre_call_hook): Guardrail is disabled {self.guardrail_name}." | |
| ) | |
| return data | |
| transformer = _get_transformer_for_request(data, call_type) | |
| if not transformer: | |
| verbose_proxy_logger.warning( | |
| f"Pangea Guardrail (async_pre_call_hook): Skipping guardrail {self.guardrail_name}" | |
| f" because we cannot determine type of request: call_type '{call_type}'" | |
| ) | |
| return | |
| messages = transformer.get_messages() | |
| if not messages: | |
| verbose_proxy_logger.warning( | |
| f"Pangea Guardrail (async_pre_call_hook): Skipping guardrail {self.guardrail_name}" | |
| " because messages is empty." | |
| ) | |
| return | |
| ai_guard_payload = { | |
| "debug": False, # Or make this configurable if needed | |
| "messages": messages, | |
| } | |
| if self.pangea_input_recipe: | |
| ai_guard_payload["recipe"] = self.pangea_input_recipe | |
| ai_guard_response = await self._call_pangea_guard( | |
| ai_guard_payload, "async_pre_call_hook" | |
| ) | |
| # Add guardrail name to header if passed | |
| add_guardrail_to_applied_guardrails_header( | |
| request_data=data, guardrail_name=self.guardrail_name | |
| ) | |
| prompt_messages = ai_guard_response.get("result", {}).get("prompt_messages", []) | |
| try: | |
| return transformer.update_original_body(prompt_messages) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail={ | |
| "error": "Failed to update original request body", | |
| "guardrail_name": self.guardrail_name, | |
| "exceptions": str(e), | |
| }, | |
| ) from e | |
| async def async_post_call_success_hook( | |
| self, | |
| data: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| # This union isn't actually correct -- it can get other response types depending on the API called | |
| response: LLMResponseTypes, | |
| ): | |
| """ | |
| Guardrail hook run after a successful LLM call (scans output). | |
| Args: | |
| data (dict): The original request data. | |
| user_api_key_dict (UserAPIKeyAuth): User API key details. | |
| response (LLMResponseTypes): The response object from the LLM call. | |
| """ | |
| event_type = GuardrailEventHooks.post_call | |
| if self.should_run_guardrail(data=data, event_type=event_type) is not True: | |
| verbose_proxy_logger.debug( | |
| f"Pangea Guardrail (async_pre_call_hook): Guardrail is disabled {self.guardrail_name}." | |
| ) | |
| return data | |
| transformer = _get_transformer_for_response(response) | |
| if not transformer: | |
| verbose_proxy_logger.warning( | |
| f"Pangea Guardrail (async_post_call_success_hook): Skipping guardrail {self.guardrail_name}" | |
| " because we cannot determine type of request" | |
| ) | |
| return | |
| messages = transformer.get_messages() | |
| verbose_proxy_logger.warning(f"GOT MESSAGES: {messages}") | |
| ai_guard_payload = { | |
| "debug": False, # Or make this configurable if needed | |
| "messages": messages, | |
| } | |
| if self.pangea_output_recipe: | |
| ai_guard_payload["recipe"] = self.pangea_output_recipe | |
| ai_guard_response = await self._call_pangea_guard( | |
| ai_guard_payload, "post_call_success_hook" | |
| ) | |
| prompt_messages = ai_guard_response.get("result", {}).get("prompt_messages", []) | |
| try: | |
| return transformer.update_original_body(prompt_messages) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail={ | |
| "error": "Failed to update original response body", | |
| "guardrail_name": self.guardrail_name, | |
| "exceptions": str(e), | |
| }, | |
| ) from e | |