Spaces:
Running
Running
Commit
·
15a4e19
1
Parent(s):
038ba10
Fix openrouter and Azure endpoint integration
Browse files- app.py +34 -3
- helpers/llm_helper.py +34 -8
app.py
CHANGED
|
@@ -210,17 +210,39 @@ with st.sidebar:
|
|
| 210 |
api_version: str = ''
|
| 211 |
else:
|
| 212 |
# The online LLMs
|
| 213 |
-
|
| 214 |
label='2: Select a suitable LLM to use:\n\n(Gemini and Mistral-Nemo are recommended)',
|
| 215 |
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
|
| 216 |
index=GlobalConfig.DEFAULT_MODEL_INDEX,
|
| 217 |
help=GlobalConfig.LLM_PROVIDER_HELP,
|
| 218 |
on_change=reset_api_key
|
| 219 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
# --- Automatically fetch API key from .env if available ---
|
| 222 |
provider_match = GlobalConfig.PROVIDER_REGEX.match(llm_provider_to_use)
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
env_key_name = GlobalConfig.PROVIDER_ENV_KEYS.get(selected_provider)
|
| 225 |
default_api_key = os.getenv(env_key_name, "") if env_key_name else ""
|
| 226 |
|
|
@@ -372,6 +394,15 @@ def set_up_chat_ui():
|
|
| 372 |
use_ollama=RUN_IN_OFFLINE_MODE
|
| 373 |
)
|
| 374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
user_key = api_key_token.strip()
|
| 376 |
az_deployment = azure_deployment.strip()
|
| 377 |
az_endpoint = azure_endpoint.strip()
|
|
|
|
| 210 |
api_version: str = ''
|
| 211 |
else:
|
| 212 |
# The online LLMs
|
| 213 |
+
selected_option = st.sidebar.selectbox(
|
| 214 |
label='2: Select a suitable LLM to use:\n\n(Gemini and Mistral-Nemo are recommended)',
|
| 215 |
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
|
| 216 |
index=GlobalConfig.DEFAULT_MODEL_INDEX,
|
| 217 |
help=GlobalConfig.LLM_PROVIDER_HELP,
|
| 218 |
on_change=reset_api_key
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Extract provider key more robustly using regex
|
| 222 |
+
provider_match = GlobalConfig.PROVIDER_REGEX.match(selected_option)
|
| 223 |
+
if provider_match:
|
| 224 |
+
llm_provider_to_use = selected_option # Use full string for get_provider_model
|
| 225 |
+
else:
|
| 226 |
+
# Fallback: try to extract the key before the first space
|
| 227 |
+
llm_provider_to_use = selected_option.split(' ')[0]
|
| 228 |
+
logger.warning(f"Could not parse provider from selectbox option: {selected_option}")
|
| 229 |
|
| 230 |
# --- Automatically fetch API key from .env if available ---
|
| 231 |
provider_match = GlobalConfig.PROVIDER_REGEX.match(llm_provider_to_use)
|
| 232 |
+
if provider_match:
|
| 233 |
+
selected_provider = provider_match.group(1)
|
| 234 |
+
else:
|
| 235 |
+
# If regex doesn't match, try to extract provider from the beginning
|
| 236 |
+
selected_provider = llm_provider_to_use.split(' ')[0] if ' ' in llm_provider_to_use else llm_provider_to_use
|
| 237 |
+
logger.warning(f"Provider regex did not match for: {llm_provider_to_use}, using: {selected_provider}")
|
| 238 |
+
|
| 239 |
+
# Validate that the selected provider is valid
|
| 240 |
+
if selected_provider not in GlobalConfig.VALID_PROVIDERS:
|
| 241 |
+
logger.error(f"Invalid provider: {selected_provider}")
|
| 242 |
+
handle_error(f"Invalid provider selected: {selected_provider}", True)
|
| 243 |
+
st.error(f"Invalid provider selected: {selected_provider}")
|
| 244 |
+
st.stop()
|
| 245 |
+
|
| 246 |
env_key_name = GlobalConfig.PROVIDER_ENV_KEYS.get(selected_provider)
|
| 247 |
default_api_key = os.getenv(env_key_name, "") if env_key_name else ""
|
| 248 |
|
|
|
|
| 394 |
use_ollama=RUN_IN_OFFLINE_MODE
|
| 395 |
)
|
| 396 |
|
| 397 |
+
# Validate that provider and model were parsed successfully
|
| 398 |
+
if not provider or not llm_name:
|
| 399 |
+
handle_error(
|
| 400 |
+
f'Failed to parse provider and model from: "{llm_provider_to_use}". '
|
| 401 |
+
f'Please select a valid LLM from the dropdown.',
|
| 402 |
+
True
|
| 403 |
+
)
|
| 404 |
+
return
|
| 405 |
+
|
| 406 |
user_key = api_key_token.strip()
|
| 407 |
az_deployment = azure_deployment.strip()
|
| 408 |
az_endpoint = azure_endpoint.strip()
|
helpers/llm_helper.py
CHANGED
|
@@ -70,8 +70,20 @@ def get_provider_model(provider_model: str, use_ollama: bool) -> Tuple[str, str]
|
|
| 70 |
if match:
|
| 71 |
inside_brackets = match.group(1)
|
| 72 |
outside_brackets = match.group(2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
return inside_brackets, outside_brackets
|
| 74 |
|
|
|
|
| 75 |
return '', ''
|
| 76 |
|
| 77 |
|
|
@@ -183,7 +195,14 @@ def stream_litellm_completion(
|
|
| 183 |
raise ImportError("LiteLLM is not installed. Please install it with: pip install litellm")
|
| 184 |
|
| 185 |
# Convert to LiteLLM model name
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
# Prepare the request parameters
|
| 189 |
request_params = {
|
|
@@ -196,13 +215,20 @@ def stream_litellm_completion(
|
|
| 196 |
|
| 197 |
# Set API key and any provider-specific params
|
| 198 |
if provider != GlobalConfig.PROVIDER_OLLAMA:
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
logger.debug('Streaming completion via LiteLLM: %s', litellm_model)
|
| 208 |
|
|
|
|
| 70 |
if match:
|
| 71 |
inside_brackets = match.group(1)
|
| 72 |
outside_brackets = match.group(2)
|
| 73 |
+
|
| 74 |
+
# Validate that the provider is in the valid providers list
|
| 75 |
+
if inside_brackets not in GlobalConfig.VALID_PROVIDERS:
|
| 76 |
+
logger.warning(f"Provider '{inside_brackets}' not in VALID_PROVIDERS: {GlobalConfig.VALID_PROVIDERS}")
|
| 77 |
+
return '', ''
|
| 78 |
+
|
| 79 |
+
# Validate that the model name is not empty
|
| 80 |
+
if not outside_brackets.strip():
|
| 81 |
+
logger.warning(f"Empty model name for provider '{inside_brackets}'")
|
| 82 |
+
return '', ''
|
| 83 |
+
|
| 84 |
return inside_brackets, outside_brackets
|
| 85 |
|
| 86 |
+
logger.warning(f"Could not parse provider_model: '{provider_model}' (use_ollama={use_ollama})")
|
| 87 |
return '', ''
|
| 88 |
|
| 89 |
|
|
|
|
| 195 |
raise ImportError("LiteLLM is not installed. Please install it with: pip install litellm")
|
| 196 |
|
| 197 |
# Convert to LiteLLM model name
|
| 198 |
+
if provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
|
| 199 |
+
# For Azure OpenAI, use the deployment name as the model
|
| 200 |
+
# This is consistent with Azure OpenAI's requirement to use deployment names
|
| 201 |
+
if not azure_deployment_name:
|
| 202 |
+
raise ValueError("Azure deployment name is required for Azure OpenAI provider")
|
| 203 |
+
litellm_model = f"azure/{azure_deployment_name}"
|
| 204 |
+
else:
|
| 205 |
+
litellm_model = get_litellm_model_name(provider, model)
|
| 206 |
|
| 207 |
# Prepare the request parameters
|
| 208 |
request_params = {
|
|
|
|
| 215 |
|
| 216 |
# Set API key and any provider-specific params
|
| 217 |
if provider != GlobalConfig.PROVIDER_OLLAMA:
|
| 218 |
+
# For OpenRouter, set environment variable as per documentation
|
| 219 |
+
if provider == GlobalConfig.PROVIDER_OPENROUTER:
|
| 220 |
+
os.environ["OPENROUTER_API_KEY"] = api_key
|
| 221 |
+
# Optional: Set base URL if different from default
|
| 222 |
+
# os.environ["OPENROUTER_API_BASE"] = "https://openrouter.ai/api/v1"
|
| 223 |
+
elif provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
|
| 224 |
+
# For Azure OpenAI, set environment variables as per documentation
|
| 225 |
+
os.environ["AZURE_API_KEY"] = api_key
|
| 226 |
+
os.environ["AZURE_API_BASE"] = azure_endpoint_url
|
| 227 |
+
os.environ["AZURE_API_VERSION"] = azure_api_version
|
| 228 |
+
else:
|
| 229 |
+
# For other providers, pass API key as parameter
|
| 230 |
+
api_key_to_use = get_litellm_api_key(provider, api_key)
|
| 231 |
+
request_params["api_key"] = api_key_to_use
|
| 232 |
|
| 233 |
logger.debug('Streaming completion via LiteLLM: %s', litellm_model)
|
| 234 |
|