sairampillai commited on
Commit
15a4e19
·
1 Parent(s): 038ba10

Fix openrouter and Azure endpoint integration

Browse files
Files changed (2) hide show
  1. app.py +34 -3
  2. helpers/llm_helper.py +34 -8
app.py CHANGED
@@ -210,17 +210,39 @@ with st.sidebar:
210
  api_version: str = ''
211
  else:
212
  # The online LLMs
213
- llm_provider_to_use = st.sidebar.selectbox(
214
  label='2: Select a suitable LLM to use:\n\n(Gemini and Mistral-Nemo are recommended)',
215
  options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
216
  index=GlobalConfig.DEFAULT_MODEL_INDEX,
217
  help=GlobalConfig.LLM_PROVIDER_HELP,
218
  on_change=reset_api_key
219
- ).split(' ')[0]
 
 
 
 
 
 
 
 
 
220
 
221
  # --- Automatically fetch API key from .env if available ---
222
  provider_match = GlobalConfig.PROVIDER_REGEX.match(llm_provider_to_use)
223
- selected_provider = provider_match.group(1) if provider_match else llm_provider_to_use
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  env_key_name = GlobalConfig.PROVIDER_ENV_KEYS.get(selected_provider)
225
  default_api_key = os.getenv(env_key_name, "") if env_key_name else ""
226
 
@@ -372,6 +394,15 @@ def set_up_chat_ui():
372
  use_ollama=RUN_IN_OFFLINE_MODE
373
  )
374
 
 
 
 
 
 
 
 
 
 
375
  user_key = api_key_token.strip()
376
  az_deployment = azure_deployment.strip()
377
  az_endpoint = azure_endpoint.strip()
 
210
  api_version: str = ''
211
  else:
212
  # The online LLMs
213
+ selected_option = st.sidebar.selectbox(
214
  label='2: Select a suitable LLM to use:\n\n(Gemini and Mistral-Nemo are recommended)',
215
  options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
216
  index=GlobalConfig.DEFAULT_MODEL_INDEX,
217
  help=GlobalConfig.LLM_PROVIDER_HELP,
218
  on_change=reset_api_key
219
+ )
220
+
221
+ # Extract provider key more robustly using regex
222
+ provider_match = GlobalConfig.PROVIDER_REGEX.match(selected_option)
223
+ if provider_match:
224
+ llm_provider_to_use = selected_option # Use full string for get_provider_model
225
+ else:
226
+ # Fallback: try to extract the key before the first space
227
+ llm_provider_to_use = selected_option.split(' ')[0]
228
+ logger.warning(f"Could not parse provider from selectbox option: {selected_option}")
229
 
230
  # --- Automatically fetch API key from .env if available ---
231
  provider_match = GlobalConfig.PROVIDER_REGEX.match(llm_provider_to_use)
232
+ if provider_match:
233
+ selected_provider = provider_match.group(1)
234
+ else:
235
+ # If regex doesn't match, try to extract provider from the beginning
236
+ selected_provider = llm_provider_to_use.split(' ')[0] if ' ' in llm_provider_to_use else llm_provider_to_use
237
+ logger.warning(f"Provider regex did not match for: {llm_provider_to_use}, using: {selected_provider}")
238
+
239
+ # Validate that the selected provider is valid
240
+ if selected_provider not in GlobalConfig.VALID_PROVIDERS:
241
+ logger.error(f"Invalid provider: {selected_provider}")
242
+ handle_error(f"Invalid provider selected: {selected_provider}", True)
243
+ st.error(f"Invalid provider selected: {selected_provider}")
244
+ st.stop()
245
+
246
  env_key_name = GlobalConfig.PROVIDER_ENV_KEYS.get(selected_provider)
247
  default_api_key = os.getenv(env_key_name, "") if env_key_name else ""
248
 
 
394
  use_ollama=RUN_IN_OFFLINE_MODE
395
  )
396
 
397
+ # Validate that provider and model were parsed successfully
398
+ if not provider or not llm_name:
399
+ handle_error(
400
+ f'Failed to parse provider and model from: "{llm_provider_to_use}". '
401
+ f'Please select a valid LLM from the dropdown.',
402
+ True
403
+ )
404
+ return
405
+
406
  user_key = api_key_token.strip()
