sairampillai commited on
Commit
80f53c9
·
unverified ·
1 Parent(s): cb43301

Fix review commits and change to lazy logging

Browse files
Files changed (2) hide show
  1. app.py +3 -4
  2. helpers/llm_helper.py +15 -8
app.py CHANGED
@@ -223,9 +223,8 @@ with st.sidebar:
223
 
224
  # Validate that the selected provider is valid
225
  if selected_provider not in GlobalConfig.VALID_PROVIDERS:
226
- logger.error(f"Invalid provider: {selected_provider}")
227
- handle_error(f"Invalid provider selected: {selected_provider}", True)
228
- st.error(f"Invalid provider selected: {selected_provider}")
229
  st.stop()
230
 
231
  env_key_name = GlobalConfig.PROVIDER_ENV_KEYS.get(selected_provider)
@@ -610,7 +609,7 @@ def generate_slide_deck(json_str: str) -> Union[pathlib.Path, None]:
610
  )
611
  except Exception as ex:
612
  st.error(APP_TEXT['content_generation_error'])
613
- logger.error('Caught a generic exception: %s', str(ex))
614
 
615
  return path
616
 
 
223
 
224
  # Validate that the selected provider is valid
225
  if selected_provider not in GlobalConfig.VALID_PROVIDERS:
226
+ logger.error('Invalid provider: %s', selected_provider)
227
+ handle_error(f'Invalid provider selected: {selected_provider}', True)
 
228
  st.stop()
229
 
230
  env_key_name = GlobalConfig.PROVIDER_ENV_KEYS.get(selected_provider)
 
609
  )
610
  except Exception as ex:
611
  st.error(APP_TEXT['content_generation_error'])
612
+ logger.exception('Caught a generic exception: %s', str(ex))
613
 
614
  return path
615
 
helpers/llm_helper.py CHANGED
@@ -124,6 +124,13 @@ def is_valid_llm_provider_model(
124
  def get_litellm_model_name(provider: str, model: str) -> str:
125
  """
126
  Convert provider and model to LiteLLM model name format.
 
 
 
 
 
 
 
127
  """
128
  provider_prefix_map = {
129
  GlobalConfig.PROVIDER_HUGGING_FACE: 'huggingface',
@@ -136,8 +143,9 @@ def get_litellm_model_name(provider: str, model: str) -> str:
136
  }
137
  prefix = provider_prefix_map.get(provider)
138
  if prefix:
139
- return '%s/%s' % (prefix, model)
140
- return model
 
141
 
142
 
143
  def stream_litellm_completion(
@@ -173,7 +181,7 @@ def stream_litellm_completion(
173
  # This is consistent with Azure OpenAI's requirement to use deployment names
174
  if not azure_deployment_name:
175
  raise ValueError("Azure deployment name is required for Azure OpenAI provider")
176
- litellm_model = 'azure/%s' % azure_deployment_name
177
  else:
178
  litellm_model = get_litellm_model_name(provider, model)
179
 
@@ -194,8 +202,8 @@ def stream_litellm_completion(
194
  elif provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
195
  # For Azure OpenAI, pass credentials as parameters
196
  request_params['api_key'] = api_key
197
- request_params['azure_api_base'] = azure_endpoint_url
198
- request_params['azure_api_version'] = azure_api_version
199
  else:
200
  # For other providers, pass API key as parameter
201
  request_params['api_key'] = api_key
@@ -216,7 +224,7 @@ def stream_litellm_completion(
216
  yield choice.message.content
217
 
218
  except Exception as e:
219
- logger.error('Error in LiteLLM completion: %s', e)
220
  raise
221
 
222
 
@@ -243,8 +251,7 @@ def get_litellm_llm(
243
  """
244
 
245
  if litellm is None:
246
- logger.error("LiteLLM is not installed")
247
- return None
248
 
249
  # Create a simple wrapper object that mimics the LangChain streaming interface
250
  class LiteLLMWrapper:
 
124
  def get_litellm_model_name(provider: str, model: str) -> str:
125
  """
126
  Convert provider and model to LiteLLM model name format.
127
+
128
+ Note: Azure OpenAI models are handled separately in stream_litellm_completion()
129
+ and should not be passed to this function.
130
+
131
+ :param provider: The LLM provider.
132
+ :param model: The model name.
133
+ :return: LiteLLM-compatible model name, or None if provider is not supported.
134
  """
135
  provider_prefix_map = {
136
  GlobalConfig.PROVIDER_HUGGING_FACE: 'huggingface',
 
143
  }
144
  prefix = provider_prefix_map.get(provider)
145
  if prefix:
146
+ return f'{prefix}/{model}'
147
+ # LiteLLM always expects a prefix for model names; if not found, return None
148
+ return None
149
 
150
 
151
  def stream_litellm_completion(
 
181
  # This is consistent with Azure OpenAI's requirement to use deployment names
182
  if not azure_deployment_name:
183
  raise ValueError("Azure deployment name is required for Azure OpenAI provider")
184
+ litellm_model = f'azure/{azure_deployment_name}'
185
  else:
186
  litellm_model = get_litellm_model_name(provider, model)
187
 
 
202
  elif provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
203
  # For Azure OpenAI, pass credentials as parameters
204
  request_params['api_key'] = api_key
205
+ request_params['api_base'] = azure_endpoint_url
206
+ request_params['api_version'] = azure_api_version
207
  else:
208
  # For other providers, pass API key as parameter
209
  request_params['api_key'] = api_key
 
224
  yield choice.message.content
225
 
226
  except Exception as e:
227
+ logger.exception('Error in LiteLLM completion: %s', e)
228
  raise
229
 
230
 
 
251
  """
252
 
253
  if litellm is None:
254
+ raise ImportError("LiteLLM is not installed. Please install it with: pip install litellm")
 
255
 
256
  # Create a simple wrapper object that mimics the LangChain streaming interface
257
  class LiteLLMWrapper: