Eliot0110 commited on
Commit
af60cba
·
1 Parent(s): 75b799c

improve: 优化模型调用并对各组件升级

Browse files
app.py CHANGED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, Gemma3nForConditionalGeneration
2
+ from PIL import Image
3
+ import requests
4
+ import torch
5
+
6
+ model_id = "google/gemma-3n-e2b-it"
7
+
8
+ model = Gemma3nForConditionalGeneration.from_pretrained(model_id, device="cuda", torch_dtype=torch.bfloat16,).eval()
9
+
10
+ processor = AutoProcessor.from_pretrained(model_id)
11
+
12
+ messages = [
13
+ {
14
+ "role": "system",
15
+ "content": [{"type": "text", "text": "You are a helpful assistant."}]
16
+ },
17
+ {
18
+ "role": "user",
19
+ "content": [
20
+ {"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
21
+ {"type": "text", "text": "Describe this image in detail."}
22
+ ]
23
+ }
24
+ ]
25
+
26
+ inputs = processor.apply_chat_template(
27
+ messages,
28
+ add_generation_prompt=True,
29
+ tokenize=True,
30
+ return_dict=True,
31
+ return_tensors="pt",
32
+ ).to(model.device, dtype=torch.bfloat16)
33
+
34
+ input_len = inputs["input_ids"].shape[-1]
35
+
36
+ with torch.inference_mode():
37
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
38
+ generation = generation[0][input_len:]
39
+
40
+ decoded = processor.decode(generation, skip_special_tokens=True)
41
+ print(decoded)
42
+
43
+ # **Overall Impression:** The image is a close-up shot of a vibrant garden scene,
44
+ # focusing on a cluster of pink cosmos flowers and a busy bumblebee.
45
+ # It has a slightly soft, natural feel, likely captured in daylight.
modules/__init__.py CHANGED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 导入所有模块类
2
+ from .config_loader import ConfigLoader
3
+ from .ai_model import AIModel
4
+ from .knowledge_base import KnowledgeBase
5
+ from .info_extractor import InfoExtractor
6
+ from .session_manager import SessionManager
7
+ from .response_generator import ResponseGenerator
8
+ from .travel_assistant import TravelAssistant
9
+
10
+ # 定义包的公共接口
11
+ __all__ = [
12
+ 'ConfigLoader',
13
+ 'AIModel',
14
+ 'KnowledgeBase',
15
+ 'InfoExtractor',
16
+ 'SessionManager',
17
+ 'ResponseGenerator',
18
+ 'TravelAssistant'
19
+ ]
20
+
21
+ # 版本信息
22
+ __version__ = '1.0.0'
23
+
24
+ # 包级别的便捷函数(可选)
25
+ def create_travel_assistant():
26
+ """
27
+ 便捷函数:创建一个完整配置的旅游助手实例
28
+ """
29
+ return TravelAssistant()
modules/ai_model.py CHANGED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/ai_model.py
2
+ import torch
3
+ import base64
4
+ import requests
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ from transformers import AutoProcessor, Gemma3nForConditionalGeneration
8
+ from utils.logger import log
9
+ from typing import Union, Tuple
10
+
11
+ class AIModel:
12
+ def __init__(self, model_name: str = "google/gemma-3n-e2b-it"):
13
+ self.model_name = model_name
14
+ self.model = None
15
+ self.processor = None
16
+ self._initialize_model()
17
+
18
+ def _initialize_model(self):
19
+ """初始化Gemma模型 - 基于官方调用方式"""
20
+ try:
21
+ log.info(f"正在加载模型: {self.model_name}")
22
+
23
+ self.model = Gemma3nForConditionalGeneration.from_pretrained(
24
+ self.model_name,
25
+ device_map="auto",
26
+ torch_dtype=torch.bfloat16,
27
+ trust_remote_code=True
28
+ ).eval()
29
+
30
+ self.processor = AutoProcessor.from_pretrained(
31
+ self.model_name,
32
+ trust_remote_code=True
33
+ )
34
+
35
+ log.info("✅ Gemma AI 模型初始化成功")
36
+
37
+ except Exception as e:
38
+ log.error(f"❌ Gemma AI 模型初始化失败: {e}", exc_info=True)
39
+ self.model = None
40
+ self.processor = None
41
+
42
+ def is_available(self) -> bool:
43
+ """检查模型是否可用"""
44
+ return self.model is not None and self.processor is not None
45
+
46
+ def detect_input_type(self, input_data: str) -> str:
47
+ """检测输入类型:图片/音频/文字"""
48
+ if isinstance(input_data, str):
49
+ # 检查是否为图片URL或路径
50
+ if (input_data.startswith(("http://", "https://")) and
51
+ any(input_data.lower().endswith(ext) for ext in [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"])):
52
+ return "image"
53
+ elif input_data.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp")):
54
+ return "image"
55
+ # 检查是否为音频URL或路径
56
+ elif (input_data.startswith(("http://", "https://")) and
57
+ any(input_data.lower().endswith(ext) for ext in [".wav", ".mp3", ".m4a", ".ogg"])):
58
+ return "audio"
59
+ elif input_data.endswith((".wav", ".mp3", ".m4a", ".ogg")):
60
+ return "audio"
61
+ # 检查是否为base64编码的图片
62
+ elif input_data.startswith("data:image/"):
63
+ return "image"
64
+
65
+ return "text"
66
+
67
+ def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]:
68
+ """格式化输入数据"""
69
+ formatted_data = None
70
+ processed_text = raw_input
71
+
72
+ if input_type == "image":
73
+ try:
74
+ if raw_input.startswith("data:image/"):
75
+ # 处理base64编码的图片
76
+ header, encoded = raw_input.split(",", 1)
77
+ image_data = base64.b64decode(encoded)
78
+ image = Image.open(BytesIO(image_data)).convert("RGB")
79
+ elif raw_input.startswith(("http://", "https://")):
80
+ # 处理图片URL
81
+ response = requests.get(raw_input, timeout=10)
82
+ response.raise_for_status()
83
+ image = Image.open(BytesIO(response.content)).convert("RGB")
84
+ else:
85
+ # 处理本地图片路径
86
+ image = Image.open(raw_input).convert("RGB")
87
+
88
+ formatted_data = image
89
+ processed_text = "请描述这张图片,并基于图片内容提供旅游建议。"
90
+ log.info("✅ 图片加载成功")
91
+
92
+ except Exception as e:
93
+ log.error(f"❌ 图片加载失败: {e}")
94
+ return "text", f"图片加载失败,请检查图片路径或URL。原始输入: {raw_input}"
95
+
96
+ elif input_type == "audio":
97
+ # 音频处理逻辑(如果需要的话,目前先返回提示)
98
+ log.warning("⚠️ 音频处理功能暂未实现")
99
+ processed_text = "抱歉,音频输入功能正在开发中。请使用文字描述您的需求。"
100
+
101
+ elif input_type == "text":
102
+ # 文字输入直接使用
103
+ formatted_data = None
104
+ processed_text = raw_input
105
+
106
+ return input_type, formatted_data, processed_text
107
+
108
+ def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str) -> str:
109
+ """执行模型推理"""
110
+ try:
111
+ if input_type == "image" and isinstance(formatted_input, Image.Image):
112
+ # 图片输入处理
113
+ image_token = self.processor.tokenizer.image_token
114
+ if image_token not in prompt:
115
+ prompt = f"{image_token}\n{prompt}"
116
+
117
+ inputs = self.processor(
118
+ text=prompt,
119
+ images=formatted_input,
120
+ return_tensors="pt"
121
+ ).to(self.model.device, dtype=torch.bfloat16)
122
+ else:
123
+ # 纯文本输入处理
124
+ inputs = self.processor(
125
+ text=prompt,
126
+ return_tensors="pt"
127
+ ).to(self.model.device, dtype=torch.bfloat16)
128
+
129
+ # 生成响应
130
+ with torch.inference_mode():
131
+ outputs = self.model.generate(
132
+ **inputs,
133
+ max_new_tokens=512,
134
+ do_sample=True,
135
+ temperature=0.7,
136
+ top_p=0.9,
137
+ pad_token_id=self.processor.tokenizer.eos_token_id
138
+ )
139
+
140
+ # 解码输出
141
+ decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
142
+
143
+ # 清理输出,移除输入的prompt部分
144
+ if prompt in decoded:
145
+ decoded = decoded.replace(prompt, "").strip()
146
+
147
+ return decoded
148
+
149
+ except Exception as e:
150
+ log.error(f"❌ 模型推理失败: {e}", exc_info=True)
151
+ return "抱歉,我在处理您的请求时遇到了技术问题,请稍后再试。"
152
+
153
+ def generate(self, user_input: str, context: str = "") -> str:
154
+ """主要的生成方法 - 支持多模态输入"""
155
+ if not self.is_available():
156
+ return "抱歉,AI 模型当前不可用,请稍后再试。"
157
+
158
+ try:
159
+ # 1. 检测输入类型
160
+ input_type = self.detect_input_type(user_input)
161
+ log.info(f"检测到输入类型: {input_type}")
162
+
163
+ # 2. 格式化输入
164
+ input_type, formatted_data, processed_text = self.format_input(input_type, user_input)
165
+
166
+ # 3. 构建prompt
167
+ if context:
168
+ prompt = (
169
+ f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n"
170
+ f"--- 背景信息 ---\n{context}\n\n"
171
+ f"--- 用户问题 ---\n{processed_text}\n\n"
172
+ f"请提供专业、实用的旅游建议:"
173
+ )
174
+ else:
175
+ prompt = (
176
+ f"你是一个专业的旅游助手。请用中文友好地回答用户的问题。\n\n"
177
+ f"用户问题:{processed_text}\n\n"
178
+ f"请提供专业、实用的旅游建议:"
179
+ )
180
+
181
+ # 4. 执行推理
182
+ if input_type == "image" and formatted_data is not None:
183
+ return self.run_inference("image", formatted_data, prompt)
184
+ else:
185
+ return self.run_inference("text", processed_text, prompt)
186
+
187
+ except Exception as e:
188
+ log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True)
189
+ return "抱歉,我在思考时遇到了点麻烦,请稍后再试。"
modules/config_loader.py CHANGED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/config_loader.py
2
+ import json
3
+ from pathlib import Path
4
+ from utils.logger import log
5
+
6
+ class ConfigLoader:
7
+ def __init__(self, config_dir: Path = Path("./config")):
8
+ self.config_dir = config_dir
9
+ self.cities = {}
10
+ self.personas = {}
11
+ self.interests = {}
12
+ try:
13
+ self._load_all()
14
+ log.info("✅ 所有配置文件加载成功")
15
+ except Exception as e:
16
+ log.error(f"❌ 配置文件加载失败: {e}", exc_info=True)
17
+ raise
18
+
19
+ def _load_all(self):
20
+ # 加载城市
21
+ with open(self.config_dir / "cities.json", 'r', encoding='utf-8') as f:
22
+ cities_data = json.load(f)
23
+ for city in cities_data['cities']:
24
+ for alias in [city['name']] + city.get('aliases', []):
25
+ self.cities[alias.lower()] = city
26
+
27
+ # 加载 personas
28
+ with open(self.config_dir / "personas.json", 'r', encoding='utf-8') as f:
29
+ self.personas = json.load(f)['personas']
30
+
31
+ # 加载兴趣
32
+ with open(self.config_dir / "interests.json", 'r', encoding='utf-8') as f:
33
+ self.interests = json.load(f)['interests']
modules/info_extractor.py CHANGED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/info_extractor.py
2
+ import re
3
+ from .config_loader import ConfigLoader
4
+
5
+ class InfoExtractor:
6
+ def __init__(self, config_loader: ConfigLoader):
7
+ self.configs = config_loader
8
+
9
+ def extract(self, user_input: str) -> dict:
10
+ """从用户输入中提取目的地、天数和旅行风格"""
11
+ extracted_info = {}
12
+ user_lower = user_input.lower()
13
+
14
+ # 提取目的地
15
+ for alias, city_info in self.configs.cities.items():
16
+ if alias in user_lower:
17
+ extracted_info["destination"] = city_info
18
+ break
19
+
20
+ # 提取天数
21
+ match = re.search(r'(\d+)\s*天', user_input)
22
+ if match:
23
+ extracted_info["duration"] = {"days": int(match.group(1))}
24
+
25
+ # 提取旅行风格 (persona)
26
+ for p_name, p_info in self.configs.personas.items():
27
+ if p_info['name'] in user_input or p_name in user_input:
28
+ extracted_info["persona"] = p_info
29
+ break
30
+
31
+ return extracted_info
modules/knowledge_base.py CHANGED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/knowledge_base.py
2
+ import json
3
+ from pathlib import Path
4
+ from utils.logger import log
5
+
6
+ class KnowledgeBase:
7
+ def __init__(self, file_path: Path = Path("./config/general_travelplan.json")):
8
+ self.knowledge = []
9
+ try:
10
+ with open(file_path, 'r', encoding='utf-8') as f:
11
+ self.knowledge = json.load(f).get('clean_knowledge', [])
12
+ log.info(f"✅ 知识库加载完成")
13
+ except Exception as e:
14
+ log.error(f"❌ 知识库加载失败: {e}", exc_info=True)
15
+ raise
16
+
17
+ def search(self, query: str) -> list:
18
+ relevant_knowledge = []
19
+ query_lower = query.lower()
20
+
21
+ for item in self.knowledge:
22
+ # 简单实现:如果查询的城市在知识库的目的地中,则返回该知识
23
+ destinations = item.get('knowledge', {}).get('travel_knowledge', {}).get('destination_info', {}).get('primary_destinations', [])
24
+ for dest in destinations:
25
+ if dest.lower() in query_lower:
26
+ if item not in relevant_knowledge:
27
+ relevant_knowledge.append(item)
28
+ break
29
+ return relevant_knowledge
modules/response_generator.py CHANGED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/response_generator.py
2
+ from .ai_model import AIModel
3
+ from .knowledge_base import KnowledgeBase
4
+
5
+ class ResponseGenerator:
6
+ def __init__(self, ai_model: AIModel, knowledge_base: KnowledgeBase):
7
+ self.ai_model = ai_model
8
+ self.kb = knowledge_base
9
+
10
+ def generate(self, user_message: str, session_state: dict) -> str:
11
+ # 1. 优先使用 RAG (检索增强生成)
12
+ # 我们用目的地名称来强化检索查询
13
+ search_query = user_message
14
+ if session_state.get("destination"):
15
+ search_query += f" {session_state['destination']['name']}"
16
+
17
+ relevant_knowledge = self.kb.search(search_query)
18
+ if relevant_knowledge:
19
+ context = self._format_knowledge_context(relevant_knowledge)
20
+ return self.ai_model.generate(user_message, context)
21
+
22
+ # 2. 如果没有知识库匹配,则使用基于规则的引导式对话
23
+ if not session_state.get("destination"):
24
+ return "听起来很棒!你想去欧洲的哪个城市呢?比如巴黎, 罗马, 巴塞罗那?"
25
+ if not session_state.get("duration"):
26
+ return f"好的,{session_state['destination']['name']}是个很棒的选择!你计划玩几天呢?"
27
+ if not session_state.get("persona"):
28
+ return "最后一个问题,这次旅行对你来说什么最重要呢?(例如:美食、艺术、购物、历史)"
29
+
30
+ # 3. 如果信息都收集全了,但没触发RAG,让Gemma生成一个通用计划
31
+ plan_prompt = (
32
+ f"请为用户生成一个在 {session_state['destination']['name']} 的 "
33
+ f"{session_state['duration']['days']} 天旅行计划。"
34
+ f"旅行风格侧重于: {session_state['persona']['description']}。"
35
+ )
36
+ return self.ai_model.generate(plan_prompt, context="用户需要一个详细的旅行计划。")
37
+
38
+ def _format_knowledge_context(self, knowledge_items: list) -> str:
39
+ if not knowledge_items: return "没有特定的背景知识。"
40
+ # 简化处理,只用最相关的一条知识
41
+ item = knowledge_items[0]['knowledge']['travel_knowledge']
42
+ context = f"相关知识:\n- 目的地: {item['destination_info']['primary_destinations']}\n"
43
+ context += f"- 推荐天数: {item['destination_info']['recommended_duration']}天\n"
44
+ context += f"- 专业见解: {item['professional_insights']['key_takeaways']}\n"
45
+ return context
modules/session_manager.py CHANGED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/session_manager.py
2
+ import uuid
3
+ from typing import Dict, Any
4
+
5
+ class SessionManager:
6
+ def __init__(self):
7
+ self.sessions: Dict[str, Dict[str, Any]] = {}
8
+
9
+ def get_or_create_session(self, session_id: str = None) -> Dict[str, Any]:
10
+ if not session_id or session_id not in self.sessions:
11
+ session_id = str(uuid.uuid4())[:8]
12
+ self.sessions[session_id] = {
13
+ "session_id": session_id,
14
+ "destination": None,
15
+ "duration": None,
16
+ "persona": None,
17
+ "stage": "greeting" # 对话状态机
18
+ }
19
+ return self.sessions[session_id]
20
+
21
+ def update_session(self, session_id: str, updates: Dict[str, Any]):
22
+ if session_id in self.sessions:
23
+ self.sessions[session_id].update(updates)
24
+
25
+ def format_session_info(self, session_state: dict) -> str:
26
+ parts = [f"ID: {session_state.get('session_id', 'N/A')}"]
27
+ if session_state.get('destination'): parts.append(f"目的地: {session_state['destination']['name']}")
28
+ if session_state.get('duration'): parts.append(f"天数: {session_state['duration']['days']}")
29
+ if session_state.get('persona'): parts.append(f"风格: {session_state['persona']['name']}")
30
+ return " | ".join(parts)
31
+
32
+ def reset(self, session_id: str):
33
+ if session_id in self.sessions:
34
+ del self.sessions[session_id]
modules/travel_assistant.py CHANGED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/travel_assistant.py
2
+ from .config_loader import ConfigLoader
3
+ from .ai_model import AIModel
4
+ from .knowledge_base import KnowledgeBase
5
+ from .info_extractor import InfoExtractor
6
+ from .session_manager import SessionManager
7
+ from .response_generator import ResponseGenerator
8
+ from utils.logger import log
9
+
10
+ class TravelAssistant:
11
+ def __init__(self):
12
+ # 依赖注入:在这里实例化所有需要的模块
13
+ log.info("开始初始化 Travel Assistant 核心模块...")
14
+ self.config = ConfigLoader()
15
+ self.kb = KnowledgeBase()
16
+ self.ai_model = AIModel()
17
+ self.session_manager = SessionManager()
18
+ self.info_extractor = InfoExtractor(self.config)
19
+ self.response_generator = ResponseGenerator(self.ai_model, self.kb)
20
+ log.info("✅ Travel Assistant 核心模块全部初始化完成!")
21
+
22
+ def chat(self, message: str, session_id: str, history: list):
23
+ # 1. 获取或创建会话
24
+ session_state = self.session_manager.get_or_create_session(session_id)
25
+ current_session_id = session_state['session_id']
26
+
27
+ # 2. 从用户输入中提取信息
28
+ extracted_info = self.info_extractor.extract(message)
29
+
30
+ # 3. 更新会话状态
31
+ if extracted_info:
32
+ self.session_manager.update_session(current_session_id, extracted_info)
33
+ # 重新获取更新后的状态
34
+ session_state = self.session_manager.get_or_create_session(current_session_id)
35
+
36
+ # 4. 生成回复
37
+ bot_response = self.response_generator.generate(message, session_state)
38
+
39
+ # 5. 格式化状态信息用于前端显示
40
+ status_info = self.session_manager.format_session_info(session_state)
41
+
42
+ # 6. 更新对话历史
43
+ new_history = history + [[message, bot_response]]
44
+
45
+ return bot_response, current_session_id, status_info, new_history