Spaces:
Sleeping
Sleeping
| """ | |
| ViettelPay AI Agent using LangGraph | |
| Multi-turn conversation support with short-term memory using InMemorySaver | |
| """ | |
| import os | |
| from typing import Dict, Optional | |
| from functools import partial | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.memory import InMemorySaver | |
| from langchain_core.messages import HumanMessage | |
| from src.agent.nodes import ( | |
| ViettelPayState, | |
| classify_intent_node, | |
| query_enhancement_node, | |
| knowledge_retrieval_node, | |
| script_response_node, | |
| generate_response_node, | |
| route_after_intent_classification, | |
| route_after_query_enhancement, | |
| route_after_knowledge_retrieval, | |
| ) | |
| # Import configuration utility | |
| from src.utils.config import get_knowledge_base_path, get_llm_provider | |
| class ViettelPayAgent: | |
| """Main ViettelPay AI Agent using LangGraph workflow with multi-turn conversation support""" | |
| def __init__( | |
| self, | |
| knowledge_base_path: str = None, | |
| scripts_file: Optional[str] = None, | |
| llm_provider: str = None, | |
| ): | |
| knowledge_base_path = knowledge_base_path or get_knowledge_base_path() | |
| scripts_file = scripts_file or "./viettelpay_docs/processed/kich_ban.csv" | |
| llm_provider = llm_provider or get_llm_provider() | |
| self.knowledge_base_path = knowledge_base_path | |
| self.scripts_file = scripts_file | |
| self.llm_provider = llm_provider | |
| # Initialize LLM client once during agent creation | |
| print(f"🧠 Initializing LLM client ({self.llm_provider})...") | |
| from src.llm.llm_client import LLMClientFactory | |
| self.llm_client = LLMClientFactory.create_client(self.llm_provider) | |
| print(f"✅ LLM client initialized and ready") | |
| # Initialize knowledge retriever once during agent creation | |
| print(f"📚 Initializing knowledge retriever...") | |
| try: | |
| from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase | |
| self.knowledge_base = ViettelKnowledgeBase() | |
| ensemble_retriever = self.knowledge_base.load_knowledge_base( | |
| knowledge_base_path | |
| ) | |
| if not ensemble_retriever: | |
| raise ValueError( | |
| f"Knowledge base not found at {knowledge_base_path}. Run build_database_script.py first." | |
| ) | |
| print(f"✅ Knowledge retriever initialized and ready") | |
| except Exception as e: | |
| print(f"⚠️ Knowledge retriever initialization failed: {e}") | |
| self.knowledge_base = None | |
| # Initialize checkpointer for short-term memory | |
| self.checkpointer = InMemorySaver() | |
| # Build workflow with pre-initialized components | |
| self.workflow = self._build_workflow() | |
| self.app = self.workflow.compile(checkpointer=self.checkpointer) | |
| print("✅ ViettelPay Agent initialized with multi-turn conversation support") | |
| def _build_workflow(self) -> StateGraph: | |
| """Build LangGraph workflow with pre-initialized components""" | |
| # Create workflow graph | |
| workflow = StateGraph(ViettelPayState) | |
| # Create node functions with pre-bound components using functools.partial | |
| # This eliminates the need to initialize components in each node call | |
| classify_intent_with_llm = partial( | |
| classify_intent_node, llm_client=self.llm_client | |
| ) | |
| query_enhancement_with_llm = partial( | |
| query_enhancement_node, llm_client=self.llm_client | |
| ) | |
| knowledge_retrieval_with_retriever = partial( | |
| knowledge_retrieval_node, knowledge_retriever=self.knowledge_base | |
| ) | |
| generate_response_with_llm = partial( | |
| generate_response_node, llm_client=self.llm_client | |
| ) | |
| # Add nodes (some with pre-bound components, some without) | |
| workflow.add_node("classify_intent", classify_intent_with_llm) | |
| workflow.add_node("query_enhancement", query_enhancement_with_llm) | |
| workflow.add_node("knowledge_retrieval", knowledge_retrieval_with_retriever) | |
| workflow.add_node( | |
| "script_response", script_response_node | |
| ) # No pre-bound components needed | |
| workflow.add_node("generate_response", generate_response_with_llm) | |
| # Set entry point | |
| workflow.set_entry_point("classify_intent") | |
| # Add conditional routing after intent classification | |
| workflow.add_conditional_edges( | |
| "classify_intent", | |
| route_after_intent_classification, | |
| { | |
| "script_response": "script_response", | |
| "query_enhancement": "query_enhancement", | |
| }, | |
| ) | |
| # Script responses go directly to end | |
| workflow.add_edge("script_response", END) | |
| # Query enhancement goes to knowledge retrieval | |
| workflow.add_edge("query_enhancement", "knowledge_retrieval") | |
| # Knowledge retrieval goes to response generation | |
| workflow.add_edge("knowledge_retrieval", "generate_response") | |
| workflow.add_edge("generate_response", END) | |
| print("🔄 LangGraph workflow built successfully with optimized component usage") | |
| return workflow | |
| def process_message(self, user_message: str, thread_id: str = "default") -> Dict: | |
| """Process a user message in a multi-turn conversation""" | |
| print(f"\n💬 Processing message: '{user_message}' (thread: {thread_id})") | |
| print("=" * 50) | |
| # Create configuration with thread_id for conversation memory | |
| config = {"configurable": {"thread_id": thread_id}} | |
| try: | |
| # Create human message | |
| human_message = HumanMessage(content=user_message) | |
| # Initialize state with the new message | |
| # Note: conversation_context is set to None so it gets recomputed with fresh message history | |
| initial_state = { | |
| "messages": [human_message], | |
| "intent": None, | |
| "confidence": None, | |
| "enhanced_query": None, | |
| "retrieved_docs": None, | |
| "conversation_context": None, # Reset to ensure fresh context computation | |
| "response_type": None, | |
| "error": None, | |
| "processing_info": None, | |
| } | |
| # Run workflow with memory | |
| result = self.app.invoke(initial_state, config) | |
| # Extract response from the last AI message | |
| messages = result.get("messages", []) | |
| if messages: | |
| # Get the last AI message | |
| last_message = messages[-1] | |
| if hasattr(last_message, "content"): | |
| response = last_message.content | |
| else: | |
| response = str(last_message) | |
| else: | |
| response = "Xin lỗi, em không thể xử lý yêu cầu này." | |
| response_type = result.get("response_type", "unknown") | |
| intent = result.get("intent", "unknown") | |
| confidence = result.get("confidence", 0.0) | |
| enhanced_query = result.get("enhanced_query", "") | |
| error = result.get("error") | |
| # Build response info | |
| response_info = { | |
| "response": response, | |
| "intent": intent, | |
| "confidence": confidence, | |
| "response_type": response_type, | |
| "enhanced_query": enhanced_query, | |
| "success": error is None, | |
| "error": error, | |
| "thread_id": thread_id, | |
| "message_count": len(messages), | |
| } | |
| print(f"✅ Response generated successfully") | |
| print(f" Intent: {intent} (confidence: {confidence})") | |
| print(f" Type: {response_type}") | |
| if enhanced_query and enhanced_query != user_message: | |
| print(f" Enhanced query: {enhanced_query}") | |
| print(f" Thread: {thread_id}") | |
| return response_info | |
| except Exception as e: | |
| print(f"❌ Workflow error: {e}") | |
| return { | |
| "response": "Xin lỗi, em gặp lỗi kỹ thuật. Vui lòng thử lại sau.", | |
| "intent": "error", | |
| "confidence": 0.0, | |
| "response_type": "error", | |
| "enhanced_query": "", | |
| "success": False, | |
| "error": str(e), | |
| "thread_id": thread_id, | |
| "message_count": 0, | |
| } | |
| def chat(self, user_message: str, thread_id: str = "default") -> str: | |
| """Simple chat interface - returns just the response text""" | |
| result = self.process_message(user_message, thread_id) | |
| return result["response"] | |
| def get_conversation_history(self, thread_id: str = "default") -> list: | |
| """Get conversation history for a specific thread""" | |
| try: | |
| config = {"configurable": {"thread_id": thread_id}} | |
| # Get the current state to access message history | |
| current_state = self.app.get_state(config) | |
| if current_state and current_state.values.get("messages"): | |
| messages = current_state.values["messages"] | |
| history = [] | |
| for msg in messages: | |
| if hasattr(msg, "type") and hasattr(msg, "content"): | |
| role = "user" if msg.type == "human" else "assistant" | |
| history.append({"role": role, "content": msg.content}) | |
| elif hasattr(msg, "role") and hasattr(msg, "content"): | |
| history.append({"role": msg.role, "content": msg.content}) | |
| return history | |
| else: | |
| return [] | |
| except Exception as e: | |
| print(f"❌ Error getting conversation history: {e}") | |
| return [] | |
| def clear_conversation(self, thread_id: str = "default") -> bool: | |
| """Clear conversation history for a specific thread""" | |
| try: | |
| # Note: InMemorySaver doesn't have a direct clear method | |
| # The conversation will be cleared when the app is restarted | |
| # For persistent memory, you'd need to implement a clear method | |
| print(f"📝 Conversation clearing requested for thread: {thread_id}") | |
| print(" Note: InMemorySaver conversations clear on app restart") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Error clearing conversation: {e}") | |
| return False | |
| def get_workflow_info(self) -> Dict: | |
| """Get information about the workflow structure""" | |
| return { | |
| "nodes": [ | |
| "classify_intent", | |
| "query_enhancement", | |
| "knowledge_retrieval", | |
| "script_response", | |
| "generate_response", | |
| ], | |
| "entry_point": "classify_intent", | |
| "knowledge_base_path": self.knowledge_base_path, | |
| "scripts_file": self.scripts_file, | |
| "llm_provider": self.llm_provider, | |
| "memory_type": "InMemorySaver", | |
| "multi_turn": True, | |
| "query_enhancement": True, | |
| "optimizations": { | |
| "llm_client": "Single initialization with functools.partial", | |
| "knowledge_retriever": "Single initialization with functools.partial", | |
| "conversation_context": "Cached in state to avoid repeated computation", | |
| }, | |
| } | |
| def health_check(self) -> Dict: | |
| """Check if all components are working""" | |
| health_status = { | |
| "agent": True, | |
| "workflow": True, | |
| "memory": True, | |
| "llm": False, | |
| "knowledge_base": False, | |
| "scripts": False, | |
| "overall": False, | |
| } | |
| try: | |
| # Test LLM client (already initialized) | |
| test_response = self.llm_client.generate("Hello", temperature=0.1) | |
| health_status["llm"] = bool(test_response) | |
| print("✅ LLM client working") | |
| except Exception as e: | |
| print(f"⚠️ LLM health check failed: {e}") | |
| health_status["llm"] = False | |
| try: | |
| # Test memory/checkpointer | |
| test_config = {"configurable": {"thread_id": "health_check"}} | |
| test_state = {"messages": [HumanMessage(content="test")]} | |
| # Try to invoke with memory | |
| self.app.invoke(test_state, test_config) | |
| health_status["memory"] = True | |
| print("✅ Memory/checkpointer working") | |
| except Exception as e: | |
| print(f"⚠️ Memory health check failed: {e}") | |
| health_status["memory"] = False | |
| try: | |
| # Test knowledge base (using pre-initialized retriever) | |
| if self.knowledge_base: | |
| # Test a simple search to verify it's working | |
| test_docs = self.knowledge_base.search("test", top_k=1) | |
| health_status["knowledge_base"] = True | |
| print("✅ Knowledge retriever working") | |
| else: | |
| health_status["knowledge_base"] = False | |
| print("❌ Knowledge retriever not initialized") | |
| except Exception as e: | |
| print(f"⚠️ Knowledge base health check failed: {e}") | |
| health_status["knowledge_base"] = False | |
| try: | |
| # Test scripts | |
| from src.agent.scripts import ConversationScripts | |
| scripts = ConversationScripts(self.scripts_file) | |
| health_status["scripts"] = len(scripts.get_all_script_types()) > 0 | |
| except Exception as e: | |
| print(f"⚠️ Scripts health check failed: {e}") | |
| # Overall health | |
| health_status["overall"] = all( | |
| [ | |
| health_status["agent"], | |
| health_status["memory"], | |
| health_status["llm"], | |
| health_status["knowledge_base"], | |
| health_status["scripts"], | |
| ] | |
| ) | |
| return health_status | |
| # Usage example and testing | |
| if __name__ == "__main__": | |
| # Initialize agent | |
| agent = ViettelPayAgent() | |
| # Health check | |
| print("\n🏥 Health Check:") | |
| health = agent.health_check() | |
| for component, status in health.items(): | |
| status_icon = "✅" if status else "❌" | |
| print(f" {component}: {status_icon}") | |
| if not health["overall"]: | |
| print("\n⚠️ Some components are not healthy. Check requirements and data files.") | |
| exit(1) | |
| print(f"\n🤖 Agent ready! Workflow info: {agent.get_workflow_info()}") | |
| # Test multi-turn conversation with query enhancement | |
| test_thread = "test_conversation" | |
| print( | |
| f"\n🧪 Testing multi-turn conversation with query enhancement (thread: {test_thread}):" | |
| ) | |
| test_messages = [ | |
| "Xin chào!", | |
| "Mã lỗi 606 là gì?", | |
| "Làm sao khắc phục?", # This should be enhanced to "làm sao khắc phục lỗi 606" | |
| "Còn lỗi nào khác tương tự không?", # This should be enhanced with error context | |
| "Cảm ơn bạn!", | |
| ] | |
| for i, message in enumerate(test_messages, 1): | |
| print(f"\n--- Turn {i} ---") | |
| result = agent.process_message(message, test_thread) | |
| print(f"User: {message}") | |
| print(f"Bot: {result['response'][:150]}...") | |
| if result.get("enhanced_query") and result["enhanced_query"] != message: | |
| print(f"🚀 Query enhanced: {result['enhanced_query']}") | |
| # Show conversation history | |
| if i > 1: | |
| history = agent.get_conversation_history(test_thread) | |
| print(f"History length: {len(history)} messages") | |
| print(f"\n📜 Final conversation history:") | |
| history = agent.get_conversation_history(test_thread) | |
| for i, msg in enumerate(history, 1): | |
| print(f" {i}. {msg['role']}: {msg['content'][:100]}...") | |