Spaces:
Sleeping
Sleeping
Add support for Gemini 1.5 Flash via Gemini API
Browse files- app.py +122 -110
- global_config.py +18 -3
- helpers/llm_helper.py +50 -31
- requirements.txt +1 -1
- strings.json +2 -1
app.py
CHANGED
|
@@ -5,7 +5,6 @@ import datetime
|
|
| 5 |
import logging
|
| 6 |
import pathlib
|
| 7 |
import random
|
| 8 |
-
import sys
|
| 9 |
import tempfile
|
| 10 |
from typing import List, Union
|
| 11 |
|
|
@@ -17,9 +16,6 @@ from langchain_community.chat_message_histories import StreamlitChatMessageHisto
|
|
| 17 |
from langchain_core.messages import HumanMessage
|
| 18 |
from langchain_core.prompts import ChatPromptTemplate
|
| 19 |
|
| 20 |
-
sys.path.append('..')
|
| 21 |
-
sys.path.append('../..')
|
| 22 |
-
|
| 23 |
from global_config import GlobalConfig
|
| 24 |
from helpers import llm_helper, pptx_helper, text_helper
|
| 25 |
|
|
@@ -54,6 +50,60 @@ def _get_prompt_template(is_refinement: bool) -> str:
|
|
| 54 |
return template
|
| 55 |
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
APP_TEXT = _load_strings()
|
| 58 |
|
| 59 |
# Session variables
|
|
@@ -80,11 +130,8 @@ with st.sidebar:
|
|
| 80 |
llm_provider_to_use = st.sidebar.selectbox(
|
| 81 |
label='2: Select an LLM to use:',
|
| 82 |
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
|
| 83 |
-
index=
|
| 84 |
-
help=
|
| 85 |
-
'LLM provider codes:\n\n'
|
| 86 |
-
'- **[hf]**: Hugging Face Inference Endpoint\n'
|
| 87 |
-
),
|
| 88 |
).split(' ')[0]
|
| 89 |
|
| 90 |
# The API key/access token
|
|
@@ -123,53 +170,28 @@ def set_up_chat_ui():
|
|
| 123 |
with st.expander('Usage Instructions'):
|
| 124 |
st.markdown(GlobalConfig.CHAT_USAGE_INSTRUCTIONS)
|
| 125 |
|
| 126 |
-
st.info(
|
| 127 |
-
|
| 128 |
-
' [Hugging Face Space](https://huggingface.co/spaces/barunsaha/slide-deck-ai/) or'
|
| 129 |
-
' a star ⭐ on [GitHub](https://github.com/barun-saha/slide-deck-ai).'
|
| 130 |
-
' Your [feedback](https://forms.gle/JECFBGhjvSj7moBx9) is appreciated.'
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
# view_messages = st.expander('View the messages in the session state')
|
| 134 |
-
|
| 135 |
-
st.chat_message('ai').write(
|
| 136 |
-
random.choice(APP_TEXT['ai_greetings'])
|
| 137 |
-
)
|
| 138 |
|
| 139 |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
prompt_template = ChatPromptTemplate.from_template(template)
|
| 147 |
|
| 148 |
# Since Streamlit app reloads at every interaction, display the chat history
|
| 149 |
# from the save session state
|
| 150 |
for msg in history.messages:
|
| 151 |
-
|
| 152 |
-
if msg_type == 'user':
|
| 153 |
-
st.chat_message(msg_type).write(msg.content)
|
| 154 |
-
else:
|
| 155 |
-
st.chat_message(msg_type).code(msg.content, language='json')
|
| 156 |
|
| 157 |
if prompt := st.chat_input(
|
| 158 |
placeholder=APP_TEXT['chat_placeholder'],
|
| 159 |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
|
| 160 |
):
|
| 161 |
-
if not text_helper.is_valid_prompt(prompt):
|
| 162 |
-
st.error(
|
| 163 |
-
'Not enough information provided!'
|
| 164 |
-
' Please be a little more descriptive and type a few words'
|
| 165 |
-
' with a few characters :)'
|
| 166 |
-
)
|
| 167 |
-
return
|
| 168 |
-
|
| 169 |
provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use)
|
| 170 |
|
| 171 |
-
if not provider
|
| 172 |
-
st.error('No valid LLM provider and/or model name found!')
|
| 173 |
return
|
| 174 |
|
| 175 |
logger.info(
|
|
@@ -178,72 +200,76 @@ def set_up_chat_ui():
|
|
| 178 |
)
|
| 179 |
st.chat_message('user').write(prompt)
|
| 180 |
|
| 181 |
-
user_messages = _get_user_messages()
|
| 182 |
-
user_messages.append(prompt)
|
| 183 |
-
list_of_msgs = [
|
| 184 |
-
f'{idx + 1}. {msg}' for idx, msg in enumerate(user_messages)
|
| 185 |
-
]
|
| 186 |
-
list_of_msgs = '\n'.join(list_of_msgs)
|
| 187 |
-
|
| 188 |
if _is_it_refinement():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
formatted_template = prompt_template.format(
|
| 190 |
**{
|
| 191 |
-
'instructions': list_of_msgs,
|
| 192 |
'previous_content': _get_last_response(),
|
| 193 |
}
|
| 194 |
)
|
| 195 |
else:
|
| 196 |
-
formatted_template = prompt_template.format(
|
| 197 |
-
**{
|
| 198 |
-
'question': prompt,
|
| 199 |
-
}
|
| 200 |
-
)
|
| 201 |
|
| 202 |
progress_bar = st.progress(0, 'Preparing to call LLM...')
|
| 203 |
response = ''
|
| 204 |
|
| 205 |
try:
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
| 217 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
progress_bar.progress(
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
text='Streaming content...this might take a while...'
|
| 221 |
)
|
| 222 |
except requests.exceptions.ConnectionError:
|
| 223 |
-
|
| 224 |
'A connection error occurred while streaming content from the LLM endpoint.'
|
| 225 |
' Unfortunately, the slide deck cannot be generated. Please try again later.'
|
| 226 |
-
' Alternatively, try selecting a different LLM from the dropdown list.'
|
|
|
|
| 227 |
)
|
| 228 |
-
logger.error(msg)
|
| 229 |
-
st.error(msg)
|
| 230 |
return
|
| 231 |
except huggingface_hub.errors.ValidationError as ve:
|
| 232 |
-
|
| 233 |
f'An error occurred while trying to generate the content: {ve}'
|
| 234 |
-
'\nPlease try again with a significantly shorter input text.'
|
|
|
|
| 235 |
)
|
| 236 |
-
logger.error(msg)
|
| 237 |
-
st.error(msg)
|
| 238 |
return
|
| 239 |
except Exception as ex:
|
| 240 |
-
|
| 241 |
f'An unexpected error occurred while generating the content: {ex}'
|
| 242 |
'\nPlease try again later, possibly with different inputs.'
|
| 243 |
-
' Alternatively, try selecting a different LLM from the dropdown list.'
|
|
|
|
| 244 |
)
|
| 245 |
-
logger.error(msg)
|
| 246 |
-
st.error(msg)
|
| 247 |
return
|
| 248 |
|
| 249 |
history.add_user_message(prompt)
|
|
@@ -252,25 +278,20 @@ def set_up_chat_ui():
|
|
| 252 |
# The content has been generated as JSON
|
| 253 |
# There maybe trailing ``` at the end of the response -- remove them
|
| 254 |
# To be careful: ``` may be part of the content as well when code is generated
|
| 255 |
-
|
| 256 |
-
|
| 257 |
logger.info(
|
| 258 |
-
'Cleaned JSON
|
| 259 |
-
len(response), len(response_cleaned)
|
| 260 |
)
|
| 261 |
-
# logger.debug('Cleaned JSON: %s', response_cleaned)
|
| 262 |
|
| 263 |
# Now create the PPT file
|
| 264 |
progress_bar.progress(
|
| 265 |
GlobalConfig.LLM_PROGRESS_MAX,
|
| 266 |
text='Finding photos online and generating the slide deck...'
|
| 267 |
)
|
| 268 |
-
path = generate_slide_deck(response_cleaned)
|
| 269 |
progress_bar.progress(1.0, text='Done!')
|
| 270 |
-
|
| 271 |
st.chat_message('ai').code(response, language='json')
|
| 272 |
|
| 273 |
-
if path:
|
| 274 |
_display_download_button(path)
|
| 275 |
|
| 276 |
logger.info(
|
|
@@ -291,44 +312,35 @@ def generate_slide_deck(json_str: str) -> Union[pathlib.Path, None]:
|
|
| 291 |
try:
|
| 292 |
parsed_data = json5.loads(json_str)
|
| 293 |
except ValueError:
|
| 294 |
-
|
| 295 |
-
'Encountered error while parsing JSON...will fix it and retry'
|
| 296 |
-
|
| 297 |
-
logger.error(
|
| 298 |
-
'Caught ValueError: trying again after repairing JSON...'
|
| 299 |
)
|
| 300 |
try:
|
| 301 |
parsed_data = json5.loads(text_helper.fix_malformed_json(json_str))
|
| 302 |
except ValueError:
|
| 303 |
-
|
| 304 |
'Encountered an error again while fixing JSON...'
|
| 305 |
'the slide deck cannot be created, unfortunately ☹'
|
| 306 |
-
'\nPlease try again later.'
|
|
|
|
| 307 |
)
|
| 308 |
-
logger.error(
|
| 309 |
-
'Caught ValueError: failed to repair JSON!'
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
return None
|
| 313 |
except RecursionError:
|
| 314 |
-
|
| 315 |
-
'Encountered
|
| 316 |
'the slide deck cannot be created, unfortunately ☹'
|
| 317 |
-
'\nPlease try again later.'
|
|
|
|
| 318 |
)
|
| 319 |
-
logger.error('Caught RecursionError while parsing JSON. Cannot generate the slide deck!')
|
| 320 |
-
|
| 321 |
return None
|
| 322 |
except Exception:
|
| 323 |
-
|
| 324 |
'Encountered an error while parsing JSON...'
|
| 325 |
'the slide deck cannot be created, unfortunately ☹'
|
| 326 |
-
'\nPlease try again later.'
|
| 327 |
-
|
| 328 |
-
logger.error(
|
| 329 |
-
'Caught ValueError: failed to parse JSON!'
|
| 330 |
)
|
| 331 |
-
|
| 332 |
return None
|
| 333 |
|
| 334 |
if DOWNLOAD_FILE_KEY in st.session_state:
|
|
|
|
| 5 |
import logging
|
| 6 |
import pathlib
|
| 7 |
import random
|
|
|
|
| 8 |
import tempfile
|
| 9 |
from typing import List, Union
|
| 10 |
|
|
|
|
| 16 |
from langchain_core.messages import HumanMessage
|
| 17 |
from langchain_core.prompts import ChatPromptTemplate
|
| 18 |
|
|
|
|
|
|
|
|
|
|
| 19 |
from global_config import GlobalConfig
|
| 20 |
from helpers import llm_helper, pptx_helper, text_helper
|
| 21 |
|
|
|
|
| 50 |
return template
|
| 51 |
|
| 52 |
|
| 53 |
+
def are_all_inputs_valid(
|
| 54 |
+
user_prompt: str,
|
| 55 |
+
selected_provider: str,
|
| 56 |
+
selected_model: str,
|
| 57 |
+
user_key: str,
|
| 58 |
+
) -> bool:
|
| 59 |
+
"""
|
| 60 |
+
Validate user input and LLM selection.
|
| 61 |
+
|
| 62 |
+
:param user_prompt: The prompt.
|
| 63 |
+
:param selected_provider: The LLM provider.
|
| 64 |
+
:param selected_model: Name of the model.
|
| 65 |
+
:param user_key: User-provided API key.
|
| 66 |
+
:return: `True` if all inputs "look" OK; `False` otherwise.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
if not text_helper.is_valid_prompt(user_prompt):
|
| 70 |
+
handle_error(
|
| 71 |
+
'Not enough information provided!'
|
| 72 |
+
' Please be a little more descriptive and type a few words'
|
| 73 |
+
' with a few characters :)',
|
| 74 |
+
False
|
| 75 |
+
)
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
if not selected_provider or not selected_model:
|
| 79 |
+
handle_error('No valid LLM provider and/or model name found!', False)
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
if not llm_helper.is_valid_llm_provider_model(selected_provider, selected_model, user_key):
|
| 83 |
+
handle_error(
|
| 84 |
+
'The LLM settings do not look correct. Make sure that an API key/access token'
|
| 85 |
+
' is provided if the selected LLM requires it.',
|
| 86 |
+
False
|
| 87 |
+
)
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
return True
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def handle_error(error_msg: str, should_log: bool):
|
| 94 |
+
"""
|
| 95 |
+
Display an error message in the app.
|
| 96 |
+
|
| 97 |
+
:param error_msg: The error message to be displayed.
|
| 98 |
+
:param should_log: If `True`, log the message.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
if should_log:
|
| 102 |
+
logger.error(error_msg)
|
| 103 |
+
|
| 104 |
+
st.error(error_msg)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
APP_TEXT = _load_strings()
|
| 108 |
|
| 109 |
# Session variables
|
|
|
|
| 130 |
llm_provider_to_use = st.sidebar.selectbox(
|
| 131 |
label='2: Select an LLM to use:',
|
| 132 |
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
|
| 133 |
+
index=GlobalConfig.DEFAULT_MODEL_INDEX,
|
| 134 |
+
help=GlobalConfig.LLM_PROVIDER_HELP,
|
|
|
|
|
|
|
|
|
|
| 135 |
).split(' ')[0]
|
| 136 |
|
| 137 |
# The API key/access token
|
|
|
|
| 170 |
with st.expander('Usage Instructions'):
|
| 171 |
st.markdown(GlobalConfig.CHAT_USAGE_INSTRUCTIONS)
|
| 172 |
|
| 173 |
+
st.info(APP_TEXT['like_feedback'])
|
| 174 |
+
st.chat_message('ai').write(random.choice(APP_TEXT['ai_greetings']))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
|
| 177 |
+
prompt_template = ChatPromptTemplate.from_template(
|
| 178 |
+
_get_prompt_template(
|
| 179 |
+
is_refinement=_is_it_refinement()
|
| 180 |
+
)
|
| 181 |
+
)
|
|
|
|
|
|
|
| 182 |
|
| 183 |
# Since Streamlit app reloads at every interaction, display the chat history
|
| 184 |
# from the save session state
|
| 185 |
for msg in history.messages:
|
| 186 |
+
st.chat_message(msg.type).code(msg.content, language='json')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
if prompt := st.chat_input(
|
| 189 |
placeholder=APP_TEXT['chat_placeholder'],
|
| 190 |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
|
| 191 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use)
|
| 193 |
|
| 194 |
+
if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
|
|
|
|
| 195 |
return
|
| 196 |
|
| 197 |
logger.info(
|
|
|
|
| 200 |
)
|
| 201 |
st.chat_message('user').write(prompt)
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
if _is_it_refinement():
|
| 204 |
+
user_messages = _get_user_messages()
|
| 205 |
+
user_messages.append(prompt)
|
| 206 |
+
list_of_msgs = [
|
| 207 |
+
f'{idx + 1}. {msg}' for idx, msg in enumerate(user_messages)
|
| 208 |
+
]
|
| 209 |
formatted_template = prompt_template.format(
|
| 210 |
**{
|
| 211 |
+
'instructions': '\n'.join(list_of_msgs),
|
| 212 |
'previous_content': _get_last_response(),
|
| 213 |
}
|
| 214 |
)
|
| 215 |
else:
|
| 216 |
+
formatted_template = prompt_template.format(**{'question': prompt})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
progress_bar = st.progress(0, 'Preparing to call LLM...')
|
| 219 |
response = ''
|
| 220 |
|
| 221 |
try:
|
| 222 |
+
llm = llm_helper.get_langchain_llm(
|
| 223 |
+
provider=provider,
|
| 224 |
+
model=llm_name,
|
| 225 |
+
max_new_tokens=GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'],
|
| 226 |
+
api_key=api_key_token.strip(),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
if not llm:
|
| 230 |
+
handle_error(
|
| 231 |
+
'Failed to create an LLM instance! Make sure that you have selected the'
|
| 232 |
+
' correct model from the dropdown list and have provided correct API key'
|
| 233 |
+
' or access token.',
|
| 234 |
+
False
|
| 235 |
)
|
| 236 |
+
return
|
| 237 |
+
|
| 238 |
+
for _ in llm.stream(formatted_template):
|
| 239 |
+
response += _
|
| 240 |
+
|
| 241 |
+
# Update the progress bar with an approx progress percentage
|
| 242 |
progress_bar.progress(
|
| 243 |
+
min(
|
| 244 |
+
len(response) / GlobalConfig.VALID_MODELS[
|
| 245 |
+
llm_provider_to_use
|
| 246 |
+
]['max_new_tokens'],
|
| 247 |
+
0.95
|
| 248 |
+
),
|
| 249 |
text='Streaming content...this might take a while...'
|
| 250 |
)
|
| 251 |
except requests.exceptions.ConnectionError:
|
| 252 |
+
handle_error(
|
| 253 |
'A connection error occurred while streaming content from the LLM endpoint.'
|
| 254 |
' Unfortunately, the slide deck cannot be generated. Please try again later.'
|
| 255 |
+
' Alternatively, try selecting a different LLM from the dropdown list.',
|
| 256 |
+
True
|
| 257 |
)
|
|
|
|
|
|
|
| 258 |
return
|
| 259 |
except huggingface_hub.errors.ValidationError as ve:
|
| 260 |
+
handle_error(
|
| 261 |
f'An error occurred while trying to generate the content: {ve}'
|
| 262 |
+
'\nPlease try again with a significantly shorter input text.',
|
| 263 |
+
True
|
| 264 |
)
|
|
|
|
|
|
|
| 265 |
return
|
| 266 |
except Exception as ex:
|
| 267 |
+
handle_error(
|
| 268 |
f'An unexpected error occurred while generating the content: {ex}'
|
| 269 |
'\nPlease try again later, possibly with different inputs.'
|
| 270 |
+
' Alternatively, try selecting a different LLM from the dropdown list.',
|
| 271 |
+
True
|
| 272 |
)
|
|
|
|
|
|
|
| 273 |
return
|
| 274 |
|
| 275 |
history.add_user_message(prompt)
|
|
|
|
| 278 |
# The content has been generated as JSON
|
| 279 |
# There maybe trailing ``` at the end of the response -- remove them
|
| 280 |
# To be careful: ``` may be part of the content as well when code is generated
|
| 281 |
+
response = text_helper.get_clean_json(response)
|
|
|
|
| 282 |
logger.info(
|
| 283 |
+
'Cleaned JSON length: %d', len(response)
|
|
|
|
| 284 |
)
|
|
|
|
| 285 |
|
| 286 |
# Now create the PPT file
|
| 287 |
progress_bar.progress(
|
| 288 |
GlobalConfig.LLM_PROGRESS_MAX,
|
| 289 |
text='Finding photos online and generating the slide deck...'
|
| 290 |
)
|
|
|
|
| 291 |
progress_bar.progress(1.0, text='Done!')
|
|
|
|
| 292 |
st.chat_message('ai').code(response, language='json')
|
| 293 |
|
| 294 |
+
if path := generate_slide_deck(response):
|
| 295 |
_display_download_button(path)
|
| 296 |
|
| 297 |
logger.info(
|
|
|
|
| 312 |
try:
|
| 313 |
parsed_data = json5.loads(json_str)
|
| 314 |
except ValueError:
|
| 315 |
+
handle_error(
|
| 316 |
+
'Encountered error while parsing JSON...will fix it and retry',
|
| 317 |
+
True
|
|
|
|
|
|
|
| 318 |
)
|
| 319 |
try:
|
| 320 |
parsed_data = json5.loads(text_helper.fix_malformed_json(json_str))
|
| 321 |
except ValueError:
|
| 322 |
+
handle_error(
|
| 323 |
'Encountered an error again while fixing JSON...'
|
| 324 |
'the slide deck cannot be created, unfortunately ☹'
|
| 325 |
+
'\nPlease try again later.',
|
| 326 |
+
True
|
| 327 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
return None
|
| 329 |
except RecursionError:
|
| 330 |
+
handle_error(
|
| 331 |
+
'Encountered a recursion error while parsing JSON...'
|
| 332 |
'the slide deck cannot be created, unfortunately ☹'
|
| 333 |
+
'\nPlease try again later.',
|
| 334 |
+
True
|
| 335 |
)
|
|
|
|
|
|
|
| 336 |
return None
|
| 337 |
except Exception:
|
| 338 |
+
handle_error(
|
| 339 |
'Encountered an error while parsing JSON...'
|
| 340 |
'the slide deck cannot be created, unfortunately ☹'
|
| 341 |
+
'\nPlease try again later.',
|
| 342 |
+
True
|
|
|
|
|
|
|
| 343 |
)
|
|
|
|
| 344 |
return None
|
| 345 |
|
| 346 |
if DOWNLOAD_FILE_KEY in st.session_state:
|
global_config.py
CHANGED
|
@@ -17,17 +17,32 @@ class GlobalConfig:
|
|
| 17 |
A data class holding the configurations.
|
| 18 |
"""
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
VALID_MODELS = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
'[hf]mistralai/Mistral-7B-Instruct-v0.2': {
|
| 23 |
'description': 'faster, shorter',
|
| 24 |
-
'max_new_tokens': 8192
|
|
|
|
| 25 |
},
|
| 26 |
'[hf]mistralai/Mistral-Nemo-Instruct-2407': {
|
| 27 |
'description': 'longer response',
|
| 28 |
-
'max_new_tokens':
|
|
|
|
| 29 |
},
|
| 30 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
LLM_MODEL_TEMPERATURE = 0.2
|
| 32 |
LLM_MODEL_MIN_OUTPUT_LENGTH = 100
|
| 33 |
LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
|
|
|
|
| 17 |
A data class holding the configurations.
|
| 18 |
"""
|
| 19 |
|
| 20 |
+
PROVIDER_HUGGING_FACE = 'hf'
|
| 21 |
+
PROVIDER_GOOGLE_GEMINI = 'gg'
|
| 22 |
+
VALID_PROVIDERS = {PROVIDER_HUGGING_FACE, PROVIDER_GOOGLE_GEMINI}
|
| 23 |
VALID_MODELS = {
|
| 24 |
+
'[gg]gemini-1.5-flash-002': {
|
| 25 |
+
'description': 'faster response',
|
| 26 |
+
'max_new_tokens': 8192,
|
| 27 |
+
'paid': True,
|
| 28 |
+
},
|
| 29 |
'[hf]mistralai/Mistral-7B-Instruct-v0.2': {
|
| 30 |
'description': 'faster, shorter',
|
| 31 |
+
'max_new_tokens': 8192,
|
| 32 |
+
'paid': False,
|
| 33 |
},
|
| 34 |
'[hf]mistralai/Mistral-Nemo-Instruct-2407': {
|
| 35 |
'description': 'longer response',
|
| 36 |
+
'max_new_tokens': 10240,
|
| 37 |
+
'paid': False,
|
| 38 |
},
|
| 39 |
}
|
| 40 |
+
LLM_PROVIDER_HELP = (
|
| 41 |
+
'LLM provider codes:\n\n'
|
| 42 |
+
'- **[gg]**: Google Gemini API\n'
|
| 43 |
+
'- **[hf]**: Hugging Face Inference Endpoint\n'
|
| 44 |
+
)
|
| 45 |
+
DEFAULT_MODEL_INDEX = 1
|
| 46 |
LLM_MODEL_TEMPERATURE = 0.2
|
| 47 |
LLM_MODEL_MIN_OUTPUT_LENGTH = 100
|
| 48 |
LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
|
helpers/llm_helper.py
CHANGED
|
@@ -1,13 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import re
|
|
|
|
| 3 |
from typing import Tuple, Union
|
| 4 |
|
| 5 |
import requests
|
| 6 |
from requests.adapters import HTTPAdapter
|
| 7 |
from urllib3.util import Retry
|
| 8 |
-
|
| 9 |
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
| 10 |
-
from langchain_core.language_models import
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from global_config import GlobalConfig
|
| 13 |
|
|
@@ -49,30 +54,26 @@ def get_provider_model(provider_model: str) -> Tuple[str, str]:
|
|
| 49 |
return '', ''
|
| 50 |
|
| 51 |
|
| 52 |
-
def
|
| 53 |
"""
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
:param
|
| 59 |
-
:
|
|
|
|
|
|
|
| 60 |
"""
|
| 61 |
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
return
|
| 65 |
-
repo_id=repo_id,
|
| 66 |
-
max_new_tokens=max_new_tokens,
|
| 67 |
-
top_k=40,
|
| 68 |
-
top_p=0.95,
|
| 69 |
-
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
|
| 70 |
-
repetition_penalty=1.03,
|
| 71 |
-
streaming=True,
|
| 72 |
-
huggingfacehub_api_token=api_key or GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
|
| 73 |
-
return_full_text=False,
|
| 74 |
-
stop_sequences=['</s>'],
|
| 75 |
-
)
|
| 76 |
|
| 77 |
|
| 78 |
def get_langchain_llm(
|
|
@@ -80,22 +81,19 @@ def get_langchain_llm(
|
|
| 80 |
model: str,
|
| 81 |
max_new_tokens: int,
|
| 82 |
api_key: str = ''
|
| 83 |
-
) -> Union[
|
| 84 |
"""
|
| 85 |
Get an LLM based on the provider and model specified.
|
| 86 |
|
| 87 |
:param provider: The LLM provider. Valid values are `hf` for Hugging Face.
|
| 88 |
-
:param model:
|
| 89 |
-
:param max_new_tokens:
|
| 90 |
-
:param api_key:
|
| 91 |
-
:return:
|
| 92 |
"""
|
| 93 |
-
if not provider or not model or provider not in GlobalConfig.VALID_PROVIDERS:
|
| 94 |
-
return None
|
| 95 |
|
| 96 |
-
if provider ==
|
| 97 |
logger.debug('Getting LLM via HF endpoint: %s', model)
|
| 98 |
-
|
| 99 |
return HuggingFaceEndpoint(
|
| 100 |
repo_id=model,
|
| 101 |
max_new_tokens=max_new_tokens,
|
|
@@ -109,6 +107,27 @@ def get_langchain_llm(
|
|
| 109 |
stop_sequences=['</s>'],
|
| 110 |
)
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
return None
|
| 113 |
|
| 114 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helper functions to access LLMs.
|
| 3 |
+
"""
|
| 4 |
import logging
|
| 5 |
import re
|
| 6 |
+
import sys
|
| 7 |
from typing import Tuple, Union
|
| 8 |
|
| 9 |
import requests
|
| 10 |
from requests.adapters import HTTPAdapter
|
| 11 |
from urllib3.util import Retry
|
|
|
|
| 12 |
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
| 13 |
+
from langchain_core.language_models import BaseLLM
|
| 14 |
+
|
| 15 |
+
sys.path.append('..')
|
| 16 |
|
| 17 |
from global_config import GlobalConfig
|
| 18 |
|
|
|
|
| 54 |
return '', ''
|
| 55 |
|
| 56 |
|
| 57 |
+
def is_valid_llm_provider_model(provider: str, model: str, api_key: str) -> bool:
|
| 58 |
"""
|
| 59 |
+
Verify whether LLM settings are proper.
|
| 60 |
+
This function does not verify whether `api_key` is correct. It only confirms that the key has
|
| 61 |
+
at least five characters. Key verification is done when the LLM is created.
|
| 62 |
+
|
| 63 |
+
:param provider: Name of the LLM provider.
|
| 64 |
+
:param model: Name of the model.
|
| 65 |
+
:param api_key: The API key or access token.
|
| 66 |
+
:return: `True` if the settings "look" OK; `False` otherwise.
|
| 67 |
"""
|
| 68 |
|
| 69 |
+
if not provider or not model or provider not in GlobalConfig.VALID_PROVIDERS:
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
if provider in [GlobalConfig.PROVIDER_GOOGLE_GEMINI, ]:
|
| 73 |
+
if not api_key or len(api_key) < 5:
|
| 74 |
+
return False
|
| 75 |
|
| 76 |
+
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
def get_langchain_llm(
|
|
|
|
| 81 |
model: str,
|
| 82 |
max_new_tokens: int,
|
| 83 |
api_key: str = ''
|
| 84 |
+
) -> Union[BaseLLM, None]:
|
| 85 |
"""
|
| 86 |
Get an LLM based on the provider and model specified.
|
| 87 |
|
| 88 |
:param provider: The LLM provider. Valid values are `hf` for Hugging Face.
|
| 89 |
+
:param model: The name of the LLM.
|
| 90 |
+
:param max_new_tokens: The maximum number of tokens to generate.
|
| 91 |
+
:param api_key: API key or access token to use.
|
| 92 |
+
:return: An instance of the LLM or `None` in case of any error.
|
| 93 |
"""
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
if provider == GlobalConfig.PROVIDER_HUGGING_FACE:
|
| 96 |
logger.debug('Getting LLM via HF endpoint: %s', model)
|
|
|
|
| 97 |
return HuggingFaceEndpoint(
|
| 98 |
repo_id=model,
|
| 99 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 107 |
stop_sequences=['</s>'],
|
| 108 |
)
|
| 109 |
|
| 110 |
+
if provider == GlobalConfig.PROVIDER_GOOGLE_GEMINI:
|
| 111 |
+
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
|
| 112 |
+
from langchain_google_genai import GoogleGenerativeAI
|
| 113 |
+
|
| 114 |
+
return GoogleGenerativeAI(
|
| 115 |
+
model=model,
|
| 116 |
+
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
|
| 117 |
+
max_tokens=max_new_tokens,
|
| 118 |
+
timeout=None,
|
| 119 |
+
max_retries=2,
|
| 120 |
+
google_api_key=api_key,
|
| 121 |
+
safety_settings={
|
| 122 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT:
|
| 123 |
+
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 124 |
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 125 |
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
| 126 |
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT:
|
| 127 |
+
HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
|
| 128 |
+
}
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
return None
|
| 132 |
|
| 133 |
|
requirements.txt
CHANGED
|
@@ -10,6 +10,7 @@ pydantic==2.9.1
|
|
| 10 |
langchain~=0.3.7
|
| 11 |
langchain-core~=0.3.0
|
| 12 |
langchain-community==0.3.0
|
|
|
|
| 13 |
streamlit~=1.38.0
|
| 14 |
|
| 15 |
python-pptx
|
|
@@ -19,7 +20,6 @@ requests~=2.32.3
|
|
| 19 |
|
| 20 |
transformers~=4.44.0
|
| 21 |
torch==2.4.0
|
| 22 |
-
langchain-community
|
| 23 |
|
| 24 |
urllib3~=2.2.1
|
| 25 |
lxml~=4.9.3
|
|
|
|
| 10 |
langchain~=0.3.7
|
| 11 |
langchain-core~=0.3.0
|
| 12 |
langchain-community==0.3.0
|
| 13 |
+
langchain-google-genai==2.0.6
|
| 14 |
streamlit~=1.38.0
|
| 15 |
|
| 16 |
python-pptx
|
|
|
|
| 20 |
|
| 21 |
transformers~=4.44.0
|
| 22 |
torch==2.4.0
|
|
|
|
| 23 |
|
| 24 |
urllib3~=2.2.1
|
| 25 |
lxml~=4.9.3
|
strings.json
CHANGED
|
@@ -33,5 +33,6 @@
|
|
| 33 |
"Looks like you have a looming deadline. Can I help you get started with your slide deck?",
|
| 34 |
"Hello! What topic do you have on your mind today?"
|
| 35 |
],
|
| 36 |
-
"chat_placeholder": "Write the topic or instructions here"
|
|
|
|
| 37 |
}
|
|
|
|
| 33 |
"Looks like you have a looming deadline. Can I help you get started with your slide deck?",
|
| 34 |
"Hello! What topic do you have on your mind today?"
|
| 35 |
],
|
| 36 |
+
"chat_placeholder": "Write the topic or instructions here",
|
| 37 |
+
"like_feedback": "If you like SlideDeck AI, please consider leaving a heart ❤\uFE0F on the [Hugging Face Space](https://huggingface.co/spaces/barunsaha/slide-deck-ai/) or a star ⭐ on [GitHub](https://github.com/barun-saha/slide-deck-ai). Your [feedback](https://forms.gle/JECFBGhjvSj7moBx9) is appreciated."
|
| 38 |
}
|