SRA25 commited on
Commit
926b19a
Β·
verified Β·
1 Parent(s): 5c20905

Upload 5 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ domain_index_sec.faiss filter=lfs diff=lfs merge=lfs -text
Db_domain_agent.db ADDED
Binary file (90.1 kB). View file
 
domain_index_sec.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5e2bda09d5e81a04ddefa8442ec2ed3e664aab2a570290298ca1867dde66b80
3
+ size 528429
mydomain_agent.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ from langgraph.graph import StateGraph, START, END
3
+ # from llm_initializer import initialize_llm, generate_prompt_phi4
4
+ from langgraph.graph import MessagesState
5
+ from langchain_core.messages import ToolMessage, HumanMessage, SystemMessage
6
+ from typing_extensions import Literal, TypedDict
7
+ from IPython.display import Image, display
8
+ from pydantic import BaseModel, Field
9
+ from pydantic import BaseModel, Field, validator
10
+ from typing import List, Optional, Dict, Any, TypedDict,Generic, TypeVar
11
+ from abc import ABC
12
+ import uuid
13
+ import io
14
+ import os
15
+ import PyPDF2
16
+ import re
17
+ import logging
18
+ import time
19
+ from docx import Document as dx
20
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
21
+ from langchain_community.document_loaders import (
22
+ DirectoryLoader,
23
+ PyPDFLoader,
24
+ TextLoader
25
+ )
26
+ import tempfile
27
+ import faiss
28
+ from langchain_community.docstore.in_memory import InMemoryDocstore
29
+ from langchain_community.vectorstores import FAISS
30
+ from langchain_core.prompts import PromptTemplate
31
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
32
+ from langchain_huggingface import HuggingFaceEmbeddings
33
+ from langgraph.checkpoint.memory import MemorySaver
34
+ from langgraph.graph import StateGraph, END
35
+ from sqlalchemy import create_engine, Column, String, Integer, DateTime, ForeignKey, Text
36
+ from sqlalchemy.dialects.sqlite import JSON as SQLiteJSON
37
+ # from sqlalchemy.ext.declarative import declarative_base
38
+ from sqlalchemy.orm import sessionmaker, relationship
39
+ from sentence_transformers import SentenceTransformer
40
+ from huggingface_hub import login
41
+ from langchain_google_genai import ChatGoogleGenerativeAI
42
+ import datetime
43
+ from enum import Enum as PyEnum
44
+ from sqlalchemy.orm import DeclarativeBase
45
+ # from config import Config
46
+ from functools import lru_cache
47
+ from dotenv import load_dotenv
48
+
49
+ load_dotenv()
50
+ hf_token = os.getenv("hf_user_token")
51
+ login(hf_token)
52
+
53
+ T = TypeVar("T")
54
+ # --- 1. Database Setup ---
55
+ DATABASE_URL = "sqlite:///Db_domain_agent.db"
56
+ engine = create_engine(DATABASE_URL)
57
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
58
+
59
+ class Base(DeclarativeBase):
60
+ pass
61
+
62
+ class FeedbackScore(PyEnum):
63
+ POSITIVE = 1
64
+ NEGATIVE = -1
65
+
66
+ class Telemetry(Base):
67
+ __tablename__ = "telemetry_table"
68
+ transaction_id = Column(String, primary_key=True)
69
+ session_id = Column(String)
70
+ user_question = Column(Text)
71
+ response = Column(Text)
72
+ context = Column(Text)
73
+ model_name = Column(String)
74
+ input_tokens = Column(Integer)
75
+ output_tokens = Column(Integer)
76
+ total_tokens = Column(Integer)
77
+ latency = Column(Integer)
78
+ dtcreatedon = Column(DateTime)
79
+
80
+ feedback = relationship("Feedback", back_populates="telemetry_entry", uselist=False)
81
+
82
+ class Feedback(Base):
83
+ __tablename__ = "feedback_table"
84
+ id = Column(Integer, primary_key=True, autoincrement=True)
85
+ telemetry_entry_id = Column(String, ForeignKey("telemetry_table.transaction_id"), nullable=False, unique=True)
86
+ feedback_score = Column(Integer, nullable=False)
87
+ feedback_text = Column(Text, nullable=True)
88
+ user_query = Column(Text, nullable=False)
89
+ llm_response = Column(Text, nullable=False)
90
+ timestamp = Column(DateTime, default=datetime.datetime.now)
91
+
92
+ telemetry_entry = relationship("Telemetry", back_populates="feedback")
93
+
94
+ class ConversationHistory(Base):
95
+ __tablename__ = "conversation_history"
96
+ session_id = Column(String, primary_key=True)
97
+ messages = Column(SQLiteJSON, nullable=False)
98
+ last_updated = Column(DateTime, default=datetime.datetime.now)
99
+
100
+ Base.metadata.create_all(bind=engine)
101
+ # --- 2. Initialize LLM and Embeddings ---
102
+ gak = os.getenv("Gapi_key")
103
+ llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite",google_api_key=gak)
104
+ # embedding_model = SentenceTransformer("ibm-granite/granite-embedding-english-r2")
105
+
106
+ # my_model_name = "gemma3:1b-it-qat"
107
+ # llm = ChatOllama(model=my_model_name)
108
+ embedding_model = HuggingFaceEmbeddings(
109
+ model_name="ibm-granite/granite-embedding-english-r2",
110
+ model_kwargs={'device': 'cpu'},
111
+ encode_kwargs={'normalize_embeddings': False}
112
+ )
113
+
114
+ # --- 3. LangGraph State and Workflow ---
115
+ class GraphState(TypedDict):
116
+ chat_history: List[Dict[str, Any]]
117
+ retrieved_documents: List[str]
118
+ user_question: str
119
+ decision:str
120
+ session_id: str
121
+ telemetry_id: Optional[str] = None
122
+
123
+ class Route(BaseModel):
124
+ step: Literal['HR Agent','Finance Agent','Legal Compliance Agent'] = Field(
125
+ None, description="The next step in routing process"
126
+ )
127
+ router = llm.with_structured_output(Route)
128
+
129
+ # class State(TypedDict):
130
+ # input:str
131
+ # decision:str
132
+ # output:str
133
+
134
+ chathistory = {}
135
+
136
+ def retrieve_documents(state: GraphState):
137
+ # global vectorstore_retriever
138
+ # upload_documents()
139
+ saved_vectorstore_index = FAISS.load_local('domain_index', embedding_model,allow_dangerous_deserialization=True)
140
+ user_question = state["user_question"]
141
+ # meta_filter = {'Domain':'HR'}
142
+ if saved_vectorstore_index is None:
143
+ raise ValueError("Knowledge base not loaded.")
144
+ retrieved_docs = saved_vectorstore_index.as_retriever(search_type="mmr", search_kwargs={"k": 5})
145
+ top_docs = retrieved_docs.invoke(user_question)
146
+ print("Top Docs: ", top_docs)
147
+ retrieved_docs_content = [doc.page_content if doc.page_content else doc for doc in top_docs]
148
+ print("retrieved_documents List: ", retrieved_docs_content)
149
+ return {"retrieved_documents": retrieved_docs_content}
150
+
151
+ def generate_response(user_question, retrieved_documents):
152
+ print("Inside generate_response--------------")
153
+ global llm
154
+ global chathistory
155
+ global agent_name
156
+ # user_question = state["user_question"]
157
+ # retrieved_documents = state["retrieved_documents"]
158
+
159
+ formatted_chat_history = []
160
+ for msg in chathistory["chat_history"]:
161
+ if msg['role'] == 'user':
162
+ formatted_chat_history.append(HumanMessage(content=msg['content']))
163
+ elif msg['role'] == 'assistant':
164
+ formatted_chat_history.append(AIMessage(content=msg['content']))
165
+
166
+ if not retrieved_documents:
167
+ response_content = "I couldn't find any relevant information in the uploaded documents for your question. Can you please rephrase or provide more context?"
168
+ response_obj = AIMessage(content=response_content)
169
+ else:
170
+ context = "\n\n".join(retrieved_documents)
171
+ template = """
172
+ You are a helpful AI assistant. Answer the user's question based on the provided context {context} and the conversation history {chat_history}.
173
+ If the answer is not in the context, state that you don't have enough information.
174
+ Do not make up answers. Only use the given context and chat_history.
175
+ Remove unwanted words like 'Response:' or 'Answer:' from answers.
176
+ \n\nHere is the Question:\n{user_question}
177
+ """
178
+ rag_prompt = PromptTemplate(
179
+ input_variables=["context", "chat_history", "user_question"],
180
+ template=template
181
+ )
182
+ rag_chain = rag_prompt | llm
183
+ time.sleep(3)
184
+ response_obj = rag_chain.invoke({
185
+ "context": [SystemMessage(content=context)],
186
+ "chat_history": formatted_chat_history,
187
+ "user_question": [HumanMessage(content=user_question)]
188
+ })
189
+
190
+ telemetry_data = response_obj.model_dump()
191
+ input_tokens = telemetry_data.get('usage_metadata', {}).get('input_tokens', 0)
192
+ output_tokens = telemetry_data.get('usage_metadata', {}).get('output_tokens', 0)
193
+ total_tokens = telemetry_data.get('usage_metadata', {}).get('total_tokens', 0)
194
+ model_name = telemetry_data.get('response_metadata', {}).get('model', 'unknown')
195
+ total_duration = telemetry_data.get('response_metadata', {}).get('total_duration', 0)
196
+
197
+ db = SessionLocal()
198
+ transaction_id = str(uuid.uuid4())
199
+ try:
200
+ telemetry_record = Telemetry(
201
+ transaction_id=transaction_id,
202
+ session_id=chathistory.get("session_id"),
203
+ user_question=user_question,
204
+ response=response_obj.content,
205
+ context="\n\n".join(retrieved_documents) if retrieved_documents else "No documents retrieved",
206
+ model_name=model_name,
207
+ input_tokens=input_tokens,
208
+ output_tokens=output_tokens,
209
+ total_tokens=total_tokens,
210
+ latency=total_duration,
211
+ dtcreatedon=datetime.datetime.now()
212
+ )
213
+ db.add(telemetry_record)
214
+
215
+ new_messages = chathistory["chat_history"] + [
216
+ {"role": "user", "content": user_question},
217
+ {"role": "assistant", "content": response_obj.content, "telemetry_id": transaction_id}
218
+ ]
219
+
220
+ # --- FIX: Refactored Database Save Logic ---
221
+ print(f"Saving conversation for session_id: {chathistory.get('session_id')}")
222
+ conversation_entry = db.query(ConversationHistory).filter_by(session_id=chathistory.get("session_id")).first()
223
+ if conversation_entry:
224
+ print(f"Updating existing conversation for session_id: {chathistory.get('session_id')}")
225
+ conversation_entry.messages = new_messages
226
+ conversation_entry.last_updated = datetime.datetime.now()
227
+ else:
228
+ print(f"Creating new conversation for session_id: {chathistory.get('session_id')}")
229
+ new_conversation_entry = ConversationHistory(
230
+ session_id=chathistory.get("session_id"),
231
+ messages=new_messages,
232
+ last_updated=datetime.datetime.now()
233
+ )
234
+ db.add(new_conversation_entry)
235
+
236
+ db.commit()
237
+ print(f"Successfully saved conversation for session_id: {chathistory.get('session_id')}")
238
+
239
+ except Exception as e:
240
+ db.rollback()
241
+ print(f"***CRITICAL ERROR***: Failed to save data to database. Error: {e}")
242
+ finally:
243
+ db.close()
244
+
245
+ return {
246
+ "chat_history": new_messages,
247
+ "telemetry_id": transaction_id,
248
+ "agent_name": agent_name
249
+ }
250
+
251
+ agent_name = ""
252
+ def hr_agent(state:GraphState):
253
+ """Answer the user question based on Human Resource(HR)"""
254
+ global agent_name
255
+ user_question = state["user_question"]
256
+ retrieved_documents = state["retrieved_documents"]
257
+ print("HR Agent")
258
+ agent_name = "HR Agent"
259
+ result = generate_response(user_question,retrieved_documents)
260
+ # return {"output":result}
261
+ return result
262
+
263
+ def finance_agent(state:GraphState):
264
+ """Answer the user question based on Finance and Bank"""
265
+ global agent_name
266
+ user_question = state["user_question"]
267
+ retrieved_documents = state["retrieved_documents"]
268
+ print("Finance Agent")
269
+ agent_name = "Finance Agent"
270
+ result = generate_response(user_question,retrieved_documents)
271
+ return result
272
+
273
+ def legals_agent(state:GraphState):
274
+ """Answer the user question based on Legal Compliance"""
275
+ global agent_name
276
+ user_question = state["user_question"]
277
+ retrieved_documents = state["retrieved_documents"]
278
+ print("LC agent")
279
+ agent_name = "Legal Compliance Agent"
280
+ result = generate_response(user_question,retrieved_documents)
281
+ # return {"output":result}
282
+ return result
283
+
284
+ def llm_call_router(state:GraphState):
285
+ decision = router.invoke(
286
+ [
287
+ SystemMessage(
288
+ content="Route the user_question to HR Agent, Finance Agent, Legal Compliance Agent based on the user's request"
289
+ ),
290
+ HumanMessage(
291
+ content=state['user_question']
292
+ ),
293
+ ]
294
+ )
295
+ return {"decision":decision.step}
296
+
297
+ def route_decision(state:GraphState):
298
+
299
+ if state['decision'] == 'HR Agent':
300
+ return "hr_agent"
301
+ elif state['decision'] == 'Finance Agent':
302
+ return "finance_agent"
303
+ elif state['decision'] == 'Legal Compliance Agent':
304
+ return "legals_agent"
305
+
306
+ router_builder = StateGraph(GraphState)
307
+
308
+ router_builder.add_node("retrieve", retrieve_documents)
309
+ router_builder.add_node("hr_agent", hr_agent)
310
+ router_builder.add_node("finance_agent", finance_agent)
311
+ router_builder.add_node("legals_agent", legals_agent)
312
+ router_builder.add_node("llm_call_router", llm_call_router)
313
+
314
+ # router_builder.add_node("generate", generate_response)
315
+ # router_builder.set_entry_point("retrieve")
316
+ # router_builder.add_edge("retrieve", "generate")
317
+ # router_builder.add_edge("generate", END)
318
+ # compiled_app = workflow.compile(checkpointer=memory)
319
+
320
+
321
+ router_builder.add_edge(START, "llm_call_router")
322
+ router_builder.add_conditional_edges(
323
+ "llm_call_router",
324
+ route_decision,
325
+ {
326
+ "hr_agent":"hr_agent",
327
+ "finance_agent":"finance_agent",
328
+ "legals_agent":"legals_agent",
329
+ },
330
+ )
331
+ router_builder.set_entry_point("retrieve")
332
+ router_builder.add_edge("retrieve","llm_call_router")
333
+ router_builder.add_edge("hr_agent",END)
334
+ router_builder.add_edge("finance_agent",END)
335
+ router_builder.add_edge("legals_agent",END)
336
+
337
+ route_workflow = router_builder.compile()
338
+
339
+ # state = route_workflow.invoke({'input': "Write a poem about a wicked cat"})
340
+ # print(state['output'])
341
+
342
+
343
+
344
+ vectorstore_retriever = None
345
+ compiled_app = None
346
+ memory = MemorySaver()
347
+
348
+ # --- 4. LangGraph Nodes ---
349
+ # def load_documents(state:GraphState):
350
+ # global selected_domain
351
+
352
+
353
+
354
+
355
+
356
+
357
+
358
+ # --- 5. API Models ---
359
+ class ChatHistoryEntry(BaseModel):
360
+ role: str
361
+ content: str
362
+ telemetry_id: Optional[str] = None
363
+
364
+ class ChatRequest(BaseModel):
365
+ user_question: str
366
+ session_id: str
367
+ chat_history: Optional[List[ChatHistoryEntry]] = Field(default_factory=list)
368
+
369
+ @validator('user_question')
370
+ def validate_prompt(cls, v):
371
+ v = v.strip()
372
+ if not v:
373
+ raise ValueError('Question cannot be empty')
374
+ return v
375
+
376
+ class ChatResponse(BaseModel):
377
+ ai_response: str
378
+ updated_chat_history: List[ChatHistoryEntry]
379
+ telemetry_entry_id: str
380
+ is_restricted: bool = False
381
+ moderation_reason: Optional[str] = None
382
+
383
+ class FeedbackRequest(BaseModel):
384
+ session_id: str
385
+ telemetry_entry_id: str
386
+ feedback_score: int
387
+ feedback_text: Optional[str] = None
388
+
389
+ class ConversationSummary(BaseModel):
390
+ session_id: str
391
+ title: str
392
+
393
+
394
+ @lru_cache(maxsize=5)
395
+ def process_text(file):
396
+ string_data = (file.read()).decode("utf-8")
397
+ return string_data
398
+
399
+ @lru_cache(maxsize=5)
400
+ def process_pdf(file):
401
+ pdf_bytes = io.BytesIO(file.read())
402
+ reader = PyPDF2.PdfReader(pdf_bytes)
403
+ pdf_text = "".join([page.extract_text() + "\n" for page in reader.pages])
404
+ return pdf_text
405
+
406
+ @lru_cache(maxsize=5)
407
+ def process_docx(file):
408
+ docx_bytes = io.BytesIO(file.read())
409
+ docx_docs = dx(docx_bytes)
410
+ docx_content = "\n".join([para.text for para in docx_docs.paragraphs])
411
+ return docx_content
412
+
413
+
414
+ # @app.post("/upload-documents")
415
+ # def upload_documents(files):
416
+ def upload_documents():
417
+ global vectorstore_retriever
418
+ # saved_vectorstore_index = FAISS.load_local('domain_index', embedding_model,allow_dangerous_deserialization=True)
419
+ try:
420
+ saved_vectorstore_index = faiss.read_index("domain_index_sec.faiss")
421
+ if saved_vectorstore_index:
422
+ vectorstore_retriever = saved_vectorstore_index
423
+
424
+ msg = f"Successfully loaded the knowledge base."
425
+ return msg, True
426
+ except Exception as e:
427
+ print("unable to find index...", e)
428
+ print("Creating new index.....")
429
+ all_documents = []
430
+ hr_loader = PyPDFLoader("D:\Pdf_data\Developments_in_HR_management_in_QAAs.pdf").load()
431
+ hr_finance = PyPDFLoader("D:\Pdf_data\White Paper_QA Practice.pdf").load()
432
+ hr_legal = PyPDFLoader("D:\Pdf_data\Legal-Aspects-Compliances.pdf").load()
433
+
434
+ for doc in hr_loader:
435
+ doc.metadata['Domain'] = 'HR'
436
+ all_documents.append(doc)
437
+ for doc in hr_finance:
438
+ doc.metadata['Domain'] = 'Finance'
439
+ all_documents.append(doc)
440
+ for doc in hr_legal:
441
+ doc.metadata['Domain'] = 'Legal'
442
+ all_documents.append(doc)
443
+ # for uploaded_file in files:
444
+ # doc_loader = PyPDFLoader(uploaded_file)
445
+ # all_documents.extend(doc_loader.load())
446
+
447
+ if not all_documents:
448
+ raise Exception(status_code=400, detail="No supported documents uploaded.")
449
+
450
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
451
+ text_chunks = text_splitter.split_documents(all_documents)
452
+ print("text_chucks: ", text_chunks[:100])
453
+
454
+ # processed_chunks_with_ids = []
455
+ # for i, chunk in enumerate(text_chunks):
456
+ # # Generate a unique ID for each chunk
457
+ # # Option 1 (Recommended): Using UUID for global uniqueness
458
+ # # chunk_id = str(uuid.uuid4())
459
+
460
+ # # Option 2 (Alternative): Combining source file path with chunk index
461
+ # # This is good if you want IDs to be deterministic based on file/chunk.
462
+ # # You might need to make the file path more robust (e.g., hash it or normalize it).
463
+ # file_source = chunk.metadata.get('source', 'unknown_source')
464
+ # chunk_id = f"{file_source.replace('.','_')}_chunk_{i}"
465
+
466
+ # # Add the unique ID to the chunk's metadata
467
+ # # It's good practice to keep original metadata and just add your custom ID.
468
+ # chunk.metadata['doc_id'] = chunk_id
469
+
470
+
471
+ # processed_chunks_with_ids.append(chunk)
472
+ # embeddings = [embedding_model.encode(doc_chunks.page_content, convert_to_numpy=True) for doc_chunks in processed_chunks_with_ids]
473
+
474
+ print(f"Split {len(text_chunks)} chunks.")
475
+ print(f"Assigned unique 'doc_id' to each chunk in metadata.")
476
+ # dimension = 768
477
+ # # hnsw_m = 32
478
+ # # index = faiss.IndexHNSWFlat(dimension, hnsw_m, faiss.METRIC_INNER_PRODUCT)
479
+ # index = faiss.IndexFlatL2(dimension)
480
+ # vector_store = FAISS(
481
+ # embedding_function=embedding_model.embed_query,
482
+ # index=index,
483
+ # docstore= InMemoryDocstore(),
484
+ # index_to_docstore_id={}
485
+ # )
486
+ vectorstore = FAISS.from_documents(documents=text_chunks, embedding=embedding_model)
487
+ # vectorstore.add_documents(text_chunks, ids = [cid.metadata['doc_id'] for cid in text_chunks])
488
+ vectorstore.add_documents(text_chunks)
489
+ # vectorstore_retriever = vectorstore.as_retriever(search_kwargs={'k': 5})
490
+ faiss.write_index(vectorstore.index, "domain_index_sec.faiss")
491
+ # vectorstore.save_local("domain_index")
492
+ vectorstore_retriever = vectorstore
493
+ if vectorstore:
494
+ msg = f"Successfully loaded the knowledge base."
495
+ return msg, True
496
+ else:
497
+ msg = f"Failed to process documents."
498
+ return msg, False
499
+
500
+ # @app.post("/chat", response_model=ChatResponse)
501
+ def chat_with_rag(chatdata):
502
+ global compiled_app
503
+ global vectorstore_retriever
504
+ global chathistory
505
+ if vectorstore_retriever is None:
506
+ raise Exception(status_code=400, detail="Knowledge base not loaded. Please upload documents first.")
507
+ print(f"Received request: {chatdata}")
508
+ # moderation_result = moderator.moderate_content(request.user_question)
509
+ # if moderation_result["is_restricted"]:
510
+ # # Get appropriate response based on restriction type
511
+ # response_type = moderation_result.get("response_type", "general")
512
+ # response_text = Config.RESTRICTED_RESPONSES.get(
513
+ # response_type,
514
+ # Config.RESTRICTED_RESPONSES["general"]
515
+ # )
516
+
517
+ # logger.warning(
518
+ # f"Restricted query: {request.prompt[:100]}... "
519
+ # f"Reason: {moderation_result['reason']}"
520
+ # )
521
+
522
+ # return ChatResponse(
523
+ # ai_response=response_text,
524
+ # updated_chat_history=[],
525
+ # telemetry_entry_id=request.session_id,
526
+ # is_restricted=True,
527
+ # moderation_reason=moderation_result["reason"],
528
+ # )
529
+ print("βœ… Question passed the RAI check.........")
530
+ print("Received data from UI: ", chatdata)
531
+ chathistory = chatdata
532
+ initial_state = {
533
+ # "chat_history": [msg.model_dump() for msg in chatdata.get('chat_history')],
534
+ "chat_history": [msg for msg in chatdata.get('chat_history')],
535
+ "retrieved_documents": [],
536
+ "user_question": chatdata.get('user_question'),
537
+ "session_id": chatdata.get('session_id')
538
+ }
539
+
540
+ try:
541
+ config = {"configurable": {"thread_id": chatdata.get('session_id')}}
542
+ final_state = route_workflow.invoke(initial_state, config=config)
543
+
544
+ # chathistory = final_state
545
+ print("chathistory inside chat_with_rag-----------------")
546
+ print("Final State--- : ", final_state)
547
+
548
+ ai_response_message = final_state["chat_history"][-1]["content"]
549
+ updated_chat_history_dicts = final_state["chat_history"]
550
+ agent_name = final_state.get("decision","No Agent")
551
+
552
+ response_chat = ChatResponse(
553
+ ai_response=ai_response_message,
554
+ updated_chat_history=updated_chat_history_dicts,
555
+ telemetry_entry_id=final_state.get("telemetry_id"),
556
+ is_restricted=False,
557
+ )
558
+
559
+ return agent_name,response_chat.dict()
560
+ except Exception as e:
561
+ print(f"Internal Server Error: {e}")
562
+ raise Exception(status_code=500, detail=f"An error occurred during chat processing: {e}")
563
+
564
+
565
+ def submit_feedback(feedbackdata):
566
+ db = SessionLocal()
567
+ try:
568
+ telemetry_record = db.query(Telemetry).filter(
569
+ Telemetry.transaction_id == feedbackdata['telemetry_entry_id'],
570
+ Telemetry.session_id == feedbackdata['session_id']
571
+ ).first()
572
+
573
+ if not telemetry_record:
574
+ raise Exception(status_code=404, detail="Telemetry entry not found or session ID mismatch.")
575
+
576
+ existing_feedback = db.query(Feedback).filter(
577
+ Feedback.telemetry_entry_id == feedbackdata['telemetry_entry_id']
578
+ ).first()
579
+
580
+ if existing_feedback:
581
+ existing_feedback.feedback_score = feedbackdata['feedback_score']
582
+ existing_feedback.feedback_text = feedbackdata['feedback_text']
583
+ existing_feedback.timestamp = datetime.datetime.now()
584
+ else:
585
+ feedback_record = Feedback(
586
+ telemetry_entry_id=feedbackdata['telemetry_entry_id'],
587
+ feedback_score=feedbackdata['feedback_score'],
588
+ feedback_text=feedbackdata['feedback_text'],
589
+ user_query=telemetry_record.user_question,
590
+ llm_response=telemetry_record.response,
591
+ timestamp=datetime.datetime.now()
592
+ )
593
+ db.add(feedback_record)
594
+
595
+ db.commit()
596
+
597
+ return {"message": "Feedback submitted successfully."}
598
+
599
+ except Exception as e:
600
+ raise e
601
+ except Exception as e:
602
+ db.rollback()
603
+ raise Exception(status_code=500, detail=f"An error occurred: {str(e)}")
604
+ finally:
605
+ db.close()
606
+
607
+ # @app.get("/conversations", response_model=List[ConversationSummary])
608
+ def get_conversations():
609
+ db = SessionLocal()
610
+ try:
611
+ conversations = db.query(ConversationHistory).order_by(ConversationHistory.last_updated.desc()).all()
612
+ summaries = []
613
+ for conv in conversations:
614
+ for msg in conv.messages:
615
+ print(msg)
616
+ first_user_message = next((msg for msg in conv.messages if msg["role"] == "user"), None)
617
+ title = first_user_message.get("content") if first_user_message else "New Conversation"
618
+ summaries.append({"session_id":conv.session_id, "title":title[:30] + "..." if len(title) > 30 else title})
619
+ return summaries
620
+ finally:
621
+ db.close()
622
+
623
+ # @app.get("/conversations/{session_id}", response_model=List[ChatHistoryEntry])
624
+ def get_conversation_history(session_id: str):
625
+ db = SessionLocal()
626
+ try:
627
+ conversation = db.query(ConversationHistory).filter(ConversationHistory.session_id == session_id).first()
628
+ if not conversation:
629
+ raise Exception(status_code=404, detail="Conversation not found.")
630
+ return conversation.messages
631
+ finally:
632
+ db.close()
633
+
634
+
635
+
636
+
637
+ # if 'selected_model' not in st.session_state:
638
+ # st.session_state.selected_model = ""
639
+ # @st.dialog("Choose a domain")
640
+ # def domain_modal():
641
+ # domain = st.selectbox("Select a domain",["HR","Finance","Legal"])
642
+ # st.session_state.selected_model = domain
643
+ # if st.button("submit"):
644
+ # st.rerun()
645
+
646
+ # domain_modal()
647
+ # print("Selected Domain: ",st.session_state['selected_model'])
648
+
649
+ # llm = initialize_llm()
650
+
requirements.txt CHANGED
@@ -1,3 +1,24 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ streamlit
4
+ requests
5
+ SQLAlchemy
6
+ sentence-transformers
7
+ python-docx
8
+ requests
9
+ PyMuPDF
10
+ pypdf
11
+ PyPDF2
12
+ langgraph
13
+ langchain-unstructured
14
+ faiss-cpu
15
+ huggingface-hub
16
+ langchain
17
+ langchain-community
18
+ langchain-core
19
+ langchain-huggingface
20
+ langchain-openai
21
+ langchain-text-splitters
22
+ langchain-google-genai
23
+ pandas
24
+ python-multipart
streamlitapp.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import uuid
3
+ import hashlib
4
+ from typing import List, Optional, Dict, Any, TypedDict,Generic, TypeVar
5
+ from huggingface_hub import login
6
+ import logging
7
+ import time
8
+ import os
9
+ from dotenv import load_dotenv
10
+ from mydomain_agent import upload_documents, submit_feedback, get_conversations,get_conversation_history, chat_with_rag
11
+
12
+ load_dotenv()
13
+
14
+ # --- 2. Streamlit UI Components and State Management ---
15
+ st.set_page_config(page_title="Agentic WorkFlow", layout="wide")
16
+ st.title("πŸ’¬ Domain-Aware AI Agent")
17
+ st.caption("Your expert assistant across HR, Finance, and Legal Compliance.")
18
+
19
+ # Initialize session state for conversations, messages, and the current session ID
20
+ if "conversations" not in st.session_state:
21
+ st.session_state.conversations = []
22
+ if "session_id" not in st.session_state:
23
+ st.session_state.session_id = str(uuid.uuid4())
24
+ if "messages" not in st.session_state:
25
+ st.session_state.messages = []
26
+ if "retriever_ready" not in st.session_state:
27
+ st.session_state.retriever_ready = False
28
+ if "feedback_given" not in st.session_state:
29
+ st.session_state.feedback_given = {}
30
+ # New state variable to handle negative feedback comments
31
+ if "negative_feedback_for" not in st.session_state:
32
+ st.session_state.negative_feedback_for = None
33
+
34
+ # Initialize session state for storing uploaded file hashes
35
+ if 'uploaded_file_hashes' not in st.session_state:
36
+ st.session_state.uploaded_file_hashes = set()
37
+ if 'uploaded_files_info' not in st.session_state:
38
+ st.session_state.uploaded_files_info = []
39
+
40
+ def get_file_hash(file):
41
+ """Generates a unique hash for a file using its name, size, and content."""
42
+ hasher = hashlib.sha256()
43
+ # Read a small chunk of the file to ensure content-based uniqueness
44
+ # Combine with file name and size for a robust identifier
45
+ file_content = file.getvalue()
46
+ hasher.update(file.name.encode('utf-8'))
47
+ hasher.update(str(file.size).encode('utf-8'))
48
+ hasher.update(file_content[:1024]) # Use first 1KB of content
49
+ return hasher.hexdigest()
50
+ # --- 3. Helper Functions for Backend Communication ---
51
+ # def send_documents_to_backend(uploaded_files):
52
+ # try:
53
+ # for file in uploaded_files:
54
+ # process_status = upload_documents(file)
55
+ # return process_status
56
+ # except Exception as e:
57
+ # st.error(f"Error processing documents: {e}")
58
+ # return None
59
+
60
+ def send_chat_message_to_backend(prompt: str, chat_history: List[Dict[str, Any]]):
61
+ """Sends a chat message to the FastAPI backend and handles the response."""
62
+ if not prompt.strip():
63
+ return {"empty":"Invalid Question"}
64
+ history_for_api = [
65
+ {"role": msg.get("role"), "content": msg.get("content")}
66
+ for msg in chat_history
67
+ ]
68
+
69
+ payload = {
70
+ "user_question": str(prompt),
71
+ "session_id": st.session_state.session_id,
72
+ "chat_history": history_for_api,
73
+ }
74
+ print(f"Sending payload: {payload}") # Debug print
75
+ agent_name,response_dict = chat_with_rag(payload)
76
+ try:
77
+ return agent_name,response_dict
78
+ except Exception as e:
79
+ st.error(f"Error communicating with the backend")
80
+ print(f"Error communicating with the backend: {e}")
81
+ return None
82
+
83
+ def send_feedback_to_backend(telemetry_entry_id: str, feedback_score: int, feedback_text: Optional[str] = None):
84
+ """Sends feedback to the FastAPI backend."""
85
+ payload = {
86
+ "session_id": st.session_state.session_id,
87
+ "telemetry_entry_id": telemetry_entry_id,
88
+ "feedback_score": feedback_score,
89
+ "feedback_text": feedback_text
90
+ }
91
+ try:
92
+ # response = requests.post(f"{API_URL}/feedback", json=payload)
93
+ response = submit_feedback(payload)
94
+ # response.raise_for_status()
95
+ st.toast("Feedback submitted! Thank you.")
96
+ except Exception as e:
97
+ st.error(f"Error submitting feedback: {e}")
98
+
99
+ def get_conversations_from_backend() -> list:
100
+ """Fetches a list of all conversations from the backend."""
101
+ try:
102
+ # response = requests.get(f"{API_URL}/conversations")
103
+ response = get_conversations()
104
+ # response.raise_for_status()
105
+ return response
106
+ except Exception as e:
107
+ st.sidebar.error(f"Error fetching conversations: {e}")
108
+ return []
109
+
110
+ def get_conversation_history_from_backend(session_id: str):
111
+ """Fetches the messages for a specific conversation ID."""
112
+ try:
113
+ # response = requests.get(f"{API_URL}/conversations/{session_id}")
114
+
115
+ response = get_conversation_history(session_id)
116
+ return response
117
+ except Exception as e:
118
+ st.error(f"Error loading conversation history: {e}")
119
+ return None
120
+
121
+ def handle_positive_feedback(telemetry_id):
122
+ """Handles positive feedback submission."""
123
+ send_feedback_to_backend(telemetry_id, 1)
124
+ st.session_state.feedback_given[telemetry_id] = True
125
+
126
+
127
+ def handle_negative_feedback_comment_submit(telemetry_id, comment_text):
128
+ """Handles the negative feedback comment submission."""
129
+ send_feedback_to_backend(telemetry_id, -1, comment_text)
130
+ st.session_state.feedback_given[telemetry_id] = True
131
+ st.session_state.negative_feedback_for = None
132
+
133
+
134
+ def refresh_conversations():
135
+ """Refreshes the conversation list in the sidebar."""
136
+ st.session_state.conversations = get_conversations_from_backend()
137
+
138
+ # --- 4. Sidebar for Document Upload and Conversation History ---
139
+ with st.sidebar:
140
+ st.header("Load Documents")
141
+ if st.button("Process Documents", key="process_docs_button"):
142
+ newmsg, status = upload_documents()
143
+ if status:
144
+ st.session_state.retriever_ready = True
145
+ # st.success(response_data.get("message", "Documents processed and knowledge base ready!"))
146
+ st.success(newmsg)
147
+ st.session_state.messages = []
148
+ refresh_conversations() # sql query need to be added here
149
+ else:
150
+ st.session_state.retriever_ready = False
151
+ st.error(newmsg)
152
+ else:
153
+ st.warning("Please Load Document.")
154
+
155
+ st.markdown("---")
156
+ st.header("Conversations")
157
+ if st.button("βž• New Chat", key="new_chat_button", use_container_width=True, type="primary"):
158
+ st.session_state.session_id = str(uuid.uuid4())
159
+ st.session_state.messages = []
160
+ st.session_state.feedback_given = {}
161
+ st.session_state.negative_feedback_for = None
162
+ refresh_conversations()
163
+ st.rerun()
164
+
165
+ refresh_conversations()
166
+
167
+ if st.session_state.conversations:
168
+ for conv in st.session_state.conversations:
169
+ if st.button(
170
+ conv["title"],
171
+ key=f"conv_{conv['session_id']}",
172
+ use_container_width=True
173
+ ):
174
+ if st.session_state.session_id != conv["session_id"]:
175
+ st.session_state.session_id = conv["session_id"]
176
+ history = get_conversation_history_from_backend(conv["session_id"])
177
+ if history != [] or history != None:
178
+ st.session_state.messages = history
179
+ st.session_state.feedback_given = {msg.get("telemetry_id"): True for msg in history if msg.get("telemetry_id")}
180
+ else:
181
+ st.session_state.messages = []
182
+ st.session_state.negative_feedback_for = None
183
+ st.rerun()
184
+
185
+ # --- 5. Main Chat Interface ---
186
+ # Display chat messages from history on app rerun
187
+ for message in st.session_state.messages:
188
+ with st.chat_message(message["role"]):
189
+ st.markdown(message["content"])
190
+
191
+ # Display feedback buttons for the last AI response
192
+ if message["role"] == "assistant" and message.get("telemetry_id") and not st.session_state.feedback_given.get(message["telemetry_id"], False):
193
+ col1, col2 = st.columns(2)
194
+ with col1:
195
+ if st.button("πŸ‘", key=f"positive_{message['telemetry_id']}", on_click=handle_positive_feedback, args=(message['telemetry_id'],)):
196
+ pass
197
+ with col2:
198
+ if st.button("πŸ‘Ž", key=f"negative_{message['telemetry_id']}"):
199
+ st.session_state.negative_feedback_for = message['telemetry_id']
200
+ st.rerun()
201
+
202
+ # --- NEW LOGIC FOR NEGATIVE FEEDBACK COMMENT ---
203
+ # Only render the comment input if this is the message the user clicked thumbs down on
204
+ if st.session_state.negative_feedback_for == message['telemetry_id']:
205
+ with st.container():
206
+ comment = st.text_area(
207
+ "Please provide some details (optional):",
208
+ key=f"feedback_text_{message['telemetry_id']}"
209
+ )
210
+ if st.button("Submit Comment", key=f"submit_feedback_button_{message['telemetry_id']}"):
211
+ handle_negative_feedback_comment_submit(message['telemetry_id'], comment)
212
+
213
+ # Chat input for new questions
214
+ if st.session_state.retriever_ready:
215
+ if prompt := st.chat_input("Ask me anything about the uploaded documents..."):
216
+ st.session_state.messages.append({"role": "user", "content": prompt})
217
+ with st.chat_message("user"):
218
+ st.markdown(prompt)
219
+
220
+ with st.chat_message("assistant"):
221
+ with st.spinner("Thinking..."):
222
+ agent_name,response_data = send_chat_message_to_backend(prompt, st.session_state.messages)
223
+ if response_data:
224
+ if response_data.get("is_restricted"):
225
+ ai_response = response_data.get("ai_response", "Sorry, I couldn't generate a response.")
226
+ reason = response_data.get("moderation_reason")
227
+ st.markdown(ai_response)
228
+ st.markdown(reason)
229
+ elif response_data.get("empty"):
230
+ st.markdown(response_data.get("empty"))
231
+
232
+ ai_response = response_data.get("ai_response", "Sorry, I couldn't generate a response.")
233
+ telemetry_id = response_data.get("telemetry_entry_id")
234
+
235
+ st.markdown(ai_response)
236
+ st.caption(agent_name)
237
+
238
+ st.session_state.messages.append({
239
+ "role": "assistant",
240
+ "content": ai_response,
241
+ "telemetry_id": telemetry_id
242
+ })
243
+
244
+ refresh_conversations()
245
+
246
+ if telemetry_id:
247
+ col1, col2 = st.columns(2)
248
+ with col1:
249
+ if st.button("πŸ‘", key=f"positive_{telemetry_id}", on_click=handle_positive_feedback, args=(telemetry_id,)):
250
+ pass
251
+ with col2:
252
+ if st.button("πŸ‘Ž", key=f"negative_{telemetry_id}"):
253
+ st.session_state.negative_feedback_for = telemetry_id
254
+ st.rerun()
255
+ else:
256
+ st.markdown("An error occurred.")
257
+ else:
258
+ st.info("Please upload and process documents to start chatting.")
259
+
260
+
261
+
262
+
263
+
264
+ # import streamlit as st
265
+
266
+ # if 'selected_model' not in st.session_state:
267
+ # st.session_state.selected_model = ""
268
+ # @st.dialog("Choose a domain")
269
+ # def domain_modal():
270
+ # domain = st.selectbox("Select a domain",["HR","Finance","Legal"])
271
+ # st.session_state.selected_model = domain
272
+ # if st.button("submit"):
273
+ # st.rerun()
274
+
275
+ # domain_modal()
276
+ # print("Selected Domain: ",st.session_state['selected_model'])