khoaneem commited on
Commit
cd876a3
·
verified ·
1 Parent(s): 3f8f41d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +154 -33
src/streamlit_app.py CHANGED
@@ -1,40 +1,161 @@
1
- import altair as alt
 
 
2
  import numpy as np
3
  import pandas as pd
4
- import streamlit as st
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # streamlit_app.py
2
+ import os
3
+ import streamlit as st
4
  import numpy as np
5
  import pandas as pd
6
+ from huggingface_hub import InferenceClient
7
+ from sentence_transformers import SentenceTransformer
8
+ from datasets import load_dataset
9
+ import faiss
10
+ from sklearn.preprocessing import normalize
11
 
12
+ st.set_page_config(page_title="🩺 Medical RAG Demo", layout="wide")
13
+
14
+ # --------------------
15
+ # Config
16
+ # --------------------
17
+ EMB_FILE = "embeddings.npy"
18
+ KB_FILE = "kb.csv"
19
+ FAISS_INDEX_FILE = "faiss.index"
20
+ EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
21
+ HF_MODEL_ID = "FreedomIntelligence/HuatuoGPT-o1-7B" # inference model id on HF
22
+ # --------------------
23
+
24
+ @st.cache_resource
25
+ def load_kb_and_embeddings():
26
+ # Load dataset from huggingface if not present
27
+ if not os.path.exists(KB_FILE):
28
+ ds = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k", split="train")
29
+ df = pd.DataFrame(ds)
30
+ df["combined"] = df["input"].astype(str) + " " + df["output"].astype(str)
31
+ df.to_csv(KB_FILE, index=False)
32
+ else:
33
+ df = pd.read_csv(KB_FILE)
34
+
35
+ embed_model = SentenceTransformer(EMBED_MODEL_NAME)
36
+
37
+ # compute or load embeddings
38
+ if os.path.exists(EMB_FILE) and os.path.exists(FAISS_INDEX_FILE):
39
+ embeddings = np.load(EMB_FILE)
40
+ # load faiss index
41
+ index = faiss.read_index(FAISS_INDEX_FILE)
42
+ else:
43
+ texts = df["combined"].astype(str).tolist()
44
+ embeddings = embed_model.encode(texts, show_progress_bar=True, batch_size=128)
45
+ # normalize for inner product (cosine)
46
+ faiss.normalize_L2(embeddings)
47
+ # save embeddings
48
+ np.save(EMB_FILE, embeddings)
49
+ # build faiss index
50
+ dim = embeddings.shape[1]
51
+ index = faiss.IndexFlatIP(dim)
52
+ index.add(embeddings)
53
+ faiss.write_index(index, FAISS_INDEX_FILE)
54
+
55
+ return df, embed_model, embeddings, index
56
+
57
+ @st.cache_resource
58
+ def get_hf_client():
59
+ token = st.secrets.get("HF_TOKEN", None)
60
+ if not token:
61
+ st.error("HF_TOKEN not found in secrets. Please add it in your Space settings.")
62
+ return None
63
+ client = InferenceClient(model=HF_MODEL_ID, token=token)
64
+ return client
65
+
66
+ def retrieve_contexts_faiss(query, embed_model, index, df, k=3):
67
+ q_emb = embed_model.encode([query])
68
+ faiss.normalize_L2(q_emb)
69
+ D, I = index.search(q_emb, k)
70
+ results = []
71
+ for score, idx in zip(D[0], I[0]):
72
+ results.append({
73
+ "question": df.iloc[int(idx)]["input"],
74
+ "answer": df.iloc[int(idx)]["output"],
75
+ "score": float(score)
76
+ })
77
+ return results
78
+
79
+ def generate_with_hf(prompt, client, max_new_tokens=512):
80
+ # Uses Hugging Face InferenceClient text_generation
81
+ # Returns string response
82
+ res = client.text_generation(prompt, max_new_tokens=max_new_tokens)
83
+ # res could be a list/dict depending on client version
84
+ # Normalize to string:
85
+ if isinstance(res, (list, tuple)):
86
+ out = res[0].get("generated_text") if isinstance(res[0], dict) else str(res[0])
87
+ elif isinstance(res, dict):
88
+ out = res.get("generated_text") or res.get("text") or str(res)
89
+ else:
90
+ out = str(res)
91
+ return out
92
+
93
+ # --------------------
94
+ # UI
95
+ # --------------------
96
+ st.title("🩺 Medical RAG Demo (HuatuoGPT via HF Inference)")
97
+ st.markdown("Enter a clinical question or ECG report. The app retrieves similar cases and asks the model to produce a structured medical answer.")
98
 
99
+ # load resources
100
+ with st.spinner("Loading knowledge base and embeddings..."):
101
+ df, embed_model, embeddings, faiss_index = load_kb_and_embeddings()
102
 
103
+ hf_client = get_hf_client()
104
+
105
+ col1, col2 = st.columns([3,1])
106
+ with col1:
107
+ query = st.text_area("Patient query / ECG text", height=180)
108
+ k = st.number_input("Number of references to retrieve (k)", min_value=1, max_value=10, value=3)
109
+
110
+ with col2:
111
+ st.markdown("**Quick tips**")
112
+ st.markdown("- Paste ECG report or clinical presentation.")
113
+ st.markdown("- Use English for best model behavior.")
114
+ st.markdown("- HF_TOKEN must be set in Space Secrets.")
115
+
116
+ if st.button("Run RAG") and query.strip():
117
+ if hf_client is None:
118
+ st.error("Hugging Face client not available (missing HF_TOKEN).")
119
+ else:
120
+ with st.spinner("Retrieving contexts..."):
121
+ contexts = retrieve_contexts_faiss(query, embed_model, faiss_index, df, k=k)
122
+
123
+ # Build prompt (keep in English)
124
+ context_prompt = "\n\n".join([f"Reference {i+1}:\nQ: {c['question']}\nA: {c['answer']}" for i,c in enumerate(contexts)])
125
+ prompt = f"""You are a professional medical assistant specialized in cardiology.
126
+ Based on the following reference documents and your medical knowledge, provide a comprehensive response to the patient's case.
127
+
128
+ References:
129
+ {context_prompt}
130
+
131
+ Patient query / case:
132
+ {query}
133
+
134
+ When formulating your answer, please consider:
135
+ 1. The key medical findings and symptoms in the patient's case.
136
+ 2. How the reference cases relate to this patient's situation.
137
+ 3. Evidence-based medical principles for diagnosis and treatment.
138
+ 4. Possible complications, warnings, or contraindications.
139
+
140
+ Your response should include:
141
+ - Most likely diagnosis
142
+ - Suggested treatment plan
143
+ - Recommended medications with dosages if applicable
144
+ - Additional advice or warnings
145
+
146
+ Provide a clear, structured, and professional answer suitable for a healthcare professional.
147
+ Give the final response below:
148
  """
149
 
150
+ with st.spinner("Calling HuatuoGPT (HF Inference)..."):
151
+ gen = generate_with_hf(prompt, hf_client, max_new_tokens=512)
152
+
153
+ st.subheader("🧠 Generated response")
154
+ st.write(gen)
155
+
156
+ st.subheader("🔍 Retrieved references (top-k)")
157
+ for i,c in enumerate(contexts):
158
+ st.markdown(f"**Reference {i+1}** (score {c['score']:.3f})")
159
+ st.write("Q:", c["question"])
160
+ st.write("A:", c["answer"])
161
+ st.markdown("---")