rag-medical / src /streamlit_app.py
khoaneem's picture
Update src/streamlit_app.py
cd876a3 verified
raw
history blame
5.87 kB
# streamlit_app.py
import os
import streamlit as st
import numpy as np
import pandas as pd
from huggingface_hub import InferenceClient
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import faiss
from sklearn.preprocessing import normalize
st.set_page_config(page_title="🩺 Medical RAG Demo", layout="wide")
# --------------------
# Config
# --------------------
EMB_FILE = "embeddings.npy"
KB_FILE = "kb.csv"
FAISS_INDEX_FILE = "faiss.index"
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
HF_MODEL_ID = "FreedomIntelligence/HuatuoGPT-o1-7B" # inference model id on HF
# --------------------
@st.cache_resource
def load_kb_and_embeddings():
# Load dataset from huggingface if not present
if not os.path.exists(KB_FILE):
ds = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k", split="train")
df = pd.DataFrame(ds)
df["combined"] = df["input"].astype(str) + " " + df["output"].astype(str)
df.to_csv(KB_FILE, index=False)
else:
df = pd.read_csv(KB_FILE)
embed_model = SentenceTransformer(EMBED_MODEL_NAME)
# compute or load embeddings
if os.path.exists(EMB_FILE) and os.path.exists(FAISS_INDEX_FILE):
embeddings = np.load(EMB_FILE)
# load faiss index
index = faiss.read_index(FAISS_INDEX_FILE)
else:
texts = df["combined"].astype(str).tolist()
embeddings = embed_model.encode(texts, show_progress_bar=True, batch_size=128)
# normalize for inner product (cosine)
faiss.normalize_L2(embeddings)
# save embeddings
np.save(EMB_FILE, embeddings)
# build faiss index
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
faiss.write_index(index, FAISS_INDEX_FILE)
return df, embed_model, embeddings, index
@st.cache_resource
def get_hf_client():
token = st.secrets.get("HF_TOKEN", None)
if not token:
st.error("HF_TOKEN not found in secrets. Please add it in your Space settings.")
return None
client = InferenceClient(model=HF_MODEL_ID, token=token)
return client
def retrieve_contexts_faiss(query, embed_model, index, df, k=3):
q_emb = embed_model.encode([query])
faiss.normalize_L2(q_emb)
D, I = index.search(q_emb, k)
results = []
for score, idx in zip(D[0], I[0]):
results.append({
"question": df.iloc[int(idx)]["input"],
"answer": df.iloc[int(idx)]["output"],
"score": float(score)
})
return results
def generate_with_hf(prompt, client, max_new_tokens=512):
# Uses Hugging Face InferenceClient text_generation
# Returns string response
res = client.text_generation(prompt, max_new_tokens=max_new_tokens)
# res could be a list/dict depending on client version
# Normalize to string:
if isinstance(res, (list, tuple)):
out = res[0].get("generated_text") if isinstance(res[0], dict) else str(res[0])
elif isinstance(res, dict):
out = res.get("generated_text") or res.get("text") or str(res)
else:
out = str(res)
return out
# --------------------
# UI
# --------------------
st.title("🩺 Medical RAG Demo (HuatuoGPT via HF Inference)")
st.markdown("Enter a clinical question or ECG report. The app retrieves similar cases and asks the model to produce a structured medical answer.")
# load resources
with st.spinner("Loading knowledge base and embeddings..."):
df, embed_model, embeddings, faiss_index = load_kb_and_embeddings()
hf_client = get_hf_client()
col1, col2 = st.columns([3,1])
with col1:
query = st.text_area("Patient query / ECG text", height=180)
k = st.number_input("Number of references to retrieve (k)", min_value=1, max_value=10, value=3)
with col2:
st.markdown("**Quick tips**")
st.markdown("- Paste ECG report or clinical presentation.")
st.markdown("- Use English for best model behavior.")
st.markdown("- HF_TOKEN must be set in Space Secrets.")
if st.button("Run RAG") and query.strip():
if hf_client is None:
st.error("Hugging Face client not available (missing HF_TOKEN).")
else:
with st.spinner("Retrieving contexts..."):
contexts = retrieve_contexts_faiss(query, embed_model, faiss_index, df, k=k)
# Build prompt (keep in English)
context_prompt = "\n\n".join([f"Reference {i+1}:\nQ: {c['question']}\nA: {c['answer']}" for i,c in enumerate(contexts)])
prompt = f"""You are a professional medical assistant specialized in cardiology.
Based on the following reference documents and your medical knowledge, provide a comprehensive response to the patient's case.
References:
{context_prompt}
Patient query / case:
{query}
When formulating your answer, please consider:
1. The key medical findings and symptoms in the patient's case.
2. How the reference cases relate to this patient's situation.
3. Evidence-based medical principles for diagnosis and treatment.
4. Possible complications, warnings, or contraindications.
Your response should include:
- Most likely diagnosis
- Suggested treatment plan
- Recommended medications with dosages if applicable
- Additional advice or warnings
Provide a clear, structured, and professional answer suitable for a healthcare professional.
Give the final response below:
"""
with st.spinner("Calling HuatuoGPT (HF Inference)..."):
gen = generate_with_hf(prompt, hf_client, max_new_tokens=512)
st.subheader("🧠 Generated response")
st.write(gen)
st.subheader("🔍 Retrieved references (top-k)")
for i,c in enumerate(contexts):
st.markdown(f"**Reference {i+1}** (score {c['score']:.3f})")
st.write("Q:", c["question"])
st.write("A:", c["answer"])
st.markdown("---")