Eliot0110 commited on
Commit
794c23a
·
1 Parent(s): 86c5051

improve: voice-totext and classifier model

Browse files
Files changed (2) hide show
  1. modules/ai_model.py +77 -45
  2. modules/intent_classifier.py +1 -2
modules/ai_model.py CHANGED
@@ -108,6 +108,58 @@ class AIModel:
108
  return "audio"
109
 
110
  return "text"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]:
113
 
@@ -133,9 +185,22 @@ class AIModel:
133
  return "text", None, f"图片加载失败,请检查路径或URL。"
134
 
135
  elif input_type == "audio":
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- log.warning("⚠️ 音频处理功能暂未实现")
138
- return "text", None, "抱歉,音频输入功能正在开发中。请使用文字描述您的需求。"
 
139
 
140
  else: # text
141
  return input_type, None, raw_input
@@ -143,36 +208,14 @@ class AIModel:
143
  def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.5) -> str:
144
 
145
  try:
146
- # 截断过长的 prompt
147
- if len(prompt) > 500:
148
- prompt = prompt[:500] + "..."
149
-
150
- # 准备输入 (处理图片或文本)
151
- if input_type == "image" and isinstance(formatted_input, Image.Image):
152
- image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
153
- if image_token not in prompt:
154
- prompt = f"{image_token}\n{prompt}"
155
- inputs = self.processor(
156
- text=prompt,
157
- images=formatted_input,
158
- return_tensors="pt"
159
- ).to(self.model.device, dtype=torch.bfloat16)
160
- else:
161
- inputs = self.processor(
162
- text=prompt,
163
- return_tensors="pt"
164
- ).to(self.model.device, dtype=torch.bfloat16)
165
-
166
- if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
167
- log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
168
- inputs.input_ids = inputs.input_ids[:, :512]
169
- if hasattr(inputs, 'attention_mask'):
170
- inputs.attention_mask = inputs.attention_mask[:, :512]
171
-
172
 
173
  with torch.inference_mode():
174
  generation_args = {
175
- "max_new_tokens": 512,
176
  "pad_token_id": self.processor.tokenizer.eos_token_id,
177
  "use_cache": True
178
  }
@@ -218,7 +261,7 @@ class AIModel:
218
 
219
  full_prompt = "\n".join([msg.get("content", "") for msg in messages])
220
 
221
- temperature = kwargs.get("temperature", 0.7)
222
 
223
  if kwargs.get("response_format", {}).get("type") == "json_object":
224
  # 在 prompt 末尾添加指令,强制模型输出 JSON
@@ -236,16 +279,8 @@ class AIModel:
236
  )
237
 
238
 
239
- def _build_limited_prompt(self, processed_text: str, context: str = "") -> str:
240
- """构建长度受限的prompt - 新增辅助方法"""
241
- # 限制输入长度
242
- if len(processed_text) > 200:
243
- processed_text = processed_text[:200] + "..."
244
-
245
- if context and len(context) > 300:
246
- context = context[:300] + "..."
247
-
248
- # 保持你原有的prompt结构
249
  if context:
250
  return (
251
  f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n"
@@ -274,13 +309,10 @@ class AIModel:
274
  input_type, formatted_data, processed_text = self.format_input(input_type, user_input)
275
 
276
  # 3. 构建prompt - 使用你的原有结构
277
- prompt = self._build_limited_prompt(processed_text, context)
278
 
279
  # 4. 执行推理
280
- if input_type == "image" and formatted_data is not None:
281
- return self.run_inference("image", formatted_data, prompt)
282
- else:
283
- return self.run_inference("text", processed_text, prompt)
284
 
285
  except Exception as e:
286
  log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True)
 
108
  return "audio"
109
 
110
  return "text"
111
+
112
+ def transcribe_audio(self, audio_path: str) -> str:
113
+ """
114
+ 使用 Hugging Face Inference API 将音频文件转写为文本。
115
+ - 通过环境变量加载 HF_TOKEN 保证安全。
116
+ - 包含网络请求超时和状态码检查,增强健壮性。
117
+ """
118
+ # 1. 从环境变量安全地获取 Token
119
+ hf_token = os.getenv("Assitant_tocken")
120
+
121
+ API_URL = "https://api-inference.huggingface.co/models/openai/whisper-large" # 建议使用更新的 v3 版本
122
+ headers = {"Authorization": f"Bearer {hf_token}"}
123
+
124
+ # 2. 检查音频文件是否存在
125
+ if not os.path.exists(audio_path):
126
+ log.error(f"❌ 音频文件不存在: {audio_path}")
127
+ raise FileNotFoundError(f"指定的音频文件路径不存在: {audio_path}")
128
+
129
+ try:
130
+ with open(audio_path, "rb") as f:
131
+ # 3. 发送请求,并设置较长的超时时间 (例如 60 秒)
132
+ log.info(f"🎤 正在向 HF API 发送音频数据... (超时设置为60秒)")
133
+ response = requests.post(API_URL, headers=headers, data=f, timeout=60)
134
+
135
+ # 4. 检查 HTTP 响应状态码,主动抛出错误
136
+ response.raise_for_status() # 如果状态码不是 2xx,则会引发 HTTPError
137
+
138
+ result = response.json()
139
+ log.info("✅ HF API 响应成功。")
140
+
141
+ # 5. 可靠地提取结果或处理错误信息
142
+ if "text" in result:
143
+ return result["text"].strip()
144
+ else:
145
+ error_message = result.get("error", "未知的 API 错误结构。")
146
+ log.error(f"❌ 转录失败,API 返回: {error_message}")
147
+ # 如果模型正在加载,HuggingFace 会在 error 字段中提示
148
+ if isinstance(error_message, dict) and "estimated_time" in error_message:
149
+ raise RuntimeError(f"模型正在加载中,请稍后重试。预计等待时间: {error_message['estimated_time']:.1f}秒")
150
+ raise RuntimeError(f"转录失败: {error_message}")
151
+
152
+ except requests.exceptions.Timeout:
153
+ log.error("❌ 请求超时!API 未在60秒内响应。")
154
+ raise RuntimeError("语音识别服务请求超时,请稍后再试。")
155
+ except requests.exceptions.RequestException as e:
156
+ log.error(f"❌ 网络请求失败: {e}")
157
+ raise RuntimeError(f"无法连接到语音识别服务: {e}")
158
+ except Exception as e:
159
+ # 捕获其他所有可能的异常,例如文件读取错误、JSON解码错误等
160
+ log.error(f"❌ 处理音频时发生未知错误: {e}", exc_info=True)
161
+ raise e
162
+
163
 
