Spaces:
Running
Running
fahmiaziz98
commited on
Commit
·
7f8bfb2
1
Parent(s):
8786174
Refactor reranking models and configuration management; add YAML support for model settings
Browse files- app.py +14 -328
- config.yaml +28 -0
- core/__init__.py +3 -0
- core/base.py +21 -0
- core/cross_encoder.py +239 -0
- core/model_manager.py +137 -0
- models/__init__.py +4 -0
- models/model.py +52 -0
app.py
CHANGED
|
@@ -1,326 +1,12 @@
|
|
| 1 |
-
from fastapi import FastAPI, HTTPException
|
| 2 |
-
from pydantic import BaseModel, Field
|
| 3 |
-
from typing import List, Optional, Dict, Any
|
| 4 |
-
from loguru import logger
|
| 5 |
import time
|
| 6 |
-
import
|
|
|
|
| 7 |
from contextlib import asynccontextmanager
|
| 8 |
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
# -------------------------
|
| 14 |
-
# Request/Response Models
|
| 15 |
-
# -------------------------
|
| 16 |
-
|
| 17 |
-
class RerankRequest(BaseModel):
|
| 18 |
-
"""
|
| 19 |
-
Request model for document reranking.
|
| 20 |
-
|
| 21 |
-
Attributes:
|
| 22 |
-
query: The search query
|
| 23 |
-
documents: List of documents to rerank
|
| 24 |
-
model_id: Identifier of the reranking model to use
|
| 25 |
-
instruction: Optional instruction for instruction-based models
|
| 26 |
-
top_k: Maximum number of documents to return (optional)
|
| 27 |
-
"""
|
| 28 |
-
query: str = Field(..., description="Search query text")
|
| 29 |
-
documents: List[str] = Field(..., min_items=1, description="List of documents to rerank")
|
| 30 |
-
model_id: str = Field(..., description="Model identifier for reranking")
|
| 31 |
-
instruction: Optional[str] = Field(None, description="Optional instruction for reranking task")
|
| 32 |
-
top_k: Optional[int] = Field(None, description="Maximum number of results to return")
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
class RerankResult(BaseModel):
|
| 36 |
-
"""
|
| 37 |
-
Single reranking result.
|
| 38 |
-
|
| 39 |
-
Attributes:
|
| 40 |
-
text: The document text
|
| 41 |
-
score: Relevance score from the reranking model
|
| 42 |
-
index: Original index of the document in input list
|
| 43 |
-
"""
|
| 44 |
-
text: str
|
| 45 |
-
score: float
|
| 46 |
-
index: int
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class RerankResponse(BaseModel):
|
| 50 |
-
"""
|
| 51 |
-
Response model for document reranking.
|
| 52 |
-
|
| 53 |
-
Attributes:
|
| 54 |
-
results: List of reranked documents with scores
|
| 55 |
-
query: The original search query
|
| 56 |
-
model_id: Identifier of the model used
|
| 57 |
-
processing_time: Time taken to process the request
|
| 58 |
-
total_documents: Total number of input documents
|
| 59 |
-
returned_documents: Number of documents returned
|
| 60 |
-
"""
|
| 61 |
-
results: List[RerankResult]
|
| 62 |
-
query: str
|
| 63 |
-
model_id: str
|
| 64 |
-
processing_time: float
|
| 65 |
-
total_documents: int
|
| 66 |
-
returned_documents: int
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
# -------------------------
|
| 70 |
-
# Model Management
|
| 71 |
-
# -------------------------
|
| 72 |
-
|
| 73 |
-
class RerankerModel:
|
| 74 |
-
"""Base class for reranking models."""
|
| 75 |
-
|
| 76 |
-
def __init__(self, model_id: str, model_name: str, model_type: str):
|
| 77 |
-
self.model_id = model_id
|
| 78 |
-
self.model_name = model_name
|
| 79 |
-
self.model_type = model_type
|
| 80 |
-
self.model = None
|
| 81 |
-
self.tokenizer = None
|
| 82 |
-
self.loaded = False
|
| 83 |
-
|
| 84 |
-
def load(self):
|
| 85 |
-
"""Load the model. To be implemented by subclasses."""
|
| 86 |
-
raise NotImplementedError
|
| 87 |
-
|
| 88 |
-
def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
|
| 89 |
-
"""Rerank documents. To be implemented by subclasses."""
|
| 90 |
-
raise NotImplementedError
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
class SentenceTransformersReranker(RerankerModel):
|
| 94 |
-
"""Reranker using sentence-transformers CrossEncoder."""
|
| 95 |
-
|
| 96 |
-
def load(self):
|
| 97 |
-
"""Load sentence-transformers CrossEncoder model."""
|
| 98 |
-
try:
|
| 99 |
-
logger.info(f"Loading SentenceTransformers model: {self.model_name}")
|
| 100 |
-
self.model = CrossEncoder(
|
| 101 |
-
self.model_name,
|
| 102 |
-
model_kwargs={"torch_dtype": "auto"},
|
| 103 |
-
trust_remote_code=True
|
| 104 |
-
)
|
| 105 |
-
self.loaded = True
|
| 106 |
-
logger.success(f"Successfully loaded {self.model_id}")
|
| 107 |
-
except Exception as e:
|
| 108 |
-
logger.error(f"Failed to load {self.model_id}: {e}")
|
| 109 |
-
raise
|
| 110 |
-
|
| 111 |
-
def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
|
| 112 |
-
"""Rerank documents using CrossEncoder."""
|
| 113 |
-
if not self.loaded:
|
| 114 |
-
raise RuntimeError(f"Model {self.model_id} not loaded")
|
| 115 |
-
|
| 116 |
-
try:
|
| 117 |
-
# For sentence-transformers, we can use the rank method directly
|
| 118 |
-
rankings = self.model.rank(query, documents, convert_to_tensor=True)
|
| 119 |
-
|
| 120 |
-
# Extract scores and maintain original order
|
| 121 |
-
scores = [0.0] * len(documents)
|
| 122 |
-
for ranking in rankings:
|
| 123 |
-
scores[ranking['corpus_id']] = float(ranking['score'])
|
| 124 |
-
|
| 125 |
-
return scores
|
| 126 |
-
|
| 127 |
-
except Exception as e:
|
| 128 |
-
logger.error(f"Reranking failed with {self.model_id}: {e}")
|
| 129 |
-
raise
|
| 130 |
|
| 131 |
|
| 132 |
-
class QwenReranker(RerankerModel):
|
| 133 |
-
"""Reranker using Qwen3-Reranker model."""
|
| 134 |
-
|
| 135 |
-
def load(self):
|
| 136 |
-
"""Load Qwen reranker model."""
|
| 137 |
-
try:
|
| 138 |
-
logger.info(f"Loading Qwen model: {self.model_name}")
|
| 139 |
-
|
| 140 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 141 |
-
self.model_name,
|
| 142 |
-
padding_side='left'
|
| 143 |
-
)
|
| 144 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
| 145 |
-
self.model_name
|
| 146 |
-
).eval()
|
| 147 |
-
|
| 148 |
-
# Set up Qwen-specific tokens
|
| 149 |
-
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
|
| 150 |
-
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
|
| 151 |
-
self.max_length = 8192
|
| 152 |
-
|
| 153 |
-
# Set up prompt templates
|
| 154 |
-
self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
|
| 155 |
-
self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
| 156 |
-
self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
|
| 157 |
-
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
|
| 158 |
-
|
| 159 |
-
self.loaded = True
|
| 160 |
-
logger.success(f"Successfully loaded {self.model_id}")
|
| 161 |
-
|
| 162 |
-
except Exception as e:
|
| 163 |
-
logger.error(f"Failed to load {self.model_id}: {e}")
|
| 164 |
-
raise
|
| 165 |
-
|
| 166 |
-
def _format_instruction(self, instruction: str, query: str, doc: str) -> str:
|
| 167 |
-
"""Format instruction for Qwen model."""
|
| 168 |
-
if instruction is None:
|
| 169 |
-
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
|
| 170 |
-
|
| 171 |
-
return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
|
| 172 |
-
instruction=instruction, query=query, doc=doc
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
def _process_inputs(self, pairs: List[str]):
|
| 176 |
-
"""Process input pairs for Qwen model."""
|
| 177 |
-
inputs = self.tokenizer(
|
| 178 |
-
pairs,
|
| 179 |
-
padding=False,
|
| 180 |
-
truncation='longest_first',
|
| 181 |
-
return_attention_mask=False,
|
| 182 |
-
max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
for i, ele in enumerate(inputs['input_ids']):
|
| 186 |
-
inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
|
| 187 |
-
|
| 188 |
-
inputs = self.tokenizer.pad(
|
| 189 |
-
inputs,
|
| 190 |
-
padding=True,
|
| 191 |
-
return_tensors="pt",
|
| 192 |
-
max_length=self.max_length
|
| 193 |
-
)
|
| 194 |
-
|
| 195 |
-
for key in inputs:
|
| 196 |
-
inputs[key] = inputs[key].to(self.model.device)
|
| 197 |
-
|
| 198 |
-
return inputs
|
| 199 |
-
|
| 200 |
-
@torch.no_grad()
|
| 201 |
-
def _compute_logits(self, inputs):
|
| 202 |
-
"""Compute relevance scores from model logits."""
|
| 203 |
-
batch_scores = self.model(**inputs).logits[:, -1, :]
|
| 204 |
-
true_vector = batch_scores[:, self.token_true_id]
|
| 205 |
-
false_vector = batch_scores[:, self.token_false_id]
|
| 206 |
-
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
| 207 |
-
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
| 208 |
-
scores = batch_scores[:, 1].exp().tolist()
|
| 209 |
-
return scores
|
| 210 |
-
|
| 211 |
-
def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
|
| 212 |
-
"""Rerank documents using Qwen model."""
|
| 213 |
-
if not self.loaded:
|
| 214 |
-
raise RuntimeError(f"Model {self.model_id} not loaded")
|
| 215 |
-
|
| 216 |
-
try:
|
| 217 |
-
# Format instruction pairs
|
| 218 |
-
pairs = [
|
| 219 |
-
self._format_instruction(instruction, query, doc)
|
| 220 |
-
for doc in documents
|
| 221 |
-
]
|
| 222 |
-
|
| 223 |
-
# Process inputs
|
| 224 |
-
inputs = self._process_inputs(pairs)
|
| 225 |
-
|
| 226 |
-
# Compute scores
|
| 227 |
-
scores = self._compute_logits(inputs)
|
| 228 |
-
|
| 229 |
-
return scores
|
| 230 |
-
|
| 231 |
-
except Exception as e:
|
| 232 |
-
logger.error(f"Reranking failed with {self.model_id}: {e}")
|
| 233 |
-
raise
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
class ModelManager:
|
| 237 |
-
"""Manager for reranking models with preloading."""
|
| 238 |
-
|
| 239 |
-
def __init__(self):
|
| 240 |
-
self.models: Dict[str, RerankerModel] = {}
|
| 241 |
-
self.model_configs = {
|
| 242 |
-
"jina-reranker-v2": {
|
| 243 |
-
"model_name": "jinaai/jina-reranker-v2-base-multilingual",
|
| 244 |
-
"model_type": "sentence_transformers",
|
| 245 |
-
"description": "Multilingual reranker from Jina AI"
|
| 246 |
-
},
|
| 247 |
-
"bge-reranker-v2": {
|
| 248 |
-
"model_name": "BAAI/bge-reranker-v2-m3",
|
| 249 |
-
"model_type": "sentence_transformers",
|
| 250 |
-
"description": "BGE multilingual reranker"
|
| 251 |
-
},
|
| 252 |
-
"qwen3-reranker": {
|
| 253 |
-
"model_name": "Qwen/Qwen3-Reranker-0.6B",
|
| 254 |
-
"model_type": "qwen",
|
| 255 |
-
"description": "Qwen3 instruction-based reranker"
|
| 256 |
-
}
|
| 257 |
-
}
|
| 258 |
-
|
| 259 |
-
async def preload_all_models(self):
|
| 260 |
-
"""Preload all configured models."""
|
| 261 |
-
logger.info(f"Starting preload of {len(self.model_configs)} reranking models...")
|
| 262 |
-
|
| 263 |
-
for model_id, config in self.model_configs.items():
|
| 264 |
-
try:
|
| 265 |
-
logger.info(f"Loading {model_id}...")
|
| 266 |
-
|
| 267 |
-
if config["model_type"] == "sentence_transformers":
|
| 268 |
-
model = SentenceTransformersReranker(
|
| 269 |
-
model_id=model_id,
|
| 270 |
-
model_name=config["model_name"],
|
| 271 |
-
model_type=config["model_type"]
|
| 272 |
-
)
|
| 273 |
-
elif config["model_type"] == "qwen":
|
| 274 |
-
model = QwenReranker(
|
| 275 |
-
model_id=model_id,
|
| 276 |
-
model_name=config["model_name"],
|
| 277 |
-
model_type=config["model_type"]
|
| 278 |
-
)
|
| 279 |
-
else:
|
| 280 |
-
logger.error(f"Unknown model type: {config['model_type']}")
|
| 281 |
-
continue
|
| 282 |
-
|
| 283 |
-
model.load()
|
| 284 |
-
self.models[model_id] = model
|
| 285 |
-
logger.success(f"Successfully preloaded {model_id}")
|
| 286 |
-
|
| 287 |
-
except Exception as e:
|
| 288 |
-
logger.error(f"Failed to preload {model_id}: {e}")
|
| 289 |
-
|
| 290 |
-
loaded_count = len([m for m in self.models.values() if m.loaded])
|
| 291 |
-
logger.success(f"Preloaded {loaded_count}/{len(self.model_configs)} models successfully")
|
| 292 |
-
|
| 293 |
-
def get_model(self, model_id: str) -> RerankerModel:
|
| 294 |
-
"""Get a loaded model by ID."""
|
| 295 |
-
if model_id not in self.models:
|
| 296 |
-
raise ValueError(f"Model {model_id} not found")
|
| 297 |
-
|
| 298 |
-
model = self.models[model_id]
|
| 299 |
-
if not model.loaded:
|
| 300 |
-
raise ValueError(f"Model {model_id} not loaded")
|
| 301 |
-
|
| 302 |
-
return model
|
| 303 |
-
|
| 304 |
-
def list_models(self) -> List[Dict[str, Any]]:
|
| 305 |
-
"""List all available models with their status."""
|
| 306 |
-
models_info = []
|
| 307 |
-
for model_id, config in self.model_configs.items():
|
| 308 |
-
model = self.models.get(model_id)
|
| 309 |
-
info = {
|
| 310 |
-
"id": model_id,
|
| 311 |
-
"name": config["model_name"],
|
| 312 |
-
"type": config["model_type"],
|
| 313 |
-
"description": config["description"],
|
| 314 |
-
"loaded": model.loaded if model else False
|
| 315 |
-
}
|
| 316 |
-
models_info.append(info)
|
| 317 |
-
return models_info
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
# -------------------------
|
| 321 |
-
# Application Setup
|
| 322 |
-
# -------------------------
|
| 323 |
-
|
| 324 |
model_manager = None
|
| 325 |
|
| 326 |
@asynccontextmanager
|
|
@@ -331,7 +17,7 @@ async def lifespan(app: FastAPI):
|
|
| 331 |
# Startup
|
| 332 |
logger.info("Starting reranking API...")
|
| 333 |
try:
|
| 334 |
-
model_manager = ModelManager()
|
| 335 |
await model_manager.preload_all_models()
|
| 336 |
logger.success("Reranking API startup complete!")
|
| 337 |
except Exception as e:
|
|
@@ -357,6 +43,7 @@ High-performance API for document reranking using multiple state-of-the-art mode
|
|
| 357 |
🚀 **Features:**
|
| 358 |
- Multiple reranking models preloaded at startup
|
| 359 |
- Batch document reranking with relevance scoring
|
|
|
|
| 360 |
- Optional instruction-based reranking (Qwen3)
|
| 361 |
- Comprehensive performance metrics
|
| 362 |
- Zero cold start delay
|
|
@@ -364,6 +51,8 @@ High-performance API for document reranking using multiple state-of-the-art mode
|
|
| 364 |
📊 **Input/Output:**
|
| 365 |
- Input: Query + documents + optional instruction
|
| 366 |
- Output: Ranked documents with relevance scores
|
|
|
|
|
|
|
| 367 |
""",
|
| 368 |
version="1.0.0",
|
| 369 |
lifespan=lifespan
|
|
@@ -407,7 +96,6 @@ async def rerank_documents(request: RerankRequest):
|
|
| 407 |
if not request.documents:
|
| 408 |
raise HTTPException(400, "Documents list cannot be empty")
|
| 409 |
|
| 410 |
-
# Filter out empty documents
|
| 411 |
valid_docs = [(i, doc.strip()) for i, doc in enumerate(request.documents) if doc.strip()]
|
| 412 |
if not valid_docs:
|
| 413 |
raise HTTPException(400, "No valid documents found after filtering empty strings")
|
|
@@ -415,20 +103,16 @@ async def rerank_documents(request: RerankRequest):
|
|
| 415 |
try:
|
| 416 |
start_time = time.time()
|
| 417 |
|
| 418 |
-
# Get model
|
| 419 |
model = model_manager.get_model(request.model_id)
|
| 420 |
-
|
| 421 |
-
# Extract valid documents and their indices
|
| 422 |
original_indices, documents = zip(*valid_docs)
|
| 423 |
-
|
| 424 |
-
|
| 425 |
scores = model.rerank(
|
| 426 |
query=request.query.strip(),
|
| 427 |
documents=list(documents),
|
| 428 |
instruction=request.instruction
|
| 429 |
)
|
| 430 |
|
| 431 |
-
# Create results with original indices
|
| 432 |
results = []
|
| 433 |
for i, (orig_idx, doc, score) in enumerate(zip(original_indices, documents, scores)):
|
| 434 |
results.append(RerankResult(
|
|
@@ -437,10 +121,8 @@ async def rerank_documents(request: RerankRequest):
|
|
| 437 |
index=orig_idx
|
| 438 |
))
|
| 439 |
|
| 440 |
-
# Sort by score (descending)
|
| 441 |
results.sort(key=lambda x: x.score, reverse=True)
|
| 442 |
|
| 443 |
-
# Apply top_k limit if specified
|
| 444 |
if request.top_k:
|
| 445 |
results = results[:request.top_k]
|
| 446 |
|
|
@@ -513,4 +195,8 @@ async def health_check():
|
|
| 513 |
"status": "error",
|
| 514 |
"error": str(e)
|
| 515 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import time
|
| 2 |
+
from loguru import logger
|
| 3 |
+
from fastapi import FastAPI, HTTPException
|
| 4 |
from contextlib import asynccontextmanager
|
| 5 |
|
| 6 |
+
from models import RerankRequest, RerankResponse, RerankResult
|
| 7 |
+
from core import ModelManager
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
model_manager = None
|
| 11 |
|
| 12 |
@asynccontextmanager
|
|
|
|
| 17 |
# Startup
|
| 18 |
logger.info("Starting reranking API...")
|
| 19 |
try:
|
| 20 |
+
model_manager = ModelManager("config.yaml")
|
| 21 |
await model_manager.preload_all_models()
|
| 22 |
logger.success("Reranking API startup complete!")
|
| 23 |
except Exception as e:
|
|
|
|
| 43 |
🚀 **Features:**
|
| 44 |
- Multiple reranking models preloaded at startup
|
| 45 |
- Batch document reranking with relevance scoring
|
| 46 |
+
- Fast prototyping app
|
| 47 |
- Optional instruction-based reranking (Qwen3)
|
| 48 |
- Comprehensive performance metrics
|
| 49 |
- Zero cold start delay
|
|
|
|
| 51 |
📊 **Input/Output:**
|
| 52 |
- Input: Query + documents + optional instruction
|
| 53 |
- Output: Ranked documents with relevance scores
|
| 54 |
+
|
| 55 |
+
**Warning**: Not use production!.
|
| 56 |
""",
|
| 57 |
version="1.0.0",
|
| 58 |
lifespan=lifespan
|
|
|
|
| 96 |
if not request.documents:
|
| 97 |
raise HTTPException(400, "Documents list cannot be empty")
|
| 98 |
|
|
|
|
| 99 |
valid_docs = [(i, doc.strip()) for i, doc in enumerate(request.documents) if doc.strip()]
|
| 100 |
if not valid_docs:
|
| 101 |
raise HTTPException(400, "No valid documents found after filtering empty strings")
|
|
|
|
| 103 |
try:
|
| 104 |
start_time = time.time()
|
| 105 |
|
|
|
|
| 106 |
model = model_manager.get_model(request.model_id)
|
|
|
|
|
|
|
| 107 |
original_indices, documents = zip(*valid_docs)
|
| 108 |
+
logger.info(f"Query: {request.query.strip()}")
|
| 109 |
+
logger.info(f"Document: {list(documents)}")
|
| 110 |
scores = model.rerank(
|
| 111 |
query=request.query.strip(),
|
| 112 |
documents=list(documents),
|
| 113 |
instruction=request.instruction
|
| 114 |
)
|
| 115 |
|
|
|
|
| 116 |
results = []
|
| 117 |
for i, (orig_idx, doc, score) in enumerate(zip(original_indices, documents, scores)):
|
| 118 |
results.append(RerankResult(
|
|
|
|
| 121 |
index=orig_idx
|
| 122 |
))
|
| 123 |
|
|
|
|
| 124 |
results.sort(key=lambda x: x.score, reverse=True)
|
| 125 |
|
|
|
|
| 126 |
if request.top_k:
|
| 127 |
results = results[:request.top_k]
|
| 128 |
|
|
|
|
| 195 |
"status": "error",
|
| 196 |
"error": str(e)
|
| 197 |
}
|
| 198 |
+
|
| 199 |
+
@app.get("/")
|
| 200 |
+
async def root():
|
| 201 |
+
return {"message": "Welcome to the Multi-Model Reranking API. Visit /docs for API documentation.", "version": "1.0.0"}
|
| 202 |
|
config.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model configuration for ModelManager
|
| 2 |
+
# You can add or modify model entries as needed
|
| 3 |
+
models:
|
| 4 |
+
jina-reranker-v2:
|
| 5 |
+
model_name: jinaai/jina-reranker-v2-base-multilingual
|
| 6 |
+
model_type: sentence_transformers
|
| 7 |
+
description: |
|
| 8 |
+
The Jina Reranker v2 (jina-reranker-v2-base-multilingual) is a transformer-based model that has been fine-tuned for text reranking task, which is a crucial component in many information retrieval systems. It is a cross-encoder model that takes a query and a document pair as input and outputs a score indicating the relevance of the document to the query. The model is trained on a large dataset of query-document pairs and is capable of reranking documents in multiple languages with high accuracy.
|
| 9 |
+
languages: ["multilingual"]
|
| 10 |
+
repository: https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual
|
| 11 |
+
|
| 12 |
+
bge-reranker-v2:
|
| 13 |
+
model_name: BAAI/bge-reranker-v2-m3
|
| 14 |
+
model_type: sentence_transformers
|
| 15 |
+
description: |
|
| 16 |
+
Different from embedding model, reranker uses question and document as input and directly output similarity instead of embedding. You can get a relevance score by inputting query and passage to the reranker. And the score can be mapped to a float value in [0,1] by sigmoid function.
|
| 17 |
+
languages: ["multilingual"]
|
| 18 |
+
repository: https://huggingface.co/BAAI/bge-reranker-v2-m3
|
| 19 |
+
|
| 20 |
+
qwen3-reranker:
|
| 21 |
+
model_name: Qwen/Qwen3-Reranker-0.6B
|
| 22 |
+
model_type: qwen
|
| 23 |
+
description: |
|
| 24 |
+
The Qwen3 Embedding model series is the latest proprietary model of the Qwen family, specifically designed for text embedding and ranking tasks. Building upon the dense foundational models of the Qwen3 series, it provides a comprehensive range of text embeddings and reranking models in various sizes (0.6B, 4B, and 8B). This series inherits the exceptional multilingual capabilities, long-text understanding, and reasoning skills of its foundational model. The Qwen3 Embedding series represents significant advancements in multiple text embedding and ranking tasks, including text retrieval, code retrieval, text classification, text clustering, and bitext mining.
|
| 25 |
+
languages: ["multilingual"]
|
| 26 |
+
repository: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
|
| 27 |
+
|
| 28 |
+
default_model: bge-reranker-v2
|
core/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model_manager import ModelManager
|
| 2 |
+
|
| 3 |
+
__all__ = ["ModelManager"]
|
core/base.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class RerankerModel:
|
| 5 |
+
"""Base class for reranking models."""
|
| 6 |
+
|
| 7 |
+
def __init__(self, model_id: str, model_name: str, model_type: str):
|
| 8 |
+
self.model_id = model_id
|
| 9 |
+
self.model_name = model_name
|
| 10 |
+
self.model_type = model_type
|
| 11 |
+
self.model = None
|
| 12 |
+
self.tokenizer = None
|
| 13 |
+
self.loaded = False
|
| 14 |
+
|
| 15 |
+
def load(self):
|
| 16 |
+
"""Load the model. To be implemented by subclasses."""
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
|
| 20 |
+
"""Rerank documents. To be implemented by subclasses."""
|
| 21 |
+
raise NotImplementedError
|
core/cross_encoder.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from sentence_transformers import CrossEncoder
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 6 |
+
|
| 7 |
+
from .base import RerankerModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SentenceTransformersReranker(RerankerModel):
|
| 11 |
+
"""
|
| 12 |
+
Reranker using sentence-transformers CrossEncoder.
|
| 13 |
+
|
| 14 |
+
This class leverages the CrossEncoder model from the sentence-transformers library to score the relevance of documents given a query. It is suitable for reranking tasks in information retrieval pipelines.
|
| 15 |
+
|
| 16 |
+
Attributes:
|
| 17 |
+
model_name (str): Name or path of the model to load.
|
| 18 |
+
model (CrossEncoder): The loaded CrossEncoder model instance.
|
| 19 |
+
loaded (bool): Whether the model has been loaded.
|
| 20 |
+
model_id (str): Unique identifier for the model instance.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def load(self):
|
| 24 |
+
"""
|
| 25 |
+
Load the sentence-transformers CrossEncoder model.
|
| 26 |
+
|
| 27 |
+
Loads the CrossEncoder model specified by self.model_name. Sets self.loaded to True if successful.
|
| 28 |
+
|
| 29 |
+
Raises:
|
| 30 |
+
Exception: If the model fails to load.
|
| 31 |
+
"""
|
| 32 |
+
try:
|
| 33 |
+
logger.info(f"Loading SentenceTransformers model: {self.model_name}")
|
| 34 |
+
self.model = CrossEncoder(
|
| 35 |
+
self.model_name,
|
| 36 |
+
model_kwargs={"torch_dtype": "auto"},
|
| 37 |
+
trust_remote_code=True
|
| 38 |
+
)
|
| 39 |
+
self.loaded = True
|
| 40 |
+
logger.success(f"Successfully loaded {self.model_id}")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.error(f"Failed to load {self.model_id}: {e}")
|
| 43 |
+
raise
|
| 44 |
+
|
| 45 |
+
def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
|
| 46 |
+
"""
|
| 47 |
+
Rerank documents using the CrossEncoder model.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
query (str): The search query string.
|
| 51 |
+
documents (List[str]): List of documents to be reranked.
|
| 52 |
+
instruction (Optional[str]): Additional instruction for reranking (not used in this implementation).
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
List[float]: List of relevance scores for each document.
|
| 56 |
+
|
| 57 |
+
Raises:
|
| 58 |
+
RuntimeError: If the model is not loaded.
|
| 59 |
+
Exception: If reranking fails.
|
| 60 |
+
"""
|
| 61 |
+
if not self.loaded:
|
| 62 |
+
raise RuntimeError(f"Model {self.model_id} not loaded")
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
rankings = self.model.rank(query, documents, convert_to_tensor=True)
|
| 66 |
+
|
| 67 |
+
scores = [0.0] * len(documents)
|
| 68 |
+
for ranking in rankings:
|
| 69 |
+
scores[ranking['corpus_id']] = float(ranking['score'])
|
| 70 |
+
|
| 71 |
+
return scores
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.error(f"Reranking failed with {self.model_id}: {e}")
|
| 75 |
+
raise
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class QwenReranker(RerankerModel):
|
| 80 |
+
"""
|
| 81 |
+
Reranker using Qwen3-Reranker model (LLM-based).
|
| 82 |
+
|
| 83 |
+
This class uses a Qwen LLM to judge the relevance of documents to a query and instruction. The model outputs a probability that each document is relevant ("yes") or not ("no").
|
| 84 |
+
|
| 85 |
+
Attributes:
|
| 86 |
+
model_name (str): Name or path of the Qwen model.
|
| 87 |
+
tokenizer (AutoTokenizer): Tokenizer for the Qwen model.
|
| 88 |
+
model (AutoModelForCausalLM): Loaded Qwen model instance.
|
| 89 |
+
loaded (bool): Whether the model has been loaded.
|
| 90 |
+
model_id (str): Unique identifier for the model instance.
|
| 91 |
+
token_false_id (int): Token ID for "no".
|
| 92 |
+
token_true_id (int): Token ID for "yes".
|
| 93 |
+
max_length (int): Maximum input token length.
|
| 94 |
+
prefix (str): Prompt prefix for the system message.
|
| 95 |
+
suffix (str): Prompt suffix for the assistant message.
|
| 96 |
+
prefix_tokens (List[int]): Tokenized prefix.
|
| 97 |
+
suffix_tokens (List[int]): Tokenized suffix.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def load(self):
|
| 101 |
+
"""
|
| 102 |
+
Load the Qwen reranker model and tokenizer, and initialize prompt templates and special tokens.
|
| 103 |
+
|
| 104 |
+
Raises:
|
| 105 |
+
Exception: If the model or tokenizer fails to load.
|
| 106 |
+
"""
|
| 107 |
+
try:
|
| 108 |
+
logger.info(f"Loading Qwen model: {self.model_name}")
|
| 109 |
+
|
| 110 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 111 |
+
self.model_name,
|
| 112 |
+
padding_side='left'
|
| 113 |
+
)
|
| 114 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 115 |
+
self.model_name
|
| 116 |
+
).eval()
|
| 117 |
+
|
| 118 |
+
# Set up Qwen-specific tokens
|
| 119 |
+
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
|
| 120 |
+
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
|
| 121 |
+
self.max_length = 8192
|
| 122 |
+
|
| 123 |
+
# Set up prompt templates
|
| 124 |
+
self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
|
| 125 |
+
self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
| 126 |
+
self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
|
| 127 |
+
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
|
| 128 |
+
|
| 129 |
+
self.loaded = True
|
| 130 |
+
logger.success(f"Successfully loaded {self.model_id}")
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.error(f"Failed to load {self.model_id}: {e}")
|
| 134 |
+
raise
|
| 135 |
+
|
| 136 |
+
def _format_instruction(self, instruction: str, query: str, doc: str) -> str:
|
| 137 |
+
"""
|
| 138 |
+
Format the instruction string for the Qwen model prompt.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
instruction (str): The instruction for the reranker. If None, a default instruction is used.
|
| 142 |
+
query (str): The search query string.
|
| 143 |
+
doc (str): The document to be evaluated.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
str: Formatted prompt string for the model.
|
| 147 |
+
"""
|
| 148 |
+
if instruction is None:
|
| 149 |
+
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
|
| 150 |
+
|
| 151 |
+
return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
|
| 152 |
+
instruction=instruction, query=query, doc=doc
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def _process_inputs(self, pairs: List[str]):
|
| 156 |
+
"""
|
| 157 |
+
Tokenize and prepare input pairs for the Qwen model.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
pairs (List[str]): List of formatted prompt strings for each document.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
dict: Tokenized and padded input tensors for the model.
|
| 164 |
+
"""
|
| 165 |
+
inputs = self.tokenizer(
|
| 166 |
+
pairs,
|
| 167 |
+
padding=False,
|
| 168 |
+
truncation='longest_first',
|
| 169 |
+
return_attention_mask=False,
|
| 170 |
+
max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
for i, ele in enumerate(inputs['input_ids']):
|
| 174 |
+
inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
|
| 175 |
+
|
| 176 |
+
inputs = self.tokenizer.pad(
|
| 177 |
+
inputs,
|
| 178 |
+
padding=True,
|
| 179 |
+
return_tensors="pt",
|
| 180 |
+
max_length=self.max_length
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
for key in inputs:
|
| 184 |
+
inputs[key] = inputs[key].to(self.model.device)
|
| 185 |
+
|
| 186 |
+
return inputs
|
| 187 |
+
|
| 188 |
+
@torch.no_grad()
|
| 189 |
+
def _compute_logits(self, inputs):
|
| 190 |
+
"""
|
| 191 |
+
Compute relevance scores from model logits.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
inputs (dict): Tokenized and padded input tensors for the model.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
List[float]: List of probabilities that each document is relevant ("yes").
|
| 198 |
+
"""
|
| 199 |
+
batch_scores = self.model(**inputs).logits[:, -1, :]
|
| 200 |
+
true_vector = batch_scores[:, self.token_true_id]
|
| 201 |
+
false_vector = batch_scores[:, self.token_false_id]
|
| 202 |
+
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
| 203 |
+
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
| 204 |
+
scores = batch_scores[:, 1].exp().tolist()
|
| 205 |
+
return scores
|
| 206 |
+
|
| 207 |
+
def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
|
| 208 |
+
"""
|
| 209 |
+
Rerank documents using the Qwen model.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
query (str): The search query string.
|
| 213 |
+
documents (List[str]): List of documents to be reranked.
|
| 214 |
+
instruction (Optional[str]): Additional instruction for reranking.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
List[float]: List of relevance scores for each document.
|
| 218 |
+
|
| 219 |
+
Raises:
|
| 220 |
+
RuntimeError: If the model is not loaded.
|
| 221 |
+
Exception: If reranking fails.
|
| 222 |
+
"""
|
| 223 |
+
if not self.loaded:
|
| 224 |
+
raise RuntimeError(f"Model {self.model_id} not loaded")
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
pairs = [
|
| 228 |
+
self._format_instruction(instruction, query, doc)
|
| 229 |
+
for doc in documents
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
inputs = self._process_inputs(pairs)
|
| 233 |
+
scores = self._compute_logits(inputs)
|
| 234 |
+
|
| 235 |
+
return scores
|
| 236 |
+
|
| 237 |
+
except Exception as e:
|
| 238 |
+
logger.error(f"Reranking failed with {self.model_id}: {e}")
|
| 239 |
+
raise
|
core/model_manager.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
from loguru import logger
|
| 4 |
+
|
| 5 |
+
from .base import RerankerModel
|
| 6 |
+
from .cross_encoder import SentenceTransformersReranker, QwenReranker
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ModelManager:
|
| 10 |
+
"""
|
| 11 |
+
Manager for reranking models with preloading and configuration.
|
| 12 |
+
|
| 13 |
+
This class loads model configurations from a YAML file (default: config.yaml),
|
| 14 |
+
instantiates and manages multiple reranker models, and provides methods to preload,
|
| 15 |
+
retrieve, and list the available models. Supports a default model if model_id is not provided.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
models (Dict[str, RerankerModel]): Dictionary of loaded model instances keyed by model ID.
|
| 19 |
+
model_configs (Dict[str, Dict[str, Any]]): Model configuration loaded from YAML file.
|
| 20 |
+
default_model_id (str): The default model ID to use if none is provided.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, config_path: str = 'config.yaml'):
|
| 24 |
+
"""
|
| 25 |
+
Initialize the ModelManager and load model configurations from a YAML file.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
config_path (str): Path to the YAML configuration file. Defaults to 'config.yaml'.
|
| 29 |
+
|
| 30 |
+
Side Effects:
|
| 31 |
+
Loads model configuration into self.model_configs.
|
| 32 |
+
Initializes an empty dictionary for loaded models.
|
| 33 |
+
Sets the default model ID from config.
|
| 34 |
+
"""
|
| 35 |
+
self.models: Dict[str, RerankerModel] = {}
|
| 36 |
+
try:
|
| 37 |
+
with open(config_path, 'r') as f:
|
| 38 |
+
config_data = yaml.safe_load(f)
|
| 39 |
+
self.model_configs = config_data.get('models', {})
|
| 40 |
+
self.default_model_id = config_data.get('default_model')
|
| 41 |
+
logger.info(f"Loaded model configs from {config_path}")
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.error(f"Failed to load config.yaml: {e}")
|
| 44 |
+
self.model_configs = {}
|
| 45 |
+
self.default_model_id = None
|
| 46 |
+
|
| 47 |
+
async def preload_all_models(self):
|
| 48 |
+
"""
|
| 49 |
+
Preload all models defined in the configuration file.
|
| 50 |
+
|
| 51 |
+
Iterates through all model configurations, instantiates the appropriate reranker class
|
| 52 |
+
(SentenceTransformersReranker or QwenReranker), loads the model, and stores it in self.models.
|
| 53 |
+
Logs the status of each model load and a summary at the end.
|
| 54 |
+
|
| 55 |
+
Raises:
|
| 56 |
+
Exception: If a model fails to load, logs the error and continues with the next model.
|
| 57 |
+
"""
|
| 58 |
+
logger.info(f"Starting preload of {len(self.model_configs)} reranking models...")
|
| 59 |
+
|
| 60 |
+
for model_id, config in self.model_configs.items():
|
| 61 |
+
try:
|
| 62 |
+
logger.info(f"Loading {model_id}...")
|
| 63 |
+
|
| 64 |
+
if config["model_type"] == "sentence_transformers":
|
| 65 |
+
model = SentenceTransformersReranker(
|
| 66 |
+
model_id=model_id,
|
| 67 |
+
model_name=config["model_name"],
|
| 68 |
+
model_type=config["model_type"]
|
| 69 |
+
)
|
| 70 |
+
elif config["model_type"] == "qwen":
|
| 71 |
+
model = QwenReranker(
|
| 72 |
+
model_id=model_id,
|
| 73 |
+
model_name=config["model_name"],
|
| 74 |
+
model_type=config["model_type"]
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
logger.error(f"Unknown model type: {config['model_type']}")
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
model.load()
|
| 81 |
+
self.models[model_id] = model
|
| 82 |
+
logger.success(f"Successfully preloaded {model_id}")
|
| 83 |
+
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.error(f"Failed to preload {model_id}: {e}")
|
| 86 |
+
|
| 87 |
+
loaded_count = len([m for m in self.models.values() if m.loaded])
|
| 88 |
+
logger.success(f"Preloaded {loaded_count}/{len(self.model_configs)} models successfully")
|
| 89 |
+
|
| 90 |
+
def get_model(self, model_id: str = None) -> RerankerModel:
|
| 91 |
+
"""
|
| 92 |
+
Retrieve a loaded model instance by its ID, or use the default model if not specified.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
model_id (str, optional): The unique identifier of the model to retrieve. If None, uses the default model.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
RerankerModel: The loaded reranker model instance.
|
| 99 |
+
|
| 100 |
+
Raises:
|
| 101 |
+
ValueError: If the model is not found or not loaded.
|
| 102 |
+
"""
|
| 103 |
+
if model_id is None:
|
| 104 |
+
if not self.default_model_id:
|
| 105 |
+
raise ValueError("No model_id provided and no default_model set in config.yaml")
|
| 106 |
+
model_id = self.default_model_id
|
| 107 |
+
|
| 108 |
+
if model_id not in self.models:
|
| 109 |
+
raise ValueError(f"Model {model_id} not found")
|
| 110 |
+
|
| 111 |
+
model = self.models[model_id]
|
| 112 |
+
if not model.loaded:
|
| 113 |
+
raise ValueError(f"Model {model_id} not loaded")
|
| 114 |
+
|
| 115 |
+
return model
|
| 116 |
+
|
| 117 |
+
def list_models(self) -> List[Dict[str, Any]]:
|
| 118 |
+
"""
|
| 119 |
+
List all available models with their configuration and load status.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
List[Dict[str, Any]]: A list of dictionaries, each containing model ID, name, type, description, and loaded status.
|
| 123 |
+
"""
|
| 124 |
+
models_info = []
|
| 125 |
+
for model_id, config in self.model_configs.items():
|
| 126 |
+
model = self.models.get(model_id)
|
| 127 |
+
info = {
|
| 128 |
+
"id": model_id,
|
| 129 |
+
"name": config.get("model_name"),
|
| 130 |
+
"type": config.get("model_type"),
|
| 131 |
+
"language": config.get("language"),
|
| 132 |
+
"description": config.get("description"),
|
| 133 |
+
"repository": config.get("repository"),
|
| 134 |
+
"loaded": model.loaded if model else False
|
| 135 |
+
}
|
| 136 |
+
models_info.append(info)
|
| 137 |
+
return models_info
|
models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import RerankRequest, RerankResponse, RerankResult
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = ["RerankRequest", "RerankResponse", "RerankResult"]
|
models/model.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RerankRequest(BaseModel):
|
| 6 |
+
"""
|
| 7 |
+
Request model for document reranking.
|
| 8 |
+
|
| 9 |
+
Attributes:
|
| 10 |
+
query: The search query
|
| 11 |
+
documents: List of documents to rerank
|
| 12 |
+
model_id: Identifier of the reranking model to use
|
| 13 |
+
instruction: Optional instruction for instruction-based models
|
| 14 |
+
top_k: Maximum number of documents to return (optional)
|
| 15 |
+
"""
|
| 16 |
+
query: str = Field(..., description="Search query text")
|
| 17 |
+
documents: List[str] = Field(..., min_items=1, description="List of documents to rerank")
|
| 18 |
+
model_id: Optional[str] = Field(..., description="Model identifier for reranking")
|
| 19 |
+
instruction: Optional[str] = Field(None, description="Optional instruction for reranking task")
|
| 20 |
+
top_k: Optional[int] = Field(None, description="Maximum number of results to return")
|
| 21 |
+
|
| 22 |
+
class RerankResult(BaseModel):
|
| 23 |
+
"""
|
| 24 |
+
Single reranking result.
|
| 25 |
+
|
| 26 |
+
Attributes:
|
| 27 |
+
text: The document text
|
| 28 |
+
score: Relevance score from the reranking model
|
| 29 |
+
index: Original index of the document in input list
|
| 30 |
+
"""
|
| 31 |
+
text: str
|
| 32 |
+
score: float
|
| 33 |
+
index: int
|
| 34 |
+
|
| 35 |
+
class RerankResponse(BaseModel):
|
| 36 |
+
"""
|
| 37 |
+
Response model for document reranking.
|
| 38 |
+
|
| 39 |
+
Attributes:
|
| 40 |
+
results: List of reranked documents with scores
|
| 41 |
+
query: The original search query
|
| 42 |
+
model_id: Identifier of the model used
|
| 43 |
+
processing_time: Time taken to process the request
|
| 44 |
+
total_documents: Total number of input documents
|
| 45 |
+
returned_documents: Number of documents returned
|
| 46 |
+
"""
|
| 47 |
+
results: List[RerankResult]
|
| 48 |
+
query: str
|
| 49 |
+
model_id: str
|
| 50 |
+
processing_time: float
|
| 51 |
+
total_documents: int
|
| 52 |
+
returned_documents: int
|