Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from pymilvus import MilvusClient | |
| import torch | |
| from model import encode_dpr_question, get_dpr_encoder | |
| from model import summarize_text, get_summarizer | |
| from model import ask_reader, get_reader | |
| TITLE = 'ReSRer: Retriever-Summarizer-Reader' | |
| INITIAL = "What is the population of NYC" | |
| st.set_page_config(page_title=TITLE) | |
| st.header(TITLE) | |
| st.markdown(''' | |
| <h5>Ask short-answer question that can be find in Wikipedia data.</h5> | |
| ''', unsafe_allow_html=True) | |
| st.markdown( | |
| 'This demo searches through 21,000,000 Wikipedia passages in real-time under the hood.') | |
| def load_models(): | |
| models = {} | |
| models['encoder'] = get_dpr_encoder() | |
| models['summarizer'] = get_summarizer() | |
| models['reader'] = get_reader() | |
| return models | |
| def load_client(): | |
| client = MilvusClient(user='resrer', password=os.environ['MILVUS_PW'], | |
| uri=f"http://{os.environ['MILVUS_HOST']}:19530", db_name='psgs_w100') | |
| return client | |
| client = load_client() | |
| models = load_models() | |
| styl = """ | |
| <style> | |
| .StatusWidget-enter-done{ | |
| position: fixed; | |
| left: 50%; | |
| top: 50%; | |
| transform: translate(-50%, -50%); | |
| } | |
| .StatusWidget-enter-done button{ | |
| display: none; | |
| } | |
| </style> | |
| """ | |
| st.markdown(styl, unsafe_allow_html=True) | |
| question = st.text_input("Question", INITIAL) | |
| col1, col2, col3 = st.columns(3) | |
| if col1.button("What is the capital of South Korea"): | |
| question = "What is the capital of South Korea" | |
| if col2.button("What is the most famous building in Paris"): | |
| question = "What is the most famous building in Paris" | |
| if col3.button("Who is the actor of Harry Potter"): | |
| question = "Who is the actor of Harry Potter" | |
| def main(question: str): | |
| if question in st.session_state: | |
| print("Cache hit!") | |
| ctx, summary, answer = st.session_state[question] | |
| else: | |
| print(f"Input: {question}") | |
| # Embedding | |
| question_vectors = encode_dpr_question( | |
| models['encoder'][0], models['encoder'][1], [question]) | |
| query_vector = question_vectors.detach().cpu().numpy().tolist()[0] | |
| # Retriever | |
| results = client.search(collection_name='dpr_nq', data=[ | |
| query_vector], limit=10, output_fields=['title', 'text']) | |
| texts = [result['entity']['text'] for result in results[0]] | |
| ctx = '\n'.join(texts) | |
| # Reader | |
| [summary] = summarize_text(models['summarizer'][0], | |
| models['summarizer'][1], [ctx]) | |
| answers = ask_reader(models['reader'][0], | |
| models['reader'][1], [question], [summary]) | |
| answer = answers[0]['answer'] | |
| print(f"\nAnswer: {answer}") | |
| st.session_state[question] = (ctx, summary, answer) | |
| # Summary | |
| st.write(f"### Answer: {answer}") | |
| st.markdown('<h5>Summarized Context</h5>', unsafe_allow_html=True) | |
| st.markdown( | |
| f"<h6 style='padding: 0'>{summary}</h6><hr style='margin: 1em 0px'>", unsafe_allow_html=True) | |
| st.markdown('<h5>Original Context</h5>', unsafe_allow_html=True) | |
| st.markdown(ctx) | |
| if question: | |
| main(question) | |