Spaces:
Paused
Paused
| """ | |
| async with websockets.connect( # type: ignore | |
| url, | |
| extra_headers={ | |
| "api-key": api_key, # type: ignore | |
| }, | |
| ) as backend_ws: | |
| forward_task = asyncio.create_task( | |
| forward_messages(websocket, backend_ws) | |
| ) | |
| try: | |
| while True: | |
| message = await websocket.receive_text() | |
| await backend_ws.send(message) | |
| except websockets.exceptions.ConnectionClosed: # type: ignore | |
| forward_task.cancel() | |
| finally: | |
| if not forward_task.done(): | |
| forward_task.cancel() | |
| try: | |
| await forward_task | |
| except asyncio.CancelledError: | |
| pass | |
| """ | |
| import asyncio | |
| import concurrent.futures | |
| import json | |
| from typing import Any, Dict, List, Optional, Union | |
| import litellm | |
| from litellm._logging import verbose_logger | |
| from litellm.types.llms.openai import ( | |
| OpenAIRealtimeStreamResponseBaseObject, | |
| OpenAIRealtimeStreamSessionEvents, | |
| ) | |
| from .litellm_logging import Logging as LiteLLMLogging | |
| # Create a thread pool with a maximum of 10 threads | |
| executor = concurrent.futures.ThreadPoolExecutor(max_workers=10) | |
| DefaultLoggedRealTimeEventTypes = [ | |
| "session.created", | |
| "response.create", | |
| "response.done", | |
| ] | |
| class RealTimeStreaming: | |
| def __init__( | |
| self, | |
| websocket: Any, | |
| backend_ws: Any, | |
| logging_obj: Optional[LiteLLMLogging] = None, | |
| ): | |
| self.websocket = websocket | |
| self.backend_ws = backend_ws | |
| self.logging_obj = logging_obj | |
| self.messages: List[ | |
| Union[ | |
| OpenAIRealtimeStreamResponseBaseObject, | |
| OpenAIRealtimeStreamSessionEvents, | |
| ] | |
| ] = [] | |
| self.input_message: Dict = {} | |
| _logged_real_time_event_types = litellm.logged_real_time_event_types | |
| if _logged_real_time_event_types is None: | |
| _logged_real_time_event_types = DefaultLoggedRealTimeEventTypes | |
| self.logged_real_time_event_types = _logged_real_time_event_types | |
| def _should_store_message( | |
| self, | |
| message_obj: Union[ | |
| dict, | |
| OpenAIRealtimeStreamSessionEvents, | |
| OpenAIRealtimeStreamResponseBaseObject, | |
| ], | |
| ) -> bool: | |
| _msg_type = message_obj["type"] | |
| if self.logged_real_time_event_types == "*": | |
| return True | |
| if _msg_type in self.logged_real_time_event_types: | |
| return True | |
| return False | |
| def store_message(self, message: Union[str, bytes]): | |
| """Store message in list""" | |
| if isinstance(message, bytes): | |
| message = message.decode("utf-8") | |
| message_obj = json.loads(message) | |
| try: | |
| if ( | |
| message_obj.get("type") == "session.created" | |
| or message_obj.get("type") == "session.updated" | |
| ): | |
| message_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore | |
| else: | |
| message_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore | |
| except Exception as e: | |
| verbose_logger.debug(f"Error parsing message for logging: {e}") | |
| raise e | |
| if self._should_store_message(message_obj): | |
| self.messages.append(message_obj) | |
| def store_input(self, message: dict): | |
| """Store input message""" | |
| self.input_message = message | |
| if self.logging_obj: | |
| self.logging_obj.pre_call(input=message, api_key="") | |
| async def log_messages(self): | |
| """Log messages in list""" | |
| if self.logging_obj: | |
| ## ASYNC LOGGING | |
| # Create an event loop for the new thread | |
| asyncio.create_task(self.logging_obj.async_success_handler(self.messages)) | |
| ## SYNC LOGGING | |
| executor.submit(self.logging_obj.success_handler(self.messages)) | |
| async def backend_to_client_send_messages(self): | |
| import websockets | |
| try: | |
| while True: | |
| message = await self.backend_ws.recv() | |
| await self.websocket.send_text(message) | |
| ## LOGGING | |
| self.store_message(message) | |
| except websockets.exceptions.ConnectionClosed: # type: ignore | |
| pass | |
| except Exception: | |
| pass | |
| finally: | |
| await self.log_messages() | |
| async def client_ack_messages(self): | |
| try: | |
| while True: | |
| message = await self.websocket.receive_text() | |
| ## LOGGING | |
| self.store_input(message=message) | |
| ## FORWARD TO BACKEND | |
| await self.backend_ws.send(message) | |
| except self.websockets.exceptions.ConnectionClosed: # type: ignore | |
| pass | |
| async def bidirectional_forward(self): | |
| forward_task = asyncio.create_task(self.backend_to_client_send_messages()) | |
| try: | |
| await self.client_ack_messages() | |
| except self.websockets.exceptions.ConnectionClosed: # type: ignore | |
| forward_task.cancel() | |
| finally: | |
| if not forward_task.done(): | |
| forward_task.cancel() | |
| try: | |
| await forward_task | |
| except asyncio.CancelledError: | |
| pass | |