sairampillai commited on
Commit
038ba10
·
1 Parent(s): 0ad77a1

Clean up llm_helper.py

Browse files
Files changed (1) hide show
  1. helpers/llm_helper.py +15 -23
helpers/llm_helper.py CHANGED
@@ -121,10 +121,6 @@ def is_valid_llm_provider_model(
121
  def get_litellm_model_name(provider: str, model: str) -> str:
122
  """
123
  Convert provider and model to LiteLLM model name format.
124
-
125
- :param provider: The LLM provider.
126
- :param model: The model name.
127
- :return: LiteLLM formatted model name.
128
  """
129
  provider_prefix_map = {
130
  GlobalConfig.PROVIDER_HUGGING_FACE: "huggingface",
@@ -144,12 +140,18 @@ def get_litellm_model_name(provider: str, model: str) -> str:
144
  def get_litellm_api_key(provider: str, api_key: str) -> str:
145
  """
146
  Get the appropriate API key for LiteLLM based on provider.
147
-
148
- :param provider: The LLM provider.
149
- :param api_key: The API key.
150
- :return: The API key.
151
  """
152
- # All current providers just return the api_key, but this is left for future extensibility.
 
 
 
 
 
 
 
 
 
 
153
  return api_key
154
 
155
 
@@ -192,25 +194,15 @@ def stream_litellm_completion(
192
  "stream": True,
193
  }
194
 
195
- # Set API key based on provider
196
  if provider != GlobalConfig.PROVIDER_OLLAMA:
197
  api_key_to_use = get_litellm_api_key(provider, api_key)
198
-
199
- if provider == GlobalConfig.PROVIDER_OPENROUTER:
200
- request_params["api_key"] = api_key_to_use
201
- elif provider == GlobalConfig.PROVIDER_COHERE:
202
- request_params["api_key"] = api_key_to_use
203
- elif provider == GlobalConfig.PROVIDER_TOGETHER_AI:
204
- request_params["api_key"] = api_key_to_use
205
- elif provider == GlobalConfig.PROVIDER_GOOGLE_GEMINI:
206
- request_params["api_key"] = api_key_to_use
207
- elif provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
208
- request_params["api_key"] = api_key_to_use
209
  request_params["azure_endpoint"] = azure_endpoint_url
210
  request_params["azure_deployment"] = azure_deployment_name
211
  request_params["api_version"] = azure_api_version
212
- elif provider == GlobalConfig.PROVIDER_HUGGING_FACE:
213
- request_params["api_key"] = api_key_to_use
214
 
215
  logger.debug('Streaming completion via LiteLLM: %s', litellm_model)
216
 
 
121
  def get_litellm_model_name(provider: str, model: str) -> str:
122
  """
123
  Convert provider and model to LiteLLM model name format.
 
 
 
 
124
  """
125
  provider_prefix_map = {
126
  GlobalConfig.PROVIDER_HUGGING_FACE: "huggingface",
 
140
  def get_litellm_api_key(provider: str, api_key: str) -> str:
141
  """
142
  Get the appropriate API key for LiteLLM based on provider.
 
 
 
 
143
  """
144
+ # All listed providers just return the api_key, so we can use a set for clarity
145
+ providers_with_api_key = {
146
+ GlobalConfig.PROVIDER_OPENROUTER,
147
+ GlobalConfig.PROVIDER_COHERE,
148
+ GlobalConfig.PROVIDER_TOGETHER_AI,
149
+ GlobalConfig.PROVIDER_GOOGLE_GEMINI,
150
+ GlobalConfig.PROVIDER_AZURE_OPENAI,
151
+ GlobalConfig.PROVIDER_HUGGING_FACE,
152
+ }
153
+ if provider in providers_with_api_key:
154
+ return api_key
155
  return api_key
156
 
157
 
 
194
  "stream": True,
195
  }
196
 
197
+ # Set API key and any provider-specific params
198
  if provider != GlobalConfig.PROVIDER_OLLAMA:
199
  api_key_to_use = get_litellm_api_key(provider, api_key)
200
+ request_params["api_key"] = api_key_to_use
201
+
202
+ if provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
 
 
 
 
 
 
 
 
203
  request_params["azure_endpoint"] = azure_endpoint_url
204
  request_params["azure_deployment"] = azure_deployment_name
205
  request_params["api_version"] = azure_api_version
 
 
206
 
207
  logger.debug('Streaming completion via LiteLLM: %s', litellm_model)
208