Spaces:
Paused
Paused
| """ | |
| Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions` | |
| """ | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| AsyncIterator, | |
| Iterator, | |
| List, | |
| Optional, | |
| Tuple, | |
| Union, | |
| cast, | |
| ) | |
| import httpx | |
| from pydantic import BaseModel | |
| from litellm.constants import RESPONSE_FORMAT_TOOL_NAME | |
| from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import ( | |
| _handle_invalid_parallel_tool_calls, | |
| _should_convert_tool_call_to_json_mode, | |
| ) | |
| from litellm.litellm_core_utils.prompt_templates.common_utils import ( | |
| handle_messages_with_content_list_to_str_conversion, | |
| strip_name_from_messages, | |
| ) | |
| from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator | |
| from litellm.types.llms.anthropic import AllAnthropicToolsValues | |
| from litellm.types.llms.databricks import ( | |
| AllDatabricksContentValues, | |
| DatabricksChoice, | |
| DatabricksFunction, | |
| DatabricksResponse, | |
| DatabricksTool, | |
| ) | |
| from litellm.types.llms.openai import ( | |
| AllMessageValues, | |
| ChatCompletionRedactedThinkingBlock, | |
| ChatCompletionThinkingBlock, | |
| ChatCompletionToolChoiceFunctionParam, | |
| ChatCompletionToolChoiceObjectParam, | |
| ) | |
| from litellm.types.utils import ( | |
| ChatCompletionMessageToolCall, | |
| Choices, | |
| Message, | |
| ModelResponse, | |
| ModelResponseStream, | |
| ProviderField, | |
| Usage, | |
| ) | |
| from ...anthropic.chat.transformation import AnthropicConfig | |
| from ...openai_like.chat.transformation import OpenAILikeChatConfig | |
| from ..common_utils import DatabricksBase, DatabricksException | |
| if TYPE_CHECKING: | |
| from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj | |
| LiteLLMLoggingObj = _LiteLLMLoggingObj | |
| else: | |
| LiteLLMLoggingObj = Any | |
| class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig): | |
| """ | |
| Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request | |
| """ | |
| max_tokens: Optional[int] = None | |
| temperature: Optional[int] = None | |
| top_p: Optional[int] = None | |
| top_k: Optional[int] = None | |
| stop: Optional[Union[List[str], str]] = None | |
| n: Optional[int] = None | |
| def __init__( | |
| self, | |
| max_tokens: Optional[int] = None, | |
| temperature: Optional[int] = None, | |
| top_p: Optional[int] = None, | |
| top_k: Optional[int] = None, | |
| stop: Optional[Union[List[str], str]] = None, | |
| n: Optional[int] = None, | |
| ) -> None: | |
| locals_ = locals().copy() | |
| for key, value in locals_.items(): | |
| if key != "self" and value is not None: | |
| setattr(self.__class__, key, value) | |
| def get_config(cls): | |
| return super().get_config() | |
| def get_required_params(self) -> List[ProviderField]: | |
| """For a given provider, return it's required fields with a description""" | |
| return [ | |
| ProviderField( | |
| field_name="api_key", | |
| field_type="string", | |
| field_description="Your Databricks API Key.", | |
| field_value="dapi...", | |
| ), | |
| ProviderField( | |
| field_name="api_base", | |
| field_type="string", | |
| field_description="Your Databricks API Base.", | |
| field_value="https://adb-..", | |
| ), | |
| ] | |
| def validate_environment( | |
| self, | |
| headers: dict, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| ) -> dict: | |
| api_base, headers = self.databricks_validate_environment( | |
| api_base=api_base, | |
| api_key=api_key, | |
| endpoint_type="chat_completions", | |
| custom_endpoint=False, | |
| headers=headers, | |
| ) | |
| # Ensure Content-Type header is set | |
| headers["Content-Type"] = "application/json" | |
| return headers | |
| def get_complete_url( | |
| self, | |
| api_base: Optional[str], | |
| api_key: Optional[str], | |
| model: str, | |
| optional_params: dict, | |
| litellm_params: dict, | |
| stream: Optional[bool] = None, | |
| ) -> str: | |
| api_base = self._get_api_base(api_base) | |
| complete_url = f"{api_base}/chat/completions" | |
| return complete_url | |
| def get_supported_openai_params(self, model: Optional[str] = None) -> list: | |
| return [ | |
| "stream", | |
| "stop", | |
| "temperature", | |
| "top_p", | |
| "max_tokens", | |
| "max_completion_tokens", | |
| "n", | |
| "response_format", | |
| "tools", | |
| "tool_choice", | |
| "reasoning_effort", | |
| "thinking", | |
| ] | |
| def convert_anthropic_tool_to_databricks_tool( | |
| self, tool: Optional[AllAnthropicToolsValues] | |
| ) -> Optional[DatabricksTool]: | |
| if tool is None: | |
| return None | |
| return DatabricksTool( | |
| type="function", | |
| function=DatabricksFunction( | |
| name=tool["name"], | |
| parameters=cast(dict, tool.get("input_schema") or {}), | |
| ), | |
| ) | |
| def _map_openai_to_dbrx_tool(self, model: str, tools: List) -> List[DatabricksTool]: | |
| # if not claude, send as is | |
| if "claude" not in model: | |
| return tools | |
| # if claude, convert to anthropic tool and then to databricks tool | |
| anthropic_tools = self._map_tools(tools=tools) | |
| databricks_tools = [ | |
| cast(DatabricksTool, self.convert_anthropic_tool_to_databricks_tool(tool)) | |
| for tool in anthropic_tools | |
| ] | |
| return databricks_tools | |
| def map_response_format_to_databricks_tool( | |
| self, | |
| model: str, | |
| value: Optional[dict], | |
| optional_params: dict, | |
| is_thinking_enabled: bool, | |
| ) -> Optional[DatabricksTool]: | |
| if value is None: | |
| return None | |
| tool = self.map_response_format_to_anthropic_tool( | |
| value, optional_params, is_thinking_enabled | |
| ) | |
| databricks_tool = self.convert_anthropic_tool_to_databricks_tool(tool) | |
| return databricks_tool | |
| def map_openai_params( | |
| self, | |
| non_default_params: dict, | |
| optional_params: dict, | |
| model: str, | |
| drop_params: bool, | |
| replace_max_completion_tokens_with_max_tokens: bool = True, | |
| ) -> dict: | |
| is_thinking_enabled = self.is_thinking_enabled(non_default_params) | |
| mapped_params = super().map_openai_params( | |
| non_default_params, optional_params, model, drop_params | |
| ) | |
| if "tools" in mapped_params: | |
| mapped_params["tools"] = self._map_openai_to_dbrx_tool( | |
| model=model, tools=mapped_params["tools"] | |
| ) | |
| if ( | |
| "max_completion_tokens" in non_default_params | |
| and replace_max_completion_tokens_with_max_tokens | |
| ): | |
| mapped_params["max_tokens"] = non_default_params[ | |
| "max_completion_tokens" | |
| ] # most openai-compatible providers support 'max_tokens' not 'max_completion_tokens' | |
| mapped_params.pop("max_completion_tokens", None) | |
| if "response_format" in non_default_params and "claude" in model: | |
| _tool = self.map_response_format_to_databricks_tool( | |
| model, | |
| non_default_params["response_format"], | |
| mapped_params, | |
| is_thinking_enabled, | |
| ) | |
| if _tool is not None: | |
| self._add_tools_to_optional_params( | |
| optional_params=optional_params, tools=[_tool] | |
| ) | |
| optional_params["json_mode"] = True | |
| if not is_thinking_enabled: | |
| _tool_choice = ChatCompletionToolChoiceObjectParam( | |
| type="function", | |
| function=ChatCompletionToolChoiceFunctionParam( | |
| name=RESPONSE_FORMAT_TOOL_NAME | |
| ), | |
| ) | |
| optional_params["tool_choice"] = _tool_choice | |
| optional_params.pop( | |
| "response_format", None | |
| ) # unsupported for claude models - if json_schema -> convert to tool call | |
| if "reasoning_effort" in non_default_params and "claude" in model: | |
| optional_params["thinking"] = AnthropicConfig._map_reasoning_effort( | |
| non_default_params.get("reasoning_effort") | |
| ) | |
| optional_params.pop("reasoning_effort", None) | |
| ## handle thinking tokens | |
| self.update_optional_params_with_thinking_tokens( | |
| non_default_params=non_default_params, optional_params=mapped_params | |
| ) | |
| return mapped_params | |
| def _should_fake_stream(self, optional_params: dict) -> bool: | |
| """ | |
| Databricks doesn't support 'response_format' while streaming | |
| """ | |
| if optional_params.get("response_format") is not None: | |
| return True | |
| return False | |
| def _transform_messages( | |
| self, messages: List[AllMessageValues], model: str | |
| ) -> List[AllMessageValues]: | |
| """ | |
| Databricks does not support: | |
| - content in list format. | |
| - 'name' in user message. | |
| """ | |
| new_messages = [] | |
| for idx, message in enumerate(messages): | |
| if isinstance(message, BaseModel): | |
| _message = message.model_dump(exclude_none=True) | |
| else: | |
| _message = message | |
| new_messages.append(_message) | |
| new_messages = handle_messages_with_content_list_to_str_conversion(new_messages) | |
| new_messages = strip_name_from_messages(new_messages) | |
| return super()._transform_messages(messages=new_messages, model=model) | |
| def extract_content_str( | |
| content: Optional[AllDatabricksContentValues], | |
| ) -> Optional[str]: | |
| if content is None: | |
| return None | |
| if isinstance(content, str): | |
| return content | |
| elif isinstance(content, list): | |
| content_str = "" | |
| for item in content: | |
| if item["type"] == "text": | |
| content_str += item["text"] | |
| return content_str | |
| else: | |
| raise Exception(f"Unsupported content type: {type(content)}") | |
| def extract_reasoning_content( | |
| content: Optional[AllDatabricksContentValues], | |
| ) -> Tuple[ | |
| Optional[str], | |
| Optional[ | |
| List[ | |
| Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock] | |
| ] | |
| ], | |
| ]: | |
| """ | |
| Extract and return the reasoning content and thinking blocks | |
| """ | |
| if content is None: | |
| return None, None | |
| thinking_blocks: Optional[ | |
| List[ | |
| Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock] | |
| ] | |
| ] = None | |
| reasoning_content: Optional[str] = None | |
| if isinstance(content, list): | |
| for item in content: | |
| if item["type"] == "reasoning": | |
| for sum in item["summary"]: | |
| if reasoning_content is None: | |
| reasoning_content = "" | |
| reasoning_content += sum["text"] | |
| thinking_block = ChatCompletionThinkingBlock( | |
| type="thinking", | |
| thinking=sum["text"], | |
| signature=sum["signature"], | |
| ) | |
| if thinking_blocks is None: | |
| thinking_blocks = [] | |
| thinking_blocks.append(thinking_block) | |
| return reasoning_content, thinking_blocks | |
| def _transform_dbrx_choices( | |
| self, choices: List[DatabricksChoice], json_mode: Optional[bool] = None | |
| ) -> List[Choices]: | |
| transformed_choices = [] | |
| for choice in choices: | |
| ## HANDLE JSON MODE - anthropic returns single function call] | |
| tool_calls = choice["message"].get("tool_calls", None) | |
| if tool_calls is not None: | |
| _openai_tool_calls = [] | |
| for _tc in tool_calls: | |
| _openai_tc = ChatCompletionMessageToolCall(**_tc) # type: ignore | |
| _openai_tool_calls.append(_openai_tc) | |
| fixed_tool_calls = _handle_invalid_parallel_tool_calls( | |
| _openai_tool_calls | |
| ) | |
| if fixed_tool_calls is not None: | |
| tool_calls = fixed_tool_calls | |
| translated_message: Optional[Message] = None | |
| finish_reason: Optional[str] = None | |
| if tool_calls and _should_convert_tool_call_to_json_mode( | |
| tool_calls=tool_calls, | |
| convert_tool_call_to_json_mode=json_mode, | |
| ): | |
| # to support response_format on claude models | |
| json_mode_content_str: Optional[str] = ( | |
| str(tool_calls[0]["function"].get("arguments", "")) or None | |
| ) | |
| if json_mode_content_str is not None: | |
| translated_message = Message(content=json_mode_content_str) | |
| finish_reason = "stop" | |
| if translated_message is None: | |
| ## get the content str | |
| content_str = DatabricksConfig.extract_content_str( | |
| choice["message"]["content"] | |
| ) | |
| ## get the reasoning content | |
| ( | |
| reasoning_content, | |
| thinking_blocks, | |
| ) = DatabricksConfig.extract_reasoning_content( | |
| choice["message"].get("content") | |
| ) | |
| translated_message = Message( | |
| role="assistant", | |
| content=content_str, | |
| reasoning_content=reasoning_content, | |
| thinking_blocks=thinking_blocks, | |
| tool_calls=choice["message"].get("tool_calls"), | |
| ) | |
| if finish_reason is None: | |
| finish_reason = choice["finish_reason"] | |
| translated_choice = Choices( | |
| finish_reason=finish_reason, | |
| index=choice["index"], | |
| message=translated_message, | |
| logprobs=None, | |
| enhancements=None, | |
| ) | |
| transformed_choices.append(translated_choice) | |
| return transformed_choices | |
| def transform_response( | |
| self, | |
| model: str, | |
| raw_response: httpx.Response, | |
| model_response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| request_data: dict, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| encoding: Any, | |
| api_key: Optional[str] = None, | |
| json_mode: Optional[bool] = None, | |
| ) -> ModelResponse: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=messages, | |
| api_key=api_key, | |
| original_response=raw_response.text, | |
| additional_args={"complete_input_dict": request_data}, | |
| ) | |
| ## RESPONSE OBJECT | |
| try: | |
| completion_response = DatabricksResponse(**raw_response.json()) # type: ignore | |
| except Exception as e: | |
| response_headers = getattr(raw_response, "headers", None) | |
| raise DatabricksException( | |
| message="Unable to get json response - {}, Original Response: {}".format( | |
| str(e), raw_response.text | |
| ), | |
| status_code=raw_response.status_code, | |
| headers=response_headers, | |
| ) | |
| model_response.model = completion_response["model"] | |
| model_response.id = completion_response["id"] | |
| model_response.created = completion_response["created"] | |
| setattr(model_response, "usage", Usage(**completion_response["usage"])) | |
| model_response.choices = self._transform_dbrx_choices( # type: ignore | |
| choices=completion_response["choices"], | |
| json_mode=json_mode, | |
| ) | |
| return model_response | |
| def get_model_response_iterator( | |
| self, | |
| streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], | |
| sync_stream: bool, | |
| json_mode: Optional[bool] = False, | |
| ): | |
| return DatabricksChatResponseIterator( | |
| streaming_response=streaming_response, | |
| sync_stream=sync_stream, | |
| json_mode=json_mode, | |
| ) | |
| class DatabricksChatResponseIterator(BaseModelResponseIterator): | |
| def __init__( | |
| self, | |
| streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], | |
| sync_stream: bool, | |
| json_mode: Optional[bool] = False, | |
| ): | |
| super().__init__(streaming_response, sync_stream) | |
| self.json_mode = json_mode | |
| self._last_function_name = None # Track the last seen function name | |
| def chunk_parser(self, chunk: dict) -> ModelResponseStream: | |
| try: | |
| translated_choices = [] | |
| for choice in chunk["choices"]: | |
| tool_calls = choice["delta"].get("tool_calls") | |
| if tool_calls and self.json_mode: | |
| # 1. Check if the function name is set and == RESPONSE_FORMAT_TOOL_NAME | |
| # 2. If no function name, just args -> check last function name (saved via state variable) | |
| # 3. Convert args to json | |
| # 4. Convert json to message | |
| # 5. Set content to message.content | |
| # 6. Set tool_calls to None | |
| from litellm.constants import RESPONSE_FORMAT_TOOL_NAME | |
| from litellm.llms.base_llm.base_utils import ( | |
| _convert_tool_response_to_message, | |
| ) | |
| # Check if this chunk has a function name | |
| function_name = tool_calls[0].get("function", {}).get("name") | |
| if function_name is not None: | |
| self._last_function_name = function_name | |
| # If we have a saved function name that matches RESPONSE_FORMAT_TOOL_NAME | |
| # or this chunk has the matching function name | |
| if ( | |
| self._last_function_name == RESPONSE_FORMAT_TOOL_NAME | |
| or function_name == RESPONSE_FORMAT_TOOL_NAME | |
| ): | |
| # Convert tool calls to message format | |
| message = _convert_tool_response_to_message(tool_calls) | |
| if message is not None: | |
| if message.content == "{}": # empty json | |
| message.content = "" | |
| choice["delta"]["content"] = message.content | |
| choice["delta"]["tool_calls"] = None | |
| elif tool_calls: | |
| for _tc in tool_calls: | |
| if _tc.get("function", {}).get("arguments") == "{}": | |
| _tc["function"]["arguments"] = "" # avoid invalid json | |
| # extract the content str | |
| content_str = DatabricksConfig.extract_content_str( | |
| choice["delta"].get("content") | |
| ) | |
| # extract the reasoning content | |
| ( | |
| reasoning_content, | |
| thinking_blocks, | |
| ) = DatabricksConfig.extract_reasoning_content( | |
| choice["delta"]["content"] | |
| ) | |
| choice["delta"]["content"] = content_str | |
| choice["delta"]["reasoning_content"] = reasoning_content | |
| choice["delta"]["thinking_blocks"] = thinking_blocks | |
| translated_choices.append(choice) | |
| return ModelResponseStream( | |
| id=chunk["id"], | |
| object="chat.completion.chunk", | |
| created=chunk["created"], | |
| model=chunk["model"], | |
| choices=translated_choices, | |
| ) | |
| except KeyError as e: | |
| raise DatabricksException( | |
| message=f"KeyError: {e}, Got unexpected response from Databricks: {chunk}", | |
| status_code=400, | |
| ) | |
| except Exception as e: | |
| raise e | |