Spaces:
Sleeping
Sleeping
| # modules/ai_model.py | |
| import torch | |
| import base64 | |
| import requests | |
| from io import BytesIO | |
| from PIL import Image | |
| from transformers import AutoProcessor, Gemma3nForConditionalGeneration | |
| from utils.logger import log | |
| from typing import Union, Tuple | |
| class AIModel: | |
| def __init__(self, model_name: str = "google/gemma-3n-e2b-it"): | |
| self.model_name = model_name | |
| self.model = None | |
| self.processor = None | |
| self._initialize_model() | |
| def _initialize_model(self): | |
| """初始化Gemma模型 - 基于官方调用方式""" | |
| try: | |
| log.info(f"正在加载模型: {self.model_name}") | |
| self.model = Gemma3nForConditionalGeneration.from_pretrained( | |
| self.model_name, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True | |
| ).eval() | |
| self.processor = AutoProcessor.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True | |
| ) | |
| log.info("✅ Gemma AI 模型初始化成功") | |
| except Exception as e: | |
| log.error(f"❌ Gemma AI 模型初始化失败: {e}", exc_info=True) | |
| self.model = None | |
| self.processor = None | |
| def is_available(self) -> bool: | |
| """检查模型是否可用""" | |
| return self.model is not None and self.processor is not None | |
| def detect_input_type(self, input_data: str) -> str: | |
| """检测输入类型:图片/音频/文字""" | |
| if isinstance(input_data, str): | |
| # 检查是否为图片URL或路径 | |
| if (input_data.startswith(("http://", "https://")) and | |
| any(input_data.lower().endswith(ext) for ext in [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"])): | |
| return "image" | |
| elif input_data.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp")): | |
| return "image" | |
| # 检查是否为音频URL或路径 | |
| elif (input_data.startswith(("http://", "https://")) and | |
| any(input_data.lower().endswith(ext) for ext in [".wav", ".mp3", ".m4a", ".ogg"])): | |
| return "audio" | |
| elif input_data.endswith((".wav", ".mp3", ".m4a", ".ogg")): | |
| return "audio" | |
| # 检查是否为base64编码的图片 | |
| elif input_data.startswith("data:image/"): | |
| return "image" | |
| return "text" | |
| def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]: | |
| """格式化输入数据""" | |
| formatted_data = None | |
| processed_text = raw_input | |
| if input_type == "image": | |
| try: | |
| if raw_input.startswith("data:image/"): | |
| # 处理base64编码的图片 | |
| header, encoded = raw_input.split(",", 1) | |
| image_data = base64.b64decode(encoded) | |
| image = Image.open(BytesIO(image_data)).convert("RGB") | |
| elif raw_input.startswith(("http://", "https://")): | |
| # 处理图片URL | |
| response = requests.get(raw_input, timeout=10) | |
| response.raise_for_status() | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| else: | |
| # 处理本地图片路径 | |
| image = Image.open(raw_input).convert("RGB") | |
| formatted_data = image | |
| processed_text = "请描述这张图片,并基于图片内容提供旅游建议。" | |
| log.info("✅ 图片加载成功") | |
| except Exception as e: | |
| log.error(f"❌ 图片加载失败: {e}") | |
| return "text", f"图片加载失败,请检查图片路径或URL。原始输入: {raw_input}" | |
| elif input_type == "audio": | |
| # 音频处理逻辑(如果需要的话,目前先返回提示) | |
| log.warning("⚠️ 音频处理功能暂未实现") | |
| processed_text = "抱歉,音频输入功能正在开发中。请使用文字描述您的需求。" | |
| elif input_type == "text": | |
| # 文字输入直接使用 | |
| formatted_data = None | |
| processed_text = raw_input | |
| return input_type, formatted_data, processed_text | |
| def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str) -> str: | |
| """执行模型推理""" | |
| try: | |
| if input_type == "image" and isinstance(formatted_input, Image.Image): | |
| # 图片输入处理 | |
| image_token = self.processor.tokenizer.image_token | |
| if image_token not in prompt: | |
| prompt = f"{image_token}\n{prompt}" | |
| inputs = self.processor( | |
| text=prompt, | |
| images=formatted_input, | |
| return_tensors="pt" | |
| ).to(self.model.device, dtype=torch.bfloat16) | |
| else: | |
| # 纯文本输入处理 | |
| inputs = self.processor( | |
| text=prompt, | |
| return_tensors="pt" | |
| ).to(self.model.device, dtype=torch.bfloat16) | |
| # 生成响应 | |
| with torch.inference_mode(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=self.processor.tokenizer.eos_token_id | |
| ) | |
| # 解码输出 | |
| decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| # 清理输出,移除输入的prompt部分 | |
| if prompt in decoded: | |
| decoded = decoded.replace(prompt, "").strip() | |
| return decoded | |
| except Exception as e: | |
| log.error(f"❌ 模型推理失败: {e}", exc_info=True) | |
| return "抱歉,我在处理您的请求时遇到了技术问题,请稍后再试。" | |
| def generate(self, user_input: str, context: str = "") -> str: | |
| """主要的生成方法 - 支持多模态输入""" | |
| if not self.is_available(): | |
| return "抱歉,AI 模型当前不可用,请稍后再试。" | |
| try: | |
| # 1. 检测输入类型 | |
| input_type = self.detect_input_type(user_input) | |
| log.info(f"检测到输入类型: {input_type}") | |
| # 2. 格式化输入 | |
| input_type, formatted_data, processed_text = self.format_input(input_type, user_input) | |
| # 3. 构建prompt | |
| if context: | |
| prompt = ( | |
| f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n" | |
| f"--- 背景信息 ---\n{context}\n\n" | |
| f"--- 用户问题 ---\n{processed_text}\n\n" | |
| f"请提供专业、实用的旅游建议:" | |
| ) | |
| else: | |
| prompt = ( | |
| f"你是一个专业的旅游助手。请用中文友好地回答用户的问题。\n\n" | |
| f"用户问题:{processed_text}\n\n" | |
| f"请提供专业、实用的旅游建议:" | |
| ) | |
| # 4. 执行推理 | |
| if input_type == "image" and formatted_data is not None: | |
| return self.run_inference("image", formatted_data, prompt) | |
| else: | |
| return self.run_inference("text", processed_text, prompt) | |
| except Exception as e: | |
| log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True) | |
| return "抱歉,我在思考时遇到了点麻烦,请稍后再试。" |