Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| import spaces | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Dict, Any | |
| import time | |
| # 创建 FastAPI 应用 | |
| app = FastAPI() | |
| # 配置 CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 加载模型和分词器 | |
| model_name = "BAAI/bge-m3" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name) | |
| model.eval() | |
| class EmbeddingRequest(BaseModel): | |
| input: List[str] | str | |
| model: str | None = model_name | |
| encoding_format: str | None = "float" | |
| user: str | None = None | |
| class EmbeddingResponse(BaseModel): | |
| object: str = "list" | |
| data: List[Dict[str, Any]] | |
| model: str | |
| usage: Dict[str, int] | |
| def get_embedding(text: str) -> List[float]: | |
| inputs = tokenizer( | |
| text, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() | |
| return embeddings[0].tolist() | |
| def process_embeddings(request_dict: dict) -> dict: | |
| """非异步函数处理嵌入向量""" | |
| input_texts = [request_dict["input"]] if isinstance(request_dict["input"], str) else request_dict["input"] | |
| embeddings = [] | |
| total_tokens = 0 | |
| for text in input_texts: | |
| tokens = tokenizer.encode(text) | |
| total_tokens += len(tokens) | |
| embedding = get_embedding(text) | |
| embeddings.append({ | |
| "object": "embedding", | |
| "embedding": embedding, | |
| "index": len(embeddings) | |
| }) | |
| return { | |
| "object": "list", | |
| "data": embeddings, | |
| "model": request_dict.get("model", model_name), | |
| "usage": { | |
| "prompt_tokens": total_tokens, | |
| "total_tokens": total_tokens | |
| } | |
| } | |
| async def create_embeddings(request: EmbeddingRequest): | |
| """异步API端点""" | |
| result = process_embeddings(request.dict()) | |
| return result | |
| def gradio_embedding(text: str) -> Dict: | |
| """Gradio接口函数""" | |
| request_dict = { | |
| "input": text, | |
| "model": model_name | |
| } | |
| return process_embeddings(request_dict) | |
| # 创建 Gradio 界面 | |
| demo = gr.Interface( | |
| fn=gradio_embedding, | |
| inputs=gr.Textbox(lines=3, placeholder="输入要进行编码的文本..."), | |
| outputs=gr.Json(), | |
| title="BGE-M3 Embeddings (OpenAI 兼容格式)", | |
| description="输入文本,获取其对应的嵌入向量,返回格式与 OpenAI API 兼容。", | |
| examples=[ | |
| ["这是一个示例文本。"], | |
| ["人工智能正在改变世界。"] | |
| ] | |
| ) | |
| # 挂载 Gradio 应用到 FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |