Spaces:
Sleeping
Sleeping
improve: voice-totext and classifier model
Browse files- modules/ai_model.py +77 -45
- modules/intent_classifier.py +1 -2
modules/ai_model.py
CHANGED
|
@@ -108,6 +108,58 @@ class AIModel:
|
|
| 108 |
return "audio"
|
| 109 |
|
| 110 |
return "text"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]:
|
| 113 |
|
|
@@ -133,9 +185,22 @@ class AIModel:
|
|
| 133 |
return "text", None, f"图片加载失败,请检查路径或URL。"
|
| 134 |
|
| 135 |
elif input_type == "audio":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
|
|
|
| 139 |
|
| 140 |
else: # text
|
| 141 |
return input_type, None, raw_input
|
|
@@ -143,36 +208,14 @@ class AIModel:
|
|
| 143 |
def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.5) -> str:
|
| 144 |
|
| 145 |
try:
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
# 准备输入 (处理图片或文本)
|
| 151 |
-
if input_type == "image" and isinstance(formatted_input, Image.Image):
|
| 152 |
-
image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
|
| 153 |
-
if image_token not in prompt:
|
| 154 |
-
prompt = f"{image_token}\n{prompt}"
|
| 155 |
-
inputs = self.processor(
|
| 156 |
-
text=prompt,
|
| 157 |
-
images=formatted_input,
|
| 158 |
-
return_tensors="pt"
|
| 159 |
-
).to(self.model.device, dtype=torch.bfloat16)
|
| 160 |
-
else:
|
| 161 |
-
inputs = self.processor(
|
| 162 |
-
text=prompt,
|
| 163 |
-
return_tensors="pt"
|
| 164 |
-
).to(self.model.device, dtype=torch.bfloat16)
|
| 165 |
-
|
| 166 |
-
if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
|
| 167 |
-
log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
|
| 168 |
-
inputs.input_ids = inputs.input_ids[:, :512]
|
| 169 |
-
if hasattr(inputs, 'attention_mask'):
|
| 170 |
-
inputs.attention_mask = inputs.attention_mask[:, :512]
|
| 171 |
-
|
| 172 |
|
| 173 |
with torch.inference_mode():
|
| 174 |
generation_args = {
|
| 175 |
-
"max_new_tokens":
|
| 176 |
"pad_token_id": self.processor.tokenizer.eos_token_id,
|
| 177 |
"use_cache": True
|
| 178 |
}
|
|
@@ -218,7 +261,7 @@ class AIModel:
|
|
| 218 |
|
| 219 |
full_prompt = "\n".join([msg.get("content", "") for msg in messages])
|
| 220 |
|
| 221 |
-
temperature = kwargs.get("temperature", 0.
|
| 222 |
|
| 223 |
if kwargs.get("response_format", {}).get("type") == "json_object":
|
| 224 |
# 在 prompt 末尾添加指令,强制模型输出 JSON
|
|
@@ -236,16 +279,8 @@ class AIModel:
|
|
| 236 |
)
|
| 237 |
|
| 238 |
|
| 239 |
-
def
|
| 240 |
-
|
| 241 |
-
# 限制输入长度
|
| 242 |
-
if len(processed_text) > 200:
|
| 243 |
-
processed_text = processed_text[:200] + "..."
|
| 244 |
-
|
| 245 |
-
if context and len(context) > 300:
|
| 246 |
-
context = context[:300] + "..."
|
| 247 |
-
|
| 248 |
-
# 保持你原有的prompt结构
|
| 249 |
if context:
|
| 250 |
return (
|
| 251 |
f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n"
|
|
@@ -274,13 +309,10 @@ class AIModel:
|
|
| 274 |
input_type, formatted_data, processed_text = self.format_input(input_type, user_input)
|
| 275 |
|
| 276 |
# 3. 构建prompt - 使用你的原有结构
|
| 277 |
-
prompt = self.
|
| 278 |
|
| 279 |
# 4. 执行推理
|
| 280 |
-
|
| 281 |
-
return self.run_inference("image", formatted_data, prompt)
|
| 282 |
-
else:
|
| 283 |
-
return self.run_inference("text", processed_text, prompt)
|
| 284 |
|
| 285 |
except Exception as e:
|
| 286 |
log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True)
|
|
|
|
| 108 |
return "audio"
|
| 109 |
|
| 110 |
return "text"
|
| 111 |
+
|
| 112 |
+
def transcribe_audio(self, audio_path: str) -> str:
|
| 113 |
+
"""
|
| 114 |
+
使用 Hugging Face Inference API 将音频文件转写为文本。
|
| 115 |
+
- 通过环境变量加载 HF_TOKEN 保证安全。
|
| 116 |
+
- 包含网络请求超时和状态码检查,增强健壮性。
|
| 117 |
+
"""
|
| 118 |
+
# 1. 从环境变量安全地获取 Token
|
| 119 |
+
hf_token = os.getenv("Assitant_tocken")
|
| 120 |
+
|
| 121 |
+
API_URL = "https://api-inference.huggingface.co/models/openai/whisper-large" # 建议使用更新的 v3 版本
|
| 122 |
+
headers = {"Authorization": f"Bearer {hf_token}"}
|
| 123 |
+
|
| 124 |
+
# 2. 检查音频文件是否存在
|
| 125 |
+
if not os.path.exists(audio_path):
|
| 126 |
+
log.error(f"❌ 音频文件不存在: {audio_path}")
|
| 127 |
+
raise FileNotFoundError(f"指定的音频文件路径不存在: {audio_path}")
|
| 128 |
+
|
| 129 |
+
try:
|
| 130 |
+
with open(audio_path, "rb") as f:
|
| 131 |
+
# 3. 发送请求,并设置较长的超时时间 (例如 60 秒)
|
| 132 |
+
log.info(f"🎤 正在向 HF API 发送音频数据... (超时设置为60秒)")
|
| 133 |
+
response = requests.post(API_URL, headers=headers, data=f, timeout=60)
|
| 134 |
+
|
| 135 |
+
# 4. 检查 HTTP 响应状态码,主动抛出错误
|
| 136 |
+
response.raise_for_status() # 如果状态码不是 2xx,则会引发 HTTPError
|
| 137 |
+
|
| 138 |
+
result = response.json()
|
| 139 |
+
log.info("✅ HF API 响应成功。")
|
| 140 |
+
|
| 141 |
+
# 5. 可靠地提取结果或处理错误信息
|
| 142 |
+
if "text" in result:
|
| 143 |
+
return result["text"].strip()
|
| 144 |
+
else:
|
| 145 |
+
error_message = result.get("error", "未知的 API 错误结构。")
|
| 146 |
+
log.error(f"❌ 转录失败,API 返回: {error_message}")
|
| 147 |
+
# 如果模型正在加载,HuggingFace 会在 error 字段中提示
|
| 148 |
+
if isinstance(error_message, dict) and "estimated_time" in error_message:
|
| 149 |
+
raise RuntimeError(f"模型正在加载中,请稍后重试。预计等待时间: {error_message['estimated_time']:.1f}秒")
|
| 150 |
+
raise RuntimeError(f"转录失败: {error_message}")
|
| 151 |
+
|
| 152 |
+
except requests.exceptions.Timeout:
|
| 153 |
+
log.error("❌ 请求超时!API 未在60秒内响应。")
|
| 154 |
+
raise RuntimeError("语音识别服务请求超时,请稍后再试。")
|
| 155 |
+
except requests.exceptions.RequestException as e:
|
| 156 |
+
log.error(f"❌ 网络请求失败: {e}")
|
| 157 |
+
raise RuntimeError(f"无法连接到语音识别服务: {e}")
|
| 158 |
+
except Exception as e:
|
| 159 |
+
# 捕获其他所有可能的异常,例如文件读取错误、JSON解码错误等
|
| 160 |
+
log.error(f"❌ 处理音频时发生未知错误: {e}", exc_info=True)
|
| 161 |
+
raise e
|
| 162 |
+
|
| 163 |
|
| 164 |
def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]:
|
| 165 |
|
|
|
|
| 185 |
return "text", None, f"图片加载失败,请检查路径或URL。"
|
| 186 |
|
| 187 |
elif input_type == "audio":
|
| 188 |
+
try:
|
| 189 |
+
# --- 音频处理核心 ---
|
| 190 |
+
# 假设: 您的类中有一个方法 `transcribe_audio` 用于语音转文字。
|
| 191 |
+
# 您需要自行实现这个方法, 例如通过调用 Whisper, FunASR 或其他 ASR 服务。
|
| 192 |
+
# 它接收音频文件路径 (raw_input) 并返回转写的文本字符串。
|
| 193 |
+
log.info(f"🎤 开始处理音频文件: {raw_input}")
|
| 194 |
+
transcribed_text = self.transcribe_audio(raw_input)
|
| 195 |
+
log.info(f"✅ 音频转写成功: '{transcribed_text[:50]}...'")
|
| 196 |
+
|
| 197 |
+
# 注意:处理成功后,我们将 input_type 转为 "text",
|
| 198 |
+
# 因为音频内容已变为文本,后续流程可以统一处理。
|
| 199 |
+
return "text", None, transcribed_text
|
| 200 |
|
| 201 |
+
except Exception as e:
|
| 202 |
+
log.error(f"❌ 音频处理失败: {e}", exc_info=True)
|
| 203 |
+
return "text", None, f"音频处理失败,请检查文件或稍后再试。"
|
| 204 |
|
| 205 |
else: # text
|
| 206 |
return input_type, None, raw_input
|
|
|
|
| 208 |
def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.5) -> str:
|
| 209 |
|
| 210 |
try:
|
| 211 |
+
inputs = self.processor(
|
| 212 |
+
text=prompt,
|
| 213 |
+
return_tensors="pt"
|
| 214 |
+
).to(self.model.device, dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
with torch.inference_mode():
|
| 217 |
generation_args = {
|
| 218 |
+
"max_new_tokens": 1024,
|
| 219 |
"pad_token_id": self.processor.tokenizer.eos_token_id,
|
| 220 |
"use_cache": True
|
| 221 |
}
|
|
|
|
| 261 |
|
| 262 |
full_prompt = "\n".join([msg.get("content", "") for msg in messages])
|
| 263 |
|
| 264 |
+
temperature = kwargs.get("temperature", 0.6)
|
| 265 |
|
| 266 |
if kwargs.get("response_format", {}).get("type") == "json_object":
|
| 267 |
# 在 prompt 末尾添加指令,强制模型输出 JSON
|
|
|
|
| 279 |
)
|
| 280 |
|
| 281 |
|
| 282 |
+
def _build_prompt(self, processed_text: str, context: str = "") -> str:
|
| 283 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
if context:
|
| 285 |
return (
|
| 286 |
f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n"
|
|
|
|
| 309 |
input_type, formatted_data, processed_text = self.format_input(input_type, user_input)
|
| 310 |
|
| 311 |
# 3. 构建prompt - 使用你的原有结构
|
| 312 |
+
prompt = self._build_prompt(processed_text, context)
|
| 313 |
|
| 314 |
# 4. 执行推理
|
| 315 |
+
return self.run_inference("text", formatted_data, prompt)
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
except Exception as e:
|
| 318 |
log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True)
|
modules/intent_classifier.py
CHANGED
|
@@ -43,7 +43,7 @@ PROVIDING_TRAVEL_INFO > INQUIRY > GREETING > OTHER
|
|
| 43 |
- 用户输入: "你好,我想去东京玩" -> 分类: PROVIDING_TRAVEL_INFO
|
| 44 |
- 用户输入: "Hi, 巴黎有什么推荐的吗?" -> 分类: INQUIRY
|
| 45 |
- 用户输入: "周末愉快!" -> 分类: GREETING
|
| 46 |
-
- 用户输入: "我们预算不多,大概3000
|
| 47 |
- 用户输入: "你好在吗" -> 分类: GREETING
|
| 48 |
- 用户输入: "随便聊聊" -> 分类: OTHER
|
| 49 |
|
|
@@ -65,7 +65,6 @@ PROVIDING_TRAVEL_INFO > INQUIRY > GREETING > OTHER
|
|
| 65 |
|
| 66 |
try:
|
| 67 |
response = self.ai_model.chat_completion(
|
| 68 |
-
model="gpt-3.5-turbo",
|
| 69 |
messages=[{"role": "user", "content": prompt}],
|
| 70 |
temperature=0.0,
|
| 71 |
max_tokens=10
|
|
|
|
| 43 |
- 用户输入: "你好,我想去东京玩" -> 分类: PROVIDING_TRAVEL_INFO
|
| 44 |
- 用户输入: "Hi, 巴黎有什么推荐的吗?" -> 分类: INQUIRY
|
| 45 |
- 用户输入: "周末愉快!" -> 分类: GREETING
|
| 46 |
+
- 用户输入: "我们预算不多,大概3000元,目的地是柏林。" -> 分类: PROVIDING_TRAVEL_INFO
|
| 47 |
- 用户输入: "你好在吗" -> 分类: GREETING
|
| 48 |
- 用户输入: "随便聊聊" -> 分类: OTHER
|
| 49 |
|
|
|
|
| 65 |
|
| 66 |
try:
|
| 67 |
response = self.ai_model.chat_completion(
|
|
|
|
| 68 |
messages=[{"role": "user", "content": prompt}],
|
| 69 |
temperature=0.0,
|
| 70 |
max_tokens=10
|