Spaces:
Sleeping
Sleeping
| # streamlit_app.py | |
| import os | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| from huggingface_hub import InferenceClient | |
| from sentence_transformers import SentenceTransformer | |
| from datasets import load_dataset | |
| import faiss | |
| from sklearn.preprocessing import normalize | |
| st.set_page_config(page_title="🩺 Medical RAG Demo", layout="wide") | |
| # -------------------- | |
| # Config | |
| # -------------------- | |
| EMB_FILE = "embeddings.npy" | |
| KB_FILE = "kb.csv" | |
| FAISS_INDEX_FILE = "faiss.index" | |
| EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| HF_MODEL_ID = "FreedomIntelligence/HuatuoGPT-o1-7B" # inference model id on HF | |
| # -------------------- | |
| def load_kb_and_embeddings(): | |
| # Load dataset from huggingface if not present | |
| if not os.path.exists(KB_FILE): | |
| ds = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k", split="train") | |
| df = pd.DataFrame(ds) | |
| df["combined"] = df["input"].astype(str) + " " + df["output"].astype(str) | |
| df.to_csv(KB_FILE, index=False) | |
| else: | |
| df = pd.read_csv(KB_FILE) | |
| embed_model = SentenceTransformer(EMBED_MODEL_NAME) | |
| # compute or load embeddings | |
| if os.path.exists(EMB_FILE) and os.path.exists(FAISS_INDEX_FILE): | |
| embeddings = np.load(EMB_FILE) | |
| # load faiss index | |
| index = faiss.read_index(FAISS_INDEX_FILE) | |
| else: | |
| texts = df["combined"].astype(str).tolist() | |
| embeddings = embed_model.encode(texts, show_progress_bar=True, batch_size=128) | |
| # normalize for inner product (cosine) | |
| faiss.normalize_L2(embeddings) | |
| # save embeddings | |
| np.save(EMB_FILE, embeddings) | |
| # build faiss index | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| index.add(embeddings) | |
| faiss.write_index(index, FAISS_INDEX_FILE) | |
| return df, embed_model, embeddings, index | |
| def get_hf_client(): | |
| token = st.secrets.get("HF_TOKEN", None) | |
| if not token: | |
| st.error("HF_TOKEN not found in secrets. Please add it in your Space settings.") | |
| return None | |
| client = InferenceClient(model=HF_MODEL_ID, token=token) | |
| return client | |
| def retrieve_contexts_faiss(query, embed_model, index, df, k=3): | |
| q_emb = embed_model.encode([query]) | |
| faiss.normalize_L2(q_emb) | |
| D, I = index.search(q_emb, k) | |
| results = [] | |
| for score, idx in zip(D[0], I[0]): | |
| results.append({ | |
| "question": df.iloc[int(idx)]["input"], | |
| "answer": df.iloc[int(idx)]["output"], | |
| "score": float(score) | |
| }) | |
| return results | |
| def generate_with_hf(prompt, client, max_new_tokens=512): | |
| # Uses Hugging Face InferenceClient text_generation | |
| # Returns string response | |
| res = client.text_generation(prompt, max_new_tokens=max_new_tokens) | |
| # res could be a list/dict depending on client version | |
| # Normalize to string: | |
| if isinstance(res, (list, tuple)): | |
| out = res[0].get("generated_text") if isinstance(res[0], dict) else str(res[0]) | |
| elif isinstance(res, dict): | |
| out = res.get("generated_text") or res.get("text") or str(res) | |
| else: | |
| out = str(res) | |
| return out | |
| # -------------------- | |
| # UI | |
| # -------------------- | |
| st.title("🩺 Medical RAG Demo (HuatuoGPT via HF Inference)") | |
| st.markdown("Enter a clinical question or ECG report. The app retrieves similar cases and asks the model to produce a structured medical answer.") | |
| # load resources | |
| with st.spinner("Loading knowledge base and embeddings..."): | |
| df, embed_model, embeddings, faiss_index = load_kb_and_embeddings() | |
| hf_client = get_hf_client() | |
| col1, col2 = st.columns([3,1]) | |
| with col1: | |
| query = st.text_area("Patient query / ECG text", height=180) | |
| k = st.number_input("Number of references to retrieve (k)", min_value=1, max_value=10, value=3) | |
| with col2: | |
| st.markdown("**Quick tips**") | |
| st.markdown("- Paste ECG report or clinical presentation.") | |
| st.markdown("- Use English for best model behavior.") | |
| st.markdown("- HF_TOKEN must be set in Space Secrets.") | |
| if st.button("Run RAG") and query.strip(): | |
| if hf_client is None: | |
| st.error("Hugging Face client not available (missing HF_TOKEN).") | |
| else: | |
| with st.spinner("Retrieving contexts..."): | |
| contexts = retrieve_contexts_faiss(query, embed_model, faiss_index, df, k=k) | |
| # Build prompt (keep in English) | |
| context_prompt = "\n\n".join([f"Reference {i+1}:\nQ: {c['question']}\nA: {c['answer']}" for i,c in enumerate(contexts)]) | |
| prompt = f"""You are a professional medical assistant specialized in cardiology. | |
| Based on the following reference documents and your medical knowledge, provide a comprehensive response to the patient's case. | |
| References: | |
| {context_prompt} | |
| Patient query / case: | |
| {query} | |
| When formulating your answer, please consider: | |
| 1. The key medical findings and symptoms in the patient's case. | |
| 2. How the reference cases relate to this patient's situation. | |
| 3. Evidence-based medical principles for diagnosis and treatment. | |
| 4. Possible complications, warnings, or contraindications. | |
| Your response should include: | |
| - Most likely diagnosis | |
| - Suggested treatment plan | |
| - Recommended medications with dosages if applicable | |
| - Additional advice or warnings | |
| Provide a clear, structured, and professional answer suitable for a healthcare professional. | |
| Give the final response below: | |
| """ | |
| with st.spinner("Calling HuatuoGPT (HF Inference)..."): | |
| gen = generate_with_hf(prompt, hf_client, max_new_tokens=512) | |
| st.subheader("🧠 Generated response") | |
| st.write(gen) | |
| st.subheader("🔍 Retrieved references (top-k)") | |
| for i,c in enumerate(contexts): | |
| st.markdown(f"**Reference {i+1}** (score {c['score']:.3f})") | |
| st.write("Q:", c["question"]) | |
| st.write("A:", c["answer"]) | |
| st.markdown("---") | |