first commit
Browse files
app.py
CHANGED
|
@@ -7,6 +7,15 @@ from langchain.schema.runnable import RunnableConfig
|
|
| 7 |
welcome_message = "Welcome! I'm Sage, your friendly AI assistant. I'm here to help you quickly find answers to your HR and policy questions. What can I assist you with today?"
|
| 8 |
@cl.on_chat_start
|
| 9 |
async def start_chat():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
await cl.Message(content=welcome_message).send()
|
| 11 |
cl.user_session.set("runnable", app)
|
| 12 |
|
|
@@ -18,7 +27,6 @@ async def main(message: cl.Message):
|
|
| 18 |
|
| 19 |
input = {"question": message.content}
|
| 20 |
|
| 21 |
-
value = None
|
| 22 |
for output in runnable.stream(input):
|
| 23 |
for key, value in output.items():
|
| 24 |
print(f"Finished running: {key}:")
|
|
|
|
| 7 |
welcome_message = "Welcome! I'm Sage, your friendly AI assistant. I'm here to help you quickly find answers to your HR and policy questions. What can I assist you with today?"
|
| 8 |
@cl.on_chat_start
|
| 9 |
async def start_chat():
|
| 10 |
+
|
| 11 |
+
input = {"question": ""}
|
| 12 |
+
|
| 13 |
+
for output in app.stream(input):
|
| 14 |
+
for key, value in output.items():
|
| 15 |
+
print(f"Finished running...{key}:")
|
| 16 |
+
|
| 17 |
+
print("Initialised chain...")
|
| 18 |
+
|
| 19 |
await cl.Message(content=welcome_message).send()
|
| 20 |
cl.user_session.set("runnable", app)
|
| 21 |
|
|
|
|
| 27 |
|
| 28 |
input = {"question": message.content}
|
| 29 |
|
|
|
|
| 30 |
for output in runnable.stream(input):
|
| 31 |
for key, value in output.items():
|
| 32 |
print(f"Finished running: {key}:")
|
sage.py
CHANGED
|
@@ -27,8 +27,6 @@ text_splitter = RecursiveCharacterTextSplitter(
|
|
| 27 |
)
|
| 28 |
doc_splits = text_splitter.split_documents(documents)
|
| 29 |
|
| 30 |
-
print(len(doc_splits),doc_splits[0])
|
| 31 |
-
|
| 32 |
vectorstore = FAISS.from_documents(documents=doc_splits,embedding=embed_model)
|
| 33 |
|
| 34 |
from langchain.retrievers import ContextualCompressionRetriever
|
|
@@ -320,7 +318,7 @@ def dummy_payroll_api_call(employee_id, month, year):
|
|
| 320 |
|
| 321 |
return data[year][month]
|
| 322 |
|
| 323 |
-
print(dummy_payroll_api_call(1234, 'CUR', 2024))
|
| 324 |
|
| 325 |
import time
|
| 326 |
from langchain.prompts import PromptTemplate
|
|
@@ -350,9 +348,9 @@ router_prompt = PromptTemplate(
|
|
| 350 |
|
| 351 |
router_chain = router_prompt | llm | JsonOutputParser()
|
| 352 |
|
| 353 |
-
print(router_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
|
| 354 |
|
| 355 |
-
print(router_chain.invoke({"question":"What is leave policy ?"}))
|
| 356 |
|
| 357 |
payroll_schema= {
|
| 358 |
"$schema": "http://json-schema.org/draft-07/schema#",
|
|
@@ -482,7 +480,7 @@ payroll_schema= {
|
|
| 482 |
"required": ["employeeDetails", "paymentDetails", "companyDetails"]
|
| 483 |
}
|
| 484 |
|
| 485 |
-
print(str(payroll_schema))
|
| 486 |
|
| 487 |
import time
|
| 488 |
from langchain.prompts import PromptTemplate
|
|
@@ -511,7 +509,7 @@ filter_extraction_prompt = PromptTemplate(
|
|
| 511 |
|
| 512 |
fiter_extraction_chain = filter_extraction_prompt | llm | JsonOutputParser()
|
| 513 |
|
| 514 |
-
print(fiter_extraction_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
|
| 515 |
|
| 516 |
import time
|
| 517 |
from langchain.prompts import PromptTemplate
|
|
@@ -550,11 +548,11 @@ api_result
|
|
| 550 |
|
| 551 |
payroll_qa_chain.invoke({"question":"What is my salary on jan 2024 ?", "data":api_result, "schema":payroll_schema})
|
| 552 |
|
|
|
|
|
|
|
| 553 |
from typing_extensions import TypedDict
|
| 554 |
from typing import List
|
| 555 |
|
| 556 |
-
### State
|
| 557 |
-
|
| 558 |
class AgentState(TypedDict):
|
| 559 |
question : str
|
| 560 |
answer : str
|
|
@@ -604,8 +602,8 @@ def retrieve_policy(state):
|
|
| 604 |
documents = compression_retriever.invoke(question)
|
| 605 |
return {"documents": documents, "question": question}
|
| 606 |
|
| 607 |
-
state = AgentState(question="What is leave policy?", answer="", documents=None)
|
| 608 |
-
retrieve_policy(state)
|
| 609 |
|
| 610 |
def generate_answer(state):
|
| 611 |
"""
|
|
@@ -626,8 +624,8 @@ def generate_answer(state):
|
|
| 626 |
|
| 627 |
return {"documents": documents, "question": question, "answer": answer}
|
| 628 |
|
| 629 |
-
state = AgentState(question="What is leave policy?", answer="", documents=[Document(page_content="According to leave policy, there are two types of leaves 1: PL 2: CL")])
|
| 630 |
-
generate_answer(state)
|
| 631 |
|
| 632 |
def query_payroll(state):
|
| 633 |
"""
|
|
@@ -652,9 +650,11 @@ def query_payroll(state):
|
|
| 652 |
documents = [Document(page_content=context)]
|
| 653 |
return {"documents": documents, "question": question}
|
| 654 |
|
| 655 |
-
state = AgentState(question="Tell me salary for Jan 2024?", answer="", documents=None)
|
| 656 |
-
query_payroll(state)
|
|
|
|
| 657 |
|
|
|
|
| 658 |
from langgraph.graph import END, StateGraph
|
| 659 |
workflow = StateGraph(AgentState)
|
| 660 |
|
|
|
|
| 27 |
)
|
| 28 |
doc_splits = text_splitter.split_documents(documents)
|
| 29 |
|
|
|
|
|
|
|
| 30 |
vectorstore = FAISS.from_documents(documents=doc_splits,embedding=embed_model)
|
| 31 |
|
| 32 |
from langchain.retrievers import ContextualCompressionRetriever
|
|
|
|
| 318 |
|
| 319 |
return data[year][month]
|
| 320 |
|
| 321 |
+
# print(dummy_payroll_api_call(1234, 'CUR', 2024))
|
| 322 |
|
| 323 |
import time
|
| 324 |
from langchain.prompts import PromptTemplate
|
|
|
|
| 348 |
|
| 349 |
router_chain = router_prompt | llm | JsonOutputParser()
|
| 350 |
|
| 351 |
+
# print(router_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
|
| 352 |
|
| 353 |
+
# print(router_chain.invoke({"question":"What is leave policy ?"}))
|
| 354 |
|
| 355 |
payroll_schema= {
|
| 356 |
"$schema": "http://json-schema.org/draft-07/schema#",
|
|
|
|
| 480 |
"required": ["employeeDetails", "paymentDetails", "companyDetails"]
|
| 481 |
}
|
| 482 |
|
| 483 |
+
# print(str(payroll_schema))
|
| 484 |
|
| 485 |
import time
|
| 486 |
from langchain.prompts import PromptTemplate
|
|
|
|
| 509 |
|
| 510 |
fiter_extraction_chain = filter_extraction_prompt | llm | JsonOutputParser()
|
| 511 |
|
| 512 |
+
# print(fiter_extraction_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
|
| 513 |
|
| 514 |
import time
|
| 515 |
from langchain.prompts import PromptTemplate
|
|
|
|
| 548 |
|
| 549 |
payroll_qa_chain.invoke({"question":"What is my salary on jan 2024 ?", "data":api_result, "schema":payroll_schema})
|
| 550 |
|
| 551 |
+
|
| 552 |
+
########### Create Nodes and Actions ###########
|
| 553 |
from typing_extensions import TypedDict
|
| 554 |
from typing import List
|
| 555 |
|
|
|
|
|
|
|
| 556 |
class AgentState(TypedDict):
|
| 557 |
question : str
|
| 558 |
answer : str
|
|
|
|
| 602 |
documents = compression_retriever.invoke(question)
|
| 603 |
return {"documents": documents, "question": question}
|
| 604 |
|
| 605 |
+
# state = AgentState(question="What is leave policy?", answer="", documents=None)
|
| 606 |
+
# retrieve_policy(state)
|
| 607 |
|
| 608 |
def generate_answer(state):
|
| 609 |
"""
|
|
|
|
| 624 |
|
| 625 |
return {"documents": documents, "question": question, "answer": answer}
|
| 626 |
|
| 627 |
+
# state = AgentState(question="What is leave policy?", answer="", documents=[Document(page_content="According to leave policy, there are two types of leaves 1: PL 2: CL")])
|
| 628 |
+
# generate_answer(state)
|
| 629 |
|
| 630 |
def query_payroll(state):
|
| 631 |
"""
|
|
|
|
| 650 |
documents = [Document(page_content=context)]
|
| 651 |
return {"documents": documents, "question": question}
|
| 652 |
|
| 653 |
+
# state = AgentState(question="Tell me salary for Jan 2024?", answer="", documents=None)
|
| 654 |
+
# query_payroll(state)
|
| 655 |
+
|
| 656 |
|
| 657 |
+
########### Build Execution Graph ###########
|
| 658 |
from langgraph.graph import END, StateGraph
|
| 659 |
workflow = StateGraph(AgentState)
|
| 660 |
|