Spaces:
Sleeping
Sleeping
| from langgraph.graph import StateGraph, START, END | |
| # from llm_initializer import initialize_llm, generate_prompt_phi4 | |
| from langgraph.graph import MessagesState | |
| from langchain_core.messages import ToolMessage, HumanMessage, SystemMessage | |
| from typing_extensions import Literal, TypedDict | |
| from pydantic import BaseModel, Field | |
| from pydantic import BaseModel, Field, validator | |
| from typing import List, Optional, Dict, Any, TypedDict,Generic, TypeVar | |
| import uuid | |
| import io | |
| import os | |
| import PyPDF2 | |
| import re | |
| import logging | |
| import time | |
| from docx import Document as dx | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import ( | |
| DirectoryLoader, | |
| PyPDFLoader, | |
| TextLoader | |
| ) | |
| import tempfile | |
| import faiss | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph import StateGraph, END | |
| from sqlalchemy import create_engine, Column, String, Integer, DateTime, ForeignKey, Text | |
| from sqlalchemy.dialects.sqlite import JSON as SQLiteJSON | |
| # from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import sessionmaker, relationship | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import login | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| import datetime | |
| from enum import Enum as PyEnum | |
| from sqlalchemy.orm import DeclarativeBase | |
| # from config import Config | |
| from functools import lru_cache | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| hf_token = os.getenv("hf_user_token") | |
| login(hf_token) | |
| T = TypeVar("T") | |
| # --- 1. Database Setup --- | |
| DATABASE_URL = "sqlite:///Db_domain_agent.db" | |
| engine = create_engine(DATABASE_URL) | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| class Base(DeclarativeBase): | |
| pass | |
| class FeedbackScore(PyEnum): | |
| POSITIVE = 1 | |
| NEGATIVE = -1 | |
| class Telemetry(Base): | |
| __tablename__ = "telemetry_table" | |
| transaction_id = Column(String, primary_key=True) | |
| session_id = Column(String) | |
| user_question = Column(Text) | |
| response = Column(Text) | |
| context = Column(Text) | |
| model_name = Column(String) | |
| input_tokens = Column(Integer) | |
| output_tokens = Column(Integer) | |
| total_tokens = Column(Integer) | |
| latency = Column(Integer) | |
| dtcreatedon = Column(DateTime) | |
| feedback = relationship("Feedback", back_populates="telemetry_entry", uselist=False) | |
| class Feedback(Base): | |
| __tablename__ = "feedback_table" | |
| id = Column(Integer, primary_key=True, autoincrement=True) | |
| telemetry_entry_id = Column(String, ForeignKey("telemetry_table.transaction_id"), nullable=False, unique=True) | |
| feedback_score = Column(Integer, nullable=False) | |
| feedback_text = Column(Text, nullable=True) | |
| user_query = Column(Text, nullable=False) | |
| llm_response = Column(Text, nullable=False) | |
| timestamp = Column(DateTime, default=datetime.datetime.now) | |
| telemetry_entry = relationship("Telemetry", back_populates="feedback") | |
| class ConversationHistory(Base): | |
| __tablename__ = "conversation_history" | |
| session_id = Column(String, primary_key=True) | |
| messages = Column(SQLiteJSON, nullable=False) | |
| last_updated = Column(DateTime, default=datetime.datetime.now) | |
| Base.metadata.create_all(bind=engine) | |
| # --- 2. Initialize LLM and Embeddings --- | |
| gak = os.getenv("Gapi_key") | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite",google_api_key=gak) | |
| # embedding_model = SentenceTransformer("ibm-granite/granite-embedding-english-r2") | |
| # my_model_name = "gemma3:1b-it-qat" | |
| # llm = ChatOllama(model=my_model_name) | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name="ibm-granite/granite-embedding-english-r2", | |
| model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': False} | |
| ) | |
| # --- 3. LangGraph State and Workflow --- | |
| class GraphState(TypedDict): | |
| chat_history: List[Dict[str, Any]] | |
| retrieved_documents: List[str] | |
| user_question: str | |
| decision:str | |
| session_id: str | |
| telemetry_id: Optional[str] = None | |
| class Route(BaseModel): | |
| step: Literal['HR Agent','Finance Agent','Legal Compliance Agent'] = Field( | |
| None, description="The next step in routing process" | |
| ) | |
| router = llm.with_structured_output(Route) | |
| # class State(TypedDict): | |
| # input:str | |
| # decision:str | |
| # output:str | |
| chathistory = {} | |
| def retrieve_documents(state: GraphState): | |
| # global vectorstore_retriever | |
| # upload_documents() | |
| saved_vectorstore_index = FAISS.load_local('domain_index', embedding_model,allow_dangerous_deserialization=True) | |
| user_question = state["user_question"] | |
| # meta_filter = {'Domain':'HR'} | |
| if saved_vectorstore_index is None: | |
| raise ValueError("Knowledge base not loaded.") | |
| retrieved_docs = saved_vectorstore_index.as_retriever(search_type="mmr", search_kwargs={"k": 5}) | |
| top_docs = retrieved_docs.invoke(user_question) | |
| print("Top Docs: ", top_docs) | |
| retrieved_docs_content = [doc.page_content if doc.page_content else doc for doc in top_docs] | |
| print("retrieved_documents List: ", retrieved_docs_content) | |
| return {"retrieved_documents": retrieved_docs_content} | |
| def generate_response(user_question, retrieved_documents): | |
| print("Inside generate_response--------------") | |
| global llm | |
| global chathistory | |
| global agent_name | |
| # user_question = state["user_question"] | |
| # retrieved_documents = state["retrieved_documents"] | |
| formatted_chat_history = [] | |
| for msg in chathistory["chat_history"]: | |
| if msg['role'] == 'user': | |
| formatted_chat_history.append(HumanMessage(content=msg['content'])) | |
| elif msg['role'] == 'assistant': | |
| formatted_chat_history.append(AIMessage(content=msg['content'])) | |
| if not retrieved_documents: | |
| response_content = "I couldn't find any relevant information in the uploaded documents for your question. Can you please rephrase or provide more context?" | |
| response_obj = AIMessage(content=response_content) | |
| else: | |
| context = "\n\n".join(retrieved_documents) | |
| template = """ | |
| You are a helpful AI assistant. Answer the user's question based on the provided context {context} and the conversation history {chat_history}. | |
| If the answer is not in the context, state that you don't have enough information. | |
| Do not make up answers. Only use the given context and chat_history. | |
| Remove unwanted words like 'Response:' or 'Answer:' from answers. | |
| \n\nHere is the Question:\n{user_question} | |
| """ | |
| rag_prompt = PromptTemplate( | |
| input_variables=["context", "chat_history", "user_question"], | |
| template=template | |
| ) | |
| rag_chain = rag_prompt | llm | |
| time.sleep(3) | |
| response_obj = rag_chain.invoke({ | |
| "context": [SystemMessage(content=context)], | |
| "chat_history": formatted_chat_history, | |
| "user_question": [HumanMessage(content=user_question)] | |
| }) | |
| telemetry_data = response_obj.model_dump() | |
| input_tokens = telemetry_data.get('usage_metadata', {}).get('input_tokens', 0) | |
| output_tokens = telemetry_data.get('usage_metadata', {}).get('output_tokens', 0) | |
| total_tokens = telemetry_data.get('usage_metadata', {}).get('total_tokens', 0) | |
| model_name = telemetry_data.get('response_metadata', {}).get('model', 'unknown') | |
| total_duration = telemetry_data.get('response_metadata', {}).get('total_duration', 0) | |
| db = SessionLocal() | |
| transaction_id = str(uuid.uuid4()) | |
| try: | |
| telemetry_record = Telemetry( | |
| transaction_id=transaction_id, | |
| session_id=chathistory.get("session_id"), | |
| user_question=user_question, | |
| response=response_obj.content, | |
| context="\n\n".join(retrieved_documents) if retrieved_documents else "No documents retrieved", | |
| model_name=model_name, | |
| input_tokens=input_tokens, | |
| output_tokens=output_tokens, | |
| total_tokens=total_tokens, | |
| latency=total_duration, | |
| dtcreatedon=datetime.datetime.now() | |
| ) | |
| db.add(telemetry_record) | |
| new_messages = chathistory["chat_history"] + [ | |
| {"role": "user", "content": user_question}, | |
| {"role": "assistant", "content": response_obj.content, "telemetry_id": transaction_id} | |
| ] | |
| # --- FIX: Refactored Database Save Logic --- | |
| print(f"Saving conversation for session_id: {chathistory.get('session_id')}") | |
| conversation_entry = db.query(ConversationHistory).filter_by(session_id=chathistory.get("session_id")).first() | |
| if conversation_entry: | |
| print(f"Updating existing conversation for session_id: {chathistory.get('session_id')}") | |
| conversation_entry.messages = new_messages | |
| conversation_entry.last_updated = datetime.datetime.now() | |
| else: | |
| print(f"Creating new conversation for session_id: {chathistory.get('session_id')}") | |
| new_conversation_entry = ConversationHistory( | |
| session_id=chathistory.get("session_id"), | |
| messages=new_messages, | |
| last_updated=datetime.datetime.now() | |
| ) | |
| db.add(new_conversation_entry) | |
| db.commit() | |
| print(f"Successfully saved conversation for session_id: {chathistory.get('session_id')}") | |
| except Exception as e: | |
| db.rollback() | |
| print(f"***CRITICAL ERROR***: Failed to save data to database. Error: {e}") | |
| finally: | |
| db.close() | |
| return { | |
| "chat_history": new_messages, | |
| "telemetry_id": transaction_id, | |
| "agent_name": agent_name | |
| } | |
| agent_name = "" | |
| def hr_agent(state:GraphState): | |
| """Answer the user question based on Human Resource(HR)""" | |
| global agent_name | |
| user_question = state["user_question"] | |
| retrieved_documents = state["retrieved_documents"] | |
| print("HR Agent") | |
| agent_name = "HR Agent" | |
| result = generate_response(user_question,retrieved_documents) | |
| # return {"output":result} | |
| return result | |
| def finance_agent(state:GraphState): | |
| """Answer the user question based on Finance and Bank""" | |
| global agent_name | |
| user_question = state["user_question"] | |
| retrieved_documents = state["retrieved_documents"] | |
| print("Finance Agent") | |
| agent_name = "Finance Agent" | |
| result = generate_response(user_question,retrieved_documents) | |
| return result | |
| def legals_agent(state:GraphState): | |
| """Answer the user question based on Legal Compliance""" | |
| global agent_name | |
| user_question = state["user_question"] | |
| retrieved_documents = state["retrieved_documents"] | |
| print("LC agent") | |
| agent_name = "Legal Compliance Agent" | |
| result = generate_response(user_question,retrieved_documents) | |
| # return {"output":result} | |
| return result | |
| def llm_call_router(state:GraphState): | |
| decision = router.invoke( | |
| [ | |
| SystemMessage( | |
| content="Route the user_question to HR Agent, Finance Agent, Legal Compliance Agent based on the user's request" | |
| ), | |
| HumanMessage( | |
| content=state['user_question'] | |
| ), | |
| ] | |
| ) | |
| return {"decision":decision.step} | |
| def route_decision(state:GraphState): | |
| if state['decision'] == 'HR Agent': | |
| return "hr_agent" | |
| elif state['decision'] == 'Finance Agent': | |
| return "finance_agent" | |
| elif state['decision'] == 'Legal Compliance Agent': | |
| return "legals_agent" | |
| router_builder = StateGraph(GraphState) | |
| router_builder.add_node("retrieve", retrieve_documents) | |
| router_builder.add_node("hr_agent", hr_agent) | |
| router_builder.add_node("finance_agent", finance_agent) | |
| router_builder.add_node("legals_agent", legals_agent) | |
| router_builder.add_node("llm_call_router", llm_call_router) | |
| # router_builder.add_node("generate", generate_response) | |
| # router_builder.set_entry_point("retrieve") | |
| # router_builder.add_edge("retrieve", "generate") | |
| # router_builder.add_edge("generate", END) | |
| # compiled_app = workflow.compile(checkpointer=memory) | |
| router_builder.add_edge(START, "llm_call_router") | |
| router_builder.add_conditional_edges( | |
| "llm_call_router", | |
| route_decision, | |
| { | |
| "hr_agent":"hr_agent", | |
| "finance_agent":"finance_agent", | |
| "legals_agent":"legals_agent", | |
| }, | |
| ) | |
| router_builder.set_entry_point("retrieve") | |
| router_builder.add_edge("retrieve","llm_call_router") | |
| router_builder.add_edge("hr_agent",END) | |
| router_builder.add_edge("finance_agent",END) | |
| router_builder.add_edge("legals_agent",END) | |
| route_workflow = router_builder.compile() | |
| # state = route_workflow.invoke({'input': "Write a poem about a wicked cat"}) | |
| # print(state['output']) | |
| vectorstore_retriever = None | |
| compiled_app = None | |
| memory = MemorySaver() | |
| # --- 4. LangGraph Nodes --- | |
| # def load_documents(state:GraphState): | |
| # global selected_domain | |
| # --- 5. API Models --- | |
| class ChatHistoryEntry(BaseModel): | |
| role: str | |
| content: str | |
| telemetry_id: Optional[str] = None | |
| class ChatRequest(BaseModel): | |
| user_question: str | |
| session_id: str | |
| chat_history: Optional[List[ChatHistoryEntry]] = Field(default_factory=list) | |
| def validate_prompt(cls, v): | |
| v = v.strip() | |
| if not v: | |
| raise ValueError('Question cannot be empty') | |
| return v | |
| class ChatResponse(BaseModel): | |
| ai_response: str | |
| updated_chat_history: List[ChatHistoryEntry] | |
| telemetry_entry_id: str | |
| is_restricted: bool = False | |
| moderation_reason: Optional[str] = None | |
| class FeedbackRequest(BaseModel): | |
| session_id: str | |
| telemetry_entry_id: str | |
| feedback_score: int | |
| feedback_text: Optional[str] = None | |
| class ConversationSummary(BaseModel): | |
| session_id: str | |
| title: str | |
| def process_text(file): | |
| string_data = (file.read()).decode("utf-8") | |
| return string_data | |
| def process_pdf(file): | |
| pdf_bytes = io.BytesIO(file.read()) | |
| reader = PyPDF2.PdfReader(pdf_bytes) | |
| pdf_text = "".join([page.extract_text() + "\n" for page in reader.pages]) | |
| return pdf_text | |
| def process_docx(file): | |
| docx_bytes = io.BytesIO(file.read()) | |
| docx_docs = dx(docx_bytes) | |
| docx_content = "\n".join([para.text for para in docx_docs.paragraphs]) | |
| return docx_content | |
| # @app.post("/upload-documents") | |
| # def upload_documents(files): | |
| def upload_documents(): | |
| global vectorstore_retriever | |
| # saved_vectorstore_index = FAISS.load_local('domain_index', embedding_model,allow_dangerous_deserialization=True) | |
| try: | |
| saved_vectorstore_index = faiss.read_index("domain_index_sec.faiss") | |
| if saved_vectorstore_index: | |
| vectorstore_retriever = saved_vectorstore_index | |
| msg = f"Successfully loaded the knowledge base." | |
| return msg, True | |
| except Exception as e: | |
| print("unable to find index...", e) | |
| print("Creating new index.....") | |
| all_documents = [] | |
| hr_loader = PyPDFLoader("D:\Pdf_data\Developments_in_HR_management_in_QAAs.pdf").load() | |
| hr_finance = PyPDFLoader("D:\Pdf_data\White Paper_QA Practice.pdf").load() | |
| hr_legal = PyPDFLoader("D:\Pdf_data\Legal-Aspects-Compliances.pdf").load() | |
| for doc in hr_loader: | |
| doc.metadata['Domain'] = 'HR' | |
| all_documents.append(doc) | |
| for doc in hr_finance: | |
| doc.metadata['Domain'] = 'Finance' | |
| all_documents.append(doc) | |
| for doc in hr_legal: | |
| doc.metadata['Domain'] = 'Legal' | |
| all_documents.append(doc) | |
| # for uploaded_file in files: | |
| # doc_loader = PyPDFLoader(uploaded_file) | |
| # all_documents.extend(doc_loader.load()) | |
| if not all_documents: | |
| raise Exception(status_code=400, detail="No supported documents uploaded.") | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| text_chunks = text_splitter.split_documents(all_documents) | |
| print("text_chucks: ", text_chunks[:100]) | |
| # processed_chunks_with_ids = [] | |
| # for i, chunk in enumerate(text_chunks): | |
| # # Generate a unique ID for each chunk | |
| # # Option 1 (Recommended): Using UUID for global uniqueness | |
| # # chunk_id = str(uuid.uuid4()) | |
| # # Option 2 (Alternative): Combining source file path with chunk index | |
| # # This is good if you want IDs to be deterministic based on file/chunk. | |
| # # You might need to make the file path more robust (e.g., hash it or normalize it). | |
| # file_source = chunk.metadata.get('source', 'unknown_source') | |
| # chunk_id = f"{file_source.replace('.','_')}_chunk_{i}" | |
| # # Add the unique ID to the chunk's metadata | |
| # # It's good practice to keep original metadata and just add your custom ID. | |
| # chunk.metadata['doc_id'] = chunk_id | |
| # processed_chunks_with_ids.append(chunk) | |
| # embeddings = [embedding_model.encode(doc_chunks.page_content, convert_to_numpy=True) for doc_chunks in processed_chunks_with_ids] | |
| print(f"Split {len(text_chunks)} chunks.") | |
| print(f"Assigned unique 'doc_id' to each chunk in metadata.") | |
| # dimension = 768 | |
| # # hnsw_m = 32 | |
| # # index = faiss.IndexHNSWFlat(dimension, hnsw_m, faiss.METRIC_INNER_PRODUCT) | |
| # index = faiss.IndexFlatL2(dimension) | |
| # vector_store = FAISS( | |
| # embedding_function=embedding_model.embed_query, | |
| # index=index, | |
| # docstore= InMemoryDocstore(), | |
| # index_to_docstore_id={} | |
| # ) | |
| vectorstore = FAISS.from_documents(documents=text_chunks, embedding=embedding_model) | |
| # vectorstore.add_documents(text_chunks, ids = [cid.metadata['doc_id'] for cid in text_chunks]) | |
| vectorstore.add_documents(text_chunks) | |
| # vectorstore_retriever = vectorstore.as_retriever(search_kwargs={'k': 5}) | |
| faiss.write_index(vectorstore.index, "domain_index_sec.faiss") | |
| # vectorstore.save_local("domain_index") | |
| vectorstore_retriever = vectorstore | |
| if vectorstore: | |
| msg = f"Successfully loaded the knowledge base." | |
| return msg, True | |
| else: | |
| msg = f"Failed to process documents." | |
| return msg, False | |
| # @app.post("/chat", response_model=ChatResponse) | |
| def chat_with_rag(chatdata): | |
| global compiled_app | |
| global vectorstore_retriever | |
| global chathistory | |
| if vectorstore_retriever is None: | |
| raise Exception(status_code=400, detail="Knowledge base not loaded. Please upload documents first.") | |
| print(f"Received request: {chatdata}") | |
| # moderation_result = moderator.moderate_content(request.user_question) | |
| # if moderation_result["is_restricted"]: | |
| # # Get appropriate response based on restriction type | |
| # response_type = moderation_result.get("response_type", "general") | |
| # response_text = Config.RESTRICTED_RESPONSES.get( | |
| # response_type, | |
| # Config.RESTRICTED_RESPONSES["general"] | |
| # ) | |
| # logger.warning( | |
| # f"Restricted query: {request.prompt[:100]}... " | |
| # f"Reason: {moderation_result['reason']}" | |
| # ) | |
| # return ChatResponse( | |
| # ai_response=response_text, | |
| # updated_chat_history=[], | |
| # telemetry_entry_id=request.session_id, | |
| # is_restricted=True, | |
| # moderation_reason=moderation_result["reason"], | |
| # ) | |
| print("✅ Question passed the RAI check.........") | |
| print("Received data from UI: ", chatdata) | |
| chathistory = chatdata | |
| initial_state = { | |
| # "chat_history": [msg.model_dump() for msg in chatdata.get('chat_history')], | |
| "chat_history": [msg for msg in chatdata.get('chat_history')], | |
| "retrieved_documents": [], | |
| "user_question": chatdata.get('user_question'), | |
| "session_id": chatdata.get('session_id') | |
| } | |
| try: | |
| config = {"configurable": {"thread_id": chatdata.get('session_id')}} | |
| final_state = route_workflow.invoke(initial_state, config=config) | |
| # chathistory = final_state | |
| print("chathistory inside chat_with_rag-----------------") | |
| print("Final State--- : ", final_state) | |
| ai_response_message = final_state["chat_history"][-1]["content"] | |
| updated_chat_history_dicts = final_state["chat_history"] | |
| agent_name = final_state.get("decision","No Agent") | |
| response_chat = ChatResponse( | |
| ai_response=ai_response_message, | |
| updated_chat_history=updated_chat_history_dicts, | |
| telemetry_entry_id=final_state.get("telemetry_id"), | |
| is_restricted=False, | |
| ) | |
| return agent_name,response_chat.dict() | |
| except Exception as e: | |
| print(f"Internal Server Error: {e}") | |
| raise Exception(status_code=500, detail=f"An error occurred during chat processing: {e}") | |
| def submit_feedback(feedbackdata): | |
| db = SessionLocal() | |
| try: | |
| telemetry_record = db.query(Telemetry).filter( | |
| Telemetry.transaction_id == feedbackdata['telemetry_entry_id'], | |
| Telemetry.session_id == feedbackdata['session_id'] | |
| ).first() | |
| if not telemetry_record: | |
| raise Exception(status_code=404, detail="Telemetry entry not found or session ID mismatch.") | |
| existing_feedback = db.query(Feedback).filter( | |
| Feedback.telemetry_entry_id == feedbackdata['telemetry_entry_id'] | |
| ).first() | |
| if existing_feedback: | |
| existing_feedback.feedback_score = feedbackdata['feedback_score'] | |
| existing_feedback.feedback_text = feedbackdata['feedback_text'] | |
| existing_feedback.timestamp = datetime.datetime.now() | |
| else: | |
| feedback_record = Feedback( | |
| telemetry_entry_id=feedbackdata['telemetry_entry_id'], | |
| feedback_score=feedbackdata['feedback_score'], | |
| feedback_text=feedbackdata['feedback_text'], | |
| user_query=telemetry_record.user_question, | |
| llm_response=telemetry_record.response, | |
| timestamp=datetime.datetime.now() | |
| ) | |
| db.add(feedback_record) | |
| db.commit() | |
| return {"message": "Feedback submitted successfully."} | |
| except Exception as e: | |
| raise e | |
| except Exception as e: | |
| db.rollback() | |
| raise Exception(status_code=500, detail=f"An error occurred: {str(e)}") | |
| finally: | |
| db.close() | |
| # @app.get("/conversations", response_model=List[ConversationSummary]) | |
| def get_conversations(): | |
| db = SessionLocal() | |
| try: | |
| conversations = db.query(ConversationHistory).order_by(ConversationHistory.last_updated.desc()).all() | |
| summaries = [] | |
| for conv in conversations: | |
| for msg in conv.messages: | |
| print(msg) | |
| first_user_message = next((msg for msg in conv.messages if msg["role"] == "user"), None) | |
| title = first_user_message.get("content") if first_user_message else "New Conversation" | |
| summaries.append({"session_id":conv.session_id, "title":title[:30] + "..." if len(title) > 30 else title}) | |
| return summaries | |
| finally: | |
| db.close() | |
| # @app.get("/conversations/{session_id}", response_model=List[ChatHistoryEntry]) | |
| def get_conversation_history(session_id: str): | |
| db = SessionLocal() | |
| try: | |
| conversation = db.query(ConversationHistory).filter(ConversationHistory.session_id == session_id).first() | |
| if not conversation: | |
| raise Exception(status_code=404, detail="Conversation not found.") | |
| return conversation.messages | |
| finally: | |
| db.close() | |
| # if 'selected_model' not in st.session_state: | |
| # st.session_state.selected_model = "" | |
| # @st.dialog("Choose a domain") | |
| # def domain_modal(): | |
| # domain = st.selectbox("Select a domain",["HR","Finance","Legal"]) | |
| # st.session_state.selected_model = domain | |
| # if st.button("submit"): | |
| # st.rerun() | |
| # domain_modal() | |
| # print("Selected Domain: ",st.session_state['selected_model']) | |
| # llm = initialize_llm() | |