Spaces:
Sleeping
Sleeping
fix: prompt and loosen intent_classifier
Browse files- modules/info_extractor.py +41 -39
- modules/travel_assistant.py +10 -2
modules/info_extractor.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import json
|
| 2 |
-
import re
|
| 3 |
from utils.logger import log
|
| 4 |
from .ai_model import AIModel
|
| 5 |
|
|
@@ -9,18 +9,17 @@ class InfoExtractor:
|
|
| 9 |
self.prompt_template = self._build_prompt_template()
|
| 10 |
|
| 11 |
def _build_prompt_template(self) -> str:
|
| 12 |
-
# ---
|
| 13 |
-
return """
|
| 14 |
-
|
| 15 |
|
| 16 |
-
|
| 17 |
-
1.
|
| 18 |
-
2.
|
| 19 |
-
3.
|
| 20 |
-
4.
|
| 21 |
-
5. **处理无关输入**: 如果用户输入是简单的问候或与旅行无关,请返回一个空的JSON对象 `{{}}`。
|
| 22 |
|
| 23 |
-
|
| 24 |
```json
|
| 25 |
{{
|
| 26 |
"destination": {{
|
|
@@ -38,19 +37,18 @@ class InfoExtractor:
|
|
| 38 |
```
|
| 39 |
|
| 40 |
**示例:**
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
你的输出:
|
| 44 |
```json
|
| 45 |
{{
|
| 46 |
"destination": {{
|
| 47 |
-
"name": "
|
| 48 |
}},
|
| 49 |
"duration": {{
|
| 50 |
-
"days":
|
| 51 |
}},
|
| 52 |
"budget": {{
|
| 53 |
-
"type":
|
| 54 |
"amount": null,
|
| 55 |
"currency": null
|
| 56 |
}}
|
|
@@ -58,41 +56,27 @@ class InfoExtractor:
|
|
| 58 |
```
|
| 59 |
|
| 60 |
---
|
| 61 |
-
|
| 62 |
-
现在,请处理以下用户输入。
|
| 63 |
-
|
| 64 |
**用户输入:**
|
| 65 |
-
|
| 66 |
-
{user_message}
|
| 67 |
-
```
|
| 68 |
|
| 69 |
-
|
| 70 |
"""
|
| 71 |
|
| 72 |
def extract(self, message: str) -> dict:
|
| 73 |
-
"""
|
| 74 |
-
使用LLM从用户消息中提取结构化信息。
|
| 75 |
-
"""
|
| 76 |
log.info(f"🧠 使用LLM开始提取信息,消息: '{message}'")
|
| 77 |
-
|
| 78 |
-
# 1. 构建完整的Prompt
|
| 79 |
prompt = self.prompt_template.format(user_message=message)
|
| 80 |
-
|
| 81 |
-
# 2. 调用AI模型生成结果
|
| 82 |
raw_response = self.ai_model.generate(prompt)
|
| 83 |
|
| 84 |
if not raw_response:
|
| 85 |
log.error("❌ LLM模型没有返回任何内容。")
|
| 86 |
return {}
|
| 87 |
|
| 88 |
-
|
| 89 |
try:
|
| 90 |
-
# 优先使用正则表达式从 ```json ... ``` 代码块中提取
|
| 91 |
match = re.search(r'```json\s*(\{.*?\})\s*```', raw_response, re.DOTALL)
|
| 92 |
if match:
|
| 93 |
json_str = match.group(1)
|
| 94 |
else:
|
| 95 |
-
# 如果正则没匹配到,就粗暴地寻找第一个'{'和最后一个'}'
|
| 96 |
start_index = raw_response.find('{')
|
| 97 |
end_index = raw_response.rfind('}')
|
| 98 |
if start_index != -1 and end_index != -1 and end_index > start_index:
|
|
@@ -106,10 +90,28 @@ class InfoExtractor:
|
|
| 106 |
log.error(f"❌ 无法解析LLM返回的JSON: '{raw_response}'. 错误: {e}")
|
| 107 |
return {}
|
| 108 |
|
| 109 |
-
#
|
| 110 |
-
final_info = {
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
log.info(f"📊 LLM
|
| 115 |
return final_info
|
|
|
|
| 1 |
import json
|
| 2 |
+
import re
|
| 3 |
from utils.logger import log
|
| 4 |
from .ai_model import AIModel
|
| 5 |
|
|
|
|
| 9 |
self.prompt_template = self._build_prompt_template()
|
| 10 |
|
| 11 |
def _build_prompt_template(self) -> str:
|
| 12 |
+
# --- 重点更新:使用更严格的指令和结构化示例 ---
|
| 13 |
+
return """你的任务是且仅是作为文本解析器。
|
| 14 |
+
严格分析用户输入,并以一个纯净、无注释的JSON对象格式返回。
|
| 15 |
|
| 16 |
+
**核心规则:**
|
| 17 |
+
1. **绝对禁止** 在JSON之外添加任何文本、注释、解释或Markdown标记。你的输出必须从 `{` 开始,到 `}` 结束。
|
| 18 |
+
2. **必须严格遵守** 下方定义的嵌套JSON结构。不要创造新的键,也不要改变层级。
|
| 19 |
+
3. 如果信息未提供,对应的键值必须为 `null`,而不是省略该键。
|
| 20 |
+
4. 如果用户输入与旅行无关(如 "你好"),必须返回一个空的JSON对象: `{{}}`。
|
|
|
|
| 21 |
|
| 22 |
+
**强制JSON输出结构:**
|
| 23 |
```json
|
| 24 |
{{
|
| 25 |
"destination": {{
|
|
|
|
| 37 |
```
|
| 38 |
|
| 39 |
**示例:**
|
| 40 |
+
- 用户输入: "我想去柏林玩3天"
|
| 41 |
+
- 你的输出:
|
|
|
|
| 42 |
```json
|
| 43 |
{{
|
| 44 |
"destination": {{
|
| 45 |
+
"name": "柏林"
|
| 46 |
}},
|
| 47 |
"duration": {{
|
| 48 |
+
"days": 3
|
| 49 |
}},
|
| 50 |
"budget": {{
|
| 51 |
+
"type": null,
|
| 52 |
"amount": null,
|
| 53 |
"currency": null
|
| 54 |
}}
|
|
|
|
| 56 |
```
|
| 57 |
|
| 58 |
---
|
|
|
|
|
|
|
|
|
|
| 59 |
**用户输入:**
|
| 60 |
+
`{user_message}`
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
**你的输出 (必须是纯JSON):**
|
| 63 |
"""
|
| 64 |
|
| 65 |
def extract(self, message: str) -> dict:
|
|
|
|
|
|
|
|
|
|
| 66 |
log.info(f"🧠 使用LLM开始提取信息,消息: '{message}'")
|
|
|
|
|
|
|
| 67 |
prompt = self.prompt_template.format(user_message=message)
|
|
|
|
|
|
|
| 68 |
raw_response = self.ai_model.generate(prompt)
|
| 69 |
|
| 70 |
if not raw_response:
|
| 71 |
log.error("❌ LLM模型没有返回任何内容。")
|
| 72 |
return {}
|
| 73 |
|
| 74 |
+
json_str = ""
|
| 75 |
try:
|
|
|
|
| 76 |
match = re.search(r'```json\s*(\{.*?\})\s*```', raw_response, re.DOTALL)
|
| 77 |
if match:
|
| 78 |
json_str = match.group(1)
|
| 79 |
else:
|
|
|
|
| 80 |
start_index = raw_response.find('{')
|
| 81 |
end_index = raw_response.rfind('}')
|
| 82 |
if start_index != -1 and end_index != -1 and end_index > start_index:
|
|
|
|
| 90 |
log.error(f"❌ 无法解析LLM返回的JSON: '{raw_response}'. 错误: {e}")
|
| 91 |
return {}
|
| 92 |
|
| 93 |
+
# --- 重点更新:使用更健壮、更安全的逻辑来清理数据 ---
|
| 94 |
+
final_info = {}
|
| 95 |
+
|
| 96 |
+
# 安全地处理 'destination'
|
| 97 |
+
destination_data = extracted_data.get("destination")
|
| 98 |
+
if isinstance(destination_data, dict) and destination_data.get("name"):
|
| 99 |
+
final_info["destination"] = {"name": destination_data["name"]}
|
| 100 |
+
|
| 101 |
+
# 安全地处理 'duration'
|
| 102 |
+
duration_data = extracted_data.get("duration")
|
| 103 |
+
if isinstance(duration_data, dict) and duration_data.get("days"):
|
| 104 |
+
try:
|
| 105 |
+
final_info["duration"] = {"days": int(duration_data["days"])}
|
| 106 |
+
except (ValueError, TypeError):
|
| 107 |
+
log.warning(f"⚠️ 无法将duration days '{duration_data.get('days')}' 转换为整数。")
|
| 108 |
+
|
| 109 |
+
# 安全地处理 'budget'
|
| 110 |
+
budget_data = extracted_data.get("budget")
|
| 111 |
+
if isinstance(budget_data, dict):
|
| 112 |
+
# 只要budget对象里有任何非null的值,就把它加进来
|
| 113 |
+
if any(v is not None for v in budget_data.values()):
|
| 114 |
+
final_info["budget"] = budget_data
|
| 115 |
|
| 116 |
+
log.info(f"📊 LLM最终提取结果 (安全处理后): {list(final_info.keys())}")
|
| 117 |
return final_info
|
modules/travel_assistant.py
CHANGED
|
@@ -46,10 +46,18 @@ class TravelAssistant:
|
|
| 46 |
log.info(f"✅ 设置persona: {persona_info['name']}")
|
| 47 |
|
| 48 |
# 3. 意图识别 (前置守卫)
|
| 49 |
-
|
| 50 |
-
log.info(f"🔍 用户意图识别结果: '{
|
| 51 |
|
| 52 |
extracted_info = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# 4.: 根据意图进行逻辑分流
|
| 55 |
if intent == 'PROVIDING_TRAVEL_INFO':
|
|
|
|
| 46 |
log.info(f"✅ 设置persona: {persona_info['name']}")
|
| 47 |
|
| 48 |
# 3. 意图识别 (前置守卫)
|
| 49 |
+
raw_intent = self.intent_classifier.classify(message)
|
| 50 |
+
log.info(f"🔍 用户意图识别结果: '{raw_intent}'")
|
| 51 |
|
| 52 |
extracted_info = {}
|
| 53 |
+
intent = 'OTHER'
|
| 54 |
+
|
| 55 |
+
if 'PROVIDING_TRAVEL_INFO' in raw_intent:
|
| 56 |
+
intent = 'PROVIDING_TRAVEL_INFO'
|
| 57 |
+
elif 'GREETING' in raw_intent:
|
| 58 |
+
intent = 'GREETING'
|
| 59 |
+
|
| 60 |
+
log.info(f"✅ 解析后用户意图: '{intent}'")
|
| 61 |
|
| 62 |
# 4.: 根据意图进行逻辑分流
|
| 63 |
if intent == 'PROVIDING_TRAVEL_INFO':
|