Spaces:
Paused
Paused
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional, Set, Union | |
| from fastapi import HTTPException, status | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.proxy._types import CommonProxyErrors | |
| from litellm.proxy.utils import PrismaClient | |
| from litellm.types.proxy.management_endpoints.common_daily_activity import ( | |
| BreakdownMetrics, | |
| DailySpendData, | |
| DailySpendMetadata, | |
| KeyMetadata, | |
| KeyMetricWithMetadata, | |
| MetricWithMetadata, | |
| SpendAnalyticsPaginatedResponse, | |
| SpendMetrics, | |
| ) | |
| def update_metrics(existing_metrics: SpendMetrics, record: Any) -> SpendMetrics: | |
| """Update metrics with new record data.""" | |
| existing_metrics.spend += record.spend | |
| existing_metrics.prompt_tokens += record.prompt_tokens | |
| existing_metrics.completion_tokens += record.completion_tokens | |
| existing_metrics.total_tokens += record.prompt_tokens + record.completion_tokens | |
| existing_metrics.cache_read_input_tokens += record.cache_read_input_tokens | |
| existing_metrics.cache_creation_input_tokens += record.cache_creation_input_tokens | |
| existing_metrics.api_requests += record.api_requests | |
| existing_metrics.successful_requests += record.successful_requests | |
| existing_metrics.failed_requests += record.failed_requests | |
| return existing_metrics | |
| def update_breakdown_metrics( | |
| breakdown: BreakdownMetrics, | |
| record: Any, | |
| model_metadata: Dict[str, Dict[str, Any]], | |
| provider_metadata: Dict[str, Dict[str, Any]], | |
| api_key_metadata: Dict[str, Dict[str, Any]], | |
| entity_id_field: Optional[str] = None, | |
| entity_metadata_field: Optional[Dict[str, dict]] = None, | |
| ) -> BreakdownMetrics: | |
| """Updates breakdown metrics for a single record using the existing update_metrics function""" | |
| # Update model breakdown | |
| if record.model not in breakdown.models: | |
| breakdown.models[record.model] = MetricWithMetadata( | |
| metrics=SpendMetrics(), | |
| metadata=model_metadata.get( | |
| record.model, {} | |
| ), # Add any model-specific metadata here | |
| ) | |
| breakdown.models[record.model].metrics = update_metrics( | |
| breakdown.models[record.model].metrics, record | |
| ) | |
| # Update provider breakdown | |
| provider = record.custom_llm_provider or "unknown" | |
| if provider not in breakdown.providers: | |
| breakdown.providers[provider] = MetricWithMetadata( | |
| metrics=SpendMetrics(), | |
| metadata=provider_metadata.get( | |
| provider, {} | |
| ), # Add any provider-specific metadata here | |
| ) | |
| breakdown.providers[provider].metrics = update_metrics( | |
| breakdown.providers[provider].metrics, record | |
| ) | |
| # Update api key breakdown | |
| if record.api_key not in breakdown.api_keys: | |
| breakdown.api_keys[record.api_key] = KeyMetricWithMetadata( | |
| metrics=SpendMetrics(), | |
| metadata=KeyMetadata( | |
| key_alias=api_key_metadata.get(record.api_key, {}).get( | |
| "key_alias", None | |
| ), | |
| team_id=api_key_metadata.get(record.api_key, {}).get("team_id", None), | |
| ), # Add any api_key-specific metadata here | |
| ) | |
| breakdown.api_keys[record.api_key].metrics = update_metrics( | |
| breakdown.api_keys[record.api_key].metrics, record | |
| ) | |
| # Update entity-specific metrics if entity_id_field is provided | |
| if entity_id_field: | |
| entity_value = getattr(record, entity_id_field, None) | |
| if entity_value: | |
| if entity_value not in breakdown.entities: | |
| breakdown.entities[entity_value] = MetricWithMetadata( | |
| metrics=SpendMetrics(), | |
| metadata=entity_metadata_field.get(entity_value, {}) | |
| if entity_metadata_field | |
| else {}, | |
| ) | |
| breakdown.entities[entity_value].metrics = update_metrics( | |
| breakdown.entities[entity_value].metrics, record | |
| ) | |
| return breakdown | |
| async def get_api_key_metadata( | |
| prisma_client: PrismaClient, | |
| api_keys: Set[str], | |
| ) -> Dict[str, Dict[str, Any]]: | |
| """Update api key metadata for a single record.""" | |
| key_records = await prisma_client.db.litellm_verificationtoken.find_many( | |
| where={"token": {"in": list(api_keys)}} | |
| ) | |
| return { | |
| k.token: {"key_alias": k.key_alias, "team_id": k.team_id} for k in key_records | |
| } | |
| async def get_daily_activity( | |
| prisma_client: Optional[PrismaClient], | |
| table_name: str, | |
| entity_id_field: str, | |
| entity_id: Optional[Union[str, List[str]]], | |
| entity_metadata_field: Optional[Dict[str, dict]], | |
| start_date: Optional[str], | |
| end_date: Optional[str], | |
| model: Optional[str], | |
| api_key: Optional[str], | |
| page: int, | |
| page_size: int, | |
| exclude_entity_ids: Optional[List[str]] = None, | |
| ) -> SpendAnalyticsPaginatedResponse: | |
| """Common function to get daily activity for any entity type.""" | |
| if prisma_client is None: | |
| raise HTTPException( | |
| status_code=500, | |
| detail={"error": CommonProxyErrors.db_not_connected_error.value}, | |
| ) | |
| if start_date is None or end_date is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail={"error": "Please provide start_date and end_date"}, | |
| ) | |
| try: | |
| # Build filter conditions | |
| where_conditions: Dict[str, Any] = { | |
| "date": { | |
| "gte": start_date, | |
| "lte": end_date, | |
| } | |
| } | |
| if model: | |
| where_conditions["model"] = model | |
| if api_key: | |
| where_conditions["api_key"] = api_key | |
| if entity_id is not None: | |
| if isinstance(entity_id, list): | |
| where_conditions[entity_id_field] = {"in": entity_id} | |
| else: | |
| where_conditions[entity_id_field] = entity_id | |
| if exclude_entity_ids: | |
| where_conditions.setdefault(entity_id_field, {})["not"] = { | |
| "in": exclude_entity_ids | |
| } | |
| # Get total count for pagination | |
| total_count = await getattr(prisma_client.db, table_name).count( | |
| where=where_conditions | |
| ) | |
| # Fetch paginated results | |
| daily_spend_data = await getattr(prisma_client.db, table_name).find_many( | |
| where=where_conditions, | |
| order=[ | |
| {"date": "desc"}, | |
| ], | |
| skip=(page - 1) * page_size, | |
| take=page_size, | |
| ) | |
| # Get all unique API keys from the spend data | |
| api_keys = set() | |
| for record in daily_spend_data: | |
| if record.api_key: | |
| api_keys.add(record.api_key) | |
| # Fetch key aliases in bulk | |
| api_key_metadata: Dict[str, Dict[str, Any]] = {} | |
| model_metadata: Dict[str, Dict[str, Any]] = {} | |
| provider_metadata: Dict[str, Dict[str, Any]] = {} | |
| if api_keys: | |
| api_key_metadata = await get_api_key_metadata(prisma_client, api_keys) | |
| # Process results | |
| results = [] | |
| total_metrics = SpendMetrics() | |
| grouped_data: Dict[str, Dict[str, Any]] = {} | |
| for record in daily_spend_data: | |
| date_str = record.date | |
| if date_str not in grouped_data: | |
| grouped_data[date_str] = { | |
| "metrics": SpendMetrics(), | |
| "breakdown": BreakdownMetrics(), | |
| } | |
| # Update metrics | |
| grouped_data[date_str]["metrics"] = update_metrics( | |
| grouped_data[date_str]["metrics"], record | |
| ) | |
| # Update breakdowns | |
| grouped_data[date_str]["breakdown"] = update_breakdown_metrics( | |
| grouped_data[date_str]["breakdown"], | |
| record, | |
| model_metadata, | |
| provider_metadata, | |
| api_key_metadata, | |
| entity_id_field=entity_id_field, | |
| entity_metadata_field=entity_metadata_field, | |
| ) | |
| # Update total metrics | |
| total_metrics.spend += record.spend | |
| total_metrics.prompt_tokens += record.prompt_tokens | |
| total_metrics.completion_tokens += record.completion_tokens | |
| total_metrics.total_tokens += ( | |
| record.prompt_tokens + record.completion_tokens | |
| ) | |
| total_metrics.cache_read_input_tokens += record.cache_read_input_tokens | |
| total_metrics.cache_creation_input_tokens += ( | |
| record.cache_creation_input_tokens | |
| ) | |
| total_metrics.api_requests += record.api_requests | |
| total_metrics.successful_requests += record.successful_requests | |
| total_metrics.failed_requests += record.failed_requests | |
| # Convert grouped data to response format | |
| for date_str, data in grouped_data.items(): | |
| results.append( | |
| DailySpendData( | |
| date=datetime.strptime(date_str, "%Y-%m-%d").date(), | |
| metrics=data["metrics"], | |
| breakdown=data["breakdown"], | |
| ) | |
| ) | |
| # Sort results by date | |
| results.sort(key=lambda x: x.date, reverse=True) | |
| return SpendAnalyticsPaginatedResponse( | |
| results=results, | |
| metadata=DailySpendMetadata( | |
| total_spend=total_metrics.spend, | |
| total_prompt_tokens=total_metrics.prompt_tokens, | |
| total_completion_tokens=total_metrics.completion_tokens, | |
| total_tokens=total_metrics.total_tokens, | |
| total_api_requests=total_metrics.api_requests, | |
| total_successful_requests=total_metrics.successful_requests, | |
| total_failed_requests=total_metrics.failed_requests, | |
| total_cache_read_input_tokens=total_metrics.cache_read_input_tokens, | |
| total_cache_creation_input_tokens=total_metrics.cache_creation_input_tokens, | |
| page=page, | |
| total_pages=-(-total_count // page_size), # Ceiling division | |
| has_more=(page * page_size) < total_count, | |
| ), | |
| ) | |
| except Exception as e: | |
| verbose_proxy_logger.exception(f"Error fetching daily activity: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail={"error": f"Failed to fetch analytics: {str(e)}"}, | |
| ) | |