Spaces:
Runtime error
Runtime error
Commit
·
e09fe1d
1
Parent(s):
3842297
final tiny changes
Browse files- app.py +11 -10
- backend_utils.py +3 -12
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline,
|
| 3 |
-
get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES
|
|
|
|
| 4 |
|
| 5 |
st.set_page_config(
|
| 6 |
page_title="Retrieval Augmentation with Haystack",
|
|
@@ -51,39 +52,39 @@ st.radio("Answer Type:", ("Retrieval Augmented (Static news dataset)", "Retrieva
|
|
| 51 |
# QUERIES,
|
| 52 |
# key='q_drop_down', on_change=set_question)
|
| 53 |
|
| 54 |
-
st.markdown("<h5>
|
| 55 |
placeholder_plain_gpt = st.empty()
|
| 56 |
st.text(" ")
|
| 57 |
st.text(" ")
|
| 58 |
if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)":
|
| 59 |
-
st.markdown("<h5>
|
| 60 |
else:
|
| 61 |
-
st.markdown("<h5>
|
| 62 |
placeholder_retrieval_augmented = st.empty()
|
| 63 |
|
| 64 |
if st.session_state.get('query') and run_pressed:
|
| 65 |
ip = st.session_state['query']
|
| 66 |
with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 67 |
p1 = get_plain_pipeline()
|
| 68 |
-
with st.spinner('Fetching answers from GPT
|
| 69 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 70 |
answers = p1.run(ip)
|
| 71 |
placeholder_plain_gpt.markdown(answers['results'][0])
|
| 72 |
|
| 73 |
if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented":
|
| 74 |
with st.spinner(
|
| 75 |
-
'Loading Retrieval Augmented pipeline...
|
| 76 |
-
n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 77 |
p2 = get_retrieval_augmented_pipeline()
|
| 78 |
-
with st.spinner('
|
| 79 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 80 |
answers_2 = p2.run(ip)
|
| 81 |
else:
|
| 82 |
with st.spinner(
|
| 83 |
-
'Loading Retrieval Augmented pipeline... \
|
| 84 |
n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 85 |
p3 = get_web_retrieval_augmented_pipeline()
|
| 86 |
-
with st.spinner('
|
| 87 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 88 |
answers_2 = p3.run(ip)
|
| 89 |
placeholder_retrieval_augmented.markdown(answers_2['results'][0])
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline,
|
| 3 |
+
get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES,
|
| 4 |
+
PLAIN_GPT_ANS, GPT_WEB_RET_AUG_ANS, GPT_LOCAL_RET_AUG_ANS)
|
| 5 |
|
| 6 |
st.set_page_config(
|
| 7 |
page_title="Retrieval Augmentation with Haystack",
|
|
|
|
| 52 |
# QUERIES,
|
| 53 |
# key='q_drop_down', on_change=set_question)
|
| 54 |
|
| 55 |
+
st.markdown(f"<h5> {PLAIN_GPT_ANS} </h5>", unsafe_allow_html=True)
|
| 56 |
placeholder_plain_gpt = st.empty()
|
| 57 |
st.text(" ")
|
| 58 |
st.text(" ")
|
| 59 |
if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)":
|
| 60 |
+
st.markdown(f"<h5> {GPT_LOCAL_RET_AUG_ANS} </h5>", unsafe_allow_html=True)
|
| 61 |
else:
|
| 62 |
+
st.markdown(f"<h5>{GPT_WEB_RET_AUG_ANS} </h5>", unsafe_allow_html=True)
|
| 63 |
placeholder_retrieval_augmented = st.empty()
|
| 64 |
|
| 65 |
if st.session_state.get('query') and run_pressed:
|
| 66 |
ip = st.session_state['query']
|
| 67 |
with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 68 |
p1 = get_plain_pipeline()
|
| 69 |
+
with st.spinner('Fetching answers from plain GPT... '
|
| 70 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 71 |
answers = p1.run(ip)
|
| 72 |
placeholder_plain_gpt.markdown(answers['results'][0])
|
| 73 |
|
| 74 |
if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented":
|
| 75 |
with st.spinner(
|
| 76 |
+
'Loading Retrieval Augmented pipeline that can fetch relevant documents from local data store... '
|
| 77 |
+
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 78 |
p2 = get_retrieval_augmented_pipeline()
|
| 79 |
+
with st.spinner('Getting relevant documents from documented stores and calculating answers... '
|
| 80 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 81 |
answers_2 = p2.run(ip)
|
| 82 |
else:
|
| 83 |
with st.spinner(
|
| 84 |
+
'Loading Retrieval Augmented pipeline that can fetch relevant documents from the web... \
|
| 85 |
n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 86 |
p3 = get_web_retrieval_augmented_pipeline()
|
| 87 |
+
with st.spinner('Getting relevant documents from the Web and calculating answers... '
|
| 88 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
| 89 |
answers_2 = p3.run(ip)
|
| 90 |
placeholder_retrieval_augmented.markdown(answers_2['results'][0])
|
backend_utils.py
CHANGED
|
@@ -12,6 +12,9 @@ QUERIES = [
|
|
| 12 |
"Who is responsible for SVC collapse?",
|
| 13 |
"When did SVB collapse?"
|
| 14 |
]
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
@st.cache_resource(show_spinner=False)
|
|
@@ -76,18 +79,6 @@ def get_web_retrieval_augmented_pipeline():
|
|
| 76 |
return pipeline
|
| 77 |
|
| 78 |
|
| 79 |
-
# @st.cache_resource(show_spinner=False)
|
| 80 |
-
# def app_init():
|
| 81 |
-
# print("Loading Pipelines...")
|
| 82 |
-
# p1 = get_plain_pipeline()
|
| 83 |
-
# print("Loaded Plain Pipeline")
|
| 84 |
-
# p2 = get_retrieval_augmented_pipeline()
|
| 85 |
-
# print("Loaded Retrieval Augmented Pipeline")
|
| 86 |
-
# p3 = get_web_retrieval_augmented_pipeline()
|
| 87 |
-
# print("Loaded Web Retrieval Augmented Pipeline")
|
| 88 |
-
# return p1, p2, p3
|
| 89 |
-
|
| 90 |
-
|
| 91 |
if 'query' not in st.session_state:
|
| 92 |
st.session_state['query'] = ""
|
| 93 |
|
|
|
|
| 12 |
"Who is responsible for SVC collapse?",
|
| 13 |
"When did SVB collapse?"
|
| 14 |
]
|
| 15 |
+
PLAIN_GPT_ANS = "Answer with plain GPT"
|
| 16 |
+
GPT_LOCAL_RET_AUG_ANS = "Answer with Retrieval Augmented GPT (Static news dataset)"
|
| 17 |
+
GPT_WEB_RET_AUG_ANS = "Answer with Retrieval Augmented GPT (Web Search)"
|
| 18 |
|
| 19 |
|
| 20 |
@st.cache_resource(show_spinner=False)
|
|
|
|
| 79 |
return pipeline
|
| 80 |
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
if 'query' not in st.session_state:
|
| 83 |
st.session_state['query'] = ""
|
| 84 |
|