Spaces:
Paused
Paused
| # What is this? | |
| ## Helper utils for the management endpoints (keys/users/teams) | |
| import uuid | |
| from datetime import datetime | |
| from functools import wraps | |
| from typing import Optional, Tuple | |
| from fastapi import HTTPException, Request | |
| import litellm | |
| from litellm._logging import verbose_logger | |
| from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types | |
| DeleteCustomerRequest, | |
| DeleteTeamRequest, | |
| DeleteUserRequest, | |
| KeyRequest, | |
| LiteLLM_TeamMembership, | |
| LiteLLM_UserTable, | |
| ManagementEndpointLoggingPayload, | |
| Member, | |
| SSOUserDefinedValues, | |
| UpdateCustomerRequest, | |
| UpdateKeyRequest, | |
| UpdateTeamRequest, | |
| UpdateUserRequest, | |
| UserAPIKeyAuth, | |
| VirtualKeyEvent, | |
| ) | |
| from litellm.proxy.common_utils.http_parsing_utils import _read_request_body | |
| from litellm.proxy.utils import PrismaClient | |
| def get_new_internal_user_defaults( | |
| user_id: str, user_email: Optional[str] = None | |
| ) -> dict: | |
| user_info = litellm.default_internal_user_params or {} | |
| returned_dict: SSOUserDefinedValues = { | |
| "models": user_info.get("models", None), | |
| "max_budget": user_info.get("max_budget", litellm.max_internal_user_budget), | |
| "budget_duration": user_info.get( | |
| "budget_duration", litellm.internal_user_budget_duration | |
| ), | |
| "user_email": user_email or user_info.get("user_email", None), | |
| "user_id": user_id, | |
| "user_role": "internal_user", | |
| } | |
| non_null_dict = {} | |
| for k, v in returned_dict.items(): | |
| if v is not None: | |
| non_null_dict[k] = v | |
| return non_null_dict | |
| async def add_new_member( | |
| new_member: Member, | |
| max_budget_in_team: Optional[float], | |
| prisma_client: PrismaClient, | |
| team_id: str, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| litellm_proxy_admin_name: str, | |
| ) -> Tuple[LiteLLM_UserTable, Optional[LiteLLM_TeamMembership]]: | |
| """ | |
| Add a new member to a team | |
| - add team id to user table | |
| - add team member w/ budget to team member table | |
| Returns created/existing user + team membership w/ budget id | |
| """ | |
| returned_user: Optional[LiteLLM_UserTable] = None | |
| returned_team_membership: Optional[LiteLLM_TeamMembership] = None | |
| ## ADD TEAM ID, to USER TABLE IF NEW ## | |
| if new_member.user_id is not None: | |
| new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id) | |
| _returned_user = await prisma_client.db.litellm_usertable.upsert( | |
| where={"user_id": new_member.user_id}, | |
| data={ | |
| "update": {"teams": {"push": [team_id]}}, | |
| "create": {"teams": [team_id], **new_user_defaults}, # type: ignore | |
| }, | |
| ) | |
| if _returned_user is not None: | |
| returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) | |
| elif new_member.user_email is not None: | |
| new_user_defaults = get_new_internal_user_defaults( | |
| user_id=str(uuid.uuid4()), user_email=new_member.user_email | |
| ) | |
| ## user email is not unique acc. to prisma schema -> future improvement | |
| ### for now: check if it exists in db, if not - insert it | |
| existing_user_row: Optional[list] = await prisma_client.get_data( | |
| key_val={"user_email": new_member.user_email}, | |
| table_name="user", | |
| query_type="find_all", | |
| ) | |
| if existing_user_row is None or ( | |
| isinstance(existing_user_row, list) and len(existing_user_row) == 0 | |
| ): | |
| new_user_defaults["teams"] = [team_id] | |
| _returned_user = await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore | |
| if _returned_user is not None: | |
| returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) | |
| elif len(existing_user_row) == 1: | |
| user_info = existing_user_row[0] | |
| _returned_user = await prisma_client.db.litellm_usertable.update( | |
| where={"user_id": user_info.user_id}, # type: ignore | |
| data={"teams": {"push": [team_id]}}, | |
| ) | |
| if _returned_user is not None: | |
| returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) | |
| elif len(existing_user_row) > 1: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={ | |
| "error": "Multiple users with this email found in db. Please use 'user_id' instead." | |
| }, | |
| ) | |
| # Check if trying to set a budget for team member | |
| if ( | |
| max_budget_in_team is not None | |
| and returned_user is not None | |
| and returned_user.user_id is not None | |
| ): | |
| # create a new budget item for this member | |
| response = await prisma_client.db.litellm_budgettable.create( | |
| data={ | |
| "max_budget": max_budget_in_team, | |
| "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, | |
| "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, | |
| } | |
| ) | |
| _budget_id = response.budget_id | |
| _returned_team_membership = ( | |
| await prisma_client.db.litellm_teammembership.create( | |
| data={ | |
| "team_id": team_id, | |
| "user_id": returned_user.user_id, | |
| "budget_id": _budget_id, | |
| }, | |
| include={"litellm_budget_table": True}, | |
| ) | |
| ) | |
| returned_team_membership = LiteLLM_TeamMembership( | |
| **_returned_team_membership.model_dump() | |
| ) | |
| if returned_user is None: | |
| raise Exception("Unable to update user table with membership information!") | |
| return returned_user, returned_team_membership | |
| def _delete_user_id_from_cache(kwargs): | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| if kwargs.get("data") is not None: | |
| update_user_request = kwargs.get("data") | |
| if isinstance(update_user_request, UpdateUserRequest): | |
| user_api_key_cache.delete_cache(key=update_user_request.user_id) | |
| # delete user request | |
| if isinstance(update_user_request, DeleteUserRequest): | |
| for user_id in update_user_request.user_ids: | |
| user_api_key_cache.delete_cache(key=user_id) | |
| pass | |
| def _delete_api_key_from_cache(kwargs): | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| if kwargs.get("data") is not None: | |
| update_request = kwargs.get("data") | |
| if isinstance(update_request, UpdateKeyRequest): | |
| user_api_key_cache.delete_cache(key=update_request.key) | |
| # delete key request | |
| if isinstance(update_request, KeyRequest) and update_request.keys: | |
| for key in update_request.keys: | |
| user_api_key_cache.delete_cache(key=key) | |
| pass | |
| def _delete_team_id_from_cache(kwargs): | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| if kwargs.get("data") is not None: | |
| update_request = kwargs.get("data") | |
| if isinstance(update_request, UpdateTeamRequest): | |
| user_api_key_cache.delete_cache(key=update_request.team_id) | |
| # delete team request | |
| if isinstance(update_request, DeleteTeamRequest): | |
| for team_id in update_request.team_ids: | |
| user_api_key_cache.delete_cache(key=team_id) | |
| pass | |
| def _delete_customer_id_from_cache(kwargs): | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| if kwargs.get("data") is not None: | |
| update_request = kwargs.get("data") | |
| if isinstance(update_request, UpdateCustomerRequest): | |
| user_api_key_cache.delete_cache(key=update_request.user_id) | |
| # delete customer request | |
| if isinstance(update_request, DeleteCustomerRequest): | |
| for user_id in update_request.user_ids: | |
| user_api_key_cache.delete_cache(key=user_id) | |
| pass | |
| async def send_management_endpoint_alert( | |
| request_kwargs: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| function_name: str, | |
| ): | |
| """ | |
| Sends a slack alert when: | |
| - A virtual key is created, updated, or deleted | |
| - An internal user is created, updated, or deleted | |
| - A team is created, updated, or deleted | |
| """ | |
| from litellm.proxy.proxy_server import premium_user, proxy_logging_obj | |
| from litellm.types.integrations.slack_alerting import AlertType | |
| if premium_user is not True: | |
| return | |
| management_function_to_event_name = { | |
| "generate_key_fn": AlertType.new_virtual_key_created, | |
| "update_key_fn": AlertType.virtual_key_updated, | |
| "delete_key_fn": AlertType.virtual_key_deleted, | |
| # Team events | |
| "new_team": AlertType.new_team_created, | |
| "update_team": AlertType.team_updated, | |
| "delete_team": AlertType.team_deleted, | |
| # Internal User events | |
| "new_user": AlertType.new_internal_user_created, | |
| "user_update": AlertType.internal_user_updated, | |
| "delete_user": AlertType.internal_user_deleted, | |
| } | |
| # Check if alerting is enabled | |
| if ( | |
| proxy_logging_obj is not None | |
| and proxy_logging_obj.slack_alerting_instance is not None | |
| ): | |
| # Virtual Key Events | |
| if function_name in management_function_to_event_name: | |
| _event_name: AlertType = management_function_to_event_name[function_name] | |
| key_event = VirtualKeyEvent( | |
| created_by_user_id=user_api_key_dict.user_id or "Unknown", | |
| created_by_user_role=user_api_key_dict.user_role or "Unknown", | |
| created_by_key_alias=user_api_key_dict.key_alias, | |
| request_kwargs=request_kwargs, | |
| ) | |
| # replace all "_" with " " and capitalize | |
| event_name = _event_name.replace("_", " ").title() | |
| await proxy_logging_obj.slack_alerting_instance.send_virtual_key_event_slack( | |
| key_event=key_event, | |
| event_name=event_name, | |
| alert_type=_event_name, | |
| ) | |
| def management_endpoint_wrapper(func): | |
| """ | |
| This wrapper does the following: | |
| 1. Log I/O, Exceptions to OTEL | |
| 2. Create an Audit log for success calls | |
| """ | |
| async def wrapper(*args, **kwargs): | |
| start_time = datetime.now() | |
| _http_request: Optional[Request] = None | |
| try: | |
| result = await func(*args, **kwargs) | |
| end_time = datetime.now() | |
| try: | |
| if kwargs is None: | |
| kwargs = {} | |
| user_api_key_dict: UserAPIKeyAuth = ( | |
| kwargs.get("user_api_key_dict") or UserAPIKeyAuth() | |
| ) | |
| await send_management_endpoint_alert( | |
| request_kwargs=kwargs, | |
| user_api_key_dict=user_api_key_dict, | |
| function_name=func.__name__, | |
| ) | |
| _http_request = kwargs.get("http_request", None) | |
| parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None) | |
| if parent_otel_span is not None: | |
| from litellm.proxy.proxy_server import open_telemetry_logger | |
| if open_telemetry_logger is not None: | |
| if _http_request: | |
| _route = _http_request.url.path | |
| _request_body: dict = await _read_request_body( | |
| request=_http_request | |
| ) | |
| _response = dict(result) if result is not None else None | |
| logging_payload = ManagementEndpointLoggingPayload( | |
| route=_route, | |
| request_data=_request_body, | |
| response=_response, | |
| start_time=start_time, | |
| end_time=end_time, | |
| ) | |
| await open_telemetry_logger.async_management_endpoint_success_hook( # type: ignore | |
| logging_payload=logging_payload, | |
| parent_otel_span=parent_otel_span, | |
| ) | |
| # Delete updated/deleted info from cache | |
| _delete_api_key_from_cache(kwargs=kwargs) | |
| _delete_user_id_from_cache(kwargs=kwargs) | |
| _delete_team_id_from_cache(kwargs=kwargs) | |
| _delete_customer_id_from_cache(kwargs=kwargs) | |
| except Exception as e: | |
| # Non-Blocking Exception | |
| verbose_logger.debug("Error in management endpoint wrapper: %s", str(e)) | |
| pass | |
| return result | |
| except Exception as e: | |
| end_time = datetime.now() | |
| if kwargs is None: | |
| kwargs = {} | |
| user_api_key_dict: UserAPIKeyAuth = ( | |
| kwargs.get("user_api_key_dict") or UserAPIKeyAuth() | |
| ) | |
| parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None) | |
| if parent_otel_span is not None: | |
| from litellm.proxy.proxy_server import open_telemetry_logger | |
| if open_telemetry_logger is not None: | |
| _http_request = kwargs.get("http_request") | |
| if _http_request: | |
| _route = _http_request.url.path | |
| _request_body: dict = await _read_request_body( | |
| request=_http_request | |
| ) | |
| logging_payload = ManagementEndpointLoggingPayload( | |
| route=_route, | |
| request_data=_request_body, | |
| response=None, | |
| start_time=start_time, | |
| end_time=end_time, | |
| exception=e, | |
| ) | |
| await open_telemetry_logger.async_management_endpoint_failure_hook( # type: ignore | |
| logging_payload=logging_payload, | |
| parent_otel_span=parent_otel_span, | |
| ) | |
| raise e | |
| return wrapper | |