407
  az_deployment = azure_deployment.strip()
408
  az_endpoint = azure_endpoint.strip()
helpers/llm_helper.py CHANGED
@@ -70,8 +70,20 @@ def get_provider_model(provider_model: str, use_ollama: bool) -> Tuple[str, str]
70
  if match:
71
  inside_brackets = match.group(1)
72
  outside_brackets = match.group(2)
 
 
 
 
 
 
 
 
 
 
 
73
  return inside_brackets, outside_brackets
74
 
 
75
  return '', ''
76
 
77
 
@@ -183,7 +195,14 @@ def stream_litellm_completion(
183
  raise ImportError("LiteLLM is not installed. Please install it with: pip install litellm")
184
 
185
  # Convert to LiteLLM model name
186
- litellm_model = get_litellm_model_name(provider, model)
 
 
 
 
 
 
 
187
 
188
  # Prepare the request parameters
189
  request_params = {
@@ -196,13 +215,20 @@ def stream_litellm_completion(
196
 
197
  # Set API key and any provider-specific params
198
  if provider != GlobalConfig.PROVIDER_OLLAMA:
199
- api_key_to_use = get_litellm_api_key(provider, api_key)
200
- request_params["api_key"] = api_key_to_use
201
-
202
- if provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
203
- request_params["azure_endpoint"] = azure_endpoint_url
204
- request_params["azure_deployment"] = azure_deployment_name
205
- request_params["api_version"] = azure_api_version
 
 
 
 
 
 
 
206
 
207
  logger.debug('Streaming completion via LiteLLM: %s', litellm_model)
208
 
 
70
  if match:
71
  inside_brackets = match.group(1)
72
  outside_brackets = match.group(2)
73
+
74
+ # Validate that the provider is in the valid providers list
75
+ if inside_brackets not in GlobalConfig.VALID_PROVIDERS:
76
+ logger.warning(f"Provider '{inside_brackets}' not in VALID_PROVIDERS: {GlobalConfig.VALID_PROVIDERS}")
77
+ return '', ''
78
+
79
+ # Validate that the model name is not empty
80
+ if not outside_brackets.strip():
81
+ logger.warning(f"Empty model name for provider '{inside_brackets}'")
82
+ return '', ''
83
+
84
  return inside_brackets, outside_brackets
85
 
86
+ logger.warning(f"Could not parse provider_model: '{provider_model}' (use_ollama={use_ollama})")
87
  return '', ''
88
 
89
 
 
195
  raise ImportError("LiteLLM is not installed. Please install it with: pip install litellm")
196
 
197
  # Convert to LiteLLM model name
198
+ if provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
199
+ # For Azure OpenAI, use the deployment name as the model
200
+ # This is consistent with Azure OpenAI's requirement to use deployment names
201
+ if not azure_deployment_name:
202
+ raise ValueError("Azure deployment name is required for Azure OpenAI provider")
203
+ litellm_model = f"azure/{azure_deployment_name}"
204
+ else:
205
+ litellm_model = get_litellm_model_name(provider, model)
206
 
207
  # Prepare the request parameters
208
  request_params = {
 
215
 
216
  # Set API key and any provider-specific params
217
  if provider != GlobalConfig.PROVIDER_OLLAMA:
218
+ # For OpenRouter, set environment variable as per documentation
219
+ if provider == GlobalConfig.PROVIDER_OPENROUTER:
220
+ os.environ["OPENROUTER_API_KEY"] = api_key
221
+ # Optional: Set base URL if different from default
222
+ # os.environ["OPENROUTER_API_BASE"] = "https://openrouter.ai/api/v1"
223
+ elif provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
224
+ # For Azure OpenAI, set environment variables as per documentation
225
+ os.environ["AZURE_API_KEY"] = api_key
226
+ os.environ["AZURE_API_BASE"] = azure_endpoint_url
227
+ os.environ["AZURE_API_VERSION"] = azure_api_version
228
+ else:
229
+ # For other providers, pass API key as parameter
230
+ api_key_to_use = get_litellm_api_key(provider, api_key)
231
+ request_params["api_key"] = api_key_to_use
232
 
233
  logger.debug('Streaming completion via LiteLLM: %s', litellm_model)
234