SRA25 commited on
Commit
1b46363
·
verified ·
1 Parent(s): 0737ac9

Update mydomain_agent.py

Browse files
Files changed (1) hide show
  1. mydomain_agent.py +646 -650
mydomain_agent.py CHANGED
@@ -1,650 +1,646 @@
1
- from langchain_core.tools import tool
2
- from langgraph.graph import StateGraph, START, END
3
- # from llm_initializer import initialize_llm, generate_prompt_phi4
4
- from langgraph.graph import MessagesState
5
- from langchain_core.messages import ToolMessage, HumanMessage, SystemMessage
6
- from typing_extensions import Literal, TypedDict
7
- from IPython.display import Image, display
8
- from pydantic import BaseModel, Field
9
- from pydantic import BaseModel, Field, validator
10
- from typing import List, Optional, Dict, Any, TypedDict,Generic, TypeVar
11
- from abc import ABC
12
- import uuid
13
- import io
14
- import os
15
- import PyPDF2
16
- import re
17
- import logging
18
- import time
19
- from docx import Document as dx
20
- from langchain_text_splitters import RecursiveCharacterTextSplitter
21
- from langchain_community.document_loaders import (
22
- DirectoryLoader,
23
- PyPDFLoader,
24
- TextLoader
25
- )
26
- import tempfile
27
- import faiss
28
- from langchain_community.docstore.in_memory import InMemoryDocstore
29
- from langchain_community.vectorstores import FAISS
30
- from langchain_core.prompts import PromptTemplate
31
- from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
32
- from langchain_huggingface import HuggingFaceEmbeddings
33
- from langgraph.checkpoint.memory import MemorySaver
34
- from langgraph.graph import StateGraph, END
35
- from sqlalchemy import create_engine, Column, String, Integer, DateTime, ForeignKey, Text
36
- from sqlalchemy.dialects.sqlite import JSON as SQLiteJSON
37
- # from sqlalchemy.ext.declarative import declarative_base
38
- from sqlalchemy.orm import sessionmaker, relationship
39
- from sentence_transformers import SentenceTransformer
40
- from huggingface_hub import login
41
- from langchain_google_genai import ChatGoogleGenerativeAI
42
- import datetime
43
- from enum import Enum as PyEnum
44
- from sqlalchemy.orm import DeclarativeBase
45
- # from config import Config
46
- from functools import lru_cache
47
- from dotenv import load_dotenv
48
-
49
- load_dotenv()
50
- hf_token = os.getenv("hf_user_token")
51
- login(hf_token)
52
-
53
- T = TypeVar("T")
54
- # --- 1. Database Setup ---
55
- DATABASE_URL = "sqlite:///Db_domain_agent.db"
56
- engine = create_engine(DATABASE_URL)
57
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
58
-
59
- class Base(DeclarativeBase):
60
- pass
61
-
62
- class FeedbackScore(PyEnum):
63
- POSITIVE = 1
64
- NEGATIVE = -1
65
-
66
- class Telemetry(Base):
67
- __tablename__ = "telemetry_table"
68
- transaction_id = Column(String, primary_key=True)
69
- session_id = Column(String)
70
- user_question = Column(Text)
71
- response = Column(Text)
72
- context = Column(Text)
73
- model_name = Column(String)
74
- input_tokens = Column(Integer)
75
- output_tokens = Column(Integer)
76
- total_tokens = Column(Integer)
77
- latency = Column(Integer)
78
- dtcreatedon = Column(DateTime)
79
-
80
- feedback = relationship("Feedback", back_populates="telemetry_entry", uselist=False)
81
-
82
- class Feedback(Base):
83
- __tablename__ = "feedback_table"
84
- id = Column(Integer, primary_key=True, autoincrement=True)
85
- telemetry_entry_id = Column(String, ForeignKey("telemetry_table.transaction_id"), nullable=False, unique=True)
86
- feedback_score = Column(Integer, nullable=False)
87
- feedback_text = Column(Text, nullable=True)
88
- user_query = Column(Text, nullable=False)
89
- llm_response = Column(Text, nullable=False)
90
- timestamp = Column(DateTime, default=datetime.datetime.now)
91
-
92
- telemetry_entry = relationship("Telemetry", back_populates="feedback")
93
-
94
- class ConversationHistory(Base):
95
- __tablename__ = "conversation_history"
96
- session_id = Column(String, primary_key=True)
97
- messages = Column(SQLiteJSON, nullable=False)
98
- last_updated = Column(DateTime, default=datetime.datetime.now)
99
-
100
- Base.metadata.create_all(bind=engine)
101
- # --- 2. Initialize LLM and Embeddings ---
102
- gak = os.getenv("Gapi_key")
103
- llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite",google_api_key=gak)
104
- # embedding_model = SentenceTransformer("ibm-granite/granite-embedding-english-r2")
105
-
106
- # my_model_name = "gemma3:1b-it-qat"
107
- # llm = ChatOllama(model=my_model_name)
108
- embedding_model = HuggingFaceEmbeddings(
109
- model_name="ibm-granite/granite-embedding-english-r2",
110
- model_kwargs={'device': 'cpu'},
111
- encode_kwargs={'normalize_embeddings': False}
112
- )
113
-
114
- # --- 3. LangGraph State and Workflow ---
115
- class GraphState(TypedDict):
116
- chat_history: List[Dict[str, Any]]
117
- retrieved_documents: List[str]
118
- user_question: str
119
- decision:str
120
- session_id: str
121
- telemetry_id: Optional[str] = None
122
-
123
- class Route(BaseModel):
124
- step: Literal['HR Agent','Finance Agent','Legal Compliance Agent'] = Field(
125
- None, description="The next step in routing process"
126
- )
127
- router = llm.with_structured_output(Route)
128
-
129
- # class State(TypedDict):
130
- # input:str
131
- # decision:str
132
- # output:str
133
-
134
- chathistory = {}
135
-
136
- def retrieve_documents(state: GraphState):
137
- # global vectorstore_retriever
138
- # upload_documents()
139
- saved_vectorstore_index = FAISS.load_local('domain_index', embedding_model,allow_dangerous_deserialization=True)
140
- user_question = state["user_question"]
141
- # meta_filter = {'Domain':'HR'}
142
- if saved_vectorstore_index is None:
143
- raise ValueError("Knowledge base not loaded.")
144
- retrieved_docs = saved_vectorstore_index.as_retriever(search_type="mmr", search_kwargs={"k": 5})
145
- top_docs = retrieved_docs.invoke(user_question)
146
- print("Top Docs: ", top_docs)
147
- retrieved_docs_content = [doc.page_content if doc.page_content else doc for doc in top_docs]
148
- print("retrieved_documents List: ", retrieved_docs_content)
149
- return {"retrieved_documents": retrieved_docs_content}
150
-
151
- def generate_response(user_question, retrieved_documents):
152
- print("Inside generate_response--------------")
153
- global llm
154
- global chathistory
155
- global agent_name
156
- # user_question = state["user_question"]
157
- # retrieved_documents = state["retrieved_documents"]
158
-
159
- formatted_chat_history = []
160
- for msg in chathistory["chat_history"]:
161
- if msg['role'] == 'user':
162
- formatted_chat_history.append(HumanMessage(content=msg['content']))
163
- elif msg['role'] == 'assistant':
164
- formatted_chat_history.append(AIMessage(content=msg['content']))
165
-
166
- if not retrieved_documents:
167
- response_content = "I couldn't find any relevant information in the uploaded documents for your question. Can you please rephrase or provide more context?"
168
- response_obj = AIMessage(content=response_content)
169
- else:
170
- context = "\n\n".join(retrieved_documents)
171
- template = """
172
- You are a helpful AI assistant. Answer the user's question based on the provided context {context} and the conversation history {chat_history}.
173
- If the answer is not in the context, state that you don't have enough information.
174
- Do not make up answers. Only use the given context and chat_history.
175
- Remove unwanted words like 'Response:' or 'Answer:' from answers.
176
- \n\nHere is the Question:\n{user_question}
177
- """
178
- rag_prompt = PromptTemplate(
179
- input_variables=["context", "chat_history", "user_question"],
180
- template=template
181
- )
182
- rag_chain = rag_prompt | llm
183
- time.sleep(3)
184
- response_obj = rag_chain.invoke({
185
- "context": [SystemMessage(content=context)],
186
- "chat_history": formatted_chat_history,
187
- "user_question": [HumanMessage(content=user_question)]
188
- })
189
-
190
- telemetry_data = response_obj.model_dump()
191
- input_tokens = telemetry_data.get('usage_metadata', {}).get('input_tokens', 0)
192
- output_tokens = telemetry_data.get('usage_metadata', {}).get('output_tokens', 0)
193
- total_tokens = telemetry_data.get('usage_metadata', {}).get('total_tokens', 0)
194
- model_name = telemetry_data.get('response_metadata', {}).get('model', 'unknown')
195
- total_duration = telemetry_data.get('response_metadata', {}).get('total_duration', 0)
196
-
197
- db = SessionLocal()
198
- transaction_id = str(uuid.uuid4())
199
- try:
200
- telemetry_record = Telemetry(
201
- transaction_id=transaction_id,
202
- session_id=chathistory.get("session_id"),
203
- user_question=user_question,
204
- response=response_obj.content,
205
- context="\n\n".join(retrieved_documents) if retrieved_documents else "No documents retrieved",
206
- model_name=model_name,
207
- input_tokens=input_tokens,
208
- output_tokens=output_tokens,
209
- total_tokens=total_tokens,
210
- latency=total_duration,
211
- dtcreatedon=datetime.datetime.now()
212
- )
213
- db.add(telemetry_record)
214
-
215
- new_messages = chathistory["chat_history"] + [
216
- {"role": "user", "content": user_question},
217
- {"role": "assistant", "content": response_obj.content, "telemetry_id": transaction_id}
218
- ]
219
-
220
- # --- FIX: Refactored Database Save Logic ---
221
- print(f"Saving conversation for session_id: {chathistory.get('session_id')}")
222
- conversation_entry = db.query(ConversationHistory).filter_by(session_id=chathistory.get("session_id")).first()
223
- if conversation_entry:
224
- print(f"Updating existing conversation for session_id: {chathistory.get('session_id')}")
225
- conversation_entry.messages = new_messages
226
- conversation_entry.last_updated = datetime.datetime.now()
227
- else:
228
- print(f"Creating new conversation for session_id: {chathistory.get('session_id')}")
229
- new_conversation_entry = ConversationHistory(
230
- session_id=chathistory.get("session_id"),
231
- messages=new_messages,
232
- last_updated=datetime.datetime.now()
233
- )
234
- db.add(new_conversation_entry)
235
-
236
- db.commit()
237
- print(f"Successfully saved conversation for session_id: {chathistory.get('session_id')}")
238
-
239
- except Exception as e:
240
- db.rollback()
241
- print(f"***CRITICAL ERROR***: Failed to save data to database. Error: {e}")
242
- finally:
243
- db.close()
244
-
245
- return {
246
- "chat_history": new_messages,
247
- "telemetry_id": transaction_id,
248
- "agent_name": agent_name
249
- }
250
-
251
- agent_name = ""
252
- def hr_agent(state:GraphState):
253
- """Answer the user question based on Human Resource(HR)"""
254
- global agent_name
255
- user_question = state["user_question"]
256
- retrieved_documents = state["retrieved_documents"]
257
- print("HR Agent")
258
- agent_name = "HR Agent"
259
- result = generate_response(user_question,retrieved_documents)
260
- # return {"output":result}
261
- return result
262
-
263
- def finance_agent(state:GraphState):
264
- """Answer the user question based on Finance and Bank"""
265
- global agent_name
266
- user_question = state["user_question"]
267
- retrieved_documents = state["retrieved_documents"]
268
- print("Finance Agent")
269
- agent_name = "Finance Agent"
270
- result = generate_response(user_question,retrieved_documents)
271
- return result
272
-
273
- def legals_agent(state:GraphState):
274
- """Answer the user question based on Legal Compliance"""
275
- global agent_name
276
- user_question = state["user_question"]
277
- retrieved_documents = state["retrieved_documents"]
278
- print("LC agent")
279
- agent_name = "Legal Compliance Agent"
280
- result = generate_response(user_question,retrieved_documents)
281
- # return {"output":result}
282
- return result
283
-
284
- def llm_call_router(state:GraphState):
285
- decision = router.invoke(
286
- [
287
- SystemMessage(
288
- content="Route the user_question to HR Agent, Finance Agent, Legal Compliance Agent based on the user's request"
289
- ),
290
- HumanMessage(
291
- content=state['user_question']
292
- ),
293
- ]
294
- )
295
- return {"decision":decision.step}
296
-
297
- def route_decision(state:GraphState):
298
-
299
- if state['decision'] == 'HR Agent':
300
- return "hr_agent"
301
- elif state['decision'] == 'Finance Agent':
302
- return "finance_agent"
303
- elif state['decision'] == 'Legal Compliance Agent':
304
- return "legals_agent"
305
-
306
- router_builder = StateGraph(GraphState)
307
-
308
- router_builder.add_node("retrieve", retrieve_documents)
309
- router_builder.add_node("hr_agent", hr_agent)
310
- router_builder.add_node("finance_agent", finance_agent)
311
- router_builder.add_node("legals_agent", legals_agent)
312
- router_builder.add_node("llm_call_router", llm_call_router)
313
-
314
- # router_builder.add_node("generate", generate_response)
315
- # router_builder.set_entry_point("retrieve")
316
- # router_builder.add_edge("retrieve", "generate")
317
- # router_builder.add_edge("generate", END)
318
- # compiled_app = workflow.compile(checkpointer=memory)
319
-
320
-
321
- router_builder.add_edge(START, "llm_call_router")
322
- router_builder.add_conditional_edges(
323
- "llm_call_router",
324
- route_decision,
325
- {
326
- "hr_agent":"hr_agent",
327
- "finance_agent":"finance_agent",
328
- "legals_agent":"legals_agent",
329
- },
330
- )
331
- router_builder.set_entry_point("retrieve")
332
- router_builder.add_edge("retrieve","llm_call_router")
333
- router_builder.add_edge("hr_agent",END)
334
- router_builder.add_edge("finance_agent",END)
335
- router_builder.add_edge("legals_agent",END)
336
-
337
- route_workflow = router_builder.compile()
338
-
339
- # state = route_workflow.invoke({'input': "Write a poem about a wicked cat"})
340
- # print(state['output'])
341
-
342
-
343
-
344
- vectorstore_retriever = None
345
- compiled_app = None
346
- memory = MemorySaver()
347
-
348
- # --- 4. LangGraph Nodes ---
349
- # def load_documents(state:GraphState):
350
- # global selected_domain
351
-
352
-
353
-
354
-
355
-
356
-
357
-
358
- # --- 5. API Models ---
359
- class ChatHistoryEntry(BaseModel):
360
- role: str
361
- content: str
362
- telemetry_id: Optional[str] = None
363
-
364
- class ChatRequest(BaseModel):
365
- user_question: str
366
- session_id: str
367
- chat_history: Optional[List[ChatHistoryEntry]] = Field(default_factory=list)
368
-
369
- @validator('user_question')
370
- def validate_prompt(cls, v):
371
- v = v.strip()
372
- if not v:
373
- raise ValueError('Question cannot be empty')
374
- return v
375
-
376
- class ChatResponse(BaseModel):
377
- ai_response: str
378
- updated_chat_history: List[ChatHistoryEntry]
379
- telemetry_entry_id: str
380
- is_restricted: bool = False
381
- moderation_reason: Optional[str] = None
382
-
383
- class FeedbackRequest(BaseModel):
384
- session_id: str
385
- telemetry_entry_id: str
386
- feedback_score: int
387
- feedback_text: Optional[str] = None
388
-
389
- class ConversationSummary(BaseModel):
390
- session_id: str
391
- title: str
392
-
393
-
394
- @lru_cache(maxsize=5)
395
- def process_text(file):
396
- string_data = (file.read()).decode("utf-8")
397
- return string_data
398
-
399
- @lru_cache(maxsize=5)
400
- def process_pdf(file):
401
- pdf_bytes = io.BytesIO(file.read())
402
- reader = PyPDF2.PdfReader(pdf_bytes)
403
- pdf_text = "".join([page.extract_text() + "\n" for page in reader.pages])
404
- return pdf_text
405
-
406
- @lru_cache(maxsize=5)
407
- def process_docx(file):
408
- docx_bytes = io.BytesIO(file.read())
409
- docx_docs = dx(docx_bytes)
410
- docx_content = "\n".join([para.text for para in docx_docs.paragraphs])
411
- return docx_content
412
-
413
-
414
- # @app.post("/upload-documents")
415
- # def upload_documents(files):
416
- def upload_documents():
417
- global vectorstore_retriever
418
- # saved_vectorstore_index = FAISS.load_local('domain_index', embedding_model,allow_dangerous_deserialization=True)
419
- try:
420
- saved_vectorstore_index = faiss.read_index("domain_index_sec.faiss")
421
- if saved_vectorstore_index:
422
- vectorstore_retriever = saved_vectorstore_index
423
-
424
- msg = f"Successfully loaded the knowledge base."
425
- return msg, True
426
- except Exception as e:
427
- print("unable to find index...", e)
428
- print("Creating new index.....")
429
- all_documents = []
430
- hr_loader = PyPDFLoader("D:\Pdf_data\Developments_in_HR_management_in_QAAs.pdf").load()
431
- hr_finance = PyPDFLoader("D:\Pdf_data\White Paper_QA Practice.pdf").load()
432
- hr_legal = PyPDFLoader("D:\Pdf_data\Legal-Aspects-Compliances.pdf").load()
433
-
434
- for doc in hr_loader:
435
- doc.metadata['Domain'] = 'HR'
436
- all_documents.append(doc)
437
- for doc in hr_finance:
438
- doc.metadata['Domain'] = 'Finance'
439
- all_documents.append(doc)
440
- for doc in hr_legal:
441
- doc.metadata['Domain'] = 'Legal'
442
- all_documents.append(doc)
443
- # for uploaded_file in files:
444
- # doc_loader = PyPDFLoader(uploaded_file)
445
- # all_documents.extend(doc_loader.load())
446
-
447
- if not all_documents:
448
- raise Exception(status_code=400, detail="No supported documents uploaded.")
449
-
450
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
451
- text_chunks = text_splitter.split_documents(all_documents)
452
- print("text_chucks: ", text_chunks[:100])
453
-
454
- # processed_chunks_with_ids = []
455
- # for i, chunk in enumerate(text_chunks):
456
- # # Generate a unique ID for each chunk
457
- # # Option 1 (Recommended): Using UUID for global uniqueness
458
- # # chunk_id = str(uuid.uuid4())
459
-
460
- # # Option 2 (Alternative): Combining source file path with chunk index
461
- # # This is good if you want IDs to be deterministic based on file/chunk.
462
- # # You might need to make the file path more robust (e.g., hash it or normalize it).
463
- # file_source = chunk.metadata.get('source', 'unknown_source')
464
- # chunk_id = f"{file_source.replace('.','_')}_chunk_{i}"
465
-
466
- # # Add the unique ID to the chunk's metadata
467
- # # It's good practice to keep original metadata and just add your custom ID.
468
- # chunk.metadata['doc_id'] = chunk_id
469
-
470
-
471
- # processed_chunks_with_ids.append(chunk)
472
- # embeddings = [embedding_model.encode(doc_chunks.page_content, convert_to_numpy=True) for doc_chunks in processed_chunks_with_ids]
473
-
474
- print(f"Split {len(text_chunks)} chunks.")
475
- print(f"Assigned unique 'doc_id' to each chunk in metadata.")
476
- # dimension = 768
477
- # # hnsw_m = 32
478
- # # index = faiss.IndexHNSWFlat(dimension, hnsw_m, faiss.METRIC_INNER_PRODUCT)
479
- # index = faiss.IndexFlatL2(dimension)
480
- # vector_store = FAISS(
481
- # embedding_function=embedding_model.embed_query,
482
- # index=index,
483
- # docstore= InMemoryDocstore(),
484
- # index_to_docstore_id={}
485
- # )
486
- vectorstore = FAISS.from_documents(documents=text_chunks, embedding=embedding_model)
487
- # vectorstore.add_documents(text_chunks, ids = [cid.metadata['doc_id'] for cid in text_chunks])
488
- vectorstore.add_documents(text_chunks)
489
- # vectorstore_retriever = vectorstore.as_retriever(search_kwargs={'k': 5})
490
- faiss.write_index(vectorstore.index, "domain_index_sec.faiss")
491
- # vectorstore.save_local("domain_index")
492
- vectorstore_retriever = vectorstore
493
- if vectorstore:
494
- msg = f"Successfully loaded the knowledge base."
495
- return msg, True
496
- else:
497
- msg = f"Failed to process documents."
498
- return msg, False
499
-
500
- # @app.post("/chat", response_model=ChatResponse)
501
- def chat_with_rag(chatdata):
502
- global compiled_app
503
- global vectorstore_retriever
504
- global chathistory
505
- if vectorstore_retriever is None:
506
- raise Exception(status_code=400, detail="Knowledge base not loaded. Please upload documents first.")
507
- print(f"Received request: {chatdata}")
508
- # moderation_result = moderator.moderate_content(request.user_question)
509
- # if moderation_result["is_restricted"]:
510
- # # Get appropriate response based on restriction type
511
- # response_type = moderation_result.get("response_type", "general")
512
- # response_text = Config.RESTRICTED_RESPONSES.get(
513
- # response_type,
514
- # Config.RESTRICTED_RESPONSES["general"]
515
- # )
516
-
517
- # logger.warning(
518
- # f"Restricted query: {request.prompt[:100]}... "
519
- # f"Reason: {moderation_result['reason']}"
520
- # )
521
-
522
- # return ChatResponse(
523
- # ai_response=response_text,
524
- # updated_chat_history=[],
525
- # telemetry_entry_id=request.session_id,
526
- # is_restricted=True,
527
- # moderation_reason=moderation_result["reason"],
528
- # )
529
- print(" Question passed the RAI check.........")
530
- print("Received data from UI: ", chatdata)
531
- chathistory = chatdata
532
- initial_state = {
533
- # "chat_history": [msg.model_dump() for msg in chatdata.get('chat_history')],
534
- "chat_history": [msg for msg in chatdata.get('chat_history')],
535
- "retrieved_documents": [],
536
- "user_question": chatdata.get('user_question'),
537
- "session_id": chatdata.get('session_id')
538
- }
539
-
540
- try:
541
- config = {"configurable": {"thread_id": chatdata.get('session_id')}}
542
- final_state = route_workflow.invoke(initial_state, config=config)
543
-
544
- # chathistory = final_state
545
- print("chathistory inside chat_with_rag-----------------")
546
- print("Final State--- : ", final_state)
547
-
548
- ai_response_message = final_state["chat_history"][-1]["content"]
549
- updated_chat_history_dicts = final_state["chat_history"]
550
- agent_name = final_state.get("decision","No Agent")
551
-
552
- response_chat = ChatResponse(
553
- ai_response=ai_response_message,
554
- updated_chat_history=updated_chat_history_dicts,
555
- telemetry_entry_id=final_state.get("telemetry_id"),
556
- is_restricted=False,
557
- )
558
-
559
- return agent_name,response_chat.dict()
560
- except Exception as e:
561
- print(f"Internal Server Error: {e}")
562
- raise Exception(status_code=500, detail=f"An error occurred during chat processing: {e}")
563
-
564
-
565
- def submit_feedback(feedbackdata):
566
- db = SessionLocal()
567
- try:
568
- telemetry_record = db.query(Telemetry).filter(
569
- Telemetry.transaction_id == feedbackdata['telemetry_entry_id'],
570
- Telemetry.session_id == feedbackdata['session_id']
571
- ).first()
572
-
573
- if not telemetry_record:
574
- raise Exception(status_code=404, detail="Telemetry entry not found or session ID mismatch.")
575
-
576
- existing_feedback = db.query(Feedback).filter(
577
- Feedback.telemetry_entry_id == feedbackdata['telemetry_entry_id']
578
- ).first()
579
-
580
- if existing_feedback:
581
- existing_feedback.feedback_score = feedbackdata['feedback_score']
582
- existing_feedback.feedback_text = feedbackdata['feedback_text']
583
- existing_feedback.timestamp = datetime.datetime.now()
584
- else:
585
- feedback_record = Feedback(
586
- telemetry_entry_id=feedbackdata['telemetry_entry_id'],
587
- feedback_score=feedbackdata['feedback_score'],
588
- feedback_text=feedbackdata['feedback_text'],
589
- user_query=telemetry_record.user_question,
590
- llm_response=telemetry_record.response,
591
- timestamp=datetime.datetime.now()
592
- )
593
- db.add(feedback_record)
594
-
595
- db.commit()
596
-
597
- return {"message": "Feedback submitted successfully."}
598
-
599
- except Exception as e:
600
- raise e
601
- except Exception as e:
602
- db.rollback()
603
- raise Exception(status_code=500, detail=f"An error occurred: {str(e)}")
604
- finally:
605
- db.close()
606
-
607
- # @app.get("/conversations", response_model=List[ConversationSummary])
608
- def get_conversations():
609
- db = SessionLocal()
610
- try:
611
- conversations = db.query(ConversationHistory).order_by(ConversationHistory.last_updated.desc()).all()
612
- summaries = []
613
- for conv in conversations:
614
- for msg in conv.messages:
615
- print(msg)
616
- first_user_message = next((msg for msg in conv.messages if msg["role"] == "user"), None)
617
- title = first_user_message.get("content") if first_user_message else "New Conversation"
618
- summaries.append({"session_id":conv.session_id, "title":title[:30] + "..." if len(title) > 30 else title})
619
- return summaries
620
- finally:
621
- db.close()
622
-
623
- # @app.get("/conversations/{session_id}", response_model=List[ChatHistoryEntry])
624
- def get_conversation_history(session_id: str):
625
- db = SessionLocal()
626
- try:
627
- conversation = db.query(ConversationHistory).filter(ConversationHistory.session_id == session_id).first()
628
- if not conversation:
629
- raise Exception(status_code=404, detail="Conversation not found.")
630
- return conversation.messages
631
- finally:
632
- db.close()
633
-
634
-
635
-
636
-
637
- # if 'selected_model' not in st.session_state:
638
- # st.session_state.selected_model = ""
639
- # @st.dialog("Choose a domain")
640
- # def domain_modal():
641
- # domain = st.selectbox("Select a domain",["HR","Finance","Legal"])
642
- # st.session_state.selected_model = domain
643
- # if st.button("submit"):
644
- # st.rerun()
645
-
646
- # domain_modal()
647
- # print("Selected Domain: ",st.session_state['selected_model'])
648
-
649
- # llm = initialize_llm()
650
-
 
