Spaces:
Running
Running
| """ | |
| LangGraph Agent State and Processing Nodes | |
| """ | |
| from typing import Dict, List, Optional, TypedDict, Annotated | |
| from langchain.schema import Document | |
| from langchain_core.messages import AnyMessage | |
| from langgraph.graph.message import add_messages | |
| import json | |
| import re | |
| from src.agent.prompts import ( | |
| INTENT_CLASSIFICATION_PROMPT, | |
| QUERY_ENHANCEMENT_PROMPT, | |
| RESPONSE_GENERATION_PROMPT, | |
| get_system_prompt_by_intent, | |
| ) | |
| class ViettelPayState(TypedDict): | |
| """State for ViettelPay agent workflow with message history support""" | |
| # Message history for multi-turn conversation | |
| messages: Annotated[List[AnyMessage], add_messages] | |
| # Processing | |
| intent: Optional[str] | |
| confidence: Optional[float] | |
| # Query enhancement | |
| enhanced_query: Optional[str] | |
| # Knowledge retrieval | |
| retrieved_docs: Optional[List[Document]] | |
| # Conversation context (cached to avoid repeated computation) | |
| conversation_context: Optional[str] | |
| # Response type metadata | |
| response_type: Optional[str] # "script" or "generated" | |
| # Metadata | |
| error: Optional[str] | |
| processing_info: Optional[Dict] | |
| def get_conversation_context(messages: List[AnyMessage], max_messages: int = 3) -> str: | |
| """ | |
| Extract conversation context from message history | |
| Args: | |
| messages: List of conversation messages | |
| max_messages: Maximum number of recent messages to include | |
| Returns: | |
| Formatted conversation context string | |
| """ | |
| if len(messages) <= 1: | |
| return "" | |
| context = "\n\nLịch sử cuộc hội thoại:\n" | |
| # Get recent messages (excluding the current/last message for intent classification) | |
| recent_messages = messages[ | |
| -(max_messages + 1) : -1 | |
| ] # Exclude the very last message | |
| for msg in recent_messages: | |
| # Handle different message types more robustly | |
| if hasattr(msg, "type"): | |
| if msg.type == "human": | |
| role = "Người dùng" | |
| elif msg.type == "ai": | |
| role = "Trợ lý" | |
| else: | |
| role = f"Unknown-{msg.type}" | |
| elif hasattr(msg, "role"): | |
| if msg.role in ["user", "human"]: | |
| role = "Người dùng" | |
| elif msg.role in ["assistant", "ai"]: | |
| role = "Trợ lý" | |
| else: | |
| role = f"Unknown-{msg.role}" | |
| else: | |
| role = "Unknown" | |
| # Limit message length to avoid token overflow | |
| # content = msg.content[:1000] + "..." if len(msg.content) > 1000 else msg.content | |
| content = msg.content | |
| context += f"{role}: {content}\n" | |
| # print(context) | |
| return context | |
| def classify_intent_node(state: ViettelPayState, llm_client) -> ViettelPayState: | |
| """Node for intent classification using LLM with conversation context""" | |
| # Get the latest user message | |
| messages = state["messages"] | |
| if not messages: | |
| return { | |
| **state, | |
| "intent": "unclear", | |
| "confidence": 0.0, | |
| "error": "No messages found", | |
| } | |
| # Find the last human/user message | |
| user_message = None | |
| for msg in reversed(messages): | |
| if hasattr(msg, "type") and msg.type == "human": | |
| user_message = msg.content | |
| break | |
| elif hasattr(msg, "role") and msg.role == "user": | |
| user_message = msg.content | |
| break | |
| if not user_message: | |
| return { | |
| **state, | |
| "intent": "unclear", | |
| "confidence": 0.0, | |
| "error": "No user message found", | |
| } | |
| try: | |
| # Get conversation context for better intent classification | |
| conversation_context = get_conversation_context(messages) | |
| # Intent classification prompt with context using the prompts file | |
| classification_prompt = INTENT_CLASSIFICATION_PROMPT.format( | |
| conversation_context=conversation_context, user_message=user_message | |
| ) | |
| # Get classification using the pre-initialized LLM client | |
| response = llm_client.generate(classification_prompt, temperature=0.1) | |
| # print(f"🔍 Raw LLM response: {response}") | |
| # Parse JSON response | |
| try: | |
| # Try to extract JSON from response (in case there's extra text) | |
| response_clean = response.strip() | |
| # Look for JSON object in the response | |
| json_match = re.search(r"\{.*\}", response_clean, re.DOTALL) | |
| if json_match: | |
| json_str = json_match.group() | |
| result = json.loads(json_str) | |
| else: | |
| # Try parsing the whole response | |
| result = json.loads(response_clean) | |
| intent = result.get("intent", "unclear") | |
| confidence = result.get("confidence", 0.5) | |
| explanation = result.get("explanation", "") | |
| # print( | |
| # f"✅ JSON parsed successfully: intent={intent}, confidence={confidence}" | |
| # ) | |
| except (json.JSONDecodeError, AttributeError) as e: | |
| print(f"❌ JSON parsing failed: {e}") | |
| print(f" Raw response: {response}") | |
| # Fallback: try to extract intent from text | |
| response_lower = response.lower() | |
| if any( | |
| word in response_lower for word in ["lỗi", "error", "606", "mã lỗi"] | |
| ): | |
| intent = "error_help" | |
| confidence = 0.7 | |
| elif any(word in response_lower for word in ["xin chào", "hello", "chào"]): | |
| intent = "greeting" | |
| confidence = 0.8 | |
| elif any(word in response_lower for word in ["hủy", "cancel", "thủ tục"]): | |
| intent = "procedure_guide" | |
| confidence = 0.7 | |
| elif any( | |
| word in response_lower for word in ["nạp", "cước", "dịch vụ", "faq"] | |
| ): | |
| intent = "faq" | |
| confidence = 0.7 | |
| else: | |
| intent = "unclear" | |
| confidence = 0.3 | |
| print(f"🔄 Fallback classification: {intent} (confidence: {confidence})") | |
| explanation = "Fallback classification due to JSON parse error" | |
| # print(f"🎯 Intent classified: {intent} (confidence: {confidence})") | |
| return { | |
| **state, | |
| "intent": intent, | |
| "confidence": confidence, | |
| "conversation_context": conversation_context, # Save context for reuse | |
| "processing_info": { | |
| "classification_raw": response, | |
| "explanation": explanation, | |
| "context_used": bool(conversation_context.strip()), | |
| }, | |
| } | |
| except Exception as e: | |
| print(f"❌ Intent classification error: {e}") | |
| return {**state, "intent": "unclear", "confidence": 0.0, "error": str(e)} | |
| def query_enhancement_node(state: ViettelPayState, llm_client) -> ViettelPayState: | |
| """Node for enhancing search query using conversation context""" | |
| # Get the latest user message | |
| messages = state["messages"] | |
| if not messages: | |
| return {**state, "enhanced_query": "", "error": "No messages found"} | |
| # Find the last human/user message | |
| user_message = None | |
| for msg in reversed(messages): | |
| if hasattr(msg, "type") and msg.type == "human": | |
| user_message = msg.content | |
| break | |
| elif hasattr(msg, "role") and msg.role == "user": | |
| user_message = msg.content | |
| break | |
| if not user_message: | |
| return {**state, "enhanced_query": "", "error": "No user message found"} | |
| try: | |
| # Use saved conversation context if available, otherwise get it | |
| conversation_context = state.get("conversation_context") | |
| if conversation_context is None: | |
| conversation_context = get_conversation_context(messages) | |
| # If no context, use original message | |
| if not conversation_context.strip(): | |
| print(f"🔍 No context available, using original query: {user_message}") | |
| return {**state, "enhanced_query": user_message} | |
| # Query enhancement prompt using the prompts file | |
| enhancement_prompt = QUERY_ENHANCEMENT_PROMPT.format( | |
| conversation_context=conversation_context, user_message=user_message | |
| ) | |
| # Get enhanced query | |
| enhanced_query = llm_client.generate(enhancement_prompt, temperature=0.1) | |
| enhanced_query = enhanced_query.strip() | |
| print(f"🔍 Original query: {user_message}") | |
| print(f"🚀 Enhanced query: {enhanced_query}") | |
| return {**state, "enhanced_query": enhanced_query} | |
| except Exception as e: | |
| print(f"❌ Query enhancement error: {e}") | |
| # Fallback to original message | |
| return {**state, "enhanced_query": user_message, "error": str(e)} | |
| def knowledge_retrieval_node( | |
| state: ViettelPayState, knowledge_retriever | |
| ) -> ViettelPayState: | |
| """Node for knowledge retrieval using pre-initialized ViettelKnowledgeBase""" | |
| # Use enhanced query if available, otherwise fall back to extracting from messages | |
| enhanced_query = state.get("enhanced_query", "") | |
| if not enhanced_query: | |
| # Fallback: extract from messages | |
| messages = state["messages"] | |
| if not messages: | |
| return {**state, "retrieved_docs": [], "error": "No messages found"} | |
| # Find the last human/user message | |
| for msg in reversed(messages): | |
| if hasattr(msg, "type") and msg.type == "human": | |
| enhanced_query = msg.content | |
| break | |
| elif hasattr(msg, "role") and msg.role == "user": | |
| enhanced_query = msg.content | |
| break | |
| if not enhanced_query: | |
| return {**state, "retrieved_docs": [], "error": "No query available"} | |
| try: | |
| if not knowledge_retriever: | |
| raise ValueError("Knowledge retriever not available") | |
| # Retrieve relevant documents using enhanced query and pre-initialized ViettelKnowledgeBase | |
| retrieved_docs = knowledge_retriever.search(enhanced_query, top_k=10) | |
| print( | |
| f"📚 Retrieved {len(retrieved_docs)} documents for enhanced query: {enhanced_query}" | |
| ) | |
| return {**state, "retrieved_docs": retrieved_docs} | |
| except Exception as e: | |
| print(f"❌ Knowledge retrieval error: {e}") | |
| return {**state, "retrieved_docs": [], "error": str(e)} | |
| def script_response_node(state: ViettelPayState) -> ViettelPayState: | |
| """Node for script-based responses""" | |
| from src.agent.scripts import ConversationScripts | |
| from langchain_core.messages import AIMessage | |
| intent = state.get("intent", "") | |
| try: | |
| # Load scripts | |
| scripts = ConversationScripts("./viettelpay_docs/processed/kich_ban.csv") | |
| # Map intents to script types | |
| intent_to_script = { | |
| "greeting": "greeting", | |
| "out_of_scope": "out_of_scope", | |
| "human_request": "human_request_attempt_1", # Could be enhanced later | |
| "unclear": "ask_for_clarity", | |
| } | |
| script_type = intent_to_script.get(intent) | |
| if script_type and scripts.has_script(script_type): | |
| response_text = scripts.get_script(script_type) | |
| print(f"📋 Using script response: {script_type}") | |
| # Add AI message to the conversation | |
| ai_message = AIMessage(content=response_text) | |
| return {**state, "messages": [ai_message], "response_type": "script"} | |
| else: | |
| # Fallback script | |
| fallback_response = ( | |
| "Xin lỗi, em chưa hiểu rõ yêu cầu của anh/chị. Vui lòng thử lại." | |
| ) | |
| ai_message = AIMessage(content=fallback_response) | |
| print(f"📋 Using fallback script for intent: {intent}") | |
| return {**state, "messages": [ai_message], "response_type": "script"} | |
| except Exception as e: | |
| print(f"❌ Script response error: {e}") | |
| fallback_response = "Xin lỗi, em gặp lỗi kỹ thuật. Vui lòng thử lại sau." | |
| ai_message = AIMessage(content=fallback_response) | |
| return { | |
| **state, | |
| "messages": [ai_message], | |
| "response_type": "error", | |
| "error": str(e), | |
| } | |
| def generate_response_node(state: ViettelPayState, llm_client) -> ViettelPayState: | |
| """Node for LLM-based response generation with conversation context""" | |
| from langchain_core.messages import AIMessage | |
| # Get the latest user message and conversation history | |
| messages = state["messages"] | |
| if not messages: | |
| ai_message = AIMessage(content="Xin lỗi, em không thể xử lý yêu cầu này.") | |
| return {**state, "messages": [ai_message], "response_type": "error"} | |
| # Find the last human/user message | |
| user_message = None | |
| for msg in reversed(messages): | |
| if hasattr(msg, "type") and msg.type == "human": | |
| user_message = msg.content | |
| break | |
| elif hasattr(msg, "role") and msg.role == "user": | |
| user_message = msg.content | |
| break | |
| if not user_message: | |
| ai_message = AIMessage(content="Xin lỗi, em không thể xử lý yêu cầu này.") | |
| return {**state, "messages": [ai_message], "response_type": "error"} | |
| intent = state.get("intent", "") | |
| retrieved_docs = state.get("retrieved_docs", []) | |
| enhanced_query = state.get("enhanced_query", "") | |
| try: | |
| # Build context from retrieved documents using original content | |
| context = "" | |
| if retrieved_docs: | |
| context = "\n\n".join( | |
| [ | |
| f"[{doc.metadata.get('doc_type', 'unknown')}] {doc.metadata.get('original_content', doc.page_content)}" | |
| for doc in retrieved_docs | |
| ] | |
| ) | |
| # Use saved conversation context if available, otherwise get it | |
| conversation_context = state.get("conversation_context") | |
| if conversation_context is None: | |
| conversation_context = get_conversation_context(messages, max_messages=6) | |
| # Get system prompt based on intent using the prompts file | |
| system_prompt = get_system_prompt_by_intent(intent) | |
| # Build full prompt with both knowledge context and conversation context using the prompts file | |
| generation_prompt = RESPONSE_GENERATION_PROMPT.format( | |
| system_prompt=system_prompt, | |
| context=context, | |
| conversation_context=conversation_context, | |
| user_message=user_message, | |
| enhanced_query=enhanced_query, | |
| ) | |
| # Generate response using the pre-initialized LLM client | |
| response_text = llm_client.generate(generation_prompt, temperature=0.1) | |
| print(f"🤖 Generated response for intent: {intent}") | |
| # Add AI message to the conversation | |
| ai_message = AIMessage(content=response_text) | |
| return {**state, "messages": [ai_message], "response_type": "generated"} | |
| except Exception as e: | |
| print(f"❌ Response generation error: {e}") | |
| error_response = "Xin lỗi, em gặp lỗi khi xử lý yêu cầu. Vui lòng thử lại sau." | |
| ai_message = AIMessage(content=error_response) | |
| return { | |
| **state, | |
| "messages": [ai_message], | |
| "response_type": "error", | |
| "error": str(e), | |
| } | |
| # Routing function for conditional edges | |
| def route_after_intent_classification(state: ViettelPayState) -> str: | |
| """Route to appropriate node after intent classification""" | |
| intent = state.get("intent", "unclear") | |
| # Script-based intents (no knowledge retrieval needed) | |
| script_intents = {"greeting", "out_of_scope", "human_request", "unclear"} | |
| if intent in script_intents: | |
| return "script_response" | |
| else: | |
| # Knowledge-based intents need query enhancement first | |
| return "query_enhancement" | |
| def route_after_query_enhancement(state: ViettelPayState) -> str: | |
| """Route after query enhancement (always to knowledge retrieval)""" | |
| return "knowledge_retrieval" | |
| def route_after_knowledge_retrieval(state: ViettelPayState) -> str: | |
| """Route after knowledge retrieval (always to generation)""" | |
| return "generate_response" | |