Spaces:
Paused
Paused
| """ | |
| This is a rate limiter implementation based on a similar one by Envoy proxy. | |
| This is currently in development and not yet ready for production. | |
| """ | |
| import os | |
| from datetime import datetime | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| Dict, | |
| List, | |
| Literal, | |
| Optional, | |
| TypedDict, | |
| Union, | |
| cast, | |
| ) | |
| from fastapi import HTTPException | |
| from litellm import DualCache | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.integrations.custom_logger import CustomLogger | |
| from litellm.proxy._types import UserAPIKeyAuth | |
| if TYPE_CHECKING: | |
| from opentelemetry.trace import Span as _Span | |
| from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache | |
| from litellm.types.caching import RedisPipelineIncrementOperation | |
| Span = Union[_Span, Any] | |
| InternalUsageCache = _InternalUsageCache | |
| else: | |
| Span = Any | |
| InternalUsageCache = Any | |
| BATCH_RATE_LIMITER_SCRIPT = """ | |
| local results = {} | |
| local now = tonumber(ARGV[1]) | |
| local window_size = tonumber(ARGV[2]) | |
| -- Process each window/counter pair | |
| for i = 1, #KEYS, 2 do | |
| local window_key = KEYS[i] | |
| local counter_key = KEYS[i + 1] | |
| local increment_value = 1 | |
| -- Check if window exists and is valid | |
| local window_start = redis.call('GET', window_key) | |
| if not window_start or (now - tonumber(window_start)) >= window_size then | |
| -- Reset window and counter | |
| redis.call('SET', window_key, tostring(now)) | |
| redis.call('SET', counter_key, increment_value) | |
| redis.call('EXPIRE', window_key, window_size) | |
| redis.call('EXPIRE', counter_key, window_size) | |
| table.insert(results, tostring(now)) -- window_start | |
| table.insert(results, increment_value) -- counter | |
| else | |
| local counter = redis.call('INCR', counter_key) | |
| table.insert(results, window_start) -- window_start | |
| table.insert(results, counter) -- counter | |
| end | |
| end | |
| return results | |
| """ | |
| class RateLimitDescriptorRateLimitObject(TypedDict, total=False): | |
| requests_per_unit: Optional[int] | |
| tokens_per_unit: Optional[int] | |
| max_parallel_requests: Optional[int] | |
| window_size: Optional[int] | |
| class RateLimitDescriptor(TypedDict): | |
| key: str | |
| value: str | |
| rate_limit: Optional[RateLimitDescriptorRateLimitObject] | |
| class RateLimitStatus(TypedDict): | |
| code: str | |
| current_limit: int | |
| limit_remaining: int | |
| rate_limit_type: Literal["requests", "tokens", "max_parallel_requests"] | |
| descriptor_key: str | |
| class RateLimitResponse(TypedDict): | |
| overall_code: str | |
| statuses: List[RateLimitStatus] | |
| class RateLimitResponseWithDescriptors(TypedDict): | |
| descriptors: List[RateLimitDescriptor] | |
| response: RateLimitResponse | |
| class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger): | |
| def __init__(self, internal_usage_cache: InternalUsageCache): | |
| self.internal_usage_cache = internal_usage_cache | |
| if self.internal_usage_cache.dual_cache.redis_cache is not None: | |
| self.batch_rate_limiter_script = ( | |
| self.internal_usage_cache.dual_cache.redis_cache.async_register_script( | |
| BATCH_RATE_LIMITER_SCRIPT | |
| ) | |
| ) | |
| else: | |
| self.batch_rate_limiter_script = None | |
| self.window_size = int(os.getenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", 60)) | |
| async def in_memory_cache_sliding_window( | |
| self, | |
| keys: List[str], | |
| now_int: int, | |
| window_size: int, | |
| ) -> List[Any]: | |
| """ | |
| Implement sliding window rate limiting logic using in-memory cache operations. | |
| This follows the same logic as the Redis Lua script but uses async cache operations. | |
| """ | |
| results: List[Any] = [] | |
| # Process each window/counter pair | |
| for i in range(0, len(keys), 2): | |
| window_key = keys[i] | |
| counter_key = keys[i + 1] | |
| increment_value = 1 | |
| # Get the window start time | |
| window_start = await self.internal_usage_cache.async_get_cache( | |
| key=window_key, | |
| litellm_parent_otel_span=None, | |
| local_only=True, | |
| ) | |
| # Check if window exists and is valid | |
| if window_start is None or (now_int - int(window_start)) >= window_size: | |
| # Reset window and counter | |
| await self.internal_usage_cache.async_set_cache( | |
| key=window_key, | |
| value=str(now_int), | |
| ttl=window_size, | |
| litellm_parent_otel_span=None, | |
| local_only=True, | |
| ) | |
| await self.internal_usage_cache.async_set_cache( | |
| key=counter_key, | |
| value=increment_value, | |
| ttl=window_size, | |
| litellm_parent_otel_span=None, | |
| local_only=True, | |
| ) | |
| results.append(str(now_int)) # window_start | |
| results.append(increment_value) # counter | |
| else: | |
| # Increment the counter | |
| current_counter = await self.internal_usage_cache.async_get_cache( | |
| key=counter_key, | |
| litellm_parent_otel_span=None, | |
| local_only=True, | |
| ) | |
| new_counter_value = ( | |
| int(current_counter) if current_counter is not None else 0 | |
| ) + increment_value | |
| await self.internal_usage_cache.async_set_cache( | |
| key=counter_key, | |
| value=new_counter_value, | |
| ttl=window_size, | |
| litellm_parent_otel_span=None, | |
| local_only=True, | |
| ) | |
| results.append(window_start) # window_start | |
| results.append(new_counter_value) # counter | |
| return results | |
| def create_rate_limit_keys( | |
| self, | |
| key: str, | |
| value: str, | |
| rate_limit_type: Literal["requests", "tokens", "max_parallel_requests"], | |
| ) -> str: | |
| """ | |
| Create the rate limit keys for the given key and value. | |
| """ | |
| counter_key = f"{{{key}:{value}}}:{rate_limit_type}" | |
| return counter_key | |
| def is_cache_list_over_limit( | |
| self, | |
| keys_to_fetch: List[str], | |
| cache_values: List[Any], | |
| key_metadata: Dict[str, Any], | |
| ) -> RateLimitResponse: | |
| """ | |
| Check if the cache values are over the limit. | |
| """ | |
| statuses: List[RateLimitStatus] = [] | |
| overall_code = "OK" | |
| for i in range(0, len(cache_values), 2): | |
| item_code = "OK" | |
| window_key = keys_to_fetch[i] | |
| counter_key = keys_to_fetch[i + 1] | |
| counter_value = cache_values[i + 1] | |
| requests_limit = key_metadata[window_key]["requests_limit"] | |
| max_parallel_requests_limit = key_metadata[window_key][ | |
| "max_parallel_requests_limit" | |
| ] | |
| tokens_limit = key_metadata[window_key]["tokens_limit"] | |
| # Determine which limit to use for current_limit and limit_remaining | |
| current_limit: Optional[int] = None | |
| rate_limit_type: Optional[ | |
| Literal["requests", "tokens", "max_parallel_requests"] | |
| ] = None | |
| if counter_key.endswith(":requests"): | |
| current_limit = requests_limit | |
| rate_limit_type = "requests" | |
| elif counter_key.endswith(":max_parallel_requests"): | |
| current_limit = max_parallel_requests_limit | |
| rate_limit_type = "max_parallel_requests" | |
| elif counter_key.endswith(":tokens"): | |
| current_limit = tokens_limit | |
| rate_limit_type = "tokens" | |
| if current_limit is None or rate_limit_type is None: | |
| continue | |
| if counter_value is not None and int(counter_value) + 1 > current_limit: | |
| overall_code = "OVER_LIMIT" | |
| item_code = "OVER_LIMIT" | |
| # Only compute limit_remaining if current_limit is not None | |
| limit_remaining = ( | |
| current_limit - int(counter_value) | |
| if counter_value is not None | |
| else current_limit | |
| ) | |
| statuses.append( | |
| { | |
| "code": item_code, | |
| "current_limit": current_limit, | |
| "limit_remaining": limit_remaining, | |
| "rate_limit_type": rate_limit_type, | |
| "descriptor_key": key_metadata[window_key]["descriptor_key"], | |
| } | |
| ) | |
| return RateLimitResponse(overall_code=overall_code, statuses=statuses) | |
| async def should_rate_limit( | |
| self, | |
| descriptors: List[RateLimitDescriptor], | |
| parent_otel_span: Optional[Span] = None, | |
| read_only: bool = False, | |
| ) -> RateLimitResponse: | |
| """ | |
| Check if any of the rate limit descriptors should be rate limited. | |
| Returns a RateLimitResponse with the overall code and status for each descriptor. | |
| Uses batch operations for Redis to improve performance. | |
| """ | |
| now = datetime.now().timestamp() | |
| now_int = int(now) # Convert to integer for Redis Lua script | |
| # Collect all keys and their metadata upfront | |
| keys_to_fetch: List[str] = [] | |
| key_metadata = {} # Store metadata for each key | |
| for descriptor in descriptors: | |
| descriptor_key = descriptor["key"] | |
| descriptor_value = descriptor["value"] | |
| rate_limit = descriptor.get("rate_limit", {}) or {} | |
| requests_limit = rate_limit.get("requests_per_unit") | |
| tokens_limit = rate_limit.get("tokens_per_unit") | |
| max_parallel_requests_limit = rate_limit.get("max_parallel_requests") | |
| window_size = rate_limit.get("window_size") or self.window_size | |
| window_key = f"{{{descriptor_key}:{descriptor_value}}}:window" | |
| rate_limit_set = False | |
| if requests_limit is not None: | |
| rpm_key = self.create_rate_limit_keys( | |
| descriptor_key, descriptor_value, "requests" | |
| ) | |
| keys_to_fetch.extend([window_key, rpm_key]) | |
| rate_limit_set = True | |
| if tokens_limit is not None: | |
| tpm_key = self.create_rate_limit_keys( | |
| descriptor_key, descriptor_value, "tokens" | |
| ) | |
| keys_to_fetch.extend([window_key, tpm_key]) | |
| rate_limit_set = True | |
| if max_parallel_requests_limit is not None: | |
| max_parallel_requests_key = self.create_rate_limit_keys( | |
| descriptor_key, descriptor_value, "max_parallel_requests" | |
| ) | |
| keys_to_fetch.extend([window_key, max_parallel_requests_key]) | |
| rate_limit_set = True | |
| if not rate_limit_set: | |
| continue | |
| key_metadata[window_key] = { | |
| "requests_limit": int(requests_limit) | |
| if requests_limit is not None | |
| else None, | |
| "tokens_limit": int(tokens_limit) if tokens_limit is not None else None, | |
| "max_parallel_requests_limit": int(max_parallel_requests_limit) | |
| if max_parallel_requests_limit is not None | |
| else None, | |
| "window_size": int(window_size), | |
| "descriptor_key": descriptor_key, | |
| } | |
| ## CHECK IN-MEMORY CACHE | |
| cache_values = await self.internal_usage_cache.async_batch_get_cache( | |
| keys=keys_to_fetch, | |
| parent_otel_span=parent_otel_span, | |
| local_only=True, | |
| ) | |
| if cache_values is not None: | |
| rate_limit_response = self.is_cache_list_over_limit( | |
| keys_to_fetch, cache_values, key_metadata | |
| ) | |
| if rate_limit_response["overall_code"] == "OVER_LIMIT": | |
| return rate_limit_response | |
| ## IF under limit, check Redis | |
| if self.batch_rate_limiter_script is not None: | |
| cache_values = await self.batch_rate_limiter_script( | |
| keys=keys_to_fetch, | |
| args=[now_int, self.window_size], # Use integer timestamp | |
| ) | |
| # update in-memory cache with new values | |
| for i in range(0, len(cache_values), 2): | |
| window_key = keys_to_fetch[i] | |
| counter_key = keys_to_fetch[i + 1] | |
| window_value = cache_values[i] | |
| counter_value = cache_values[i + 1] | |
| await self.internal_usage_cache.async_set_cache( | |
| key=counter_key, | |
| value=counter_value, | |
| ttl=self.window_size, | |
| litellm_parent_otel_span=parent_otel_span, | |
| local_only=True, | |
| ) | |
| await self.internal_usage_cache.async_set_cache( | |
| key=window_key, | |
| value=window_value, | |
| ttl=self.window_size, | |
| litellm_parent_otel_span=parent_otel_span, | |
| local_only=True, | |
| ) | |
| else: | |
| cache_values = await self.in_memory_cache_sliding_window( | |
| keys=keys_to_fetch, | |
| now_int=now_int, | |
| window_size=self.window_size, | |
| ) | |
| rate_limit_response = self.is_cache_list_over_limit( | |
| keys_to_fetch, cache_values, key_metadata | |
| ) | |
| return rate_limit_response | |
| async def async_pre_call_hook( | |
| self, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| cache: DualCache, | |
| data: dict, | |
| call_type: str, | |
| ): | |
| """ | |
| Pre-call hook to check rate limits before making the API call. | |
| """ | |
| from litellm.proxy.auth.auth_utils import ( | |
| get_key_model_rpm_limit, | |
| get_key_model_tpm_limit, | |
| ) | |
| verbose_proxy_logger.debug("Inside Rate Limit Pre-Call Hook") | |
| # Create rate limit descriptors | |
| descriptors = [] | |
| # API Key rate limits | |
| if user_api_key_dict.api_key: | |
| descriptors.append( | |
| RateLimitDescriptor( | |
| key="api_key", | |
| value=user_api_key_dict.api_key, | |
| rate_limit={ | |
| "requests_per_unit": user_api_key_dict.rpm_limit, | |
| "tokens_per_unit": user_api_key_dict.tpm_limit, | |
| "max_parallel_requests": user_api_key_dict.max_parallel_requests, | |
| "window_size": self.window_size, # 1 minute window | |
| }, | |
| ) | |
| ) | |
| # User rate limits | |
| if user_api_key_dict.user_id: | |
| descriptors.append( | |
| RateLimitDescriptor( | |
| key="user", | |
| value=user_api_key_dict.user_id, | |
| rate_limit={ | |
| "requests_per_unit": user_api_key_dict.user_rpm_limit, | |
| "tokens_per_unit": user_api_key_dict.user_tpm_limit, | |
| "window_size": self.window_size, | |
| }, | |
| ) | |
| ) | |
| # Team rate limits | |
| if user_api_key_dict.team_id: | |
| descriptors.append( | |
| RateLimitDescriptor( | |
| key="team", | |
| value=user_api_key_dict.team_id, | |
| rate_limit={ | |
| "requests_per_unit": user_api_key_dict.team_rpm_limit, | |
| "tokens_per_unit": user_api_key_dict.team_tpm_limit, | |
| "window_size": self.window_size, | |
| }, | |
| ) | |
| ) | |
| # End user rate limits | |
| if user_api_key_dict.end_user_id: | |
| descriptors.append( | |
| RateLimitDescriptor( | |
| key="end_user", | |
| value=user_api_key_dict.end_user_id, | |
| rate_limit={ | |
| "requests_per_unit": user_api_key_dict.end_user_rpm_limit, | |
| "tokens_per_unit": user_api_key_dict.end_user_tpm_limit, | |
| "window_size": self.window_size, | |
| }, | |
| ) | |
| ) | |
| # Model rate limits | |
| requested_model = data.get("model", None) | |
| if requested_model and ( | |
| get_key_model_tpm_limit(user_api_key_dict) is not None | |
| or get_key_model_rpm_limit(user_api_key_dict) is not None | |
| ): | |
| _tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict) or {} | |
| _rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict) or {} | |
| should_check_rate_limit = False | |
| if requested_model in _tpm_limit_for_key_model: | |
| should_check_rate_limit = True | |
| elif requested_model in _rpm_limit_for_key_model: | |
| should_check_rate_limit = True | |
| if should_check_rate_limit: | |
| model_specific_tpm_limit: Optional[int] = None | |
| model_specific_rpm_limit: Optional[int] = None | |
| if requested_model in _tpm_limit_for_key_model: | |
| model_specific_tpm_limit = _tpm_limit_for_key_model[requested_model] | |
| if requested_model in _rpm_limit_for_key_model: | |
| model_specific_rpm_limit = _rpm_limit_for_key_model[requested_model] | |
| descriptors.append( | |
| RateLimitDescriptor( | |
| key="model_per_key", | |
| value=f"{user_api_key_dict.api_key}:{requested_model}", | |
| rate_limit={ | |
| "requests_per_unit": model_specific_rpm_limit, | |
| "tokens_per_unit": model_specific_tpm_limit, | |
| "window_size": self.window_size, | |
| }, | |
| ) | |
| ) | |
| # Check rate limits | |
| response = await self.should_rate_limit( | |
| descriptors=descriptors, | |
| parent_otel_span=user_api_key_dict.parent_otel_span, | |
| ) | |
| if response["overall_code"] == "OVER_LIMIT": | |
| # Find which descriptor hit the limit | |
| for i, status in enumerate(response["statuses"]): | |
| if status["code"] == "OVER_LIMIT": | |
| descriptor = descriptors[i] | |
| raise HTTPException( | |
| status_code=429, | |
| detail=f"Rate limit exceeded for {descriptor['key']}: {descriptor['value']}. Remaining: {status['limit_remaining']}", | |
| headers={ | |
| "retry-after": str(self.window_size) | |
| }, # Retry after 1 minute | |
| ) | |
| else: | |
| # add descriptors to request headers | |
| data["litellm_proxy_rate_limit_response"] = response | |
| def _create_pipeline_operations( | |
| self, | |
| key: str, | |
| value: str, | |
| rate_limit_type: Literal["requests", "tokens", "max_parallel_requests"], | |
| total_tokens: int, | |
| ) -> List["RedisPipelineIncrementOperation"]: | |
| """ | |
| Create pipeline operations for TPM increments | |
| """ | |
| from litellm.types.caching import RedisPipelineIncrementOperation | |
| pipeline_operations: List[RedisPipelineIncrementOperation] = [] | |
| counter_key = self.create_rate_limit_keys( | |
| key=key, | |
| value=value, | |
| rate_limit_type="tokens", | |
| ) | |
| pipeline_operations.append( | |
| RedisPipelineIncrementOperation( | |
| key=counter_key, | |
| increment_value=total_tokens, | |
| ttl=self.window_size, | |
| ) | |
| ) | |
| return pipeline_operations | |
| def get_rate_limit_type(self) -> Literal["output", "input", "total"]: | |
| from litellm.proxy.proxy_server import general_settings | |
| specified_rate_limit_type = general_settings.get( | |
| "token_rate_limit_type", "output" | |
| ) | |
| if not specified_rate_limit_type or specified_rate_limit_type not in [ | |
| "output", | |
| "input", | |
| "total", | |
| ]: | |
| return "total" # default to total | |
| return specified_rate_limit_type | |
| async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
| """ | |
| Update TPM usage on successful API calls by incrementing counters using pipeline | |
| """ | |
| from litellm.litellm_core_utils.core_helpers import ( | |
| _get_parent_otel_span_from_kwargs, | |
| ) | |
| from litellm.proxy.common_utils.callback_utils import ( | |
| get_model_group_from_litellm_kwargs, | |
| ) | |
| from litellm.types.caching import RedisPipelineIncrementOperation | |
| from litellm.types.utils import ModelResponse, Usage | |
| rate_limit_type = self.get_rate_limit_type() | |
| litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs( | |
| kwargs | |
| ) | |
| try: | |
| verbose_proxy_logger.debug( | |
| "INSIDE parallel request limiter ASYNC SUCCESS LOGGING" | |
| ) | |
| # Get metadata from kwargs | |
| user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key") | |
| user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( | |
| "user_api_key_user_id" | |
| ) | |
| user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( | |
| "user_api_key_team_id" | |
| ) | |
| user_api_key_end_user_id = kwargs.get("user") or kwargs["litellm_params"][ | |
| "metadata" | |
| ].get("user_api_key_end_user_id") | |
| model_group = get_model_group_from_litellm_kwargs(kwargs) | |
| # Get total tokens from response | |
| total_tokens = 0 | |
| if isinstance(response_obj, ModelResponse): | |
| _usage = getattr(response_obj, "usage", None) | |
| if _usage and isinstance(_usage, Usage): | |
| if rate_limit_type == "output": | |
| total_tokens = _usage.completion_tokens | |
| elif rate_limit_type == "input": | |
| total_tokens = _usage.prompt_tokens | |
| elif rate_limit_type == "total": | |
| total_tokens = _usage.total_tokens | |
| # Create pipeline operations for TPM increments | |
| pipeline_operations: List[RedisPipelineIncrementOperation] = [] | |
| # API Key TPM | |
| if user_api_key: | |
| # MAX PARALLEL REQUESTS - only support for API Key, just decrement the counter | |
| counter_key = self.create_rate_limit_keys( | |
| key="api_key", | |
| value=user_api_key, | |
| rate_limit_type="max_parallel_requests", | |
| ) | |
| pipeline_operations.append( | |
| RedisPipelineIncrementOperation( | |
| key=counter_key, | |
| increment_value=-1, | |
| ttl=self.window_size, | |
| ) | |
| ) | |
| pipeline_operations.extend( | |
| self._create_pipeline_operations( | |
| key="api_key", | |
| value=user_api_key, | |
| rate_limit_type="tokens", | |
| total_tokens=total_tokens, | |
| ) | |
| ) | |
| # User TPM | |
| if user_api_key_user_id: | |
| # TPM | |
| pipeline_operations.extend( | |
| self._create_pipeline_operations( | |
| key="user", | |
| value=user_api_key_user_id, | |
| rate_limit_type="tokens", | |
| total_tokens=total_tokens, | |
| ) | |
| ) | |
| # Team TPM | |
| if user_api_key_team_id: | |
| pipeline_operations.extend( | |
| self._create_pipeline_operations( | |
| key="team", | |
| value=user_api_key_team_id, | |
| rate_limit_type="tokens", | |
| total_tokens=total_tokens, | |
| ) | |
| ) | |
| # End User TPM | |
| if user_api_key_end_user_id: | |
| pipeline_operations.extend( | |
| self._create_pipeline_operations( | |
| key="end_user", | |
| value=user_api_key_end_user_id, | |
| rate_limit_type="tokens", | |
| total_tokens=total_tokens, | |
| ) | |
| ) | |
| # Model-specific TPM | |
| if model_group and user_api_key: | |
| pipeline_operations.extend( | |
| self._create_pipeline_operations( | |
| key="model_per_key", | |
| value=f"{user_api_key}:{model_group}", | |
| rate_limit_type="tokens", | |
| total_tokens=total_tokens, | |
| ) | |
| ) | |
| # Execute all increments in a single pipeline | |
| if pipeline_operations: | |
| await self.internal_usage_cache.dual_cache.async_increment_cache_pipeline( | |
| increment_list=pipeline_operations, | |
| litellm_parent_otel_span=litellm_parent_otel_span, | |
| ) | |
| except Exception as e: | |
| verbose_proxy_logger.exception( | |
| f"Error in rate limit success event: {str(e)}" | |
| ) | |
| async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
| """ | |
| Decrement max parallel requests counter for the API Key | |
| """ | |
| from litellm.litellm_core_utils.core_helpers import ( | |
| _get_parent_otel_span_from_kwargs, | |
| ) | |
| from litellm.types.caching import RedisPipelineIncrementOperation | |
| try: | |
| litellm_parent_otel_span: Union[ | |
| Span, None | |
| ] = _get_parent_otel_span_from_kwargs(kwargs) | |
| user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key") | |
| pipeline_operations: List[RedisPipelineIncrementOperation] = [] | |
| if user_api_key: | |
| # MAX PARALLEL REQUESTS - only support for API Key, just decrement the counter | |
| counter_key = self.create_rate_limit_keys( | |
| key="api_key", | |
| value=user_api_key, | |
| rate_limit_type="max_parallel_requests", | |
| ) | |
| pipeline_operations.append( | |
| RedisPipelineIncrementOperation( | |
| key=counter_key, | |
| increment_value=-1, | |
| ttl=self.window_size, | |
| ) | |
| ) | |
| # Execute all increments in a single pipeline | |
| if pipeline_operations: | |
| await self.internal_usage_cache.dual_cache.async_increment_cache_pipeline( | |
| increment_list=pipeline_operations, | |
| litellm_parent_otel_span=litellm_parent_otel_span, | |
| ) | |
| except Exception as e: | |
| verbose_proxy_logger.exception( | |
| f"Error in rate limit failure event: {str(e)}" | |
| ) | |
| async def async_post_call_success_hook( | |
| self, data: dict, user_api_key_dict: UserAPIKeyAuth, response | |
| ): | |
| """ | |
| Post-call hook to update rate limit headers in the response. | |
| """ | |
| try: | |
| from pydantic import BaseModel | |
| litellm_proxy_rate_limit_response = cast( | |
| Optional[RateLimitResponse], | |
| data.get("litellm_proxy_rate_limit_response", None), | |
| ) | |
| if litellm_proxy_rate_limit_response is not None: | |
| # Update response headers | |
| if hasattr(response, "_hidden_params"): | |
| _hidden_params = getattr(response, "_hidden_params") | |
| else: | |
| _hidden_params = None | |
| if _hidden_params is not None and ( | |
| isinstance(_hidden_params, BaseModel) | |
| or isinstance(_hidden_params, dict) | |
| ): | |
| if isinstance(_hidden_params, BaseModel): | |
| _hidden_params = _hidden_params.model_dump() | |
| _additional_headers = ( | |
| _hidden_params.get("additional_headers", {}) or {} | |
| ) | |
| # Add rate limit headers | |
| for status in litellm_proxy_rate_limit_response["statuses"]: | |
| prefix = f"x-ratelimit-{status['descriptor_key']}" | |
| _additional_headers[ | |
| f"{prefix}-remaining-{status['rate_limit_type']}" | |
| ] = status["limit_remaining"] | |
| _additional_headers[ | |
| f"{prefix}-limit-{status['rate_limit_type']}" | |
| ] = status["current_limit"] | |
| setattr( | |
| response, | |
| "_hidden_params", | |
| {**_hidden_params, "additional_headers": _additional_headers}, | |
| ) | |
| except Exception as e: | |
| verbose_proxy_logger.exception( | |
| f"Error in rate limit post-call hook: {str(e)}" | |
| ) | |