Spaces:
Sleeping
Sleeping
Add support for offline LLMs via Ollama
Browse files- app.py +59 -26
- global_config.py +22 -2
- helpers/llm_helper.py +30 -9
- requirements.txt +6 -1
app.py
CHANGED
|
@@ -3,23 +3,34 @@ Streamlit app containing the UI and the application logic.
|
|
| 3 |
"""
|
| 4 |
import datetime
|
| 5 |
import logging
|
|
|
|
| 6 |
import pathlib
|
| 7 |
import random
|
| 8 |
import tempfile
|
| 9 |
from typing import List, Union
|
| 10 |
|
|
|
|
| 11 |
import huggingface_hub
|
| 12 |
import json5
|
|
|
|
| 13 |
import requests
|
| 14 |
import streamlit as st
|
|
|
|
| 15 |
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
|
| 16 |
from langchain_core.messages import HumanMessage
|
| 17 |
from langchain_core.prompts import ChatPromptTemplate
|
| 18 |
|
|
|
|
| 19 |
from global_config import GlobalConfig
|
| 20 |
from helpers import llm_helper, pptx_helper, text_helper
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
@st.cache_data
|
| 24 |
def _load_strings() -> dict:
|
| 25 |
"""
|
|
@@ -135,25 +146,36 @@ with st.sidebar:
|
|
| 135 |
horizontal=True
|
| 136 |
)
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
'
|
| 152 |
-
'
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
def build_ui():
|
|
@@ -200,7 +222,11 @@ def set_up_chat_ui():
|
|
| 200 |
placeholder=APP_TEXT['chat_placeholder'],
|
| 201 |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
|
| 202 |
):
|
| 203 |
-
provider, llm_name = llm_helper.get_provider_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
|
| 206 |
return
|
|
@@ -233,7 +259,7 @@ def set_up_chat_ui():
|
|
| 233 |
llm = llm_helper.get_langchain_llm(
|
| 234 |
provider=provider,
|
| 235 |
model=llm_name,
|
| 236 |
-
max_new_tokens=
|
| 237 |
api_key=api_key_token.strip(),
|
| 238 |
)
|
| 239 |
|
|
@@ -252,18 +278,17 @@ def set_up_chat_ui():
|
|
| 252 |
# Update the progress bar with an approx progress percentage
|
| 253 |
progress_bar.progress(
|
| 254 |
min(
|
| 255 |
-
len(response) /
|
| 256 |
-
llm_provider_to_use
|
| 257 |
-
]['max_new_tokens'],
|
| 258 |
0.95
|
| 259 |
),
|
| 260 |
text='Streaming content...this might take a while...'
|
| 261 |
)
|
| 262 |
-
except requests.exceptions.ConnectionError:
|
| 263 |
handle_error(
|
| 264 |
'A connection error occurred while streaming content from the LLM endpoint.'
|
| 265 |
' Unfortunately, the slide deck cannot be generated. Please try again later.'
|
| 266 |
-
' Alternatively, try selecting a different LLM from the dropdown list.'
|
|
|
|
| 267 |
True
|
| 268 |
)
|
| 269 |
return
|
|
@@ -274,6 +299,14 @@ def set_up_chat_ui():
|
|
| 274 |
True
|
| 275 |
)
|
| 276 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
except Exception as ex:
|
| 278 |
handle_error(
|
| 279 |
f'An unexpected error occurred while generating the content: {ex}'
|
|
|
|
| 3 |
"""
|
| 4 |
import datetime
|
| 5 |
import logging
|
| 6 |
+
import os
|
| 7 |
import pathlib
|
| 8 |
import random
|
| 9 |
import tempfile
|
| 10 |
from typing import List, Union
|
| 11 |
|
| 12 |
+
import httpx
|
| 13 |
import huggingface_hub
|
| 14 |
import json5
|
| 15 |
+
import ollama
|
| 16 |
import requests
|
| 17 |
import streamlit as st
|
| 18 |
+
from dotenv import load_dotenv
|
| 19 |
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
|
| 20 |
from langchain_core.messages import HumanMessage
|
| 21 |
from langchain_core.prompts import ChatPromptTemplate
|
| 22 |
|
| 23 |
+
import global_config as gcfg
|
| 24 |
from global_config import GlobalConfig
|
| 25 |
from helpers import llm_helper, pptx_helper, text_helper
|
| 26 |
|
| 27 |
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
RUN_IN_OFFLINE_MODE = os.getenv('RUN_IN_OFFLINE_MODE', 'False').lower() == 'true'
|
| 32 |
+
|
| 33 |
+
|
| 34 |
@st.cache_data
|
| 35 |
def _load_strings() -> dict:
|
| 36 |
"""
|
|
|
|
| 146 |
horizontal=True
|
| 147 |
)
|
| 148 |
|
| 149 |
+
if RUN_IN_OFFLINE_MODE:
|
| 150 |
+
llm_provider_to_use = st.text_input(
|
| 151 |
+
label='2: Enter Ollama model name to use:',
|
| 152 |
+
help=(
|
| 153 |
+
'Specify a correct, locally available LLM, found by running `ollama list`, for'
|
| 154 |
+
' example `mistral:v0.2` and `mistral-nemo:latest`. Having an Ollama-compatible'
|
| 155 |
+
' and supported GPU is strongly recommended.'
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
api_key_token: str = ''
|
| 159 |
+
else:
|
| 160 |
+
# The LLMs
|
| 161 |
+
llm_provider_to_use = st.sidebar.selectbox(
|
| 162 |
+
label='2: Select an LLM to use:',
|
| 163 |
+
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
|
| 164 |
+
index=GlobalConfig.DEFAULT_MODEL_INDEX,
|
| 165 |
+
help=GlobalConfig.LLM_PROVIDER_HELP,
|
| 166 |
+
on_change=reset_api_key
|
| 167 |
+
).split(' ')[0]
|
| 168 |
+
|
| 169 |
+
# The API key/access token
|
| 170 |
+
api_key_token = st.text_input(
|
| 171 |
+
label=(
|
| 172 |
+
'3: Paste your API key/access token:\n\n'
|
| 173 |
+
'*Mandatory* for Cohere and Gemini LLMs.'
|
| 174 |
+
' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
|
| 175 |
+
),
|
| 176 |
+
type='password',
|
| 177 |
+
key='api_key_input'
|
| 178 |
+
)
|
| 179 |
|
| 180 |
|
| 181 |
def build_ui():
|
|
|
|
| 222 |
placeholder=APP_TEXT['chat_placeholder'],
|
| 223 |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
|
| 224 |
):
|
| 225 |
+
provider, llm_name = llm_helper.get_provider_model(
|
| 226 |
+
llm_provider_to_use,
|
| 227 |
+
use_ollama=RUN_IN_OFFLINE_MODE
|
| 228 |
+
)
|
| 229 |
+
print(f'{llm_provider_to_use=}, {provider=}, {llm_name=}, {api_key_token=}')
|
| 230 |
|
| 231 |
if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
|
| 232 |
return
|
|
|
|
| 259 |
llm = llm_helper.get_langchain_llm(
|
| 260 |
provider=provider,
|
| 261 |
model=llm_name,
|
| 262 |
+
max_new_tokens=gcfg.get_max_output_tokens(llm_provider_to_use),
|
| 263 |
api_key=api_key_token.strip(),
|
| 264 |
)
|
| 265 |
|
|
|
|
| 278 |
# Update the progress bar with an approx progress percentage
|
| 279 |
progress_bar.progress(
|
| 280 |
min(
|
| 281 |
+
len(response) / gcfg.get_max_output_tokens(llm_provider_to_use),
|
|
|
|
|
|
|
| 282 |
0.95
|
| 283 |
),
|
| 284 |
text='Streaming content...this might take a while...'
|
| 285 |
)
|
| 286 |
+
except (httpx.ConnectError, requests.exceptions.ConnectionError):
|
| 287 |
handle_error(
|
| 288 |
'A connection error occurred while streaming content from the LLM endpoint.'
|
| 289 |
' Unfortunately, the slide deck cannot be generated. Please try again later.'
|
| 290 |
+
' Alternatively, try selecting a different LLM from the dropdown list. If you are'
|
| 291 |
+
' using Ollama, make sure that Ollama is already running on your system.',
|
| 292 |
True
|
| 293 |
)
|
| 294 |
return
|
|
|
|
| 299 |
True
|
| 300 |
)
|
| 301 |
return
|
| 302 |
+
except ollama.ResponseError:
|
| 303 |
+
handle_error(
|
| 304 |
+
f'The model `{llm_name}` is unavailable with Ollama on your system.'
|
| 305 |
+
f' Make sure that you have provided the correct LLM name or pull it using'
|
| 306 |
+
f' `ollama pull {llm_name}`. View LLMs available locally by running `ollama list`.',
|
| 307 |
+
True
|
| 308 |
+
)
|
| 309 |
+
return
|
| 310 |
except Exception as ex:
|
| 311 |
handle_error(
|
| 312 |
f'An unexpected error occurred while generating the content: {ex}'
|
global_config.py
CHANGED
|
@@ -20,7 +20,13 @@ class GlobalConfig:
|
|
| 20 |
PROVIDER_COHERE = 'co'
|
| 21 |
PROVIDER_GOOGLE_GEMINI = 'gg'
|
| 22 |
PROVIDER_HUGGING_FACE = 'hf'
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
VALID_MODELS = {
|
| 25 |
'[co]command-r-08-2024': {
|
| 26 |
'description': 'simpler, slower',
|
|
@@ -47,7 +53,7 @@ class GlobalConfig:
|
|
| 47 |
'LLM provider codes:\n\n'
|
| 48 |
'- **[co]**: Cohere\n'
|
| 49 |
'- **[gg]**: Google Gemini API\n'
|
| 50 |
-
'- **[hf]**: Hugging Face Inference
|
| 51 |
)
|
| 52 |
DEFAULT_MODEL_INDEX = 2
|
| 53 |
LLM_MODEL_TEMPERATURE = 0.2
|
|
@@ -125,3 +131,17 @@ logging.basicConfig(
|
|
| 125 |
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
| 126 |
datefmt='%Y-%m-%d %H:%M:%S'
|
| 127 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
PROVIDER_COHERE = 'co'
|
| 21 |
PROVIDER_GOOGLE_GEMINI = 'gg'
|
| 22 |
PROVIDER_HUGGING_FACE = 'hf'
|
| 23 |
+
PROVIDER_OLLAMA = 'ol'
|
| 24 |
+
VALID_PROVIDERS = {
|
| 25 |
+
PROVIDER_COHERE,
|
| 26 |
+
PROVIDER_GOOGLE_GEMINI,
|
| 27 |
+
PROVIDER_HUGGING_FACE,
|
| 28 |
+
PROVIDER_OLLAMA
|
| 29 |
+
}
|
| 30 |
VALID_MODELS = {
|
| 31 |
'[co]command-r-08-2024': {
|
| 32 |
'description': 'simpler, slower',
|
|
|
|
| 53 |
'LLM provider codes:\n\n'
|
| 54 |
'- **[co]**: Cohere\n'
|
| 55 |
'- **[gg]**: Google Gemini API\n'
|
| 56 |
+
'- **[hf]**: Hugging Face Inference API\n'
|
| 57 |
)
|
| 58 |
DEFAULT_MODEL_INDEX = 2
|
| 59 |
LLM_MODEL_TEMPERATURE = 0.2
|
|
|
|
| 131 |
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
| 132 |
datefmt='%Y-%m-%d %H:%M:%S'
|
| 133 |
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get_max_output_tokens(llm_name: str) -> int:
|
| 137 |
+
"""
|
| 138 |
+
Get the max output tokens value configured for an LLM. Return a default value if not configured.
|
| 139 |
+
|
| 140 |
+
:param llm_name: The name of the LLM.
|
| 141 |
+
:return: Max output tokens or a default count.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
return GlobalConfig.VALID_MODELS[llm_name]['max_new_tokens']
|
| 146 |
+
except KeyError:
|
| 147 |
+
return 2048
|
helpers/llm_helper.py
CHANGED
|
@@ -17,8 +17,9 @@ from global_config import GlobalConfig
|
|
| 17 |
|
| 18 |
|
| 19 |
LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
|
|
|
|
| 20 |
# 6-64 characters long, only containing alphanumeric characters, hyphens, and underscores
|
| 21 |
-
API_KEY_REGEX = re.compile(r'^[a-zA-Z0-
|
| 22 |
HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
|
| 23 |
REQUEST_TIMEOUT = 35
|
| 24 |
|
|
@@ -39,20 +40,28 @@ http_session.mount('https://', adapter)
|
|
| 39 |
http_session.mount('http://', adapter)
|
| 40 |
|
| 41 |
|
| 42 |
-
def get_provider_model(provider_model: str) -> Tuple[str, str]:
|
| 43 |
"""
|
| 44 |
Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
|
| 45 |
|
| 46 |
:param provider_model: The provider, model name string from `GlobalConfig`.
|
| 47 |
-
:
|
|
|
|
| 48 |
"""
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
-
if
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
return '', ''
|
| 58 |
|
|
@@ -152,6 +161,18 @@ def get_langchain_llm(
|
|
| 152 |
streaming=True,
|
| 153 |
)
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
return None
|
| 156 |
|
| 157 |
|
|
@@ -163,4 +184,4 @@ if __name__ == '__main__':
|
|
| 163 |
]
|
| 164 |
|
| 165 |
for text in inputs:
|
| 166 |
-
print(get_provider_model(text))
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
|
| 20 |
+
OLLAMA_MODEL_REGEX = re.compile(r'[a-zA-Z0-9._:-]+$')
|
| 21 |
# 6-64 characters long, only containing alphanumeric characters, hyphens, and underscores
|
| 22 |
+
API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9_-]{6,64}$')
|
| 23 |
HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
|
| 24 |
REQUEST_TIMEOUT = 35
|
| 25 |
|
|
|
|
| 40 |
http_session.mount('http://', adapter)
|
| 41 |
|
| 42 |
|
| 43 |
+
def get_provider_model(provider_model: str, use_ollama: bool) -> Tuple[str, str]:
|
| 44 |
"""
|
| 45 |
Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
|
| 46 |
|
| 47 |
:param provider_model: The provider, model name string from `GlobalConfig`.
|
| 48 |
+
:param use_ollama: Whether Ollama is used (i.e., running in offline mode).
|
| 49 |
+
:return: The provider and the model name; empty strings in case no matching pattern found.
|
| 50 |
"""
|
| 51 |
|
| 52 |
+
provider_model = provider_model.strip()
|
| 53 |
|
| 54 |
+
if use_ollama:
|
| 55 |
+
match = OLLAMA_MODEL_REGEX.match(provider_model)
|
| 56 |
+
if match:
|
| 57 |
+
return GlobalConfig.PROVIDER_OLLAMA, match.group(0)
|
| 58 |
+
else:
|
| 59 |
+
match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)
|
| 60 |
+
|
| 61 |
+
if match:
|
| 62 |
+
inside_brackets = match.group(1)
|
| 63 |
+
outside_brackets = match.group(2)
|
| 64 |
+
return inside_brackets, outside_brackets
|
| 65 |
|
| 66 |
return '', ''
|
| 67 |
|
|
|
|
| 161 |
streaming=True,
|
| 162 |
)
|
| 163 |
|
| 164 |
+
if provider == GlobalConfig.PROVIDER_OLLAMA:
|
| 165 |
+
from langchain_ollama.llms import OllamaLLM
|
| 166 |
+
|
| 167 |
+
logger.debug('Getting LLM via Ollama: %s', model)
|
| 168 |
+
return OllamaLLM(
|
| 169 |
+
model=model,
|
| 170 |
+
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
|
| 171 |
+
num_predict=max_new_tokens,
|
| 172 |
+
format='json',
|
| 173 |
+
streaming=True,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
return None
|
| 177 |
|
| 178 |
|
|
|
|
| 184 |
]
|
| 185 |
|
| 186 |
for text in inputs:
|
| 187 |
+
print(get_provider_model(text, use_ollama=False))
|
requirements.txt
CHANGED
|
@@ -12,9 +12,10 @@ langchain-core~=0.3.0
|
|
| 12 |
langchain-community==0.3.0
|
| 13 |
langchain-google-genai==2.0.6
|
| 14 |
langchain-cohere==0.3.3
|
|
|
|
| 15 |
streamlit~=1.38.0
|
| 16 |
|
| 17 |
-
python-pptx
|
| 18 |
# metaphor-python
|
| 19 |
json5~=0.9.14
|
| 20 |
requests~=2.32.3
|
|
@@ -32,3 +33,7 @@ certifi==2024.8.30
|
|
| 32 |
urllib3==2.2.3
|
| 33 |
|
| 34 |
anyio==4.4.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
langchain-community==0.3.0
|
| 13 |
langchain-google-genai==2.0.6
|
| 14 |
langchain-cohere==0.3.3
|
| 15 |
+
langchain-ollama==0.2.1
|
| 16 |
streamlit~=1.38.0
|
| 17 |
|
| 18 |
+
python-pptx~=0.6.21
|
| 19 |
# metaphor-python
|
| 20 |
json5~=0.9.14
|
| 21 |
requests~=2.32.3
|
|
|
|
| 33 |
urllib3==2.2.3
|
| 34 |
|
| 35 |
anyio==4.4.0
|
| 36 |
+
|
| 37 |
+
httpx~=0.27.2
|
| 38 |
+
huggingface-hub~=0.24.5
|
| 39 |
+
ollama~=0.4.3
|