Eliot0110 commited on
Commit
6c0d50f
·
1 Parent(s): 632df2f

fix : fix the inference

Browse files
Files changed (1) hide show
  1. modules/ai_model.py +98 -82
modules/ai_model.py CHANGED
@@ -69,7 +69,6 @@ class AIModel:
69
  try:
70
  log.info(f"正在加载模型: {self.model_name}")
71
 
72
- # 先进行认证并获取token
73
  token = self._authenticate_hf()
74
 
75
  if not token:
@@ -78,7 +77,6 @@ class AIModel:
78
  self.processor = None
79
  return
80
 
81
- # 设置缓存目录
82
  cache_dir = "/app/.cache/huggingface"
83
 
84
  self.model = Gemma3nForConditionalGeneration.from_pretrained(
@@ -105,150 +103,168 @@ class AIModel:
105
  self.processor = None
106
 
107
  def is_available(self) -> bool:
108
- """检查模型是否可用"""
109
  return self.model is not None and self.processor is not None
110
 
111
  def detect_input_type(self, input_data: str) -> str:
112
- """检测输入类型:图片/音频/文字"""
113
- if isinstance(input_data, str):
114
- # 检查是否为图片URL或路径
115
- if (input_data.startswith(("http://", "https://")) and
116
- any(input_data.lower().endswith(ext) for ext in [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"])):
117
- return "image"
118
- elif input_data.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp")):
119
- return "image"
120
- # 检查是否为音频URL或路径
121
- elif (input_data.startswith(("http://", "https://")) and
122
- any(input_data.lower().endswith(ext) for ext in [".wav", ".mp3", ".m4a", ".ogg"])):
123
- return "audio"
124
- elif input_data.endswith((".wav", ".mp3", ".m4a", ".ogg")):
125
- return "audio"
126
- # 检查是否为base64编码的图片
127
- elif input_data.startswith("data:image/"):
128
- return "image"
129
-
 
 
130
  return "text"
131
 
132
  def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]:
133
- """格式化输入数据"""
134
- formatted_data = None
135
- processed_text = raw_input
136
-
137
  if input_type == "image":
138
  try:
139
  if raw_input.startswith("data:image/"):
140
- # 处理base64编码的图片
141
  header, encoded = raw_input.split(",", 1)
142
  image_data = base64.b64decode(encoded)
143
  image = Image.open(BytesIO(image_data)).convert("RGB")
144
  elif raw_input.startswith(("http://", "https://")):
145
- # 处理图片URL
146
  response = requests.get(raw_input, timeout=10)
147
  response.raise_for_status()
148
  image = Image.open(BytesIO(response.content)).convert("RGB")
149
  else:
150
- # 处理本地图片路径
151
- image = Image.open(raw_input).convert("RGB")
152
 
153
- formatted_data = image
154
- processed_text = "请描述这张图片,并基于图片内容提供旅游建议。"
155
  log.info("✅ 图片加载成功")
156
-
 
157
  except Exception as e:
158
  log.error(f"❌ 图片加载失败: {e}")
159
- return "text", f"图片加载失败,请检查图片路径或URL。原始输入: {raw_input}"
160
-
161
  elif input_type == "audio":
162
- # 音频处理逻辑(如果需要的话,目前先返回提示)
163
- log.warning("⚠️ 音频处理功能暂未实现")
164
- processed_text = "抱歉,音频输入功能正在开发中。请使用文字描述您的需求。"
165
 
166
- elif input_type == "text":
167
- # 文字输入直接使用
168
- formatted_data = None
169
- processed_text = raw_input
170
 
171
- return input_type, formatted_data, processed_text
 
172
 
173
  def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str) -> str:
174
- """执行模型推理"""
175
  try:
 
 
 
 
176
  if input_type == "image" and isinstance(formatted_input, Image.Image):
177
- # 图片输入处理
178
- image_token = self.processor.tokenizer.image_token
179
  if image_token not in prompt:
180
  prompt = f"{image_token}\n{prompt}"
181
-
182
  inputs = self.processor(
183
- text=prompt,
184
- images=formatted_input,
185
  return_tensors="pt"
186
  ).to(self.model.device, dtype=torch.bfloat16)
187
  else:
188
- # 纯文本输入处理
189
  inputs = self.processor(
190
- text=prompt,
191
  return_tensors="pt"
192
  ).to(self.model.device, dtype=torch.bfloat16)
193
 
194
- # 生成响应
 
 
 
 
 
195
  with torch.inference_mode():
196
  outputs = self.model.generate(
197
- **inputs,
198
- max_new_tokens=512,
199
  do_sample=True,
200
  temperature=0.7,
201
  top_p=0.9,
202
- pad_token_id=self.processor.tokenizer.eos_token_id
 
203
  )
204
-
205
- # 解码输出
206
  decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
207
-
208
- # 清理输出,移除输入的prompt部分
209
  if prompt in decoded:
210
  decoded = decoded.replace(prompt, "").strip()
211
-
212
- return decoded
213
-
 
 
 
 
 
214
  except Exception as e:
215
  log.error(f"❌ 模型推理失败: {e}", exc_info=True)
216
- return "抱歉,我在处理您的请求时遇到了技术问题,请稍后再试。"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  def generate(self, user_input: str, context: str = "") -> str:
219
- """主要的生成方法 - 支持多模态输入"""
220
  if not self.is_available():
221
  return "抱歉,AI 模型当前不可用,请稍后再试。"
222
-
223
  try:
224
  # 1. 检测输入类型
225
  input_type = self.detect_input_type(user_input)
226
  log.info(f"检测到输入类型: {input_type}")
227
-
228
  # 2. 格式化输入
229
  input_type, formatted_data, processed_text = self.format_input(input_type, user_input)
230
-
231
- # 3. 构建prompt
232
- if context:
233
- prompt = (
234
- f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n"
235
- f"--- 背景信息 ---\n{context}\n\n"
236
- f"--- 用户问题 ---\n{processed_text}\n\n"
237
- f"请提供专业、实用的旅游建议:"
238
- )
239
- else:
240
- prompt = (
241
- f"你是一个专业的旅游助手。请用中文友好地回答用户的问题。\n\n"
242
- f"用户问题:{processed_text}\n\n"
243
- f"请提供专业、实用的旅游建议:"
244
- )
245
-
246
  # 4. 执行推理
247
  if input_type == "image" and formatted_data is not None:
248
  return self.run_inference("image", formatted_data, prompt)
249
  else:
250
  return self.run_inference("text", processed_text, prompt)
251
-
252
  except Exception as e:
253
  log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True)
254
  return "抱歉,我在思考时遇到了点麻烦,请稍后再试。"
 
69
  try:
70
  log.info(f"正在加载模型: {self.model_name}")
71
 
 
72
  token = self._authenticate_hf()
73
 
74
  if not token:
 
77
  self.processor = None
78
  return
79
 
 
80
  cache_dir = "/app/.cache/huggingface"
81
 
82
  self.model = Gemma3nForConditionalGeneration.from_pretrained(
 
103
  self.processor = None
104
 
105
  def is_available(self) -> bool:
106
+
107
  return self.model is not None and self.processor is not None
108
 
109
  def detect_input_type(self, input_data: str) -> str:
110
+
111
+ if not isinstance(input_data, str):
112
+ return "text"
113
+
114
+ image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"]
115
+ if (input_data.startswith(("http://", "https://")) and
116
+ any(input_data.lower().endswith(ext) for ext in image_extensions)):
117
+ return "image"
118
+ elif any(input_data.endswith(ext) for ext in image_extensions):
119
+ return "image"
120
+ elif input_data.startswith("data:image/"):
121
+ return "image"
122
+
123
+ audio_extensions = [".wav", ".mp3", ".m4a", ".ogg", ".flac"]
124
+ if (input_data.startswith(("http://", "https://")) and
125
+ any(input_data.lower().endswith(ext) for ext in audio_extensions)):
126
+ return "audio"
127
+ elif any(input_data.endswith(ext) for ext in audio_extensions):
128
+ return "audio"
129
+
130
  return "text"
131
 
132
  def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]:
133
+
 
 
 
134
  if input_type == "image":
135
  try:
136
  if raw_input.startswith("data:image/"):
 
137
  header, encoded = raw_input.split(",", 1)
138
  image_data = base64.b64decode(encoded)
139
  image = Image.open(BytesIO(image_data)).convert("RGB")
140
  elif raw_input.startswith(("http://", "https://")):
 
141
  response = requests.get(raw_input, timeout=10)
142
  response.raise_for_status()
143
  image = Image.open(BytesIO(response.content)).convert("RGB")
144
  else:
 
 
145
 
146
+ image = Image.open(raw_input).convert("RGB")
147
+
148
  log.info("✅ 图片加载成功")
149
+ return input_type, image, "请描述这张图片,并基于图片内容提供旅游建议。"
150
+
151
  except Exception as e:
152
  log.error(f"❌ 图片加载失败: {e}")
153
+ return "text", None, f"图片加载失败,请检查路径或URL"
154
+
155
  elif input_type == "audio":
 
 
 
156
 
157
+ log.warning("⚠️ 音频处理功能暂未实现")
158
+ return "text", None, "抱歉,音频输入功能正在开发中。请使用文字描述您的需求。"
 
 
159
 
160
+ else: # text
161
+ return input_type, None, raw_input
162
 
163
  def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str) -> str:
