Spaces:
Running
Running
| from langchain_core.prompts import PromptTemplate | |
| import os | |
| from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.llms.ctransformers import CTransformers | |
| from langchain.chains.retrieval_qa.base import RetrievalQA | |
| import streamlit as st | |
| import fitz # PyMuPDF | |
| from PIL import Image | |
| import io | |
| DB_FAISS_PATH = 'vectorstores/' | |
| #pdf_path = 'data/Harrisons_Internal_Medicine_2022,_21th_Edition_Vol_1_&_Vol_2_.pdf' | |
| custom_prompt_template = '''use the following pieces of information to answer the user's questions. | |
| If you don't know the answer, please just say that don't know the answer, don't try to make up an answer. | |
| Context : {context} | |
| Question : {question} | |
| only return the helpful answer below and nothing else. | |
| ''' | |
| def set_custom_prompt(): | |
| """ | |
| Prompt template for QA retrieval for vector stores | |
| """ | |
| prompt = PromptTemplate(template=custom_prompt_template, | |
| input_variables=['context', 'question']) | |
| return prompt | |
| def load_llm(): | |
| llm = CTransformers( | |
| #model='epfl-llm/meditron-7b', | |
| model = 'TheBloke/Llama-2-7B-Chat-GGML', | |
| model_type='llama', | |
| max_new_token=512, | |
| temperature=0.5 | |
| ) | |
| return llm | |
| # def load_embeddings(): | |
| # embeddings = HuggingFaceBgeEmbeddings(model_name='NeuML/pubmedbert-base-embeddings', | |
| # model_kwargs={'device': 'cpu'}) | |
| # return embeddings | |
| # def load_faiss_index(embeddings): | |
| # db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) | |
| # return db | |
| def retrieval_qa_chain(llm, prompt, db): | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type='stuff', | |
| retriever=db.as_retriever(search_kwargs={'k': 2}), | |
| return_source_documents=True, | |
| chain_type_kwargs={'prompt': prompt} | |
| ) | |
| return qa_chain | |
| def qa_bot(): | |
| embeddings = HuggingFaceBgeEmbeddings(model_name = 'sentence-transformers/all-MiniLM-L6-v2', | |
| model_kwargs = {'device':'cpu'}) | |
| db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) | |
| llm = load_llm() | |
| qa_prompt = set_custom_prompt() | |
| qa = retrieval_qa_chain(llm, qa_prompt, db) | |
| return qa | |
| def final_result(query): | |
| qa_result = qa_bot() | |
| response = qa_result({'query': query}) | |
| return response | |
| def get_pdf_page_as_image(pdf_path, page_number): | |
| document = fitz.open(pdf_path) | |
| page = document.load_page(page_number) | |
| pix = page.get_pixmap() | |
| img = Image.open(io.BytesIO(pix.tobytes())) | |
| return img | |
| # # Initialize the bot | |
| # bot = qa_bot() | |
| # Streamlit webpage title | |
| st.title('Medical Chatbot') | |
| # User input | |
| user_query = st.text_input("Please enter your question:") | |
| # Button to get answer | |
| if st.button('Get Answer'): | |
| if user_query: | |
| # Call the function from your chatbot script | |
| response = final_result(user_query) | |
| if response: | |
| # Displaying the response | |
| st.write("### Answer") | |
| st.write(response['result']) | |
| # Displaying source document details if available | |
| if 'source_documents' in response: | |
| st.write("### Source Document Information") | |
| for doc in response['source_documents']: | |
| # Retrieve and format page content by replacing '\n' with new line | |
| formatted_content = doc.page_content.replace("\\n", "\n") | |
| st.write("#### Document Content") | |
| st.text_area(label="Page Content", value=formatted_content, height=300) | |
| # Retrieve source and page from metadata | |
| source = doc.metadata['source'] | |
| page = doc.metadata['page'] | |
| st.write(f"Source: {source}") | |
| st.write(f"Page Number: {page+1}") | |
| # Display the PDF page as an image | |
| # pdf_page_image = get_pdf_page_as_image(pdf_path, page) | |
| # st.image(pdf_page_image, caption=f"Page {page+1} from {source}") | |
| else: | |
| st.write("Sorry, I couldn't find an answer to your question.") | |
| else: | |
| st.write("Please enter a question to get an answer.") |