Travel_Assistant / modules /ai_model.py
Eliot0110's picture
improve: 优化模型调用并对各组件升级
af60cba
raw
history blame
8.13 kB
# modules/ai_model.py
import torch
import base64
import requests
from io import BytesIO
from PIL import Image
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
from utils.logger import log
from typing import Union, Tuple
class AIModel:
def __init__(self, model_name: str = "google/gemma-3n-e2b-it"):
self.model_name = model_name
self.model = None
self.processor = None
self._initialize_model()
def _initialize_model(self):
"""初始化Gemma模型 - 基于官方调用方式"""
try:
log.info(f"正在加载模型: {self.model_name}")
self.model = Gemma3nForConditionalGeneration.from_pretrained(
self.model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
).eval()
self.processor = AutoProcessor.from_pretrained(
self.model_name,
trust_remote_code=True
)
log.info("✅ Gemma AI 模型初始化成功")
except Exception as e:
log.error(f"❌ Gemma AI 模型初始化失败: {e}", exc_info=True)
self.model = None
self.processor = None
def is_available(self) -> bool:
"""检查模型是否可用"""
return self.model is not None and self.processor is not None
def detect_input_type(self, input_data: str) -> str:
"""检测输入类型:图片/音频/文字"""
if isinstance(input_data, str):
# 检查是否为图片URL或路径
if (input_data.startswith(("http://", "https://")) and
any(input_data.lower().endswith(ext) for ext in [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"])):
return "image"
elif input_data.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp")):
return "image"
# 检查是否为音频URL或路径
elif (input_data.startswith(("http://", "https://")) and
any(input_data.lower().endswith(ext) for ext in [".wav", ".mp3", ".m4a", ".ogg"])):
return "audio"
elif input_data.endswith((".wav", ".mp3", ".m4a", ".ogg")):
return "audio"
# 检查是否为base64编码的图片
elif input_data.startswith("data:image/"):
return "image"
return "text"
def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]:
"""格式化输入数据"""
formatted_data = None
processed_text = raw_input
if input_type == "image":
try:
if raw_input.startswith("data:image/"):
# 处理base64编码的图片
header, encoded = raw_input.split(",", 1)
image_data = base64.b64decode(encoded)
image = Image.open(BytesIO(image_data)).convert("RGB")
elif raw_input.startswith(("http://", "https://")):
# 处理图片URL
response = requests.get(raw_input, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
# 处理本地图片路径
image = Image.open(raw_input).convert("RGB")
formatted_data = image
processed_text = "请描述这张图片,并基于图片内容提供旅游建议。"
log.info("✅ 图片加载成功")
except Exception as e:
log.error(f"❌ 图片加载失败: {e}")
return "text", f"图片加载失败,请检查图片路径或URL。原始输入: {raw_input}"
elif input_type == "audio":
# 音频处理逻辑(如果需要的话,目前先返回提示)
log.warning("⚠️ 音频处理功能暂未实现")
processed_text = "抱歉,音频输入功能正在开发中。请使用文字描述您的需求。"
elif input_type == "text":
# 文字输入直接使用
formatted_data = None
processed_text = raw_input
return input_type, formatted_data, processed_text
def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str) -> str:
"""执行模型推理"""
try:
if input_type == "image" and isinstance(formatted_input, Image.Image):
# 图片输入处理
image_token = self.processor.tokenizer.image_token
if image_token not in prompt:
prompt = f"{image_token}\n{prompt}"
inputs = self.processor(
text=prompt,
images=formatted_input,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
else:
# 纯文本输入处理
inputs = self.processor(
text=prompt,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
# 生成响应
with torch.inference_mode():
outputs = self.model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=self.processor.tokenizer.eos_token_id
)
# 解码输出
decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# 清理输出,移除输入的prompt部分
if prompt in decoded:
decoded = decoded.replace(prompt, "").strip()
return decoded
except Exception as e:
log.error(f"❌ 模型推理失败: {e}", exc_info=True)
return "抱歉,我在处理您的请求时遇到了技术问题,请稍后再试。"
def generate(self, user_input: str, context: str = "") -> str:
"""主要的生成方法 - 支持多模态输入"""
if not self.is_available():
return "抱歉,AI 模型当前不可用,请稍后再试。"
try:
# 1. 检测输入类型
input_type = self.detect_input_type(user_input)
log.info(f"检测到输入类型: {input_type}")
# 2. 格式化输入
input_type, formatted_data, processed_text = self.format_input(input_type, user_input)
# 3. 构建prompt
if context:
prompt = (
f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n"
f"--- 背景信息 ---\n{context}\n\n"
f"--- 用户问题 ---\n{processed_text}\n\n"
f"请提供专业、实用的旅游建议:"
)
else:
prompt = (
f"你是一个专业的旅游助手。请用中文友好地回答用户的问题。\n\n"
f"用户问题:{processed_text}\n\n"
f"请提供专业、实用的旅游建议:"
)
# 4. 执行推理
if input_type == "image" and formatted_data is not None:
return self.run_inference("image", formatted_data, prompt)
else:
return self.run_inference("text", processed_text, prompt)
except Exception as e:
log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True)
return "抱歉,我在思考时遇到了点麻烦,请稍后再试。"