|
|
import streamlit as st |
|
|
|
|
|
st.title("Medical RAG and Reasoning App") |
|
|
st.write("This app demonstrates Retrieval-Augmented Generation (RAG) for medical question answering.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
dataset = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
model_name = "FreedomIntelligence/HuatuoGPT-o1-7B" |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, torch_dtype="auto", device_map="auto" |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
embed_model = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
df = pd.DataFrame(dataset["train"]) |
|
|
|
|
|
|
|
|
df["combined"] = df["input"] + " " + df["output"] |
|
|
|
|
|
|
|
|
st.write("Generating embeddings for the knowledge base...") |
|
|
embeddings = embed_model.encode( |
|
|
df["combined"].tolist(), show_progress_bar=True, batch_size=128 |
|
|
) |
|
|
st.write("Embeddings generated!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
def retrieve_relevant_contexts(query: str, k: int = 3) -> list: |
|
|
""" |
|
|
Retrieves the k most relevant contexts to a given query. |
|
|
|
|
|
Args: |
|
|
query (str): The user's medical query. |
|
|
k (int): The number of relevant contexts to retrieve. |
|
|
|
|
|
Returns: |
|
|
list: A list of dictionaries, each containing a relevant context. |
|
|
""" |
|
|
|
|
|
query_embedding = embed_model.encode([query])[0] |
|
|
|
|
|
|
|
|
similarities = cosine_similarity([query_embedding], embeddings)[0] |
|
|
|
|
|
|
|
|
top_k_indices = np.argsort(similarities)[-k:][::-1] |
|
|
|
|
|
contexts = [] |
|
|
for idx in top_k_indices: |
|
|
contexts.append( |
|
|
{ |
|
|
"question": df.iloc[idx]["input"], |
|
|
"answer": df.iloc[idx]["output"], |
|
|
"similarity": similarities[idx], |
|
|
} |
|
|
) |
|
|
|
|
|
return contexts |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_structured_response(query: str, contexts: list) -> str: |
|
|
""" |
|
|
Generates a detailed response using the retrieved contexts. |
|
|
|
|
|
Args: |
|
|
query (str): The user's medical query. |
|
|
contexts (list): A list of relevant contexts. |
|
|
|
|
|
Returns: |
|
|
str: The generated response. |
|
|
""" |
|
|
|
|
|
context_prompt = "\n".join( |
|
|
[ |
|
|
f"Reference {i+1}:" |
|
|
f"\nQuestion: {ctx['question']}" |
|
|
f"\nAnswer: {ctx['answer']}" |
|
|
for i, ctx in enumerate(contexts) |
|
|
] |
|
|
) |
|
|
|
|
|
prompt = f"""Based on the following references and your medical knowledge, provide a detailed response: |
|
|
|
|
|
References: |
|
|
{context_prompt} |
|
|
|
|
|
Question: {query} |
|
|
|
|
|
By considering: |
|
|
1. The key medical concepts in the question. |
|
|
2. How the reference cases relate to this question. |
|
|
3. What medical principles should be applied. |
|
|
4. Any potential complications or considerations. |
|
|
|
|
|
Give the final response: |
|
|
""" |
|
|
|
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
inputs = tokenizer( |
|
|
tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
), |
|
|
return_tensors="pt", |
|
|
).to(model.device) |
|
|
|
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=1024, |
|
|
temperature=0.7, |
|
|
num_beams=1, |
|
|
do_sample=True, |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
final_response = response.split("Give the final response:\n")[-1] |
|
|
|
|
|
return final_response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_query(query: str, k: int = 3) -> tuple: |
|
|
""" |
|
|
Processes a medical query end-to-end. |
|
|
|
|
|
Args: |
|
|
query (str): The user's medical query. |
|
|
k (int): The number of relevant contexts to retrieve. |
|
|
|
|
|
Returns: |
|
|
tuple: The generated response and the retrieved contexts. |
|
|
""" |
|
|
contexts = retrieve_relevant_contexts(query, k) |
|
|
response = generate_structured_response(query, contexts) |
|
|
return response, contexts |
|
|
|
|
|
|
|
|
query = "I've been experiencing persistent headaches and dizziness for the past week. What could be the cause?" |
|
|
|
|
|
|
|
|
response, contexts = process_query(query) |
|
|
|
|
|
|
|
|
st.write("\nQuery:", query) |
|
|
st.write("\nRelevant Contexts:") |
|
|
for i, ctx in enumerate(contexts, 1): |
|
|
st.write(f"\nReference {i} (Similarity: {ctx['similarity']:.3f}):") |
|
|
st.write(f"Q: {ctx['question']}") |
|
|
st.write(f"A: {ctx['answer']}") |
|
|
|
|
|
st.write("\nGenerated Response:") |
|
|
st.write(response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|