164
+
165
  try:
166
+ if len(prompt) > 500:
167
+ prompt = prompt[:500] + "..."
168
+
169
+
170
  if input_type == "image" and isinstance(formatted_input, Image.Image):
171
+
172
+ image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
173
  if image_token not in prompt:
174
  prompt = f"{image_token}\n{prompt}"
175
+
176
  inputs = self.processor(
177
+ text=prompt,
178
+ images=formatted_input,
179
  return_tensors="pt"
180
  ).to(self.model.device, dtype=torch.bfloat16)
181
  else:
182
+
183
  inputs = self.processor(
184
+ text=prompt,
185
  return_tensors="pt"
186
  ).to(self.model.device, dtype=torch.bfloat16)
187
 
188
+ if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
189
+ log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
190
+ inputs.input_ids = inputs.input_ids[:, :512]
191
+ if hasattr(inputs, 'attention_mask'):
192
+ inputs.attention_mask = inputs.attention_mask[:, :512]
193
+
194
  with torch.inference_mode():
195
  outputs = self.model.generate(
196
+ **inputs,
197
+ max_new_tokens=256,
198
  do_sample=True,
199
  temperature=0.7,
200
  top_p=0.9,
201
+ pad_token_id=self.processor.tokenizer.eos_token_id,
202
+ use_cache=True
203
  )