1
+ from langgraph.graph import StateGraph, START, END
2
+ # from llm_initializer import initialize_llm, generate_prompt_phi4
3
+ from langgraph.graph import MessagesState
4
+ from langchain_core.messages import ToolMessage, HumanMessage, SystemMessage
5
+ from typing_extensions import Literal, TypedDict
6
+ from pydantic import BaseModel, Field
7
+ from pydantic import BaseModel, Field, validator
8
+ from typing import List, Optional, Dict, Any, TypedDict,Generic, TypeVar
9
+ import uuid
10
+ import io
11
+ import os
12
+ import PyPDF2
13
+ import re
14
+ import logging
15
+ import time
16
+ from docx import Document as dx
17
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
18
+ from langchain_community.document_loaders import (
19
+ DirectoryLoader,
20
+ PyPDFLoader,
21
+ TextLoader
22
+ )
23
+ import tempfile
24
+ import faiss
25
+ from langchain_community.vectorstores import FAISS
26
+ from langchain_core.prompts import PromptTemplate
27
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
28
+ from langchain_huggingface import HuggingFaceEmbeddings
29
+ from langgraph.checkpoint.memory import MemorySaver
30
+ from langgraph.graph import StateGraph, END
31
+ from sqlalchemy import create_engine, Column, String, Integer, DateTime, ForeignKey, Text
32
+ from sqlalchemy.dialects.sqlite import JSON as SQLiteJSON
33
+ # from sqlalchemy.ext.declarative import declarative_base
34
+ from sqlalchemy.orm import sessionmaker, relationship
35
+ from sentence_transformers import SentenceTransformer
36
+ from huggingface_hub import login
37
+ from langchain_google_genai import ChatGoogleGenerativeAI
38
+ import datetime
39
+ from enum import Enum as PyEnum
40
+ from sqlalchemy.orm import DeclarativeBase
41
+ # from config import Config
42
+ from functools import lru_cache
43
+ from dotenv import load_dotenv
44
+
45
+ load_dotenv()
46
+ hf_token = os.getenv("hf_user_token")
47
+ login(hf_token)
48
+
49
+ T = TypeVar("T")
50
+ # --- 1. Database Setup ---
51
+ DATABASE_URL = "sqlite:///Db_domain_agent.db"
52
+ engine = create_engine(DATABASE_URL)
53
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
54
+
55
+ class Base(DeclarativeBase):
56
+ pass
57
+
58
+ class FeedbackScore(PyEnum):
59
+ POSITIVE = 1
60
+ NEGATIVE = -1
61
+
62
+ class Telemetry(Base):
63
+ __tablename__ = "telemetry_table"
64
+ transaction_id = Column(String, primary_key=True)
65
+ session_id = Column(String)
66
+ user_question = Column(Text)
67
+ response = Column(Text)
68
+ context = Column(Text)
69
+ model_name = Column(String)
70
+ input_tokens = Column(Integer)
71
+ output_tokens = Column(Integer)
72
+ total_tokens = Column(Integer)
73
+ latency = Column(Integer)
74
+ dtcreatedon = Column(DateTime)
75
+
76
+ feedback = relationship("Feedback", back_populates="telemetry_entry", uselist=False)
77
+
78
+ class Feedback(Base):
79
+ __tablename__ = "feedback_table"
80
+ id = Column(Integer, primary_key=True, autoincrement=True)
81
+ telemetry_entry_id = Column(String, ForeignKey("telemetry_table.transaction_id"), nullable=False, unique=True)
82
+ feedback_score = Column(Integer, nullable=False)
83
+ feedback_text = Column(Text, nullable=True)
84
+ user_query = Column(Text, nullable=False)
85
+ llm_response = Column(Text, nullable=False)
86
+ timestamp = Column(DateTime, default=datetime.datetime.now)
87
+
88
+ telemetry_entry = relationship("Telemetry", back_populates="feedback")
89
+
90
+ class ConversationHistory(Base):
91
+ __tablename__ = "conversation_history"
92
+ session_id = Column(String, primary_key=True)
93
+ messages = Column(SQLiteJSON, nullable=False)
94
+ last_updated = Column(DateTime, default=datetime.datetime.now)
95
+
96
+ Base.metadata.create_all(bind=engine)
97
+ # --- 2. Initialize LLM and Embeddings ---
98
+ gak = os.getenv("Gapi_key")
99
+ llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite",google_api_key=gak)
100
+ # embedding_model = SentenceTransformer("ibm-granite/granite-embedding-english-r2")
101
+
102
+ # my_model_name = "gemma3:1b-it-qat"
103
+ # llm = ChatOllama(model=my_model_name)
104
+ embedding_model = HuggingFaceEmbeddings(
105
+ model_name="ibm-granite/granite-embedding-english-r2",
106
+ model_kwargs={'device': 'cpu'},
107
+ encode_kwargs={'normalize_embeddings': False}
108
+ )
109
+
110
+ # --- 3. LangGraph State and Workflow ---
111
+ class GraphState(TypedDict):
112
+ chat_history: List[Dict[str, Any]]
113
+ retrieved_documents: List[str]
114
+ user_question: str
115
+ decision:str
116
+ session_id: str
117
+ telemetry_id: Optional[str] = None
118
+
119
+ class Route(BaseModel):
120
+ step: Literal['HR Agent','Finance Agent','Legal Compliance Agent'] = Field(
121
+ None, description="The next step in routing process"
122
+ )
123
+ router = llm.with_structured_output(Route)
124
+
125
+ # class State(TypedDict):
126
+ # input:str
127
+ # decision:str
128
+ # output:str
129
+
130
+ chathistory = {}
131
+
132
+ def retrieve_documents(state: GraphState):
133
+ # global vectorstore_retriever
134
+ # upload_documents()
135
+ saved_vectorstore_index = FAISS.load_local('domain_index', embedding_model,allow_dangerous_deserialization=True)
136
+ user_question = state["user_question"]
137
+ # meta_filter = {'Domain':'HR'}
138
+ if saved_vectorstore_index is None:
139
+ raise ValueError("Knowledge base not loaded.")
140
+ retrieved_docs = saved_vectorstore_index.as_retriever(search_type="mmr", search_kwargs={"k": 5})
141
+ top_docs = retrieved_docs.invoke(user_question)
142
+ print("Top Docs: ", top_docs)
143
+ retrieved_docs_content = [doc.page_content if doc.page_content else doc for doc in top_docs]
144
+ print("retrieved_documents List: ", retrieved_docs_content)
145
+ return {"retrieved_documents": retrieved_docs_content}
146
+
147
+ def generate_response(user_question, retrieved_documents):
148
+ print("Inside generate_response--------------")
149
+ global llm
150
+ global chathistory
151
+ global agent_name
152
+ # user_question = state["user_question"]
153
+ # retrieved_documents = state["retrieved_documents"]
154
+
155
+ formatted_chat_history = []
156
+ for msg in chathistory["chat_history"]:
157
+ if msg['role'] == 'user':
158
+ formatted_chat_history.append(HumanMessage(content=msg['content']))
159
+ elif msg['role'] == 'assistant':
160
+ formatted_chat_history.append(AIMessage(content=msg['content']))
161
+
162
+ if not retrieved_documents:
163
+ response_content = "I couldn't find any relevant information in the uploaded documents for your question. Can you please rephrase or provide more context?"
164
+ response_obj = AIMessage(content=response_content)
165
+ else:
166
+ context = "\n\n".join(retrieved_documents)
167
+ template = """
168
+ You are a helpful AI assistant. Answer the user's question based on the provided context {context} and the conversation history {chat_history}.
169
+ If the answer is not in the context, state that you don't have enough information.
170
+ Do not make up answers. Only use the given context and chat_history.
171
+ Remove unwanted words like 'Response:' or 'Answer:' from answers.
172
+ \n\nHere is the Question:\n{user_question}
173
+ """
174
+ rag_prompt = PromptTemplate(
175
+ input_variables=["context", "chat_history", "user_question"],
176
+ template=template
177
+ )
178
+ rag_chain = rag_prompt | llm
179
+ time.sleep(3)
180
+ response_obj = rag_chain.invoke({
181
+ "context": [SystemMessage(content=context)],
182
+ "chat_history": formatted_chat_history,
183
+ "user_question": [HumanMessage(content=user_question)]
184
+ })
185
+
186
+ telemetry_data = response_obj.model_dump()
187
+ input_tokens = telemetry_data.get('usage_metadata', {}).get('input_tokens', 0)
188
+ output_tokens = telemetry_data.get('usage_metadata', {}).get('output_tokens', 0)
189
+ total_tokens = telemetry_data.get('usage_metadata', {}).get('total_tokens', 0)
190
+ model_name = telemetry_data.get('response_metadata', {}).get('model', 'unknown')
191
+ total_duration = telemetry_data.get('response_metadata', {}).get('total_duration', 0)
192
+
193
+ db = SessionLocal()
194
+ transaction_id = str(uuid.uuid4())
195
+ try:
196
+ telemetry_record = Telemetry(
197
+ transaction_id=transaction_id,
198
+ session_id=chathistory.get("session_id"),
199
+ user_question=user_question,
200
+ response=response_obj.content,
201
+ context="\n\n".join(retrieved_documents) if retrieved_documents else "No documents retrieved",
202
+ model_name=model_name,
203
+ input_tokens=input_tokens,
204
+ output_tokens=output_tokens,
205
+ total_tokens=total_tokens,
206
+ latency=total_duration,
207
+ dtcreatedon=datetime.datetime.now()
208
+ )
209
+ db.add(telemetry_record)
210
+
211
+ new_messages = chathistory["chat_history"] + [
212
+ {"role": "user", "content": user_question},
213
+ {"role": "assistant", "content": response_obj.content, "telemetry_id": transaction_id}
214
+ ]
215
+
216
+ # --- FIX: Refactored Database Save Logic ---
217
+ print(f"Saving conversation for session_id: {chathistory.get('session_id')}")
218
+ conversation_entry = db.query(ConversationHistory).filter_by(session_id=chathistory.get("session_id")).first()
219
+ if conversation_entry:
220
+ print(f"Updating existing conversation for session_id: {chathistory.get('session_id')}")
221
+ conversation_entry.messages = new_messages
222
+ conversation_entry.last_updated = datetime.datetime.now()
223
+ else:
224
+ print(f"Creating new conversation for session_id: {chathistory.get('session_id')}")
225
+ new_conversation_entry = ConversationHistory(
226
+ session_id=chathistory.get("session_id"),
227
+ messages=new_messages,
228
+ last_updated=datetime.datetime.now()
229
+ )
230
+ db.add(new_conversation_entry)
231
+
232
+ db.commit()
233
+ print(f"Successfully saved conversation for session_id: {chathistory.get('session_id')}")
234
+
235
+ except Exception as e:
236
+ db.rollback()
237
+ print(f"***CRITICAL ERROR***: Failed to save data to database. Error: {e}")
238
+ finally:
239
+ db.close()
240
+
241
+ return {
242
+ "chat_history": new_messages,
243
+ "telemetry_id": transaction_id,
244
+ "agent_name": agent_name
245
+ }
246
+
247
+ agent_name = ""
248
+ def hr_agent(state:GraphState):
249
+ """Answer the user question based on Human Resource(HR)"""
250
+ global agent_name
251
+ user_question = state["user_question"]
252
+ retrieved_documents = state["retrieved_documents"]
253
+ print("HR Agent")
254
+ agent_name = "HR Agent"
255
+ result = generate_response(user_question,retrieved_documents)
256
+ # return {"output":result}
257
+ return result
258
+
259
+ def finance_agent(state:GraphState):
260
+ """Answer the user question based on Finance and Bank"""
261
+ global agent_name
262
+ user_question = state["user_question"]
263
+ retrieved_documents = state["retrieved_documents"]
264
+ print("Finance Agent")
265
+ agent_name = "Finance Agent"
266
+ result = generate_response(user_question,retrieved_documents)
267
+ return result
268
+
269
+ def legals_agent(state:GraphState):
270
+ """Answer the user question based on Legal Compliance"""
271
+ global agent_name
272
+ user_question = state["user_question"]
273
+ retrieved_documents = state["retrieved_documents"]
274
+ print("LC agent")
275
+ agent_name = "Legal Compliance Agent"
276
+ result = generate_response(user_question,retrieved_documents)
277
+ # return {"output":result}
278
+ return result
279
+
280
+ def llm_call_router(state:GraphState):
281
+ decision = router.invoke(
282
+ [
283
+ SystemMessage(
284
+ content="Route the user_question to HR Agent, Finance Agent, Legal Compliance Agent based on the user's request"
285
+ ),
286
+ HumanMessage(
287
+ content=state['user_question']
288
+ ),
289
+ ]
290
+ )
291
+ return {"decision":decision.step}
292
+
293
+ def route_decision(state:GraphState):
294
+
295
+ if state['decision'] == 'HR Agent':
296
+ return "hr_agent"
297
+ elif state['decision'] == 'Finance Agent':
298
+ return "finance_agent"
299
+ elif state['decision'] == 'Legal Compliance Agent':
300
+ return "legals_agent"
301
+
302
+ router_builder = StateGraph(GraphState)
303
+
304
+ router_builder.add_node("retrieve", retrieve_documents)
305
+ router_builder.add_node("hr_agent", hr_agent)
306
+ router_builder.add_node("finance_agent", finance_agent)
307
+ router_builder.add_node("legals_agent", legals_agent)
308
+ router_builder.add_node("llm_call_router", llm_call_router)
309
+
310
+ # router_builder.add_node("generate", generate_response)
311
+ # router_builder.set_entry_point("retrieve")
312
+ # router_builder.add_edge("retrieve", "generate")
313
+ # router_builder.add_edge("generate", END)
314
+ # compiled_app = workflow.compile(checkpointer=memory)
315
+
316
+
317
+ router_builder.add_edge(START, "llm_call_router")
318
+ router_builder.add_conditional_edges(
319
+ "llm_call_router",
320
+ route_decision,
321
+ {
322
+ "hr_agent":"hr_agent",
323
+ "finance_agent":"finance_agent",
324
+ "legals_agent":"legals_agent",
325
+ },
326
+ )
327
+ router_builder.set_entry_point("retrieve")
328
+ router_builder.add_edge("retrieve","llm_call_router")
329
+ router_builder.add_edge("hr_agent",END)
330
+ router_builder.add_edge("finance_agent",END)
331
+ router_builder.add_edge("legals_agent",END)
332
+
333
+ route_workflow = router_builder.compile()
334
+
335
+ # state = route_workflow.invoke({'input': "Write a poem about a wicked cat"})
336
+ # print(state['output'])
337
+
338
+
339
+
340
+ vectorstore_retriever = None
341
+ compiled_app = None
342
+ memory = MemorySaver()
343
+
344
+ # --- 4. LangGraph Nodes ---
345
+ # def load_documents(state:GraphState):
346
+ # global selected_domain
347
+
348
+
349
+
350
+
351
+
352
+
353
+
354
+ # --- 5. API Models ---
355
+ class ChatHistoryEntry(BaseModel):
356
+ role: str
357
+ content: str
358
+ telemetry_id: Optional[str] = None
359
+
360
+ class ChatRequest(BaseModel):
361
+ user_question: str
362
+ session_id: str
363
+ chat_history: Optional[List[ChatHistoryEntry]] = Field(default_factory=list)
364
+
365
+ @validator('user_question')
366
+ def validate_prompt(cls, v):
367
+ v = v.strip()
368
+ if not v:
369
+ raise ValueError('Question cannot be empty')
370
+ return v
371
+
372
+ class ChatResponse(BaseModel):
373
+ ai_response: str
374
+ updated_chat_history: List[ChatHistoryEntry]
375
+ telemetry_entry_id: str
376
+ is_restricted: bool = False
377
+ moderation_reason: Optional[str] = None
378
+
379
+ class FeedbackRequest(BaseModel):
380
+ session_id: str
381
+ telemetry_entry_id: str
382
+ feedback_score: int
383
+ feedback_text: Optional[str] = None
384
+
385
+ class ConversationSummary(BaseModel):
386
+ session_id: str
387
+ title: str
388
+
389
+
390
+ @lru_cache(maxsize=5)
391
+ def process_text(file):
392
+ string_data = (file.read()).decode("utf-8")
393
+ return string_data
394
+
395
+ @lru_cache(maxsize=5)
396
+ def process_pdf(file):
397
+ pdf_bytes = io.BytesIO(file.read())
398
+ reader = PyPDF2.PdfReader(pdf_bytes)
399
+ pdf_text = "".join([page.extract_text() + "\n" for page in reader.pages])
400
+ return pdf_text
401
+
402
+ @lru_cache(maxsize=5)
403
+ def process_docx(file):
404
+ docx_bytes = io.BytesIO(file.read())
405
+ docx_docs = dx(docx_bytes)
406
+ docx_content = "\n".join([para.text for para in docx_docs.paragraphs])
407
+ return docx_content
408
+
409
+
410
+ # @app.post("/upload-documents")
411
+ # def upload_documents(files):
412
+ def upload_documents():
413
+ global vectorstore_retriever
414
+ # saved_vectorstore_index = FAISS.load_local('domain_index', embedding_model,allow_dangerous_deserialization=True)
415
+ try:
416
+ saved_vectorstore_index = faiss.read_index("domain_index_sec.faiss")
417
+ if saved_vectorstore_index:
418
+ vectorstore_retriever = saved_vectorstore_index
419
+
420
+ msg = f"Successfully loaded the knowledge base."
421
+ return msg, True
422
+ except Exception as e:
423
+ print("unable to find index...", e)
424
+ print("Creating new index.....")
425
+ all_documents = []
426
+ hr_loader = PyPDFLoader("D:\Pdf_data\Developments_in_HR_management_in_QAAs.pdf").load()
427
+ hr_finance = PyPDFLoader("D:\Pdf_data\White Paper_QA Practice.pdf").load()
428
+ hr_legal = PyPDFLoader("D:\Pdf_data\Legal-Aspects-Compliances.pdf").load()
429
+
430
+ for doc in hr_loader:
431
+ doc.metadata['Domain'] = 'HR'
432
+ all_documents.append(doc)
433
+ for doc in hr_finance:
434
+ doc.metadata['Domain'] = 'Finance'
435
+ all_documents.append(doc)
436
+ for doc in hr_legal:
437
+ doc.metadata['Domain'] = 'Legal'
438
+ all_documents.append(doc)
439
+ # for uploaded_file in files:
440
+ # doc_loader = PyPDFLoader(uploaded_file)
441
+ # all_documents.extend(doc_loader.load())
442
+
443
+ if not all_documents:
444
+ raise Exception(status_code=400, detail="No supported documents uploaded.")
445
+
446
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
447
+ text_chunks = text_splitter.split_documents(all_documents)
448
+ print("text_chucks: ", text_chunks[:100])
449
+
450
+ # processed_chunks_with_ids = []
451
+ # for i, chunk in enumerate(text_chunks):
452
+ # # Generate a unique ID for each chunk
453
+ # # Option 1 (Recommended): Using UUID for global uniqueness
454
+ # # chunk_id = str(uuid.uuid4())
455
+
456
+ # # Option 2 (Alternative): Combining source file path with chunk index
457
+ # # This is good if you want IDs to be deterministic based on file/chunk.
458
+ # # You might need to make the file path more robust (e.g., hash it or normalize it).
459
+ # file_source = chunk.metadata.get('source', 'unknown_source')
460
+ # chunk_id = f"{file_source.replace('.','_')}_chunk_{i}"
461
+
462
+ # # Add the unique ID to the chunk's metadata
463
+ # # It's good practice to keep original metadata and just add your custom ID.
464
+ # chunk.metadata['doc_id'] = chunk_id
465
+
466
+
467
+ # processed_chunks_with_ids.append(chunk)
468
+ # embeddings = [embedding_model.encode(doc_chunks.page_content, convert_to_numpy=True) for doc_chunks in processed_chunks_with_ids]
469
+
470
+ print(f"Split {len(text_chunks)} chunks.")
471
+ print(f"Assigned unique 'doc_id' to each chunk in metadata.")
472
+ # dimension = 768
473
+ # # hnsw_m = 32
474
+ # # index = faiss.IndexHNSWFlat(dimension, hnsw_m, faiss.METRIC_INNER_PRODUCT)
475
+ # index = faiss.IndexFlatL2(dimension)
476
+ # vector_store = FAISS(
477
+ # embedding_function=embedding_model.embed_query,
478
+ # index=index,
479
+ # docstore= InMemoryDocstore(),
480
+ # index_to_docstore_id={}
481
+ # )
482
+ vectorstore = FAISS.from_documents(documents=text_chunks, embedding=embedding_model)
483
+ # vectorstore.add_documents(text_chunks, ids = [cid.metadata['doc_id'] for cid in text_chunks])
484
+ vectorstore.add_documents(text_chunks)
485
+ # vectorstore_retriever = vectorstore.as_retriever(search_kwargs={'k': 5})
486
+ faiss.write_index(vectorstore.index, "domain_index_sec.faiss")
487
+ # vectorstore.save_local("domain_index")
488
+ vectorstore_retriever = vectorstore
489
+ if vectorstore:
490
+ msg = f"Successfully loaded the knowledge base."
491
+ return msg, True
492
+ else:
493
+ msg = f"Failed to process documents."
494
+ return msg, False
495
+
496
+ # @app.post("/chat", response_model=ChatResponse)
497
+ def chat_with_rag(chatdata):
498
+ global compiled_app
499
+ global vectorstore_retriever
500
+ global chathistory
501
+ if vectorstore_retriever is None:
502
+ raise Exception(status_code=400, detail="Knowledge base not loaded. Please upload documents first.")
503
+ print(f"Received request: {chatdata}")
504
+ # moderation_result = moderator.moderate_content(request.user_question)
505
+ # if moderation_result["is_restricted"]:
506
+ # # Get appropriate response based on restriction type
507
+ # response_type = moderation_result.get("response_type", "general")
508
+ # response_text = Config.RESTRICTED_RESPONSES.get(
509
+ # response_type,
510
+ # Config.RESTRICTED_RESPONSES["general"]
511
+ # )
512
+
513
+ # logger.warning(
514
+ # f"Restricted query: {request.prompt[:100]}... "
515
+ # f"Reason: {moderation_result['reason']}"
516
+ # )
517
+
518
+ # return ChatResponse(
519
+ # ai_response=response_text,
520
+ # updated_chat_history=[],
521
+ # telemetry_entry_id=request.session_id,
522
+ # is_restricted=True,
523
+ # moderation_reason=moderation_result["reason"],
524
+ # )
525
+ print("✅ Question passed the RAI check.........")
526
+ print("Received data from UI: ", chatdata)
527
+ chathistory = chatdata
528
+ initial_state = {
529
+ # "chat_history": [msg.model_dump() for msg in chatdata.get('chat_history')],
530
+ "chat_history": [msg for msg in chatdata.get('chat_history')],
531
+ "retrieved_documents": [],
532
+ "user_question": chatdata.get('user_question'),
533
+ "session_id": chatdata.get('session_id')
534
+ }
535
+
536
+ try:
537
+ config = {"configurable": {"thread_id": chatdata.get('session_id')}}
538
+ final_state = route_workflow.invoke(initial_state, config=config)
539
+
540
+ # chathistory = final_state
541
+ print("chathistory inside chat_with_rag-----------------")
542
+ print("Final State--- : ", final_state)
543
+
544
+ ai_response_message = final_state["chat_history"][-1]["content"]
545
+ updated_chat_history_dicts = final_state["chat_history"]
546
+ agent_name = final_state.get("decision","No Agent")
547
+
548
+ response_chat = ChatResponse(
549
+ ai_response=ai_response_message,
550
+ updated_chat_history=updated_chat_history_dicts,
551
+ telemetry_entry_id=final_state.get("telemetry_id"),
552
+ is_restricted=False,
553
+ )
554
+
555
+ return agent_name,response_chat.dict()
556
+ except Exception as e:
557
+ print(f"Internal Server Error: {e}")
558
+ raise Exception(status_code=500, detail=f"An error occurred during chat processing: {e}")
559
+
560
+
561
+ def submit_feedback(feedbackdata):
562
+ db = SessionLocal()
563
+ try:
564
+ telemetry_record = db.query(Telemetry).filter(
565
+ Telemetry.transaction_id == feedbackdata['telemetry_entry_id'],
566
+ Telemetry.session_id == feedbackdata['session_id']
567
+ ).first()
568
+
569
+ if not telemetry_record:
570
+ raise Exception(status_code=404, detail="Telemetry entry not found or session ID mismatch.")
571
+
572
+ existing_feedback = db.query(Feedback).filter(
573
+ Feedback.telemetry_entry_id == feedbackdata['telemetry_entry_id']
574
+ ).first()
575
+
576
+ if existing_feedback:
577
+ existing_feedback.feedback_score = feedbackdata['feedback_score']
578
+ existing_feedback.feedback_text = feedbackdata['feedback_text']
579
+ existing_feedback.timestamp = datetime.datetime.now()
580
+ else:
581
+ feedback_record = Feedback(
582
+ telemetry_entry_id=feedbackdata['telemetry_entry_id'],
583
+ feedback_score=feedbackdata['feedback_score'],
584
+ feedback_text=feedbackdata['feedback_text'],
585
+ user_query=telemetry_record.user_question,
586
+ llm_response=telemetry_record.response,
587
+ timestamp=datetime.datetime.now()
588
+ )
589
+ db.add(feedback_record)
590
+
591
+ db.commit()
592
+
593
+ return {"message": "Feedback submitted successfully."}
594
+
595
+ except Exception as e:
596
+ raise e
597
+ except Exception as e:
598
+ db.rollback()
599
+ raise Exception(status_code=500, detail=f"An error occurred: {str(e)}")
600
+ finally:
601
+ db.close()
602
+
603
+ # @app.get("/conversations", response_model=List[ConversationSummary])
604
+ def get_conversations():
605
+ db = SessionLocal()
606
+ try:
607
+ conversations = db.query(ConversationHistory).order_by(ConversationHistory.last_updated.desc()).all()
608
+ summaries = []
609
+ for conv in conversations:
610
+ for msg in conv.messages:
611
+ print(msg)
612
+ first_user_message = next((msg for msg in conv.messages if msg["role"] == "user"), None)
613
+ title = first_user_message.get("content") if first_user_message else "New Conversation"
614
+ summaries.append({"session_id":conv.session_id, "title":title[:30] + "..." if len(title) > 30 else title})
615
+ return summaries
616
+ finally:
617
+ db.close()
618
+
619
+ # @app.get("/conversations/{session_id}", response_model=List[ChatHistoryEntry])
620
+ def get_conversation_history(session_id: str):
621
+ db = SessionLocal()
622
+ try:
623
+ conversation = db.query(ConversationHistory).filter(ConversationHistory.session_id == session_id).first()
624
+ if not conversation:
625
+ raise Exception(status_code=404, detail="Conversation not found.")
626
+ return conversation.messages
627
+ finally:
628
+ db.close()
629
+
630
+
631
+
632
+
633
+ # if 'selected_model' not in st.session_state:
634
+ # st.session_state.selected_model = ""
635
+ # @st.dialog("Choose a domain")
636
+ # def domain_modal():
637
+ # domain = st.selectbox("Select a domain",["HR","Finance","Legal"])
638
+ # st.session_state.selected_model = domain
639
+ # if st.button("submit"):
640
+ # st.rerun()
641
+
642
+ # domain_modal()
643
+ # print("Selected Domain: ",st.session_state['selected_model'])
644
+
645
+ # llm = initialize_llm()
646
+