Spaces:
Sleeping
Sleeping
Use LangChain to get streaming response from the LLM; update progress bar to display the current status
Browse files
app.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import datetime
|
| 2 |
import logging
|
| 3 |
import pathlib
|
|
@@ -7,9 +10,7 @@ from typing import List
|
|
| 7 |
|
| 8 |
import json5
|
| 9 |
import streamlit as st
|
| 10 |
-
from langchain_community.chat_message_histories import
|
| 11 |
-
StreamlitChatMessageHistory
|
| 12 |
-
)
|
| 13 |
from langchain_core.messages import HumanMessage
|
| 14 |
from langchain_core.prompts import ChatPromptTemplate
|
| 15 |
|
|
@@ -47,17 +48,9 @@ def _get_prompt_template(is_refinement: bool) -> str:
|
|
| 47 |
return template
|
| 48 |
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# Get Mistral tokenizer for counting tokens.
|
| 54 |
-
#
|
| 55 |
-
# :return: The tokenizer.
|
| 56 |
-
# """
|
| 57 |
-
#
|
| 58 |
-
# return AutoTokenizer.from_pretrained(
|
| 59 |
-
# pretrained_model_name_or_path=GlobalConfig.HF_LLM_MODEL_NAME
|
| 60 |
-
# )
|
| 61 |
|
| 62 |
|
| 63 |
APP_TEXT = _load_strings()
|
|
@@ -66,9 +59,10 @@ APP_TEXT = _load_strings()
|
|
| 66 |
CHAT_MESSAGES = 'chat_messages'
|
| 67 |
DOWNLOAD_FILE_KEY = 'download_file_name'
|
| 68 |
IS_IT_REFINEMENT = 'is_it_refinement'
|
|
|
|
|
|
|
| 69 |
|
| 70 |
logger = logging.getLogger(__name__)
|
| 71 |
-
progress_bar = st.progress(0, text='Setting up SlideDeck AI...')
|
| 72 |
|
| 73 |
texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
|
| 74 |
captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
|
|
@@ -110,7 +104,6 @@ def build_ui():
|
|
| 110 |
with st.expander('Usage Policies and Limitations'):
|
| 111 |
display_page_footer_content()
|
| 112 |
|
| 113 |
-
progress_bar.progress(50, text='Setting up chat interface...')
|
| 114 |
set_up_chat_ui()
|
| 115 |
|
| 116 |
|
|
@@ -131,8 +124,6 @@ def set_up_chat_ui():
|
|
| 131 |
st.chat_message('ai').write(
|
| 132 |
random.choice(APP_TEXT['ai_greetings'])
|
| 133 |
)
|
| 134 |
-
progress_bar.progress(100, text='Done!')
|
| 135 |
-
progress_bar.empty()
|
| 136 |
|
| 137 |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
|
| 138 |
|
|
@@ -188,66 +179,51 @@ def set_up_chat_ui():
|
|
| 188 |
}
|
| 189 |
)
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
expanded=False
|
| 194 |
-
) as status:
|
| 195 |
-
response: dict = llm_helper.hf_api_query({
|
| 196 |
-
'inputs': formatted_template,
|
| 197 |
-
'parameters': {
|
| 198 |
-
'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
|
| 199 |
-
'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
|
| 200 |
-
'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
|
| 201 |
-
'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
|
| 202 |
-
'num_return_sequences': 1,
|
| 203 |
-
'return_full_text': False,
|
| 204 |
-
# "repetition_penalty": 0.0001
|
| 205 |
-
},
|
| 206 |
-
'options': {
|
| 207 |
-
'wait_for_model': True,
|
| 208 |
-
'use_cache': True
|
| 209 |
-
}
|
| 210 |
-
})
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
|
| 215 |
-
|
|
|
|
|
|
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
expanded=False
|
| 236 |
-
)
|
| 237 |
-
generate_slide_deck(response_cleaned)
|
| 238 |
-
status.update(label='Done!', state='complete', expanded=True)
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
|
| 246 |
-
def generate_slide_deck(json_str: str):
|
| 247 |
"""
|
| 248 |
-
Create a slide deck.
|
|
|
|
| 249 |
|
| 250 |
:param json_str: The content in *valid* JSON format.
|
|
|
|
| 251 |
"""
|
| 252 |
|
| 253 |
if DOWNLOAD_FILE_KEY in st.session_state:
|
|
@@ -269,17 +245,6 @@ def generate_slide_deck(json_str: str):
|
|
| 269 |
output_file_path=path
|
| 270 |
)
|
| 271 |
except ValueError:
|
| 272 |
-
# st.error(
|
| 273 |
-
# f"{APP_TEXT['json_parsing_error']}"
|
| 274 |
-
# f"\n\nAdditional error info: {ve}"
|
| 275 |
-
# f"\n\nHere are some sample instructions that you could try to possibly fix this error;"
|
| 276 |
-
# f" if these don't work, try rephrasing or refreshing:"
|
| 277 |
-
# f"\n\n"
|
| 278 |
-
# "- Regenerate content and fix the JSON error."
|
| 279 |
-
# "\n- Regenerate content and fix the JSON error. Quotes inside quotes should be escaped."
|
| 280 |
-
# )
|
| 281 |
-
# logger.error('%s', APP_TEXT['json_parsing_error'])
|
| 282 |
-
# logger.error('Additional error info: %s', str(ve))
|
| 283 |
st.error(
|
| 284 |
'Encountered error while parsing JSON...will fix it and retry'
|
| 285 |
)
|
|
@@ -295,8 +260,8 @@ def generate_slide_deck(json_str: str):
|
|
| 295 |
except Exception as ex:
|
| 296 |
st.error(APP_TEXT['content_generation_error'])
|
| 297 |
logger.error('Caught a generic exception: %s', str(ex))
|
| 298 |
-
|
| 299 |
-
|
| 300 |
|
| 301 |
|
| 302 |
def _is_it_refinement() -> bool:
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Streamlit app containing the UI and the application logic.
|
| 3 |
+
"""
|
| 4 |
import datetime
|
| 5 |
import logging
|
| 6 |
import pathlib
|
|
|
|
| 10 |
|
| 11 |
import json5
|
| 12 |
import streamlit as st
|
| 13 |
+
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
|
|
|
|
|
|
|
| 14 |
from langchain_core.messages import HumanMessage
|
| 15 |
from langchain_core.prompts import ChatPromptTemplate
|
| 16 |
|
|
|
|
| 48 |
return template
|
| 49 |
|
| 50 |
|
| 51 |
+
@st.cache_resource
|
| 52 |
+
def _get_llm():
|
| 53 |
+
return llm_helper.get_hf_endpoint()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
APP_TEXT = _load_strings()
|
|
|
|
| 59 |
CHAT_MESSAGES = 'chat_messages'
|
| 60 |
DOWNLOAD_FILE_KEY = 'download_file_name'
|
| 61 |
IS_IT_REFINEMENT = 'is_it_refinement'
|
| 62 |
+
APPROX_TARGET_LENGTH = GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH / 2
|
| 63 |
+
|
| 64 |
|
| 65 |
logger = logging.getLogger(__name__)
|
|
|
|
| 66 |
|
| 67 |
texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
|
| 68 |
captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
|
|
|
|
| 104 |
with st.expander('Usage Policies and Limitations'):
|
| 105 |
display_page_footer_content()
|
| 106 |
|
|
|
|
| 107 |
set_up_chat_ui()
|
| 108 |
|
| 109 |
|
|
|
|
| 124 |
st.chat_message('ai').write(
|
| 125 |
random.choice(APP_TEXT['ai_greetings'])
|
| 126 |
)
|
|
|
|
|
|
|
| 127 |
|
| 128 |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
|
| 129 |
|
|
|
|
| 179 |
}
|
| 180 |
)
|
| 181 |
|
| 182 |
+
progress_bar = st.progress(0, 'Preparing to call LLM...')
|
| 183 |
+
response = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
+
for chunk in _get_llm().stream(formatted_template):
|
| 186 |
+
response += chunk
|
| 187 |
|
| 188 |
+
# Update the progress bar
|
| 189 |
+
progress_percentage = min(len(response) / APPROX_TARGET_LENGTH, 0.95)
|
| 190 |
+
progress_bar.progress(progress_percentage, text='Streaming content...')
|
| 191 |
|
| 192 |
+
history.add_user_message(prompt)
|
| 193 |
+
history.add_ai_message(response)
|
| 194 |
|
| 195 |
+
# The content has been generated as JSON
|
| 196 |
+
# There maybe trailing ``` at the end of the response -- remove them
|
| 197 |
+
# To be careful: ``` may be part of the content as well when code is generated
|
| 198 |
+
response_cleaned = text_helper.get_clean_json(response)
|
| 199 |
|
| 200 |
+
logger.info(
|
| 201 |
+
'Cleaned JSON response:: original length: %d | cleaned length: %d',
|
| 202 |
+
len(response), len(response_cleaned)
|
| 203 |
+
)
|
| 204 |
+
logger.debug('Cleaned JSON: %s', response_cleaned)
|
| 205 |
|
| 206 |
+
# Now create the PPT file
|
| 207 |
+
progress_bar.progress(0.95, text='Searching photos and generating the slide deck...')
|
| 208 |
+
path = generate_slide_deck(response_cleaned)
|
| 209 |
+
progress_bar.progress(1.0, text='Done!')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
+
st.chat_message('ai').code(response, language='json')
|
| 212 |
+
_display_download_button(path)
|
| 213 |
+
|
| 214 |
+
logger.info(
|
| 215 |
+
'#messages in history / 2: %d',
|
| 216 |
+
len(st.session_state[CHAT_MESSAGES]) / 2
|
| 217 |
+
)
|
| 218 |
|
| 219 |
|
| 220 |
+
def generate_slide_deck(json_str: str) -> pathlib.Path:
|
| 221 |
"""
|
| 222 |
+
Create a slide deck and return the file path. In case there is any error creating the slide
|
| 223 |
+
deck, the path may be to an empty file.
|
| 224 |
|
| 225 |
:param json_str: The content in *valid* JSON format.
|
| 226 |
+
:return: The file of the .pptx file.
|
| 227 |
"""
|
| 228 |
|
| 229 |
if DOWNLOAD_FILE_KEY in st.session_state:
|
|
|
|
| 245 |
output_file_path=path
|
| 246 |
)
|
| 247 |
except ValueError:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
st.error(
|
| 249 |
'Encountered error while parsing JSON...will fix it and retry'
|
| 250 |
)
|
|
|
|
| 260 |
except Exception as ex:
|
| 261 |
st.error(APP_TEXT['content_generation_error'])
|
| 262 |
logger.error('Caught a generic exception: %s', str(ex))
|
| 263 |
+
|
| 264 |
+
return path
|
| 265 |
|
| 266 |
|
| 267 |
def _is_it_refinement() -> bool:
|