Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import sqlalchemy | |
| import sqlite_vss | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| from langchain import OpenAI | |
| from langchain.callbacks import get_openai_callback | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.chains.conversation.memory import ConversationBufferMemory | |
| from langchain.embeddings import GPT4AllEmbeddings | |
| from sqlalchemy import event | |
| from chat_history import insert_chat_history, insert_chat_history_articles | |
| from css import load_css | |
| from custom_pgvector import CustomPGVector | |
| from message import Message | |
| CONNECTION_STRING = "sqlite:///data/sorbobot.db" | |
| st.set_page_config(layout="wide") | |
| st.title("Sorbobot - Le futur de la recherche scientifique interactive") | |
| chat_column, doc_column = st.columns([2, 1]) | |
| def connect() -> sqlalchemy.engine.Connection: | |
| engine = sqlalchemy.create_engine(CONNECTION_STRING) | |
| def receive_connect(connection, _): | |
| connection.enable_load_extension(True) | |
| sqlite_vss.load(connection) | |
| connection.enable_load_extension(False) | |
| conn = engine.connect() | |
| return conn | |
| conn = connect() | |
| def initialize_session_state(): | |
| if "history" not in st.session_state: | |
| st.session_state.history = [] | |
| if "token_count" not in st.session_state: | |
| st.session_state.token_count = 0 | |
| if "conversation" not in st.session_state: | |
| embeddings = GPT4AllEmbeddings() | |
| db = CustomPGVector( | |
| embedding_function=embeddings, | |
| table_name="article", | |
| column_name="abstract_embedding", | |
| connection=conn, | |
| ) | |
| retriever = db.as_retriever() | |
| llm = OpenAI( | |
| temperature=0, | |
| openai_api_key=os.environ["OPENAI_API_KEY"], | |
| model="text-davinci-003", | |
| ) | |
| memory = ConversationBufferMemory( | |
| output_key="answer", memory_key="chat_history", return_messages=True | |
| ) | |
| st.session_state.conversation = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| verbose=True, | |
| memory=memory, | |
| return_source_documents=True, | |
| ) | |
| def on_click_callback(): | |
| with get_openai_callback() as cb: | |
| human_prompt = st.session_state.human_prompt | |
| llm_response = st.session_state.conversation(human_prompt) | |
| st.session_state.history.append(Message("human", human_prompt)) | |
| st.session_state.history.append( | |
| Message( | |
| "ai", llm_response["answer"], documents=llm_response["source_documents"] | |
| ) | |
| ) | |
| st.session_state.token_count += cb.total_tokens | |
| # history_id = insert_chat_history(conn, human_prompt, llm_response["answer"]) | |
| # insert_chat_history_articles(conn, history_id, llm_response["source_documents"]) | |
| load_css() | |
| initialize_session_state() | |
| with chat_column: | |
| chat_placeholder = st.container() | |
| prompt_placeholder = st.form("chat-form") | |
| information_placeholder = st.empty() | |
| with chat_placeholder: | |
| for chat in st.session_state.history: | |
| div = f""" | |
| <div class="chat-row | |
| {'' if chat.origin == 'ai' else 'row-reverse'}"> | |
| <img class="chat-icon" src="./app/static/{ | |
| 'ai_icon.png' if chat.origin == 'ai' | |
| else 'user_icon.png'}" | |
| width=32 height=32> | |
| <div class="chat-bubble | |
| {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}"> | |
| ​{chat.message} | |
| </div> | |
| </div> | |
| """ | |
| st.markdown(div, unsafe_allow_html=True) | |
| for _ in range(3): | |
| st.markdown("") | |
| with prompt_placeholder: | |
| st.markdown("**Chat**") | |
| cols = st.columns((6, 1)) | |
| cols[0].text_input( | |
| "Chat", | |
| value="Hello bot", | |
| label_visibility="collapsed", | |
| key="human_prompt", | |
| ) | |
| cols[1].form_submit_button( | |
| "Submit", | |
| type="primary", | |
| on_click=on_click_callback, | |
| ) | |
| information_placeholder.caption( | |
| f""" | |
| Used {st.session_state.token_count} tokens \n | |
| Debug Langchain conversation: | |
| {st.session_state.conversation.memory.buffer} | |
| """ | |
| ) | |
| components.html( | |
| """ | |
| <script> | |
| const streamlitDoc = window.parent.document; | |
| const buttons = Array.from( | |
| streamlitDoc.querySelectorAll('.stButton > button') | |
| ); | |
| const submitButton = buttons.find( | |
| el => el.innerText === 'Submit' | |
| ); | |
| streamlitDoc.addEventListener('keydown', function(e) { | |
| switch (e.key) { | |
| case 'Enter': | |
| submitButton.click(); | |
| break; | |
| } | |
| }); | |
| </script> | |
| """, | |
| height=0, | |
| width=0, | |
| ) | |
| with doc_column: | |
| if len(st.session_state.history) > 0: | |
| st.markdown("**Source documents**") | |
| for doc in st.session_state.history[-1].documents: | |
| doc_content = json.loads(doc.page_content) | |
| expander = st.expander(doc_content["title"]) | |
| expander.markdown(f"**DOI : {doc_content['doi']}**") | |
| expander.markdown(doc_content["abstract"]) | |
| expander.markdown(f"**Authors** : {doc_content['authors']}") | |
| expander.markdown(f"**Keywords** : {doc_content['keywords']}") | |
| expander.markdown(f"**Distance** : {doc_content['distance']}") | |