Spaces:
Paused
Paused
| import hashlib | |
| import json | |
| import secrets | |
| from datetime import datetime | |
| from datetime import datetime as dt | |
| from datetime import timezone | |
| from typing import Any, List, Optional, cast | |
| from pydantic import BaseModel | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs | |
| from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload | |
| from litellm.proxy.utils import PrismaClient, hash_token | |
| from litellm.types.utils import ( | |
| StandardLoggingMCPToolCall, | |
| StandardLoggingModelInformation, | |
| StandardLoggingPayload, | |
| ) | |
| from litellm.utils import get_end_user_id_for_cost_tracking | |
| def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool: | |
| if _master_key is None: | |
| return False | |
| ## string comparison | |
| is_master_key = secrets.compare_digest(api_key, _master_key) | |
| if is_master_key: | |
| return True | |
| ## hash comparison | |
| is_master_key = secrets.compare_digest(api_key, hash_token(_master_key)) | |
| if is_master_key: | |
| return True | |
| return False | |
| def _get_spend_logs_metadata( | |
| metadata: Optional[dict], | |
| applied_guardrails: Optional[List[str]] = None, | |
| batch_models: Optional[List[str]] = None, | |
| mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None, | |
| usage_object: Optional[dict] = None, | |
| model_map_information: Optional[StandardLoggingModelInformation] = None, | |
| ) -> SpendLogsMetadata: | |
| if metadata is None: | |
| return SpendLogsMetadata( | |
| user_api_key=None, | |
| user_api_key_alias=None, | |
| user_api_key_team_id=None, | |
| user_api_key_org_id=None, | |
| user_api_key_user_id=None, | |
| user_api_key_team_alias=None, | |
| spend_logs_metadata=None, | |
| requester_ip_address=None, | |
| additional_usage_values=None, | |
| applied_guardrails=None, | |
| status=None or "success", | |
| error_information=None, | |
| proxy_server_request=None, | |
| batch_models=None, | |
| mcp_tool_call_metadata=None, | |
| model_map_information=None, | |
| usage_object=None, | |
| ) | |
| verbose_proxy_logger.debug( | |
| "getting payload for SpendLogs, available keys in metadata: " | |
| + str(list(metadata.keys())) | |
| ) | |
| # Filter the metadata dictionary to include only the specified keys | |
| clean_metadata = SpendLogsMetadata( | |
| **{ # type: ignore | |
| key: metadata[key] | |
| for key in SpendLogsMetadata.__annotations__.keys() | |
| if key in metadata | |
| } | |
| ) | |
| clean_metadata["applied_guardrails"] = applied_guardrails | |
| clean_metadata["batch_models"] = batch_models | |
| clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata | |
| clean_metadata["usage_object"] = usage_object | |
| clean_metadata["model_map_information"] = model_map_information | |
| return clean_metadata | |
| def generate_hash_from_response(response_obj: Any) -> str: | |
| """ | |
| Generate a stable hash from a response object. | |
| Args: | |
| response_obj: The response object to hash (can be dict, list, etc.) | |
| Returns: | |
| A hex string representation of the MD5 hash | |
| """ | |
| try: | |
| # Create a stable JSON string of the entire response object | |
| # Sort keys to ensure consistent ordering | |
| json_str = json.dumps(response_obj, sort_keys=True) | |
| # Generate a hash of the response object | |
| unique_hash = hashlib.md5(json_str.encode()).hexdigest() | |
| return unique_hash | |
| except Exception: | |
| # Return a fallback hash if serialization fails | |
| return hashlib.md5(str(response_obj).encode()).hexdigest() | |
| def get_spend_logs_id( | |
| call_type: str, response_obj: dict, kwargs: dict | |
| ) -> Optional[str]: | |
| if call_type == "aretrieve_batch" or call_type == "acreate_file": | |
| # Generate a hash from the response object | |
| id: Optional[str] = generate_hash_from_response(response_obj) | |
| else: | |
| id = cast(Optional[str], response_obj.get("id")) or cast( | |
| Optional[str], kwargs.get("litellm_call_id") | |
| ) | |
| return id | |
| def get_logging_payload( # noqa: PLR0915 | |
| kwargs, response_obj, start_time, end_time | |
| ) -> SpendLogsPayload: | |
| from litellm.proxy.proxy_server import general_settings, master_key | |
| if kwargs is None: | |
| kwargs = {} | |
| if response_obj is None or ( | |
| not isinstance(response_obj, BaseModel) and not isinstance(response_obj, dict) | |
| ): | |
| response_obj = {} | |
| # standardize this function to be used across, s3, dynamoDB, langfuse logging | |
| litellm_params = kwargs.get("litellm_params", {}) | |
| metadata = get_litellm_metadata_from_kwargs(kwargs) | |
| completion_start_time = kwargs.get("completion_start_time", end_time) | |
| call_type = kwargs.get("call_type") | |
| cache_hit = kwargs.get("cache_hit", False) | |
| usage = cast(dict, response_obj).get("usage", None) or {} | |
| if isinstance(usage, litellm.Usage): | |
| usage = dict(usage) | |
| if isinstance(response_obj, dict): | |
| response_obj_dict = response_obj | |
| elif isinstance(response_obj, BaseModel): | |
| response_obj_dict = response_obj.model_dump() | |
| else: | |
| response_obj_dict = {} | |
| id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs) | |
| standard_logging_payload = cast( | |
| Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None) | |
| ) | |
| end_user_id = get_end_user_id_for_cost_tracking(litellm_params) | |
| api_key = metadata.get("user_api_key", "") | |
| standard_logging_prompt_tokens: int = 0 | |
| standard_logging_completion_tokens: int = 0 | |
| standard_logging_total_tokens: int = 0 | |
| if standard_logging_payload is not None: | |
| standard_logging_prompt_tokens = standard_logging_payload.get( | |
| "prompt_tokens", 0 | |
| ) | |
| standard_logging_completion_tokens = standard_logging_payload.get( | |
| "completion_tokens", 0 | |
| ) | |
| standard_logging_total_tokens = standard_logging_payload.get("total_tokens", 0) | |
| if api_key is not None and isinstance(api_key, str): | |
| if api_key.startswith("sk-"): | |
| # hash the api_key | |
| api_key = hash_token(api_key) | |
| if ( | |
| _is_master_key(api_key=api_key, _master_key=master_key) | |
| and general_settings.get("disable_adding_master_key_hash_to_db") is True | |
| ): | |
| api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db | |
| if ( | |
| standard_logging_payload is not None | |
| ): # [TODO] migrate completely to sl payload. currently missing pass-through endpoint data | |
| api_key = ( | |
| api_key | |
| or standard_logging_payload["metadata"].get("user_api_key_hash") | |
| or "" | |
| ) | |
| end_user_id = end_user_id or standard_logging_payload["metadata"].get( | |
| "user_api_key_end_user_id" | |
| ) | |
| else: | |
| api_key = "" | |
| request_tags = ( | |
| json.dumps(metadata.get("tags", [])) | |
| if isinstance(metadata.get("tags", []), list) | |
| else "[]" | |
| ) | |
| if ( | |
| _is_master_key(api_key=api_key, _master_key=master_key) | |
| and general_settings.get("disable_adding_master_key_hash_to_db") is True | |
| ): | |
| api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db | |
| _model_id = metadata.get("model_info", {}).get("id", "") | |
| _model_group = metadata.get("model_group", "") | |
| # clean up litellm metadata | |
| clean_metadata = _get_spend_logs_metadata( | |
| metadata, | |
| applied_guardrails=( | |
| standard_logging_payload["metadata"].get("applied_guardrails", None) | |
| if standard_logging_payload is not None | |
| else None | |
| ), | |
| batch_models=( | |
| standard_logging_payload.get("hidden_params", {}).get("batch_models", None) | |
| if standard_logging_payload is not None | |
| else None | |
| ), | |
| mcp_tool_call_metadata=( | |
| standard_logging_payload["metadata"].get("mcp_tool_call_metadata", None) | |
| if standard_logging_payload is not None | |
| else None | |
| ), | |
| usage_object=( | |
| standard_logging_payload["metadata"].get("usage_object", None) | |
| if standard_logging_payload is not None | |
| else None | |
| ), | |
| model_map_information=( | |
| standard_logging_payload["model_map_information"] | |
| if standard_logging_payload is not None | |
| else None | |
| ), | |
| ) | |
| special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"] | |
| additional_usage_values = {} | |
| for k, v in usage.items(): | |
| if k not in special_usage_fields: | |
| if isinstance(v, BaseModel): | |
| v = v.model_dump() | |
| additional_usage_values.update({k: v}) | |
| clean_metadata["additional_usage_values"] = additional_usage_values | |
| if litellm.cache is not None: | |
| cache_key = litellm.cache.get_cache_key(**kwargs) | |
| else: | |
| cache_key = "Cache OFF" | |
| if cache_hit is True: | |
| import time | |
| id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id | |
| try: | |
| payload: SpendLogsPayload = SpendLogsPayload( | |
| request_id=str(id), | |
| call_type=call_type or "", | |
| api_key=str(api_key), | |
| cache_hit=str(cache_hit), | |
| startTime=_ensure_datetime_utc(start_time), | |
| endTime=_ensure_datetime_utc(end_time), | |
| completionStartTime=_ensure_datetime_utc(completion_start_time), | |
| model=kwargs.get("model", "") or "", | |
| user=metadata.get("user_api_key_user_id", "") or "", | |
| team_id=metadata.get("user_api_key_team_id", "") or "", | |
| metadata=json.dumps(clean_metadata), | |
| cache_key=cache_key, | |
| spend=kwargs.get("response_cost", 0), | |
| total_tokens=usage.get("total_tokens", standard_logging_total_tokens), | |
| prompt_tokens=usage.get("prompt_tokens", standard_logging_prompt_tokens), | |
| completion_tokens=usage.get( | |
| "completion_tokens", standard_logging_completion_tokens | |
| ), | |
| request_tags=request_tags, | |
| end_user=end_user_id or "", | |
| api_base=litellm_params.get("api_base", ""), | |
| model_group=_model_group, | |
| model_id=_model_id, | |
| requester_ip_address=clean_metadata.get("requester_ip_address", None), | |
| custom_llm_provider=kwargs.get("custom_llm_provider", ""), | |
| messages=_get_messages_for_spend_logs_payload( | |
| standard_logging_payload=standard_logging_payload, metadata=metadata | |
| ), | |
| response=_get_response_for_spend_logs_payload(standard_logging_payload), | |
| proxy_server_request=_get_proxy_server_request_for_spend_logs_payload( | |
| metadata=metadata, litellm_params=litellm_params | |
| ), | |
| session_id=_get_session_id_for_spend_log( | |
| kwargs=kwargs, | |
| standard_logging_payload=standard_logging_payload, | |
| ), | |
| ) | |
| verbose_proxy_logger.debug( | |
| "SpendTable: created payload - payload: %s\n\n", | |
| json.dumps(payload, indent=4, default=str), | |
| ) | |
| return payload | |
| except Exception as e: | |
| verbose_proxy_logger.exception( | |
| "Error creating spendlogs object - {}".format(str(e)) | |
| ) | |
| raise e | |
| def _get_session_id_for_spend_log( | |
| kwargs: dict, | |
| standard_logging_payload: Optional[StandardLoggingPayload], | |
| ) -> str: | |
| """ | |
| Get the session id for the spend log. | |
| This ensures each spend log is associated with a unique session id. | |
| """ | |
| import uuid | |
| if ( | |
| standard_logging_payload is not None | |
| and standard_logging_payload.get("trace_id") is not None | |
| ): | |
| return str(standard_logging_payload.get("trace_id")) | |
| # Users can dynamically set the trace_id for each request by passing `litellm_trace_id` in kwargs | |
| if kwargs.get("litellm_trace_id") is not None: | |
| return str(kwargs.get("litellm_trace_id")) | |
| # Ensure we always have a session id, if none is provided | |
| return str(uuid.uuid4()) | |
| def _ensure_datetime_utc(timestamp: datetime) -> datetime: | |
| """Helper to ensure datetime is in UTC""" | |
| timestamp = timestamp.astimezone(timezone.utc) | |
| return timestamp | |
| async def get_spend_by_team_and_customer( | |
| start_date: dt, | |
| end_date: dt, | |
| team_id: str, | |
| customer_id: str, | |
| prisma_client: PrismaClient, | |
| ): | |
| sql_query = """ | |
| WITH SpendByModelApiKey AS ( | |
| SELECT | |
| date_trunc('day', sl."startTime") AS group_by_day, | |
| COALESCE(tt.team_alias, 'Unassigned Team') AS team_name, | |
| sl.end_user AS customer, | |
| sl.model, | |
| sl.api_key, | |
| SUM(sl.spend) AS model_api_spend, | |
| SUM(sl.total_tokens) AS model_api_tokens | |
| FROM | |
| "LiteLLM_SpendLogs" sl | |
| LEFT JOIN | |
| "LiteLLM_TeamTable" tt | |
| ON | |
| sl.team_id = tt.team_id | |
| WHERE | |
| sl."startTime" BETWEEN $1::date AND $2::date | |
| AND sl.team_id = $3 | |
| AND sl.end_user = $4 | |
| GROUP BY | |
| date_trunc('day', sl."startTime"), | |
| tt.team_alias, | |
| sl.end_user, | |
| sl.model, | |
| sl.api_key | |
| ) | |
| SELECT | |
| group_by_day, | |
| jsonb_agg(jsonb_build_object( | |
| 'team_name', team_name, | |
| 'customer', customer, | |
| 'total_spend', total_spend, | |
| 'metadata', metadata | |
| )) AS teams_customers | |
| FROM ( | |
| SELECT | |
| group_by_day, | |
| team_name, | |
| customer, | |
| SUM(model_api_spend) AS total_spend, | |
| jsonb_agg(jsonb_build_object( | |
| 'model', model, | |
| 'api_key', api_key, | |
| 'spend', model_api_spend, | |
| 'total_tokens', model_api_tokens | |
| )) AS metadata | |
| FROM | |
| SpendByModelApiKey | |
| GROUP BY | |
| group_by_day, | |
| team_name, | |
| customer | |
| ) AS aggregated | |
| GROUP BY | |
| group_by_day | |
| ORDER BY | |
| group_by_day; | |
| """ | |
| db_response = await prisma_client.db.query_raw( | |
| sql_query, start_date, end_date, team_id, customer_id | |
| ) | |
| if db_response is None: | |
| return [] | |
| return db_response | |
| def _get_messages_for_spend_logs_payload( | |
| standard_logging_payload: Optional[StandardLoggingPayload], | |
| metadata: Optional[dict] = None, | |
| ) -> str: | |
| return "{}" | |
| def _sanitize_request_body_for_spend_logs_payload( | |
| request_body: dict, | |
| visited: Optional[set] = None, | |
| ) -> dict: | |
| """ | |
| Recursively sanitize request body to prevent logging large base64 strings or other large values. | |
| Truncates strings longer than 1000 characters and handles nested dictionaries. | |
| """ | |
| MAX_STRING_LENGTH = 1000 | |
| if visited is None: | |
| visited = set() | |
| # Get the object's memory address to track visited objects | |
| obj_id = id(request_body) | |
| if obj_id in visited: | |
| return {} | |
| visited.add(obj_id) | |
| def _sanitize_value(value: Any) -> Any: | |
| if isinstance(value, dict): | |
| return _sanitize_request_body_for_spend_logs_payload(value, visited) | |
| elif isinstance(value, list): | |
| return [_sanitize_value(item) for item in value] | |
| elif isinstance(value, str): | |
| if len(value) > MAX_STRING_LENGTH: | |
| return f"{value[:MAX_STRING_LENGTH]}... (truncated {len(value) - MAX_STRING_LENGTH} chars)" | |
| return value | |
| return value | |
| return {k: _sanitize_value(v) for k, v in request_body.items()} | |
| def _get_proxy_server_request_for_spend_logs_payload( | |
| metadata: dict, | |
| litellm_params: dict, | |
| ) -> str: | |
| """ | |
| Only store if _should_store_prompts_and_responses_in_spend_logs() is True | |
| """ | |
| if _should_store_prompts_and_responses_in_spend_logs(): | |
| _proxy_server_request = cast( | |
| Optional[dict], litellm_params.get("proxy_server_request", {}) | |
| ) | |
| if _proxy_server_request is not None: | |
| _request_body = _proxy_server_request.get("body", {}) or {} | |
| _request_body = _sanitize_request_body_for_spend_logs_payload(_request_body) | |
| _request_body_json_str = json.dumps(_request_body, default=str) | |
| return _request_body_json_str | |
| return "{}" | |
| def _get_response_for_spend_logs_payload( | |
| payload: Optional[StandardLoggingPayload], | |
| ) -> str: | |
| if payload is None: | |
| return "{}" | |
| if _should_store_prompts_and_responses_in_spend_logs(): | |
| return json.dumps(payload.get("response", {})) | |
| return "{}" | |
| def _should_store_prompts_and_responses_in_spend_logs() -> bool: | |
| from litellm.proxy.proxy_server import general_settings | |
| return general_settings.get("store_prompts_in_spend_logs") is True | |