Spaces:
Paused
Paused
| """ | |
| Ollama /chat/completion calls handled in llm_http_handler.py | |
| [TODO]: migrate embeddings to a base handler as well. | |
| """ | |
| import asyncio | |
| from typing import Any, Dict, List | |
| import litellm | |
| from litellm.types.utils import EmbeddingResponse | |
| # ollama wants plain base64 jpeg/png files as images. strip any leading dataURI | |
| # and convert to jpeg if necessary. | |
| async def ollama_aembeddings( | |
| api_base: str, | |
| model: str, | |
| prompts: List[str], | |
| model_response: EmbeddingResponse, | |
| optional_params: dict, | |
| logging_obj: Any, | |
| encoding: Any, | |
| ): | |
| if api_base.endswith("/api/embed"): | |
| url = api_base | |
| else: | |
| url = f"{api_base}/api/embed" | |
| ## Load Config | |
| config = litellm.OllamaConfig.get_config() | |
| for k, v in config.items(): | |
| if ( | |
| k not in optional_params | |
| ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in | |
| optional_params[k] = v | |
| data: Dict[str, Any] = {"model": model, "input": prompts} | |
| special_optional_params = ["truncate", "options", "keep_alive"] | |
| for k, v in optional_params.items(): | |
| if k in special_optional_params: | |
| data[k] = v | |
| else: | |
| # Ensure "options" is a dictionary before updating it | |
| data.setdefault("options", {}) | |
| if isinstance(data["options"], dict): | |
| data["options"].update({k: v}) | |
| total_input_tokens = 0 | |
| output_data = [] | |
| response = await litellm.module_level_aclient.post(url=url, json=data) | |
| response_json = response.json() | |
| embeddings: List[List[float]] = response_json["embeddings"] | |
| for idx, emb in enumerate(embeddings): | |
| output_data.append({"object": "embedding", "index": idx, "embedding": emb}) | |
| input_tokens = response_json.get("prompt_eval_count") or len( | |
| encoding.encode("".join(prompt for prompt in prompts)) | |
| ) | |
| total_input_tokens += input_tokens | |
| model_response.object = "list" | |
| model_response.data = output_data | |
| model_response.model = "ollama/" + model | |
| setattr( | |
| model_response, | |
| "usage", | |
| litellm.Usage( | |
| prompt_tokens=total_input_tokens, | |
| completion_tokens=total_input_tokens, | |
| total_tokens=total_input_tokens, | |
| prompt_tokens_details=None, | |
| completion_tokens_details=None, | |
| ), | |
| ) | |
| return model_response | |
| def ollama_embeddings( | |
| api_base: str, | |
| model: str, | |
| prompts: list, | |
| optional_params: dict, | |
| model_response: EmbeddingResponse, | |
| logging_obj: Any, | |
| encoding=None, | |
| ): | |
| return asyncio.run( | |
| ollama_aembeddings( | |
| api_base=api_base, | |
| model=model, | |
| prompts=prompts, | |
| model_response=model_response, | |
| optional_params=optional_params, | |
| logging_obj=logging_obj, | |
| encoding=encoding, | |
| ) | |
| ) | |