164
  def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]:
165
 
 
185
  return "text", None, f"图片加载失败,请检查路径或URL。"
186
 
187
  elif input_type == "audio":
188
+ try:
189
+ # --- 音频处理核心 ---
190
+ # 假设: 您的类中有一个方法 `transcribe_audio` 用于语音转文字。
191
+ # 您需要自行实现这个方法, 例如通过调用 Whisper, FunASR 或其他 ASR 服务。
192
+ # 它接收音频文件路径 (raw_input) 并返回转写的文本字符串。
193
+ log.info(f"🎤 开始处理音频文件: {raw_input}")
194
+ transcribed_text = self.transcribe_audio(raw_input)
195
+ log.info(f"✅ 音频转写成功: '{transcribed_text[:50]}...'")
196
+
197
+ # 注意:处理成功后,我们将 input_type 转为 "text",
198
+ # 因为音频内容已变为文本,后续流程可以统一处理。
199
+ return "text", None, transcribed_text
200
 
201
+ except Exception as e:
202
+ log.error(f"❌ 音频处理失败: {e}", exc_info=True)
203
+ return "text", None, f"音频处理失败,请检查文件或稍后再试。"
204
 
205
  else: # text
206
  return input_type, None, raw_input
 
208
  def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.5) -> str:
209
 
210
  try:
211
+ inputs = self.processor(
212
+ text=prompt,
213
+ return_tensors="pt"
214
+ ).to(self.model.device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  with torch.inference_mode():
217
  generation_args = {
218
+ "max_new_tokens": 1024,
219
  "pad_token_id": self.processor.tokenizer.eos_token_id,
220
  "use_cache": True
221
  }
 
261
 
262
  full_prompt = "\n".join([msg.get("content", "") for msg in messages])
263
 
264
+ temperature = kwargs.get("temperature", 0.6)
265
 
266
  if kwargs.get("response_format", {}).get("type") == "json_object":
267
  # 在 prompt 末尾添加指令,强制模型输出 JSON
 
279
  )
280
 
281
 
282
+ def _build_prompt(self, processed_text: str, context: str = "") -> str:
283
+
 
 
 
 
 
 
 
 
284
  if context:
285
  return (
286
  f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n"
 
309
  input_type, formatted_data, processed_text = self.format_input(input_type, user_input)
310
 
311
  # 3. 构建prompt - 使用你的原有结构
312
+ prompt = self._build_prompt(processed_text, context)
313
 
314
  # 4. 执行推理
315
+ return self.run_inference("text", formatted_data, prompt)
 
 
 
316
 
317
  except Exception as e:
318
  log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True)
modules/intent_classifier.py CHANGED
@@ -43,7 +43,7 @@ PROVIDING_TRAVEL_INFO > INQUIRY > GREETING > OTHER
43
  - 用户输入: "你好,我想去东京玩" -> 分类: PROVIDING_TRAVEL_INFO
44
  - 用户输入: "Hi, 巴黎有什么推荐的吗?" -> 分类: INQUIRY
45
  - 用户输入: "周末愉快!" -> 分类: GREETING
46
- - 用户输入: "我们预算不多,大概3000元,目的地是成都。" -> 分类: PROVIDING_TRAVEL_INFO
47
  - 用户输入: "你好在吗" -> 分类: GREETING
48
  - 用户输入: "随便聊聊" -> 分类: OTHER
49
 
@@ -65,7 +65,6 @@ PROVIDING_TRAVEL_INFO > INQUIRY > GREETING > OTHER
65
 
66
  try:
67
  response = self.ai_model.chat_completion(
68
- model="gpt-3.5-turbo",
69
  messages=[{"role": "user", "content": prompt}],
70
  temperature=0.0,
71
  max_tokens=10
 
43
  - 用户输入: "你好,我想去东京玩" -> 分类: PROVIDING_TRAVEL_INFO
44
  - 用户输入: "Hi, 巴黎有什么推荐的吗?" -> 分类: INQUIRY
45
  - 用户输入: "周末愉快!" -> 分类: GREETING
46
+ - 用户输入: "我们预算不多,大概3000元,目的地是柏林。" -> 分类: PROVIDING_TRAVEL_INFO
47
  - 用户输入: "你好在吗" -> 分类: GREETING
48
  - 用户输入: "随便聊聊" -> 分类: OTHER
49
 
 
65
 
66
  try:
67
  response = self.ai_model.chat_completion(
 
68
  messages=[{"role": "user", "content": prompt}],
69
  temperature=0.0,
70
  max_tokens=10