Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from pymilvus import MilvusClient | |
| from model import encode_dpr_question, get_dpr_encoder | |
| from model import summarize_text, get_summarizer | |
| from model import ask_reader, get_reader | |
| TITLE = 'ReSRer: Retriever-Summarizer-Reader' | |
| INITIAL = "What is the population of NYC" | |
| st.set_page_config(page_title=TITLE) | |
| st.header(TITLE) | |
| st.markdown(''' | |
| ### Ask short-answer question that can be find in Wikipedia data. | |
| ''', unsafe_allow_html=True) | |
| def load_models(): | |
| models = {} | |
| models['encoder'] = get_dpr_encoder() | |
| models['summarizer'] = get_summarizer() | |
| models['reader'] = get_reader() | |
| return models | |
| def load_client(): | |
| client = MilvusClient(user='resrer', password=os.env['MILVUS_PW'], | |
| uri=f"http://{os.env['MILVUS_HOST']}:19530", db_name='psgs_w100') | |
| return client | |
| client = load_client() | |
| models = load_models() | |
| styl = """ | |
| <style> | |
| .StatusWidget-enter-done{ | |
| position: fixed; | |
| left: 50%; | |
| top: 50%; | |
| transform: translate(-50%, -50%); | |
| } | |
| .StatusWidget-enter-done button{ | |
| display: none; | |
| } | |
| </style> | |
| """ | |
| st.markdown(styl, unsafe_allow_html=True) | |
| question = st.text_area("Text to summarize", INITIAL, height=400) | |
| def main(question: str): | |
| if question in st.session_state: | |
| print("Cache hit!") | |
| ctx, summary, answer = st.session_state[question] | |
| else: | |
| print(f"Input: {question}") | |
| # Embedding | |
| question_vectors = encode_dpr_question( | |
| models['encoder'][0], models['encoder'][1], [question]) | |
| query_vector = question_vectors.detach().cpu().numpy().tolist()[0] | |
| # Retriever | |
| results = client.search(collection_name='dpr_nq', data=[ | |
| query_vector], limit=10, output_fields=['title', 'text']) | |
| texts = [result['entity']['text'] for result in results[0]] | |
| ctx = '\n'.join(texts) | |
| # Reader | |
| summary = summarize_text(models['summarizer'][0], | |
| models['summarizer'][1], [summary]) | |
| answers = ask_reader(models['reader'][0], | |
| models['reader'][1], [question], [ctx]) | |
| answer = answers[0]['answer'] | |
| print(f"\nAnswer: {answer}") | |
| st.session_state[question] = (ctx, summary, answer) | |
| # Summary | |
| st.markdown(answer) | |
| st.write("## Summary") | |
| st.markdown( | |
| f"<h6 style='padding: 0'>{summary}</h6><hr style='margin: 1em 0px'>", unsafe_allow_html=True) | |
| st.markdown(ctx) | |
| st.write(f"{question}", unsafe_allow_html=True) | |
| if question: | |
| main(question) | |