Eliot0110 commited on
Commit
632df2f
·
1 Parent(s): 2b20519

fix : recover file

Browse files
Files changed (1) hide show
  1. modules/info_extractor.py +119 -0
modules/info_extractor.py CHANGED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from .config_loader import ConfigLoader
3
+ from utils.logger import log
4
+
5
+ class InfoExtractor:
6
+ def __init__(self, config_loader: ConfigLoader):
7
+ self.configs = config_loader
8
+
9
+ def extract(self, user_input: str) -> dict:
10
+ """从用户输入中提取目的地、天数和旅行风格"""
11
+ extracted_info = {}
12
+ user_lower = user_input.lower()
13
+
14
+ try:
15
+ # 提取目的地
16
+ destination = self._extract_destination(user_lower)
17
+ if destination:
18
+ extracted_info["destination"] = destination
19
+
20
+ # 提取天数
21
+ duration = self._extract_duration(user_input)
22
+ if duration:
23
+ extracted_info["duration"] = duration
24
+
25
+ # 提取旅行风格 (persona)
26
+ persona = self._extract_persona(user_input)
27
+ if persona:
28
+ extracted_info["persona"] = persona
29
+
30
+ return extracted_info
31
+
32
+ except Exception as e:
33
+ log.error(f"❌ 信息提取失败: {e}", exc_info=True)
34
+ return {}
35
+
36
+ def _extract_destination(self, user_lower: str) -> dict:
37
+ """提取目的地信息"""
38
+ for alias, city_info in self.configs.cities.items():
39
+ if alias in user_lower:
40
+ return city_info
41
+ return None
42
+
43
+ def _extract_duration(self, user_input: str) -> dict:
44
+ """提取天数信息 - 改进版本"""
45
+ patterns = [
46
+ (r'(\d+)\s*天', lambda x: int(x)),
47
+ (r'(\d+)\s*日', lambda x: int(x)),
48
+ (r'一周|7天|七天', lambda x: 7),
49
+ (r'两周|14天|十四天', lambda x: 14),
50
+ (r'周末', lambda x: 2),
51
+ (r'三天|3天', lambda x: 3),
52
+ (r'五天|5天', lambda x: 5),
53
+ (r'十天|10天', lambda x: 10),
54
+ ]
55
+
56
+ for pattern, converter in patterns:
57
+ match = re.search(pattern, user_input)
58
+ if match:
59
+ try:
60
+ if match.groups():
61
+ days = converter(match.group(1))
62
+ else:
63
+ days = converter(None)
64
+
65
+ if 1 <= days <= 30:
66
+ return {"days": days, "description": f"{days}天旅行"}
67
+ except (ValueError, TypeError):
68
+ continue
69
+ return None
70
+
71
+ def _extract_persona(self, user_input: str) -> dict:
72
+ """提取用户偏好 - 根据实际personas.json结构"""
73
+ user_lower = user_input.lower()
74
+
75
+ # 检查personas配置
76
+ for persona_key, persona_info in self.configs.personas.items():
77
+ # 检查persona名称中的关键词
78
+ persona_name = persona_info.get('name', '').lower()
79
+ # 移除emoji并检查关键词
80
+ clean_name = ''.join(char for char in persona_name if char.isalpha() or char.isspace())
81
+
82
+ if any(keyword in user_lower for keyword in clean_name.split() if len(keyword) > 1):
83
+ return {**persona_info, "key": persona_key} # 添加key方便后续使用
84
+
85
+ # 检查风格关键词
86
+ style = persona_info.get('style', '').lower()
87
+ if any(keyword in user_lower for keyword in style.split('、') if len(keyword) > 1):
88
+ return {**persona_info, "key": persona_key}
89
+
90
+ # 检查特征描述
91
+ characteristics = persona_info.get('characteristics', [])
92
+ for char in characteristics:
93
+ char_keywords = char.lower().split()[:3] # 取前3个词作为关键词
94
+ if any(keyword in user_lower for keyword in char_keywords if len(keyword) > 1):
95
+ return {**persona_info, "key": persona_key}
96
+
97
+ # 基于关键词的简单映射
98
+ keyword_persona_map = {
99
+ '规划': 'planner',
100
+ '计划': 'planner',
101
+ '效率': 'planner',
102
+ '预算': 'planner',
103
+ '分享': 'social',
104
+ '朋友': 'social',
105
+ '拍照': 'social',
106
+ '打卡': 'social',
107
+ '深度': 'experiential',
108
+ '文化': 'experiential',
109
+ '地道': 'experiential',
110
+ '当地': 'experiential'
111
+ }
112
+
113
+ for keyword, persona_key in keyword_persona_map.items():
114
+ if keyword in user_input:
115
+ persona_info = self.configs.personas.get(persona_key, {})
116
+ if persona_info:
117
+ return {**persona_info, "key": persona_key}
118
+
119
+ return None