204
+
 
205
  decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
206
+
207
+ # 移除prompt部分
208
  if prompt in decoded:
209
  decoded = decoded.replace(prompt, "").strip()
210
+
211
+ return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"
212
+
213
+ except RuntimeError as e:
214
+ if "shape" in str(e):
215
+ log.error(f"❌ Tensor形状错误: {e}")
216
+ return "输入处理遇到问题,请尝试简化您的问题。"
217
+ raise e
218
  except Exception as e:
219
  log.error(f"❌ 模型推理失败: {e}", exc_info=True)
220
+ return "抱歉,处理您的请求时遇到技术问题。"
221
+
222
+ def _build_limited_prompt(self, processed_text: str, context: str = "") -> str:
223
+ """构建长度受限的prompt - 新增辅助方法"""
224
+ # 限制输入长度
225
+ if len(processed_text) > 200:
226
+ processed_text = processed_text[:200] + "..."
227
+
228
+ if context and len(context) > 300:
229
+ context = context[:300] + "..."
230
+
231
+ # 保持你原有的prompt结构
232
+ if context:
233
+ return (
234
+ f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n"
235
+ f"--- 背景信息 ---\n{context}\n\n"
236
+ f"--- 用户问题 ---\n{processed_text}\n\n"
237
+ f"请提供专业、实用的旅游建议:"
238
+ )
239
+ else:
240
+ return (
241
+ f"你是一个专业的旅游助手。请用中文友好地回答用户的问题。\n\n"
242
+ f"用户问题:{processed_text}\n\n"
243
+ f"请提供专业、实用的旅游建议:"
244
+ )
245
 
246
  def generate(self, user_input: str, context: str = "") -> str:
247
+ """主要的生成方法 - 保持原有逻辑"""
248
  if not self.is_available():
249
  return "抱歉,AI 模型当前不可用,请稍后再试。"
250
+
251
  try:
252
  # 1. 检测输入类型
253
  input_type = self.detect_input_type(user_input)
254
  log.info(f"检测到输入类型: {input_type}")
255
+
256
  # 2. 格式化输入
257
  input_type, formatted_data, processed_text = self.format_input(input_type, user_input)
258
+
259
+ # 3. 构建prompt - 使用你的原有结构
260
+ prompt = self._build_limited_prompt(processed_text, context)
261
+
 
 
 
 
 
 
 
 
 
 
 
 
262
  # 4. 执行推理
263
  if input_type == "image" and formatted_data is not None:
264
  return self.run_inference("image", formatted_data, prompt)
265
  else:
266
  return self.run_inference("text", processed_text, prompt)
267
+
268
  except Exception as e:
269
  log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True)
270
  return "抱歉,我在思考时遇到了点麻烦,请稍后再试。"