hassanmzia commited on
Commit
2d51bc9
·
verified ·
1 Parent(s): 65efd89

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +1710 -0
  3. female.jpg +3 -0
  4. requirements.txt +5 -0
  5. x-ray-chest.png +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ female.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,1710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ # General
5
+ import os
6
+ import kagglehub
7
+ import pandas as pd
8
+ import json
9
+ from typing import Literal
10
+ from datasets import load_dataset
11
+ import random
12
+
13
+ #Markdown
14
+ from IPython.display import Markdown, display, Image
15
+
16
+ # Image
17
+ from PIL import Image
18
+
19
+ # langchain for llms
20
+ from langchain_groq import ChatGroq
21
+
22
+ # Langchain
23
+ from langchain.prompts import PromptTemplate, ChatPromptTemplate
24
+ from langchain.output_parsers import StructuredOutputParser, ResponseSchema
25
+ from langchain_core.output_parsers import JsonOutputParser
26
+ from langchain_core.messages import HumanMessage
27
+ from langgraph.checkpoint.memory import MemorySaver
28
+ from langgraph.graph import END, START, StateGraph, MessagesState
29
+ from langgraph.prebuilt import ToolNode
30
+ from langchain_core.tools import tool
31
+
32
+ # Hugging Face
33
+ from transformers import AutoModelForImageClassification, AutoProcessor
34
+
35
+ from langchain_huggingface import HuggingFaceEmbeddings
36
+
37
+
38
+ # Extra libraries
39
+ from pydantic import BaseModel, Field, model_validator
40
+
41
+ # Advanced RAG
42
+ from langchain_core.documents import Document
43
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
44
+ from langchain.embeddings import HuggingFaceEmbeddings
45
+ from langchain_community.vectorstores import Chroma
46
+ from langchain.retrievers.multi_query import MultiQueryRetriever
47
+ from langchain_core.runnables import RunnablePassthrough
48
+ from langchain_core.output_parsers import StrOutputParser
49
+
50
+
51
+ # ## APIs
52
+
53
+
54
+
55
+ os.environ["SERPER_API_KEY"] = os.getenv("SERPER_API_KEY")
56
+ os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
57
+ os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
58
+
59
+ GROQ_API_KEY = os.environ["GROQ_API_KEY"]
60
+ HF_TOKEN = os.environ["HF_TOKEN"]
61
+
62
+
63
+ # ## Setup LLM (Llama 3.3 via Groq)
64
+
65
+ # Note: Model 3.2 70b is not available on Groq any more
66
+ # We will be using 3.3 from Now on
67
+
68
+
69
+
70
+ os.environ["GROQ_API_KEY"] = GROQ_API_KEY
71
+
72
+ #model_3_2 = 'llama-3.2-11b-text-preview' => his model has been removed from Groq platform
73
+ model_3_2_small = 'llama-3.1-8b-instant' # Smaller Model 3 Billion parameters if you need speed
74
+ model_3_3 ='llama-3.3-70b-versatile' # Very Large and Versatile Model with 70 Billion parameters
75
+
76
+ llm = ChatGroq(
77
+ model= model_3_3, #
78
+ temperature=0,
79
+ max_tokens=None,
80
+ timeout=None,
81
+ max_retries=2,
82
+ # groq_api_key=os.getenv("GROQ_API_KEY")
83
+ # other params...
84
+ )
85
+
86
+ # A test message
87
+ # new text:
88
+ response = llm.invoke("hi, Please generate 10 unique Dutch names for both male and female?")
89
+ response
90
+
91
+
92
+
93
+ display(Markdown(response.content))
94
+
95
+
96
+ # # First Agent: Chatbot Agent
97
+
98
+
99
+
100
+
101
+ from typing import Annotated
102
+ from typing_extensions import TypedDict
103
+ from langgraph.graph import StateGraph, START, END
104
+ from langgraph.graph.message import add_messages
105
+
106
+
107
+ class ChatState(TypedDict):
108
+ # Messages have the type "list". The `add_messages` function
109
+ # in the annotation defines how this state key should be updated
110
+ # (in this case, it appends messages to the list, rather than overwriting them)
111
+ messages: Annotated[list, add_messages]
112
+
113
+
114
+ chat_graph = StateGraph(ChatState)
115
+
116
+ def chatbot_agent(state: ChatState):
117
+ return {"messages": [llm.invoke(state["messages"])]}
118
+
119
+ # The first argument is the unique node name
120
+ # The second argument is the function or object that will be called whenever
121
+ # the node is used.
122
+ chat_graph.add_node("chatbot_agent", chatbot_agent)
123
+ chat_graph.add_edge(START, "chatbot_agent")
124
+ chat_graph.add_edge("chatbot_agent", END)
125
+
126
+ # Finally, we'll want to be able to run our graph. To do so, call "compile()"
127
+ # We basically now give our AI Agent
128
+ graph_app = chat_graph.compile()
129
+
130
+ # Persistent state to maintain conversation history
131
+ persistent_state = {"messages": []} # Start with an empty message list
132
+
133
+
134
+
135
+
136
+ from IPython.display import Image, display
137
+ display(Image(graph_app.get_graph(xray=True).draw_mermaid_png()))
138
+
139
+
140
+
141
+
142
+ from typing import Annotated
143
+ from typing_extensions import TypedDict
144
+ from langgraph.graph import StateGraph, START, END
145
+ from langgraph.graph.message import add_messages
146
+ from IPython.display import display, Markdown
147
+
148
+ class ChatState(TypedDict):
149
+ messages: Annotated[list, add_messages]
150
+
151
+ chat_graph = StateGraph(ChatState)
152
+
153
+ def chatbot_agent(state: ChatState):
154
+ # Assuming `llm` is your language model that can handle the conversation history
155
+ return {"messages": [llm.invoke(state["messages"])]}
156
+
157
+ chat_graph.add_node("chatbot_agent", chatbot_agent)
158
+ chat_graph.add_edge(START, "chatbot_agent")
159
+ chat_graph.add_edge("chatbot_agent", END)
160
+
161
+ graph_app = chat_graph.compile()
162
+
163
+ # Persistent state to maintain conversation history
164
+ persistent_state = {"messages": []} # Start with an empty message list
165
+
166
+ def stream_graph_updates(user_input: str):
167
+ global persistent_state
168
+ # Append the user's message to the persistent state
169
+ persistent_state["messages"].append(("user", user_input))
170
+
171
+ is_finished = False
172
+ for event in graph_app.stream(persistent_state):
173
+ for value in event.values():
174
+ last_msg = value["messages"][-1]
175
+ display(Markdown("Assistant: " + last_msg.content))
176
+
177
+ # Append the assistant's response to the persistent state
178
+ persistent_state["messages"].append(("assistant", last_msg.content))
179
+
180
+ finish_reason = last_msg.response_metadata.get("finish_reason")
181
+ if finish_reason == "stop":
182
+ is_finished = True
183
+ break
184
+ if is_finished:
185
+ break
186
+
187
+ while True:
188
+ try:
189
+ user_input = input('User:')
190
+ if user_input.lower() in ["quit", "exit", "q"]:
191
+ print("Thank you and Goodbye!")
192
+ break
193
+
194
+ stream_graph_updates(user_input)
195
+ except Exception as e:
196
+ print(f"An error occurred: {e}")
197
+ break
198
+
199
+
200
+
201
+
202
+
203
+
204
+ # # Second Agent: Add Search to Chatbot to make it Stronger
205
+
206
+
207
+ from langchain_community.tools import GoogleSerperResults
208
+ from typing import List, Annotated
209
+ from langchain_core.messages import BaseMessage
210
+ from langgraph.prebuilt import ToolNode, create_react_agent
211
+ import operator
212
+ import functools
213
+
214
+ class ChatState(TypedDict):
215
+ # Messages have the type "list". The `add_messages` function
216
+ # in the annotation defines how this state key should be updated
217
+ # (in this case, it appends messages to the list, rather than overwriting them)
218
+ messages: Annotated[list, add_messages]
219
+
220
+ def agent_node(state, agent, name):
221
+ result = agent.invoke(state)
222
+ return {
223
+ "messages": [HumanMessage(content=result["messages"][-1].content, name=name)]
224
+ }
225
+
226
+
227
+ class SearchState(TypedDict):
228
+ # A message is added after each team member finishes
229
+ messages: Annotated[List[BaseMessage], operator.add]
230
+
231
+ # Search Tool
232
+
233
+ serper_tool = GoogleSerperResults(
234
+ num_results=5,
235
+ # how many Google results to return
236
+ )
237
+
238
+ search_agent = create_react_agent(llm, tools=[serper_tool])
239
+ search_node = functools.partial(agent_node,
240
+ agent=search_agent,
241
+ name="search_agent")
242
+
243
+
244
+ # The first argument is the unique node name
245
+ # The second argument is the function or object that will be called whenever
246
+ # the node is used.
247
+ search_graph = StateGraph(SearchState)
248
+ search_graph.add_node("search_agent", search_node)
249
+ search_graph.add_edge(START, "search_agent")
250
+ search_graph.add_edge("search_agent", END)
251
+
252
+ # Finally, we'll want to be able to run our graph. To do so, call "compile()"
253
+ # We basically now give our AI Agent
254
+ search_app = search_graph.compile()
255
+
256
+
257
+
258
+
259
+
260
+
261
+ from IPython.display import Image, display
262
+ display(Image(search_app.get_graph(xray=True).draw_mermaid_png()))
263
+
264
+
265
+
266
+
267
+ from langchain_community.tools import GoogleSerperResults
268
+ from typing import List, Annotated
269
+ from langchain_core.messages import BaseMessage, HumanMessage
270
+ from langgraph.prebuilt import ToolNode, create_react_agent
271
+ from langgraph.graph import StateGraph, START, END
272
+ from langgraph.graph.message import add_messages
273
+ from IPython.display import display, Markdown
274
+ import operator
275
+ import functools
276
+
277
+ class ChatState(TypedDict):
278
+ messages: Annotated[List[BaseMessage], operator.add]
279
+
280
+ def agent_node(state, agent, name):
281
+ result = agent.invoke(state)
282
+ return {
283
+ "messages": [HumanMessage(content=result["messages"][-1].content, name=name)]
284
+ }
285
+
286
+ class SearchState(TypedDict):
287
+ messages: Annotated[List[BaseMessage], operator.add]
288
+
289
+ # Search Tool
290
+ serper_tool = GoogleSerperResults(num_results=5) # how many Google results to return
291
+
292
+ search_agent = create_react_agent(llm, tools=[serper_tool])
293
+ search_node = functools.partial(agent_node, agent=search_agent, name="search_agent")
294
+
295
+ # Create the search graph
296
+ search_graph = StateGraph(SearchState)
297
+ search_graph.add_node("search_agent", search_node)
298
+ search_graph.add_edge(START, "search_agent")
299
+ search_graph.add_edge("search_agent", END)
300
+
301
+ # Compile the search graph
302
+ search_app = search_graph.compile()
303
+
304
+ # Persistent state to maintain conversation history
305
+ persistent_state = {"messages": []} # Start with an empty message list
306
+
307
+ def stream_graph_updates(user_input: str):
308
+ global persistent_state
309
+ # Append the user's message to the persistent state
310
+ persistent_state["messages"].append(HumanMessage(content=user_input))
311
+
312
+ # Display "Searching the Web Now..." message
313
+ display(Markdown("**Assistant:** Searching the Web Now..."))
314
+
315
+ is_finished = False
316
+ for event in search_app.stream(persistent_state):
317
+ for value in event.values():
318
+ last_msg = value["messages"][-1]
319
+ display(Markdown("**Assistant:** " + last_msg.content))
320
+
321
+ # Append the assistant's response to the persistent state
322
+ persistent_state["messages"].append(last_msg)
323
+
324
+ finish_reason = last_msg.response_metadata.get("finish_reason")
325
+ if finish_reason == "stop":
326
+ is_finished = True
327
+ break
328
+ if is_finished:
329
+ break
330
+
331
+ while True:
332
+ try:
333
+ user_input = input('User:')
334
+ if user_input.lower() in ["quit", "exit", "q"]:
335
+ print("Thank you and Goodbye!")
336
+ break
337
+
338
+ stream_graph_updates(user_input)
339
+ except Exception as e:
340
+ print(f"An error occurred: {e}")
341
+ break
342
+
343
+
344
+ # # Step 1: Medical Database Preparation
345
+ # This step involves preparing and enhancing patient data to be used throughout the simulation.
346
+
347
+ # ## 1.1 Load Dataset
348
+
349
+ # ### 1.1.1 Disease Symptoms and Patient Profile Dataset
350
+ # Ensure you have downloaded it and placed it in your project directory.
351
+ # - https://www.kaggle.com/datasets/uom190346a/disease-symptoms-and-patient-profile-dataset
352
+
353
+
354
+
355
+
356
+ # Download latest version
357
+ path = kagglehub.dataset_download("uom190346a/disease-symptoms-and-patient-profile-dataset")
358
+ print("Path to dataset files:", path)
359
+
360
+
361
+
362
+
363
+ patient_df = pd.read_csv(path+'/Disease_symptom_and_patient_profile_dataset.csv')
364
+ patient_df.shape
365
+
366
+
367
+
368
+
369
+ patient_df.head()
370
+
371
+
372
+
373
+
374
+ # Calculate the counts of each gender
375
+ female_count = patient_df[patient_df['Gender'] == 'Female'].shape[0]
376
+ male_count = patient_df[patient_df['Gender'] == 'Male'].shape[0]
377
+
378
+ # Calculate the ratio
379
+ ratio = female_count / male_count
380
+ print(f"The ratio of Female to Male is {ratio}:1")
381
+
382
+
383
+
384
+
385
+
386
+ patient_df['Disease'].value_counts().head(20)
387
+
388
+
389
+ # **prepare_medical_dataset Code in One Plalce**
390
+
391
+
392
+
393
+ def prepare_medical_dataset(path, file_name):
394
+ patient_df = pd.read_csv(path+file_name)
395
+ return patient_df
396
+
397
+ path = kagglehub.dataset_download("uom190346a/disease-symptoms-and-patient-profile-dataset")
398
+ file_name = '/Disease_symptom_and_patient_profile_dataset.csv'
399
+ patient_df = prepare_medical_dataset(path, file_name)
400
+
401
+
402
+ # ### 1.1.2 Chest X-Ray Images (Pneumonia)
403
+ #
404
+ # - https://huggingface.co/lxyuan/vit-xray-pneumonia-classification
405
+ # - https://huggingface.co/datasets/keremberke/chest-xray-classification
406
+ #
407
+ #
408
+
409
+
410
+
411
+ #from datasets import load_dataset
412
+ #patient_x_ray_path = "keremberke/chest-xray-classification"
413
+ #x_ray_ds = load_dataset(patient_x_ray_path, name="full")
414
+ from datasets import load_dataset
415
+ x_ray_ds = load_dataset("keremberke/chest-xray-classification", name="full")
416
+
417
+
418
+
419
+ random_index = random.randint(0, x_ray_ds['train'].shape[0] - 1)
420
+ patient_x_ray = random_row = x_ray_ds['train'][random_index]['image']
421
+
422
+ from datasets import load_dataset
423
+ x_ray_ds = load_dataset("keremberke/chest-xray-classification", name="full")
424
+
425
+
426
+
427
+
428
+ x_ray_ds['train'].shape[0]
429
+
430
+
431
+
432
+
433
+
434
+ # Assuming x_ray_ds['train'] is a dataset where we want to pick a random row
435
+ import random
436
+ random_index = random.randint(0, x_ray_ds['train'].shape[0] - 1)
437
+
438
+
439
+
440
+
441
+ patient_x_ray = x_ray_ds['train'][random_index]['image']
442
+ patient_x_ray
443
+
444
+
445
+
446
+
447
+ type(patient_x_ray)
448
+
449
+
450
+
451
+
452
+ #!pip install --upgrade accelerate==0.31.0
453
+ #!pip install --upgrade huggingface-hub>=0.23.0
454
+
455
+
456
+
457
+
458
+
459
+ from transformers import pipeline
460
+
461
+ # Model in Hugging Face: https://huggingface.co/lxyuan/vit-xray-pneumonia-classification
462
+ # vit-xray-pneumonia-classification
463
+ classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification")
464
+ patient_x_ray_results = classifier(patient_x_ray)
465
+ patient_x_ray_results
466
+
467
+
468
+
469
+
470
+ # Find the label with the highest score
471
+ patient_x_ray_label = max(patient_x_ray_results, key=lambda x: x['score'])['label']
472
+ print(patient_x_ray_label)
473
+
474
+
475
+
476
+ # Model in Hugging Face: https://huggingface.co/lxyuan/vit-xray-pneumonia-classification
477
+ # vit-xray-pneumonia-classification
478
+ classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification")
479
+ patient_x_ray_results = classifier(patient_x_ray)
480
+
481
+ # Find the label with the highest score and its score
482
+ highest = max(patient_x_ray_results, key=lambda x: x['score'])
483
+ highest_score_label = highest['label']
484
+ highest_score = highest['score'] * 100 # Convert to percentage
485
+
486
+ # Choose the correct verb based on the label
487
+ verb = "is" if highest_score_label == "NORMAL" else "has"
488
+
489
+ # Print the result dynamically
490
+ print(f"Patient {verb} {highest_score_label} with Probability of ca. {highest_score:.0f}%")
491
+
492
+
493
+ # ## 1.2 Generate Synthetic Data with LLMs
494
+ # Generate culturally appropriate Dutch names and unique alphanumeric IDs for each patient.
495
+
496
+ # ### 1.2.1 Generate Random Names and IDs for Patience
497
+
498
+ # This Code Goes Slower because of Llama 3.3 70b being very big and slow LLM
499
+ # comparing to llama 3.2 11b
500
+ # Switch to model_3_2_smal when running this code
501
+
502
+
503
+
504
+ # === Step 1: Define Response Schemas ===
505
+ # Define the structure of the expected JSON output.
506
+
507
+ # ResponseSchema for First_Name
508
+ first_name_schema = ResponseSchema(
509
+ name="First_Name",
510
+ description="The first name of the patient."
511
+ )
512
+
513
+ # ResponseSchema for Last_Name
514
+ last_name_schema = ResponseSchema(
515
+ name="Last_Name",
516
+ description="The last name of the patient."
517
+ )
518
+
519
+ # ResponseSchema for Patient_ID
520
+ patient_id_schema = ResponseSchema(
521
+ name="Patient_ID",
522
+ description="A unique 13-character alphanumeric patient identifier."
523
+ )
524
+
525
+ # ResponseSchema for Patient_ID
526
+ gender_schema = ResponseSchema(
527
+ name="G_Gender",
528
+ description="Indicate the first name you generate belong which Gender: Male or Female"
529
+ )
530
+
531
+ # Aggregate all response schemas
532
+ response_schemas = [
533
+ first_name_schema,
534
+ last_name_schema,
535
+ patient_id_schema,
536
+ gender_schema
537
+ ]
538
+
539
+ # === Step 2: Set Up the Output Parser ===
540
+ # Initialize the StructuredOutputParser with the defined response schemas.
541
+
542
+ output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
543
+
544
+ # Get the format instructions to include in the prompt
545
+ format_instructions = output_parser.get_format_instructions()
546
+
547
+ # === Step 3: Craft the Prompt ===
548
+ # Create a prompt that instructs the LLM to generate only the structured JSON data.
549
+
550
+ # Define the prompt template using ChatPromptTemplate
551
+ prompt_template = ChatPromptTemplate.from_template("""
552
+ you MUST Generate a list of {n} Dutch names along with a unique 13-character alphanumeric Patient_ID for each gender provided.
553
+ Always Use {genders} to generate a First_Name which belong to the right Gender, two category is possible: 'Male' or 'Female'.
554
+ Ensure the names are culturally appropriate for the Netherlands.
555
+ Generate unique names, no repetitions, and ensure diversity.
556
+ The ratio of Female to Male is {ratio}:1
557
+
558
+ {format_instructions}
559
+
560
+ Genders:
561
+ {genders}
562
+
563
+ **IMPORTANT:** Do not include any explanations, code, or additional text.
564
+ you MUST ALWAYS generate Dutch names and Patient_ID according {format_instructions}
565
+ and NEVER return empty values.
566
+ YOU MUST Provide only the JSON array as specified.
567
+ JSON array Should have exactly {n} rows and 3 columns
568
+ """)
569
+
570
+ # Determine the number of patients
571
+ n_patients = len(patient_df)
572
+ #n_patients = 120
573
+ # Calculate the counts of each gender
574
+ female_count = patient_df[patient_df['Gender'] == 'Female'].shape[0]
575
+ male_count = patient_df[patient_df['Gender'] == 'Male'].shape[0]
576
+
577
+ # Calculate the ratio
578
+ ratio = female_count / male_count
579
+
580
+ # Prepare the list of genders
581
+ genders = patient_df['Gender'].tolist()
582
+
583
+ # === Step 6: Generate the Prompt ===
584
+ # Format the prompt with the number of patients and their genders.
585
+
586
+ formatted_prompt = prompt_template.format(
587
+ n=n_patients,
588
+ ratio = ratio,
589
+ genders=', '.join(genders),
590
+ format_instructions=format_instructions
591
+ )
592
+
593
+ # Invoke the model with s Smaller Llama Model for Speed
594
+ model_3_2_small = 'llama-3.1-8b-instant' # if you need speed
595
+
596
+ llm = ChatGroq(
597
+ model= model_3_2_small, #
598
+ temperature=0,
599
+ max_tokens=None,
600
+ timeout=None,
601
+ max_retries=2
602
+ )
603
+
604
+ output = llm.invoke(formatted_prompt, timeout=1000)
605
+
606
+
607
+
608
+
609
+ display(Markdown(output.content))
610
+
611
+
612
+
613
+
614
+ output_parser = JsonOutputParser()
615
+ json_output = output_parser.invoke(output)
616
+ json_output
617
+
618
+
619
+
620
+
621
+
622
+ all_patients = []
623
+ generated_patients = pd.DataFrame(json_output)
624
+ generated_patients.head(5)
625
+
626
+
627
+
628
+
629
+
630
+ generated_patients.shape
631
+
632
+
633
+
634
+
635
+ # Adjusted LLM parameters (if supported)
636
+ llm.temperature = 0.9 # Increases randomness
637
+
638
+ all_patients_name_id = pd.DataFrame()
639
+ output_parser = JsonOutputParser()
640
+
641
+ while all_patients_name_id.shape[0] < n_patients:
642
+ output = llm.invoke(formatted_prompt)
643
+ json_output = output_parser.invoke(output)
644
+ generated_patients = pd.DataFrame(json_output)
645
+ all_patients_name_id = pd.concat([generated_patients, all_patients_name_id], axis = 0)
646
+ print(f"len all_patients_name_id: {len(all_patients_name_id)}")
647
+ all_patients_name_id = all_patients_name_id.drop_duplicates()
648
+ print(f"len all_patients_name_id after droping duplicates: {len(all_patients_name_id)}")
649
+
650
+
651
+
652
+
653
+
654
+ all_patients_name_id.rename(columns = {"G_Gender": "Gender"}, inplace= True)
655
+ all_patients_name_id.head(10)
656
+
657
+
658
+
659
+
660
+
661
+ gender_counts = patient_df['Gender'].value_counts()
662
+ gender_counts
663
+
664
+
665
+
666
+
667
+ all_patients_name_id['Gender'].value_counts()
668
+
669
+
670
+
671
+
672
+
673
+ # Step 1: Count the number of males and females in patient_df
674
+ gender_counts = patient_df['Gender'].value_counts()
675
+
676
+ # Step 2: Select the required number of unique males and females from all_patients_name_id
677
+ unique_males = all_patients_name_id[all_patients_name_id['Gender'] == 'Male'].drop_duplicates().head(gender_counts['Male'])
678
+ unique_females = all_patients_name_id[all_patients_name_id['Gender'] == 'Female'].drop_duplicates().head(gender_counts['Female'])
679
+
680
+
681
+ patient_male = patient_df[patient_df['Gender'] == 'Male'].reset_index(drop=True)
682
+ patient_female = patient_df[patient_df['Gender'] == 'Female'].reset_index(drop=True)
683
+
684
+
685
+ updated_male_patients = pd.concat([patient_male.reset_index(drop=True),
686
+ unique_males[0:patient_male.shape[0]].reset_index(drop=True)],
687
+ axis = 1)
688
+
689
+ updated_female_patients = pd.concat([patient_female.reset_index(drop=True),
690
+ unique_females[0:patient_female.shape[0]].reset_index(drop=True)],
691
+ axis = 1)
692
+
693
+ # Step 3: Concatenate patient_df with the selected rows from all_patients_name_id
694
+ updated_patient_df = pd.concat([updated_male_patients, updated_female_patients], axis = 0)
695
+
696
+
697
+
698
+
699
+ updated_patient_df.shape[0]
700
+
701
+
702
+
703
+
704
+
705
+ # Display the final concatenated dataframe
706
+ updated_patient_df
707
+
708
+
709
+
710
+
711
+
712
+ updated_patient_df = updated_patient_df.loc[:, ~updated_patient_df.columns.duplicated()]
713
+ updated_patient_df
714
+
715
+
716
+
717
+
718
+ updated_patient_df['Gender'].value_counts()
719
+
720
+
721
+ # #### 1.2.1.1 Select a Random Patient
722
+
723
+
724
+
725
+
726
+ # Pick a Random Patient: A female between 20 and 29 and with Pneumonia as Positive so that later we can check X-Ray Agent
727
+ mask = (updated_patient_df['Gender'] == 'Female') & \
728
+ (updated_patient_df["Age"].between(20, 29)) & \
729
+ (updated_patient_df['Difficulty Breathing'] == 'Yes') & \
730
+ (updated_patient_df['Outcome Variable'] == 'Positive')
731
+ selected_patients = updated_patient_df[mask].reset_index(drop=True)
732
+ selected_patients.head()
733
+
734
+
735
+
736
+
737
+
738
+ selected_patient = selected_patients.iloc[0]
739
+ selected_patient
740
+
741
+
742
+ # # Step 2: Create IDentity Photo for the Front Desk Agent
743
+
744
+ # ## 2.1 Build the Vision Model for Gender Classification (Image Classification Task)
745
+
746
+ # In[46]:
747
+
748
+
749
+ # Use a pipeline as a high-level helper
750
+ from transformers import pipeline
751
+
752
+ pipe = pipeline("image-classification", model="rizvandwiki/gender-classification")
753
+
754
+
755
+ # In[47]:
756
+
757
+
758
+ # Load model directly
759
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
760
+
761
+ processor = AutoImageProcessor.from_pretrained("rizvandwiki/gender-classification")
762
+ model = AutoModelForImageClassification.from_pretrained("rizvandwiki/gender-classification")
763
+
764
+
765
+ # In machine learning, particularly in classification tasks, logits are the raw, unnormalized outputs produced by a model's final layer before any activation function is applied. These outputs represent the model's confidence scores for each class and are essential for subsequent probability calculations.
766
+
767
+ # In[48]:
768
+
769
+
770
+ from transformers import AutoModelForImageClassification, AutoProcessor
771
+ from PIL import Image
772
+ import requests
773
+
774
+ # Load the model and processor
775
+ model_name = "rizvandwiki/gender-classification"
776
+ model = AutoModelForImageClassification.from_pretrained(model_name)
777
+ processor = AutoProcessor.from_pretrained(model_name)
778
+
779
+ # Load the image from URL or local path
780
+ image_url = "https://thispersondoesnotexist.com"
781
+ image = Image.open(requests.get(image_url, stream=True).raw)
782
+
783
+ # Prepare the image for the model
784
+ inputs = processor(images=image, return_tensors="pt")
785
+
786
+ # Perform inference
787
+ outputs = model(**inputs)
788
+ logits = outputs.logits
789
+ predicted_class = logits.argmax(-1).item()
790
+
791
+ # Map prediction to class label
792
+ classes = model.config.id2label
793
+ gender_label = classes[predicted_class]
794
+
795
+ print(f"Predicted Gender: {gender_label}")
796
+
797
+
798
+
799
+
800
+
801
+ import matplotlib.pyplot as plt
802
+
803
+ # Display the image and prediction
804
+ plt.imshow(image)
805
+ plt.axis('off') # Hide axes
806
+ plt.title(f"Predicted Gender: {gender_label}")
807
+ plt.show()
808
+
809
+
810
+ # ## 2.2 Build the Vision Model for Age Classification (Image Classification Task)
811
+
812
+
813
+
814
+ # Load age classification model
815
+ age_model_name = "nateraw/vit-age-classifier"
816
+ age_model = AutoModelForImageClassification.from_pretrained(age_model_name)
817
+ age_processor = AutoProcessor.from_pretrained(age_model_name)
818
+
819
+
820
+
821
+
822
+
823
+ # Age Prediction
824
+ age_inputs = age_processor(images=image, return_tensors="pt")
825
+ age_outputs = age_model(**age_inputs)
826
+ age_logits = age_outputs.logits
827
+ age_prediction = age_logits.argmax(-1).item()
828
+ age_label = age_model.config.id2label[age_prediction]
829
+ age_label
830
+
831
+
832
+
833
+
834
+ # Display the image with both predictions
835
+ plt.imshow(image)
836
+ plt.axis('off')
837
+ plt.title(f"Predicted Gender: {gender_label}, Predicted Age: {age_label}")
838
+ plt.show()
839
+
840
+
841
+ # # Step 3: Start Building Multi-Agents
842
+ #
843
+ # Define Each AI Agent
844
+ # We'll define agents for:
845
+ #
846
+ # * Administration Front Desk
847
+ # * Physician for General Health Examination + Blood Laboratory
848
+ # * X-Ray Image Department
849
+
850
+ # ## 3.1 Hospital Front Desk Agent
851
+ #
852
+ #
853
+
854
+ # **--IMPORTANT NOTE--** <br>
855
+ # 1. Don't forget to save one photo from https://thispersondoesnotexist.com/
856
+ # <br> as female.jpg and save it to this Path "/content/sample_data/'
857
+ # <br> which is standard path within your Google Colab
858
+ #
859
+ # ---
860
+ # 2. Don't Forget to Save one of the images from the x-ray-dataset <br>**Load Dataset in this way:** <br>
861
+ # patient_x_ray_path = "keremberke/chest-xray-classification" <br>
862
+ # x_ray_ds = load_dataset(patient_x_ray_path, name="full")
863
+ # <br> Then save one image labelled as x-ray-chest.jpg to the path "/content/sample_data/'
864
+
865
+
866
+
867
+
868
+ patient_x_ray_path = "keremberke/chest-xray-classification"
869
+ x_ray_ds = load_dataset(patient_x_ray_path, name="full")
870
+
871
+
872
+
873
+
874
+ from typing import List, Tuple, Dict, Any, Sequence, Annotated, Literal
875
+ from typing_extensions import TypedDict
876
+ from langchain_core.messages import BaseMessage
877
+ import operator
878
+ import functools
879
+ from langchain_core.messages import HumanMessage
880
+ from langgraph.checkpoint.memory import MemorySaver
881
+ from langgraph.graph import END, START, StateGraph, MessagesState
882
+ from langgraph.prebuilt import ToolNode, create_react_agent
883
+ from langchain_core.tools import tool
884
+ from transformers import AutoModelForImageClassification, AutoProcessor
885
+ from PIL import Image
886
+ from pydantic import BaseModel
887
+
888
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
889
+ from langchain_core.prompts import ChatPromptTemplate
890
+
891
+ # Annotated in python allows developers to declare the type of a reference and provide additional information related to it.
892
+ # Literal, after that the value are exact and literal
893
+
894
+
895
+
896
+
897
+
898
+ #----------------- Build Fucntions that Agents use ----------------------
899
+
900
+ def patient_verification_tool(image_Path, selected_patient_data, updated_patient_df) -> str:
901
+ """Detects the gender from an image provided as a file path."""
902
+ from PIL import Image
903
+ print(image_Path)
904
+ model = AutoModelForImageClassification.from_pretrained("rizvandwiki/gender-classification")
905
+ processor = AutoProcessor.from_pretrained("rizvandwiki/gender-classification")
906
+ image = Image.open(image_Path)
907
+ inputs = processor(images=image, return_tensors="pt")
908
+ outputs = model(**inputs)
909
+ predicted_class = outputs.logits.argmax(-1).item()
910
+ print(f"Predicted Gender Of Patient is : {model.config.id2label[predicted_class]}")
911
+ predicted_gender = model.config.id2label[predicted_class]
912
+
913
+ from PIL import Image
914
+ model = AutoModelForImageClassification.from_pretrained("nateraw/vit-age-classifier")
915
+ processor = AutoProcessor.from_pretrained("nateraw/vit-age-classifier")
916
+ image = Image.open(image_Path)
917
+ inputs = processor(images=image, return_tensors="pt")
918
+ outputs = model(**inputs)
919
+ predicted_class = outputs.logits.argmax(-1).item()
920
+ print(f"predicted Age Class: {model.config.id2label[predicted_class]}")
921
+ predicted_age_range = model.config.id2label[predicted_class]
922
+
923
+ # Parse the age range string (e.g., "20-29")
924
+ age_min, age_max = map(int, predicted_age_range.split('-'))
925
+ print(f"age_mi: {age_min}, age_max: {age_max}")
926
+
927
+ # Verify against the DataFrame
928
+ matching_row = updated_patient_df[
929
+ (updated_patient_df["First_Name"] == selected_patient["First_Name"]) &
930
+ (updated_patient_df["Last_Name"] == selected_patient["Last_Name"]) &
931
+ (updated_patient_df["Patient_ID"] == selected_patient["Patient_ID"]) &
932
+ (updated_patient_df["Gender"].str.lower() == predicted_gender) &
933
+ (updated_patient_df["Age"].between(age_min, age_max))
934
+ ]
935
+ print(f"matching_row {matching_row} ")
936
+ if not matching_row.empty:
937
+ patient_verification = f'''Verification successful.
938
+ Patient is : {selected_patient["First_Name"]} {selected_patient["Last_Name"]}
939
+ with ID {selected_patient["Patient_ID"]}
940
+ which is {predicted_gender} in age range of {predicted_age_range} can proceed to the physician.'''
941
+ else:
942
+ patient_verification = "ID not verified. Patient cannot proceed."
943
+ return patient_verification
944
+
945
+ #------------------- Define Agents-----------------------------
946
+
947
+ class AgentState(TypedDict):
948
+ initial_prompt : str
949
+ messages: Annotated[List[BaseMessage], operator.add]
950
+ patient_verification : str
951
+
952
+ def front_desk_agent(state, image_Path, selected_patient_data, updated_patient_df):
953
+ initial_prompt = state["initial_prompt"]
954
+ # Call function
955
+ patient_verification = patient_verification_tool(image_Path, selected_patient_data, updated_patient_df)
956
+ print(patient_verification)
957
+ return {"patient_verification": patient_verification}
958
+
959
+ #-----------------------------------------------------------------
960
+ # Build the LangGraph for Hospital Front Desk #
961
+ #-----------------------------------------------------------------
962
+
963
+ image_Path = "female.jpg"
964
+ selected_patient_data = selected_patient.to_dict()
965
+ updated_patient_df
966
+
967
+
968
+ front_desk_agent_node = functools.partial(front_desk_agent,
969
+ image_Path = image_Path,
970
+ selected_patient_data=selected_patient_data,
971
+ updated_patient_df =updated_patient_df)
972
+
973
+ # 6. Set up the Langgraph state graph
974
+ FrontDeskGraph = StateGraph(AgentState)
975
+
976
+ # Define nodes for workflow
977
+ FrontDeskGraph.add_node("front_desk_agent", front_desk_agent_node)
978
+ FrontDeskGraph.add_edge(START, "front_desk_agent")
979
+ FrontDeskGraph.add_edge("front_desk_agent", END)
980
+
981
+
982
+ # Initialize memory to persist state between graph runs
983
+ FrontDeskWorkflow = FrontDeskGraph.compile()
984
+
985
+ from IPython.display import Markdown, display, Image
986
+ display(Image(FrontDeskWorkflow.get_graph(xray=True).draw_mermaid_png()))
987
+
988
+
989
+
990
+
991
+
992
+ initial_prompt = "You are Front Desk Administrator in an Hospital in the Netherlands. Start Verification of the following Patient:"
993
+
994
+
995
+ # Run the workflow
996
+ inputs = {"initial_prompt" : initial_prompt
997
+ }
998
+ output = FrontDeskWorkflow.invoke(inputs)
999
+ output
1000
+
1001
+
1002
+
1003
+
1004
+
1005
+ display(Markdown(output['patient_verification']))
1006
+
1007
+
1008
+ # ## 3.2 Pysician Agent
1009
+
1010
+
1011
+
1012
+
1013
+ def question_patient_symptoms(selected_patient_data) -> str:
1014
+ """Asks the patient about symptoms, generates responses, and summarizes the answers based on patient data."""
1015
+ symptoms_questions = {
1016
+ "Cough": "\nAre you coughing?\n",
1017
+ "Fatigue": "\nDo you feel fatigue?\n",
1018
+ "\nDifficulty Breathing": "Do you have difficulty breathing?\n"
1019
+ }
1020
+
1021
+ conversation = []
1022
+
1023
+ for symptom, question in symptoms_questions.items():
1024
+ conversation.append(f"\nPhysician: {question}")
1025
+ response = selected_patient_data.get(symptom, "No")
1026
+ answer = "Yes" if response == "Yes" else "No"
1027
+ conversation.append(f"\nPatient: {answer}")
1028
+
1029
+ first_name = selected_patient_data.get("First_Name", "")
1030
+ last_name = selected_patient_data.get("Last_Name", "")
1031
+ patient_id = selected_patient_data.get("Patient_ID", "")
1032
+ gender = selected_patient_data.get("Gender", "")
1033
+ age = selected_patient_data.get("Age", "")
1034
+
1035
+ profile = f"\nYou are {first_name} {last_name}, a {age} years old {gender} with Patient ID: {patient_id}."
1036
+ summary = profile +"I gathered that you are experiencing the following: "
1037
+ summaries = []
1038
+ for symptom in symptoms_questions.keys():
1039
+ response = selected_patient_data.get(symptom, "No")
1040
+ if response == "Yes":
1041
+ summaries.append(f"you are experiencing {symptom.lower()}")
1042
+ else:
1043
+ summaries.append(f"\nI am glad you are not experiencing {symptom.lower()}")
1044
+ summary += "; ".join(summaries) + "."
1045
+
1046
+ conversation.append(f"\nPhysician: {summary}")
1047
+
1048
+ return "\n".join(conversation)
1049
+
1050
+ def perform_examination(selected_patient_data) -> str:
1051
+ """Performs examination by reporting fever, blood pressure, and cholesterol level from patient data."""
1052
+ fever = selected_patient_data.get("Fever", "Unknown")
1053
+ blood_pressure = selected_patient_data.get("Blood Pressure", "Unknown")
1054
+ cholesterol = selected_patient_data.get("Cholesterol Level", "Unknown")
1055
+ return f"Examination Results: Fever - {fever}, Blood Pressure - {blood_pressure}, Cholesterol Level - {cholesterol}"
1056
+
1057
+ def diagnose_patient(selected_patient_data) -> str:
1058
+ """Provides diagnosis based on Disease and Outcome columns in patient data."""
1059
+ disease = selected_patient_data.get("Disease", "Unknown Disease")
1060
+ outcome = selected_patient_data.get("Outcome Variable", "Unknown Outcome")
1061
+ if outcome == 'Positive':
1062
+ diagnosis = 'Make X-Ray from Chest'
1063
+ else:
1064
+ diagnosis = 'Rest to Recover'
1065
+ return f"Diagnosis: {disease}. Test Result: {outcome}. Final Diagnosis: {diagnosis}", diagnosis
1066
+
1067
+
1068
+ class AgentState(TypedDict):
1069
+ initial_prompt : str
1070
+ messages: Annotated[List[BaseMessage], operator.add]
1071
+ question_patient_symptoms: str
1072
+ examination_patient: str
1073
+ diagnosis_patient: str
1074
+ diagnosis : str
1075
+
1076
+
1077
+ def physician_agent(state, selected_patient_data):
1078
+ question_patient= question_patient_symptoms(selected_patient_data)
1079
+ examination = perform_examination(selected_patient_data)
1080
+ diagnosis_report, diagnosis = diagnose_patient(selected_patient_data)
1081
+ return {"question_patient_symptoms": question_patient,
1082
+ "examination_patient": examination,
1083
+ "diagnosis_patient": diagnosis_report,
1084
+ "diagnosis": diagnosis}
1085
+
1086
+
1087
+ selected_patient_data = selected_patient.to_dict()
1088
+
1089
+ physician_agent_node = functools.partial(physician_agent,
1090
+ selected_patient_data=selected_patient_data)
1091
+
1092
+
1093
+ # 6. Set up the Langgraph state graph
1094
+ PhysicianGraph = StateGraph(AgentState)
1095
+
1096
+ # Define nodes for workflow
1097
+ PhysicianGraph.add_node("physician_agent", physician_agent_node)
1098
+ PhysicianGraph.add_edge(START, "physician_agent")
1099
+ PhysicianGraph.add_edge("physician_agent", END)
1100
+
1101
+
1102
+ # Initialize memory to persist state between graph runs
1103
+ PhysicianWorkflow = PhysicianGraph.compile()
1104
+
1105
+ display(Image(PhysicianWorkflow.get_graph(xray=True).draw_mermaid_png()))
1106
+
1107
+
1108
+
1109
+
1110
+
1111
+ initial_prompt = "You are a Very Experience Doctor in an Hospital in the Netherlands. Start a conversation with the patient and determine \
1112
+ symptoms and give diagnosis"
1113
+
1114
+
1115
+ # Run the workflow
1116
+ inputs = {"initial_prompt" : initial_prompt
1117
+ }
1118
+ output = PhysicianWorkflow.invoke(inputs)
1119
+ output
1120
+
1121
+
1122
+
1123
+
1124
+
1125
+ display(Markdown(output['question_patient_symptoms']))
1126
+ display(Markdown(output['examination_patient']))
1127
+ display(Markdown(output['diagnosis_patient']))
1128
+
1129
+
1130
+ # ## 3.3 Radiologist
1131
+
1132
+
1133
+
1134
+
1135
+ def examine_X_ray_image(patient_x_ray_path) -> str:
1136
+ """Use Vision Models to recognise if the X-Ray Image of Patient is NORMAL or PNEUMONIA"""
1137
+ # Model in Hugging Face: https://huggingface.co/lxyuan/vit-xray-pneumonia-classification
1138
+ # vit-xray-pneumonia-classification
1139
+ x_ray_ds = load_dataset(patient_x_ray_path, name="full")
1140
+ random_index = random.randint(0, x_ray_ds['train'].shape[0] - 1)
1141
+ patient_x_ray_image = x_ray_ds['train'][random_index]['image']
1142
+ classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification")
1143
+ patient_x_ray_results = classifier(patient_x_ray_image)
1144
+
1145
+ # Find the label with the highest score and its score
1146
+ highest = max(patient_x_ray_results, key=lambda x: x['score'])
1147
+ highest_score_label = highest['label']
1148
+ highest_score = highest['score'] * 100 # Convert to percentage
1149
+
1150
+ # Choose the correct verb based on the label
1151
+ verb = "is" if highest_score_label == "NORMAL" else "has"
1152
+
1153
+ return f"Patient {verb} {highest_score_label} with Probability of ca. {highest_score:.0f}%"
1154
+
1155
+ class AgentState(TypedDict):
1156
+ initial_prompt : str
1157
+ messages: Annotated[List[BaseMessage], operator.add]
1158
+ pneumonia_detection: str
1159
+
1160
+
1161
+
1162
+ def radiologist_agent(state, patient_x_ray_path):
1163
+ pneumonia_detection = examine_X_ray_image(patient_x_ray_path)
1164
+ return {"pneumonia_detection": pneumonia_detection}
1165
+
1166
+ patient_x_ray_path = "keremberke/chest-xray-classification"
1167
+
1168
+ radiologist_agent_node = functools.partial(radiologist_agent,
1169
+ patient_x_ray_path=patient_x_ray_path)
1170
+
1171
+ # 6. Set up the Langgraph state graph
1172
+ RadiologistGraph = StateGraph(AgentState)
1173
+
1174
+ # Define nodes for workflow
1175
+ RadiologistGraph.add_node("radiologist_agent", radiologist_agent_node)
1176
+ RadiologistGraph.add_edge(START, "radiologist_agent")
1177
+ RadiologistGraph.add_edge("radiologist_agent", END)
1178
+
1179
+ # Initialize memory to persist state between graph runs
1180
+ RadiologistWorkflow = RadiologistGraph.compile()
1181
+
1182
+ display(Image(RadiologistWorkflow.get_graph(xray=True).draw_mermaid_png()))
1183
+
1184
+
1185
+
1186
+ initial_prompt = "You are a Very Experienced Radiologist in an Hospital in the Netherlands. Diagnose if the patient has pneumonia"
1187
+
1188
+
1189
+ # Run the workflow
1190
+ inputs = {"initial_prompt" : initial_prompt
1191
+ }
1192
+ output = RadiologistWorkflow.invoke(inputs)
1193
+ output
1194
+
1195
+
1196
+
1197
+ display(Markdown(output['pneumonia_detection']))
1198
+
1199
+
1200
+ # # Step 4: Putting All Agents in One Graph
1201
+
1202
+
1203
+
1204
+
1205
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
1206
+ from langchain_core.prompts import ChatPromptTemplate
1207
+
1208
+ selected_patient_data = selected_patient.to_dict()
1209
+ image_Path = "female.jpg"
1210
+ patient_x_ray_image = patient_x_ray
1211
+
1212
+ def patient_verification_tool(image_Path, selected_patient_data, updated_patient_df) -> str:
1213
+ """Detects the gender from an image provided as a file path."""
1214
+ from PIL import Image
1215
+ print(image_Path)
1216
+ model = AutoModelForImageClassification.from_pretrained("rizvandwiki/gender-classification")
1217
+ processor = AutoProcessor.from_pretrained("rizvandwiki/gender-classification")
1218
+ image = Image.open(image_Path)
1219
+ inputs = processor(images=image, return_tensors="pt")
1220
+ outputs = model(**inputs)
1221
+ predicted_class = outputs.logits.argmax(-1).item()
1222
+ print(f"Predicted Gender Of Patient is : {model.config.id2label[predicted_class]}")
1223
+ predicted_gender = model.config.id2label[predicted_class]
1224
+
1225
+ from PIL import Image
1226
+ model = AutoModelForImageClassification.from_pretrained("nateraw/vit-age-classifier")
1227
+ processor = AutoProcessor.from_pretrained("nateraw/vit-age-classifier")
1228
+ image = Image.open(image_Path)
1229
+ inputs = processor(images=image, return_tensors="pt")
1230
+ outputs = model(**inputs)
1231
+ predicted_class = outputs.logits.argmax(-1).item()
1232
+ print(f"predicted Age Class: {model.config.id2label[predicted_class]}")
1233
+ predicted_age_range = model.config.id2label[predicted_class]
1234
+
1235
+ # Parse the age range string (e.g., "20-29")
1236
+ age_min, age_max = map(int, predicted_age_range.split('-'))
1237
+ print(f"age_mi: {age_min}, age_max: {age_max}")
1238
+
1239
+ # Verify against the DataFrame
1240
+ matching_row = updated_patient_df[
1241
+ (updated_patient_df["First_Name"] == selected_patient["First_Name"]) &
1242
+ (updated_patient_df["Last_Name"] == selected_patient["Last_Name"]) &
1243
+ (updated_patient_df["Patient_ID"] == selected_patient["Patient_ID"]) &
1244
+ (updated_patient_df["Gender"].str.lower() == predicted_gender) &
1245
+ (updated_patient_df["Age"].between(age_min, age_max))
1246
+ ]
1247
+ print(f"matching_row {matching_row} ")
1248
+ if not matching_row.empty:
1249
+ patient_verification = f'''Verification successful.
1250
+ Patient is : {selected_patient["First_Name"]} {selected_patient["Last_Name"]}
1251
+ with ID {selected_patient["Patient_ID"]}
1252
+ which is {predicted_gender} in age range of {predicted_age_range} can proceed to the physician.'''
1253
+ else:
1254
+ patient_verification = "ID not verified. Patient cannot proceed."
1255
+ return patient_verification
1256
+
1257
+ def question_patient_symptoms(selected_patient_data) -> str:
1258
+ """Asks the patient about symptoms, generates responses, and summarizes the answers based on patient data."""
1259
+ symptoms_questions = {
1260
+ "Cough": "\nAre you coughing?\n",
1261
+ "Fatigue": "\nDo you feel fatigue?\n",
1262
+ "\nDifficulty Breathing": "Do you have difficulty breathing?\n"
1263
+ }
1264
+
1265
+ conversation = []
1266
+
1267
+ for symptom, question in symptoms_questions.items():
1268
+ conversation.append(f"\nPhysician: {question}")
1269
+ response = selected_patient_data.get(symptom, "No")
1270
+ answer = "Yes" if response == "Yes" else "No"
1271
+ conversation.append(f"\nPatient: {answer}")
1272
+
1273
+ first_name = selected_patient_data.get("First_Name", "")
1274
+ last_name = selected_patient_data.get("Last_Name", "")
1275
+ patient_id = selected_patient_data.get("Patient_ID", "")
1276
+ gender = selected_patient_data.get("Gender", "")
1277
+ age = selected_patient_data.get("Age", "")
1278
+
1279
+ profile = f"\nYou are {first_name} {last_name}, a {age} years old {gender} with Patient ID: {patient_id}."
1280
+ summary = profile +"I gathered that you are experiencing the following: "
1281
+ summaries = []
1282
+ for symptom in symptoms_questions.keys():
1283
+ response = selected_patient_data.get(symptom, "No")
1284
+ if response == "Yes":
1285
+ summaries.append(f"you are experiencing {symptom.lower()}")
1286
+ else:
1287
+ summaries.append(f"\nI am glad you are not experiencing {symptom.lower()}")
1288
+ summary += "; ".join(summaries) + "."
1289
+
1290
+ conversation.append(f"\nPhysician: {summary}")
1291
+
1292
+ return "\n".join(conversation)
1293
+
1294
+ def perform_examination(selected_patient_data) -> str:
1295
+ """Performs examination by reporting fever, blood pressure, and cholesterol level from patient data."""
1296
+ fever = selected_patient_data.get("Fever", "Unknown")
1297
+ blood_pressure = selected_patient_data.get("Blood Pressure", "Unknown")
1298
+ cholesterol = selected_patient_data.get("Cholesterol Level", "Unknown")
1299
+ return f"Examination Results: Fever - {fever}, Blood Pressure - {blood_pressure}, Cholesterol Level - {cholesterol}"
1300
+
1301
+ def diagnose_patient(selected_patient_data) -> str:
1302
+ """Provides diagnosis based on Disease and Outcome columns in patient data."""
1303
+ disease = selected_patient_data.get("Disease", "Unknown Disease")
1304
+ outcome = selected_patient_data.get("Outcome Variable", "Unknown Outcome")
1305
+ if outcome == 'Positive':
1306
+ diagnosis = 'Make X-Ray from Chest'
1307
+ else:
1308
+ diagnosis = 'Rest to Recover'
1309
+ return f"Diagnosis: {disease}. Test Result: {outcome}. Final Diagnosis: {diagnosis}", diagnosis
1310
+
1311
+ def examine_X_ray_image(patient_x_ray_path) -> str:
1312
+ """Use Vision Models to recognise if the X-Ray Image of Patient is NORMAL or PNEUMONIA"""
1313
+ # Model in Hugging Face: https://huggingface.co/lxyuan/vit-xray-pneumonia-classification
1314
+ # vit-xray-pneumonia-classification
1315
+ x_ray_ds = load_dataset(patient_x_ray_path, name="full")
1316
+ random_index = random.randint(0, x_ray_ds['train'].shape[0] - 1)
1317
+ patient_x_ray_image = x_ray_ds['train'][random_index]['image']
1318
+ classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification")
1319
+ patient_x_ray_results = classifier(patient_x_ray_image)
1320
+
1321
+ # Find the label with the highest score and its score
1322
+ highest = max(patient_x_ray_results, key=lambda x: x['score'])
1323
+ highest_score_label = highest['label']
1324
+ highest_score = highest['score'] * 100 # Convert to percentage
1325
+
1326
+ # Choose the correct verb based on the label
1327
+ verb = "is" if highest_score_label == "NORMAL" else "has"
1328
+
1329
+ return f"Patient {verb} {highest_score_label} with Probability of ca. {highest_score:.0f}%"
1330
+
1331
+ # The agent state is the input to each node in the graph
1332
+ class AgentState(TypedDict):
1333
+ # The annotation tells the graph that new messages will always
1334
+ # be added to the current states
1335
+ initial_prompt : str
1336
+ messages: Annotated[List[BaseMessage], operator.add]
1337
+ patient_verification : str
1338
+ question_patient_symptoms: str
1339
+ examination_patient: str
1340
+ diagnosis_patient: str
1341
+ diagnosis : str
1342
+ pneumonia_detection: str
1343
+
1344
+ def front_desk_agent(state, image_Path, selected_patient_data, updated_patient_df):
1345
+ initial_prompt = state["initial_prompt"]
1346
+ patient_verification = patient_verification_tool(image_Path, selected_patient_data, updated_patient_df)
1347
+ print(patient_verification)
1348
+ return {"patient_verification": patient_verification}
1349
+
1350
+ def physician_agent(state, selected_patient_data):
1351
+ question_patient= question_patient_symptoms(selected_patient_data)
1352
+ examination = perform_examination(selected_patient_data)
1353
+ diagnosis_report, diagnosis = diagnose_patient(selected_patient_data)
1354
+ pneumonia_detection = examine_X_ray_image(patient_x_ray_path)
1355
+ return {"question_patient_symptoms": question_patient,
1356
+ "examination_patient": examination,
1357
+ "diagnosis_patient": diagnosis_report,
1358
+ "diagnosis": diagnosis}
1359
+
1360
+ def radiologist_agent(state, patient_x_ray_path):
1361
+ pneumonia_detection = examine_X_ray_image(patient_x_ray_path)
1362
+ return {"pneumonia_detection": pneumonia_detection}
1363
+
1364
+ def decide_on_radiologist(state):
1365
+ if state["diagnosis"] == 'Make X-Ray from Chest':
1366
+ return 'radiologist'
1367
+ else:
1368
+ return ''
1369
+
1370
+
1371
+ image_Path = "female.jpg"
1372
+ selected_patient_data = selected_patient.to_dict()
1373
+ updated_patient_df
1374
+ patient_x_ray_path = "keremberke/chest-xray-classification"
1375
+
1376
+ front_desk_agent_node = functools.partial(front_desk_agent,
1377
+ image_Path = image_Path,
1378
+ selected_patient_data=selected_patient_data,
1379
+ updated_patient_df =updated_patient_df)
1380
+ physician_agent_node = functools.partial(physician_agent,
1381
+ selected_patient_data=selected_patient_data)
1382
+
1383
+ radiologist_agent_node = functools.partial(radiologist_agent,
1384
+ patient_x_ray_path=patient_x_ray_path)
1385
+
1386
+ def decide_on_radiologist(state):
1387
+ if state["diagnosis"] == 'Make X-Ray from Chest':
1388
+ return 'radiologist'
1389
+ else:
1390
+ return 'end'
1391
+
1392
+ # 6. Set up the Langgraph state graph
1393
+ HospitalGraph = StateGraph(AgentState)
1394
+
1395
+ # Define nodes for workflow
1396
+ HospitalGraph.add_node("front_desk_agent", front_desk_agent_node)
1397
+ HospitalGraph.add_node("physician_agent", physician_agent_node)
1398
+ HospitalGraph.add_node("radiologist_agent", radiologist_agent_node)
1399
+
1400
+ HospitalGraph.add_edge(START, "front_desk_agent")
1401
+ HospitalGraph.add_edge("front_desk_agent", "physician_agent")
1402
+ HospitalGraph.add_conditional_edges("physician_agent",
1403
+ decide_on_radiologist,
1404
+ {'radiologist': "radiologist_agent",
1405
+ 'end': END})
1406
+
1407
+
1408
+ # Initialize memory to persist state between graph runs
1409
+ HospitalWorkflow = HospitalGraph.compile()
1410
+
1411
+ display(Image(HospitalWorkflow.get_graph(xray=True).draw_mermaid_png()))
1412
+
1413
+
1414
+
1415
+
1416
+
1417
+ initial_prompt = "Start with the following Patient"
1418
+
1419
+
1420
+ # Run the workflow
1421
+ inputs = {"initial_prompt" : initial_prompt
1422
+ }
1423
+ output = HospitalWorkflow.invoke(inputs)
1424
+ output
1425
+
1426
+
1427
+
1428
+
1429
+ display(Markdown(output['patient_verification']))
1430
+
1431
+
1432
+
1433
+
1434
+
1435
+ display(Markdown(output['question_patient_symptoms']))
1436
+ display(Markdown(output['examination_patient']))
1437
+ display(Markdown(output['diagnosis_patient']))
1438
+
1439
+
1440
+
1441
+
1442
+ display(Markdown(output['pneumonia_detection']))
1443
+
1444
+
1445
+ # # Step 5: Gradio Dashboard
1446
+
1447
+ # ## 5.1 Build the Hospital Dashboard APP
1448
+
1449
+ # In[69]:
1450
+
1451
+
1452
+ x_ray_image_path = 'x-ray-chest.png'
1453
+
1454
+ import gradio as gr
1455
+ info = (
1456
+ f"**First Name:** {selected_patient_data['First_Name']}\n\n"
1457
+ f"**Last Name:** {selected_patient_data['Last_Name']}\n\n"
1458
+ f"**Patient ID:** {selected_patient_data['Patient_ID']}"
1459
+ )
1460
+
1461
+ def verify_age_gender():
1462
+ """
1463
+ Function to verify age and gender.
1464
+ """
1465
+ # Placeholder logic: In a real scenario, perform necessary checks or computations
1466
+ initial_prompt = "You are Front Desk Administrator in an Hospital in the Netherlands. Start Verification of the following Patient:"
1467
+ inputs = {"initial_prompt" : initial_prompt
1468
+ }
1469
+ output = FrontDeskWorkflow.invoke(inputs)
1470
+ verification_message = '✅ ' + output['patient_verification']
1471
+ return verification_message, gr.update(visible=True)
1472
+
1473
+ def physician_examination():
1474
+ initial_prompt = "You are a Very Experience Doctor in an Hospital in the Netherlands. Start a conversation with the patient and determine \
1475
+ symptoms and give diagnosis"
1476
+ # Run the workflow
1477
+ inputs = {"initial_prompt" : initial_prompt
1478
+ }
1479
+ output = PhysicianWorkflow.invoke(inputs)
1480
+ output_all = f''' 🩺 {output['question_patient_symptoms']}\n
1481
+ 💓 {output['examination_patient']}\n
1482
+ 🌬️ {output['diagnosis_patient']}'''
1483
+ return output_all, gr.update(visible=True)
1484
+
1485
+ def pneumonia_detection():
1486
+ initial_prompt = "You are a Very Experienced Radiologist in an Hospital in the Netherlands. Diagnose if the patient has pneumonia"
1487
+ inputs = {"initial_prompt" : initial_prompt
1488
+ }
1489
+ output = RadiologistWorkflow.invoke(inputs)
1490
+ pneumonia_detection = 'From X-Ray Image 🖼️ ' + output['pneumonia_detection']
1491
+ return pneumonia_detection
1492
+
1493
+ def take_xray_image():
1494
+
1495
+ return gr.update(visible=True), gr.update(visible=True)
1496
+
1497
+ with gr.Blocks() as demo:
1498
+ with gr.Row():
1499
+ with gr.Column(scale=1):
1500
+ gr.Markdown(info)
1501
+ # Add a Button below the Markdown
1502
+ verify_button = gr.Button("Verify Age and Gender")
1503
+ # Add an output component to display verification status
1504
+ verification_output = gr.Textbox(label="Verification Status", interactive=False)
1505
+ # Add a Button below the Markdown
1506
+ physician_button = gr.Button("Get Examination at Physician", visible=False)
1507
+ physician_output = gr.Textbox(label="Examination by Physician Placeholder", interactive=False)
1508
+ x_ray_button = gr.Button("Take Chest X-Ray Image", visible=False)
1509
+ # Display X-Ray Image (Initially Hidden)
1510
+ xray_image_display = gr.Image(value=x_ray_image_path, label="X-Ray Image", visible=False)
1511
+ radiologist_button = gr.Button("Go to Radiologist", visible=False)
1512
+ # Add an output component to display verification status
1513
+ radiologist_output = gr.Textbox(label="Radiologist Placeholder", interactive=False)
1514
+
1515
+ with gr.Column(scale=1):
1516
+ gr.Image(value=image_Path, label="Static Image", show_label=True)
1517
+
1518
+ # Define the button's action: When clicked, call verify_age_gender and display the result
1519
+ verify_button.click(fn=verify_age_gender, inputs=None, outputs=[verification_output, physician_button])
1520
+ physician_button.click(fn=physician_examination, inputs=None, outputs=[physician_output, x_ray_button])
1521
+ x_ray_button.click(fn=take_xray_image, inputs=None, outputs=[xray_image_display, radiologist_button])
1522
+ radiologist_button.click(fn=pneumonia_detection, inputs=None, outputs=[radiologist_output])
1523
+
1524
+
1525
+ # ## 5.2 Run the App
1526
+
1527
+
1528
+
1529
+ # Launch the app
1530
+ demo.launch()
1531
+
1532
+
1533
+ # # Step 6: Building Advanced Retrieval (RAG)
1534
+
1535
+ # ## 6.1 Textsplitter
1536
+
1537
+
1538
+
1539
+
1540
+ # Patient records (3 example patients)
1541
+
1542
+ text_content = ["Patient 1: Mette Smit, a 25 years old Female with Patient ID: X8g6eC2R7uPvN5a1."
1543
+ "Mette is coughing and is experiencing fatigue. Mette has fever and Influenza."
1544
+ "Mette has Pneuomnia with Probability of ca. 92%."
1545
+ "Patient 2: Tim Sutherland has fever and suffer from difficuly in breathing.",
1546
+ "We made an X-Ray Image from Tim Sutherland chest.",
1547
+ "Radiologist give Tim Sutherland 93% chance of Pneuomnia",
1548
+ "Patient 3: Jane Bright has no fever and suffer from high blood pressure and high chlostole.",
1549
+ "We made an X-Ray Image from Jane Bright chest because of non-stop caughing",
1550
+ "Radiologist give only 8% chance of Pneuomnia for Jane. It seems that Jane Bright has an influenza",]
1551
+
1552
+ documents = [Document(page_content=text) for text in text_content]
1553
+
1554
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
1555
+ splits = text_splitter.split_documents(documents)
1556
+ splits
1557
+
1558
+
1559
+
1560
+
1561
+ text_chunks = []
1562
+ for page in splits:
1563
+ chunks = text_splitter.split_text(page.page_content)
1564
+ text_chunks.extend(chunks)
1565
+ text_chunks
1566
+
1567
+
1568
+ # ## 6.2 Embedding
1569
+
1570
+
1571
+ #!pip install -U sentence-transformers langchain-huggingface accelerate
1572
+ #!pip install "transformers==4.41.1"
1573
+ #!pip install "peft==0.13.2"
1574
+ #from langchain_huggingface import HuggingFaceEmbeddings
1575
+
1576
+ hf_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
1577
+ embeddings = hf_embeddings.embed_documents(text_chunks)
1578
+
1579
+
1580
+ # ## 6.3 Vector Database
1581
+
1582
+
1583
+
1584
+
1585
+ ## persist_directory = '/content/drive/MyDrive/chromadb'
1586
+
1587
+ ## vectordb = Chroma.from_documents(documents=splits,
1588
+ ## embedding=hf_embeddings,
1589
+ ## persist_directory=persist_directory)
1590
+
1591
+
1592
+ # ## 6.4 LLM | Groq + Llama 3.3
1593
+
1594
+
1595
+
1596
+
1597
+ os.environ["GROQ_API_KEY"] = GROQ_API_KEY
1598
+
1599
+ # Model 3.2 is removed from Groq platform
1600
+ # So we use the Newest one: 3.3
1601
+
1602
+ model_3_3 ='llama-3.3-70b-versatile'
1603
+
1604
+ llm = ChatGroq(
1605
+ model=model_3_3,
1606
+ temperature=0,
1607
+ max_tokens=None,
1608
+ timeout=None,
1609
+ max_retries=2,
1610
+ # other params...
1611
+ )
1612
+
1613
+
1614
+ # ## 6.5 Query Prompt
1615
+
1616
+ # In[76]:
1617
+
1618
+
1619
+ QUERY_PROMPT = PromptTemplate(
1620
+ input_variables=["question"],
1621
+ template="""You are an AI language model assistant. Your task is to generate five
1622
+ different versions of the given user question to retrieve relevant documents from
1623
+ a vector database. By generating multiple perspectives on the user question, your
1624
+ goal is to help the user overcome some of the limitations of the distance-based
1625
+ similarity search. Provide these alternative questions separated by newlines.
1626
+ Original question: {question}""",
1627
+ )
1628
+
1629
+
1630
+ # ## 6.6 Retriever
1631
+
1632
+
1633
+
1634
+ overall_retriever = MultiQueryRetriever.from_llm(
1635
+ vectordb.as_retriever(),
1636
+ llm,
1637
+ prompt=QUERY_PROMPT
1638
+ )
1639
+
1640
+ # RAG prompt
1641
+ template = """Answer the question based ONLY on the following context:
1642
+ {context}
1643
+ Question: {question}
1644
+ """
1645
+
1646
+ prompt = ChatPromptTemplate.from_template(template)
1647
+
1648
+
1649
+ # ## 6.7 Chain
1650
+
1651
+
1652
+
1653
+
1654
+ chain = (
1655
+ {"context": overall_retriever, "question": RunnablePassthrough()}
1656
+ | prompt
1657
+ | llm
1658
+ | StrOutputParser()
1659
+ )
1660
+
1661
+
1662
+ # # Step 7: Chatting with RAG
1663
+
1664
+
1665
+
1666
+
1667
+ questions = '''What are the names of all the patients in the database?'''
1668
+ display(Markdown(chain.invoke(questions)))
1669
+
1670
+
1671
+
1672
+
1673
+
1674
+ questions = '''What are all the health issues that Jane Bright has?'''
1675
+ display(Markdown(chain.invoke(questions)))
1676
+
1677
+
1678
+
1679
+
1680
+
1681
+ questions = '''What are all the health issues that Mette Smit has?'''
1682
+ display(Markdown(chain.invoke(questions)))
1683
+
1684
+
1685
+
1686
+
1687
+
1688
+ questions = '''What is the age of Tim Sutherland?'''
1689
+ display(Markdown(chain.invoke(questions)))
1690
+
1691
+
1692
+
1693
+
1694
+ questions = '''Which patient has a Patient ID?'''
1695
+ display(Markdown(chain.invoke(questions)))
1696
+
1697
+
1698
+
1699
+ hf_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
1700
+ embeddings = hf_embeddings.embed_documents(text_chunks)
1701
+
1702
+
1703
+
1704
+
1705
+ ## persist_directory = '/content/drive/MyDrive/chromadb'
1706
+
1707
+ ## vectordb = Chroma.from_documents(documents=splits,
1708
+ ## embedding=hf_embeddings,
1709
+ ## persist_directory=persist_directory)
1710
+
female.jpg ADDED

Git LFS Details

  • SHA256: d46988d99d369d2b2f04c3db12f86146661731f3b331ceb64112bf885e5649ff
  • Pointer size: 131 Bytes
  • Size of remote file: 542 kB
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ textwrap3
2
+ crewai
3
+ crewai-tools
4
+ gradio
5
+ python-dotenv
x-ray-chest.png ADDED