fahmiaziz98 commited on
Commit
7f8bfb2
·
1 Parent(s): 8786174

Refactor reranking models and configuration management; add YAML support for model settings

Browse files
Files changed (8) hide show
  1. app.py +14 -328
  2. config.yaml +28 -0
  3. core/__init__.py +3 -0
  4. core/base.py +21 -0
  5. core/cross_encoder.py +239 -0
  6. core/model_manager.py +137 -0
  7. models/__init__.py +4 -0
  8. models/model.py +52 -0
app.py CHANGED
@@ -1,326 +1,12 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel, Field
3
- from typing import List, Optional, Dict, Any
4
- from loguru import logger
5
  import time
6
- import torch
 
7
  from contextlib import asynccontextmanager
8
 
9
- from sentence_transformers import CrossEncoder
10
- from transformers import AutoTokenizer, AutoModelForCausalLM
11
-
12
-
13
- # -------------------------
14
- # Request/Response Models
15
- # -------------------------
16
-
17
- class RerankRequest(BaseModel):
18
- """
19
- Request model for document reranking.
20
-
21
- Attributes:
22
- query: The search query
23
- documents: List of documents to rerank
24
- model_id: Identifier of the reranking model to use
25
- instruction: Optional instruction for instruction-based models
26
- top_k: Maximum number of documents to return (optional)
27
- """
28
- query: str = Field(..., description="Search query text")
29
- documents: List[str] = Field(..., min_items=1, description="List of documents to rerank")
30
- model_id: str = Field(..., description="Model identifier for reranking")
31
- instruction: Optional[str] = Field(None, description="Optional instruction for reranking task")
32
- top_k: Optional[int] = Field(None, description="Maximum number of results to return")
33
-
34
-
35
- class RerankResult(BaseModel):
36
- """
37
- Single reranking result.
38
-
39
- Attributes:
40
- text: The document text
41
- score: Relevance score from the reranking model
42
- index: Original index of the document in input list
43
- """
44
- text: str
45
- score: float
46
- index: int
47
-
48
-
49
- class RerankResponse(BaseModel):
50
- """
51
- Response model for document reranking.
52
-
53
- Attributes:
54
- results: List of reranked documents with scores
55
- query: The original search query
56
- model_id: Identifier of the model used
57
- processing_time: Time taken to process the request
58
- total_documents: Total number of input documents
59
- returned_documents: Number of documents returned
60
- """
61
- results: List[RerankResult]
62
- query: str
63
- model_id: str
64
- processing_time: float
65
- total_documents: int
66
- returned_documents: int
67
-
68
-
69
- # -------------------------
70
- # Model Management
71
- # -------------------------
72
-
73
- class RerankerModel:
74
- """Base class for reranking models."""
75
-
76
- def __init__(self, model_id: str, model_name: str, model_type: str):
77
- self.model_id = model_id
78
- self.model_name = model_name
79
- self.model_type = model_type
80
- self.model = None
81
- self.tokenizer = None
82
- self.loaded = False
83
-
84
- def load(self):
85
- """Load the model. To be implemented by subclasses."""
86
- raise NotImplementedError
87
-
88
- def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
89
- """Rerank documents. To be implemented by subclasses."""
90
- raise NotImplementedError
91
-
92
-
93
- class SentenceTransformersReranker(RerankerModel):
94
- """Reranker using sentence-transformers CrossEncoder."""
95
-
96
- def load(self):
97
- """Load sentence-transformers CrossEncoder model."""
98
- try:
99
- logger.info(f"Loading SentenceTransformers model: {self.model_name}")
100
- self.model = CrossEncoder(
101
- self.model_name,
102
- model_kwargs={"torch_dtype": "auto"},
103
- trust_remote_code=True
104
- )
105
- self.loaded = True
106
- logger.success(f"Successfully loaded {self.model_id}")
107
- except Exception as e:
108
- logger.error(f"Failed to load {self.model_id}: {e}")
109
- raise
110
-
111
- def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
112
- """Rerank documents using CrossEncoder."""
113
- if not self.loaded:
114
- raise RuntimeError(f"Model {self.model_id} not loaded")
115
-
116
- try:
117
- # For sentence-transformers, we can use the rank method directly
118
- rankings = self.model.rank(query, documents, convert_to_tensor=True)
119
-
120
- # Extract scores and maintain original order
121
- scores = [0.0] * len(documents)
122
- for ranking in rankings:
123
- scores[ranking['corpus_id']] = float(ranking['score'])
124
-
125
- return scores
126
-
127
- except Exception as e:
128
- logger.error(f"Reranking failed with {self.model_id}: {e}")
129
- raise
130
 
131
 
132
- class QwenReranker(RerankerModel):
133
- """Reranker using Qwen3-Reranker model."""
134
-
135
- def load(self):
136
- """Load Qwen reranker model."""
137
- try:
138
- logger.info(f"Loading Qwen model: {self.model_name}")
139
-
140
- self.tokenizer = AutoTokenizer.from_pretrained(
141
- self.model_name,
142
- padding_side='left'
143
- )
144
- self.model = AutoModelForCausalLM.from_pretrained(
145
- self.model_name
146
- ).eval()
147
-
148
- # Set up Qwen-specific tokens
149
- self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
150
- self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
151
- self.max_length = 8192
152
-
153
- # Set up prompt templates
154
- self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
155
- self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
156
- self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
157
- self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
158
-
159
- self.loaded = True
160
- logger.success(f"Successfully loaded {self.model_id}")
161
-
162
- except Exception as e:
163
- logger.error(f"Failed to load {self.model_id}: {e}")
164
- raise
165
-
166
- def _format_instruction(self, instruction: str, query: str, doc: str) -> str:
167
- """Format instruction for Qwen model."""
168
- if instruction is None:
169
- instruction = 'Given a web search query, retrieve relevant passages that answer the query'
170
-
171
- return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
172
- instruction=instruction, query=query, doc=doc
173
- )
174
-
175
- def _process_inputs(self, pairs: List[str]):
176
- """Process input pairs for Qwen model."""
177
- inputs = self.tokenizer(
178
- pairs,
179
- padding=False,
180
- truncation='longest_first',
181
- return_attention_mask=False,
182
- max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
183
- )
184
-
185
- for i, ele in enumerate(inputs['input_ids']):
186
- inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
187
-
188
- inputs = self.tokenizer.pad(
189
- inputs,
190
- padding=True,
191
- return_tensors="pt",
192
- max_length=self.max_length
193
- )
194
-
195
- for key in inputs:
196
- inputs[key] = inputs[key].to(self.model.device)
197
-
198
- return inputs
199
-
200
- @torch.no_grad()
201
- def _compute_logits(self, inputs):
202
- """Compute relevance scores from model logits."""
203
- batch_scores = self.model(**inputs).logits[:, -1, :]
204
- true_vector = batch_scores[:, self.token_true_id]
205
- false_vector = batch_scores[:, self.token_false_id]
206
- batch_scores = torch.stack([false_vector, true_vector], dim=1)
207
- batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
208
- scores = batch_scores[:, 1].exp().tolist()
209
- return scores
210
-
211
- def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
212
- """Rerank documents using Qwen model."""
213
- if not self.loaded:
214
- raise RuntimeError(f"Model {self.model_id} not loaded")
215
-
216
- try:
217
- # Format instruction pairs
218
- pairs = [
219
- self._format_instruction(instruction, query, doc)
220
- for doc in documents
221
- ]
222
-
223
- # Process inputs
224
- inputs = self._process_inputs(pairs)
225
-
226
- # Compute scores
227
- scores = self._compute_logits(inputs)
228
-
229
- return scores
230
-
231
- except Exception as e:
232
- logger.error(f"Reranking failed with {self.model_id}: {e}")
233
- raise
234
-
235
-
236
- class ModelManager:
237
- """Manager for reranking models with preloading."""
238
-
239
- def __init__(self):
240
- self.models: Dict[str, RerankerModel] = {}
241
- self.model_configs = {
242
- "jina-reranker-v2": {
243
- "model_name": "jinaai/jina-reranker-v2-base-multilingual",
244
- "model_type": "sentence_transformers",
245
- "description": "Multilingual reranker from Jina AI"
246
- },
247
- "bge-reranker-v2": {
248
- "model_name": "BAAI/bge-reranker-v2-m3",
249
- "model_type": "sentence_transformers",
250
- "description": "BGE multilingual reranker"
251
- },
252
- "qwen3-reranker": {
253
- "model_name": "Qwen/Qwen3-Reranker-0.6B",
254
- "model_type": "qwen",
255
- "description": "Qwen3 instruction-based reranker"
256
- }
257
- }
258
-
259
- async def preload_all_models(self):
260
- """Preload all configured models."""
261
- logger.info(f"Starting preload of {len(self.model_configs)} reranking models...")
262
-
263
- for model_id, config in self.model_configs.items():
264
- try:
265
- logger.info(f"Loading {model_id}...")
266
-
267
- if config["model_type"] == "sentence_transformers":
268
- model = SentenceTransformersReranker(
269
- model_id=model_id,
270
- model_name=config["model_name"],
271
- model_type=config["model_type"]
272
- )
273
- elif config["model_type"] == "qwen":
274
- model = QwenReranker(
275
- model_id=model_id,
276
- model_name=config["model_name"],
277
- model_type=config["model_type"]
278
- )
279
- else:
280
- logger.error(f"Unknown model type: {config['model_type']}")
281
- continue
282
-
283
- model.load()
284
- self.models[model_id] = model
285
- logger.success(f"Successfully preloaded {model_id}")
286
-
287
- except Exception as e:
288
- logger.error(f"Failed to preload {model_id}: {e}")
289
-
290
- loaded_count = len([m for m in self.models.values() if m.loaded])
291
- logger.success(f"Preloaded {loaded_count}/{len(self.model_configs)} models successfully")
292
-
293
- def get_model(self, model_id: str) -> RerankerModel:
294
- """Get a loaded model by ID."""
295
- if model_id not in self.models:
296
- raise ValueError(f"Model {model_id} not found")
297
-
298
- model = self.models[model_id]
299
- if not model.loaded:
300
- raise ValueError(f"Model {model_id} not loaded")
301
-
302
- return model
303
-
304
- def list_models(self) -> List[Dict[str, Any]]:
305
- """List all available models with their status."""
306
- models_info = []
307
- for model_id, config in self.model_configs.items():
308
- model = self.models.get(model_id)
309
- info = {
310
- "id": model_id,
311
- "name": config["model_name"],
312
- "type": config["model_type"],
313
- "description": config["description"],
314
- "loaded": model.loaded if model else False
315
- }
316
- models_info.append(info)
317
- return models_info
318
-
319
-
320
- # -------------------------
321
- # Application Setup
322
- # -------------------------
323
-
324
  model_manager = None
325
 
326
  @asynccontextmanager
@@ -331,7 +17,7 @@ async def lifespan(app: FastAPI):
331
  # Startup
332
  logger.info("Starting reranking API...")
333
  try:
334
- model_manager = ModelManager()
335
  await model_manager.preload_all_models()
336
  logger.success("Reranking API startup complete!")
337
  except Exception as e:
@@ -357,6 +43,7 @@ High-performance API for document reranking using multiple state-of-the-art mode
357
  🚀 **Features:**
358
  - Multiple reranking models preloaded at startup
359
  - Batch document reranking with relevance scoring
 
360
  - Optional instruction-based reranking (Qwen3)
361
  - Comprehensive performance metrics
362
  - Zero cold start delay
@@ -364,6 +51,8 @@ High-performance API for document reranking using multiple state-of-the-art mode
364
  📊 **Input/Output:**
365
  - Input: Query + documents + optional instruction
366
  - Output: Ranked documents with relevance scores
 
 
367
  """,
368
  version="1.0.0",
369
  lifespan=lifespan
@@ -407,7 +96,6 @@ async def rerank_documents(request: RerankRequest):
407
  if not request.documents:
408
  raise HTTPException(400, "Documents list cannot be empty")
409
 
410
- # Filter out empty documents
411
  valid_docs = [(i, doc.strip()) for i, doc in enumerate(request.documents) if doc.strip()]
412
  if not valid_docs:
413
  raise HTTPException(400, "No valid documents found after filtering empty strings")
@@ -415,20 +103,16 @@ async def rerank_documents(request: RerankRequest):
415
  try:
416
  start_time = time.time()
417
 
418
- # Get model
419
  model = model_manager.get_model(request.model_id)
420
-
421
- # Extract valid documents and their indices
422
  original_indices, documents = zip(*valid_docs)
423
-
424
- # Perform reranking
425
  scores = model.rerank(
426
  query=request.query.strip(),
427
  documents=list(documents),
428
  instruction=request.instruction
429
  )
430
 
431
- # Create results with original indices
432
  results = []
433
  for i, (orig_idx, doc, score) in enumerate(zip(original_indices, documents, scores)):
434
  results.append(RerankResult(
@@ -437,10 +121,8 @@ async def rerank_documents(request: RerankRequest):
437
  index=orig_idx
438
  ))
439
 
440
- # Sort by score (descending)
441
  results.sort(key=lambda x: x.score, reverse=True)
442
 
443
- # Apply top_k limit if specified
444
  if request.top_k:
445
  results = results[:request.top_k]
446
 
@@ -513,4 +195,8 @@ async def health_check():
513
  "status": "error",
514
  "error": str(e)
515
  }
 
 
 
 
516
 
 
 
 
 
 
1
  import time
2
+ from loguru import logger
3
+ from fastapi import FastAPI, HTTPException
4
  from contextlib import asynccontextmanager
5
 
6
+ from models import RerankRequest, RerankResponse, RerankResult
7
+ from core import ModelManager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  model_manager = None
11
 
12
  @asynccontextmanager
 
17
  # Startup
18
  logger.info("Starting reranking API...")
19
  try:
20
+ model_manager = ModelManager("config.yaml")
21
  await model_manager.preload_all_models()
22
  logger.success("Reranking API startup complete!")
23
  except Exception as e:
 
43
  🚀 **Features:**
44
  - Multiple reranking models preloaded at startup
45
  - Batch document reranking with relevance scoring
46
+ - Fast prototyping app
47
  - Optional instruction-based reranking (Qwen3)
48
  - Comprehensive performance metrics
49
  - Zero cold start delay
 
51
  📊 **Input/Output:**
52
  - Input: Query + documents + optional instruction
53
  - Output: Ranked documents with relevance scores
54
+
55
+ **Warning**: Not use production!.
56
  """,
57
  version="1.0.0",
58
  lifespan=lifespan
 
96
  if not request.documents:
97
  raise HTTPException(400, "Documents list cannot be empty")
98
 
 
99
  valid_docs = [(i, doc.strip()) for i, doc in enumerate(request.documents) if doc.strip()]
100
  if not valid_docs:
101
  raise HTTPException(400, "No valid documents found after filtering empty strings")
 
103
  try:
104
  start_time = time.time()
105
 
 
106
  model = model_manager.get_model(request.model_id)
 
 
107
  original_indices, documents = zip(*valid_docs)
108
+ logger.info(f"Query: {request.query.strip()}")
109
+ logger.info(f"Document: {list(documents)}")
110
  scores = model.rerank(
111
  query=request.query.strip(),
112
  documents=list(documents),
113
  instruction=request.instruction
114
  )
115
 
 
116
  results = []
117
  for i, (orig_idx, doc, score) in enumerate(zip(original_indices, documents, scores)):
118
  results.append(RerankResult(
 
121
  index=orig_idx
122
  ))
123
 
 
124
  results.sort(key=lambda x: x.score, reverse=True)
125
 
 
126
  if request.top_k:
127
  results = results[:request.top_k]
128
 
 
195
  "status": "error",
196
  "error": str(e)
197
  }
198
+
199
+ @app.get("/")
200
+ async def root():
201
+ return {"message": "Welcome to the Multi-Model Reranking API. Visit /docs for API documentation.", "version": "1.0.0"}
202
 
config.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model configuration for ModelManager
2
+ # You can add or modify model entries as needed
3
+ models:
4
+ jina-reranker-v2:
5
+ model_name: jinaai/jina-reranker-v2-base-multilingual
6
+ model_type: sentence_transformers
7
+ description: |
8
+ The Jina Reranker v2 (jina-reranker-v2-base-multilingual) is a transformer-based model that has been fine-tuned for text reranking task, which is a crucial component in many information retrieval systems. It is a cross-encoder model that takes a query and a document pair as input and outputs a score indicating the relevance of the document to the query. The model is trained on a large dataset of query-document pairs and is capable of reranking documents in multiple languages with high accuracy.
9
+ languages: ["multilingual"]
10
+ repository: https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual
11
+
12
+ bge-reranker-v2:
13
+ model_name: BAAI/bge-reranker-v2-m3
14
+ model_type: sentence_transformers
15
+ description: |
16
+ Different from embedding model, reranker uses question and document as input and directly output similarity instead of embedding. You can get a relevance score by inputting query and passage to the reranker. And the score can be mapped to a float value in [0,1] by sigmoid function.
17
+ languages: ["multilingual"]
18
+ repository: https://huggingface.co/BAAI/bge-reranker-v2-m3
19
+
20
+ qwen3-reranker:
21
+ model_name: Qwen/Qwen3-Reranker-0.6B
22
+ model_type: qwen
23
+ description: |
24
+ The Qwen3 Embedding model series is the latest proprietary model of the Qwen family, specifically designed for text embedding and ranking tasks. Building upon the dense foundational models of the Qwen3 series, it provides a comprehensive range of text embeddings and reranking models in various sizes (0.6B, 4B, and 8B). This series inherits the exceptional multilingual capabilities, long-text understanding, and reasoning skills of its foundational model. The Qwen3 Embedding series represents significant advancements in multiple text embedding and ranking tasks, including text retrieval, code retrieval, text classification, text clustering, and bitext mining.
25
+ languages: ["multilingual"]
26
+ repository: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
27
+
28
+ default_model: bge-reranker-v2
core/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model_manager import ModelManager
2
+
3
+ __all__ = ["ModelManager"]
core/base.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+
4
+ class RerankerModel:
5
+ """Base class for reranking models."""
6
+
7
+ def __init__(self, model_id: str, model_name: str, model_type: str):
8
+ self.model_id = model_id
9
+ self.model_name = model_name
10
+ self.model_type = model_type
11
+ self.model = None
12
+ self.tokenizer = None
13
+ self.loaded = False
14
+
15
+ def load(self):
16
+ """Load the model. To be implemented by subclasses."""
17
+ raise NotImplementedError
18
+
19
+ def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
20
+ """Rerank documents. To be implemented by subclasses."""
21
+ raise NotImplementedError
core/cross_encoder.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Optional
3
+ from loguru import logger
4
+ from sentence_transformers import CrossEncoder
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+
7
+ from .base import RerankerModel
8
+
9
+
10
+ class SentenceTransformersReranker(RerankerModel):
11
+ """
12
+ Reranker using sentence-transformers CrossEncoder.
13
+
14
+ This class leverages the CrossEncoder model from the sentence-transformers library to score the relevance of documents given a query. It is suitable for reranking tasks in information retrieval pipelines.
15
+
16
+ Attributes:
17
+ model_name (str): Name or path of the model to load.
18
+ model (CrossEncoder): The loaded CrossEncoder model instance.
19
+ loaded (bool): Whether the model has been loaded.
20
+ model_id (str): Unique identifier for the model instance.
21
+ """
22
+
23
+ def load(self):
24
+ """
25
+ Load the sentence-transformers CrossEncoder model.
26
+
27
+ Loads the CrossEncoder model specified by self.model_name. Sets self.loaded to True if successful.
28
+
29
+ Raises:
30
+ Exception: If the model fails to load.
31
+ """
32
+ try:
33
+ logger.info(f"Loading SentenceTransformers model: {self.model_name}")
34
+ self.model = CrossEncoder(
35
+ self.model_name,
36
+ model_kwargs={"torch_dtype": "auto"},
37
+ trust_remote_code=True
38
+ )
39
+ self.loaded = True
40
+ logger.success(f"Successfully loaded {self.model_id}")
41
+ except Exception as e:
42
+ logger.error(f"Failed to load {self.model_id}: {e}")
43
+ raise
44
+
45
+ def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
46
+ """
47
+ Rerank documents using the CrossEncoder model.
48
+
49
+ Args:
50
+ query (str): The search query string.
51
+ documents (List[str]): List of documents to be reranked.
52
+ instruction (Optional[str]): Additional instruction for reranking (not used in this implementation).
53
+
54
+ Returns:
55
+ List[float]: List of relevance scores for each document.
56
+
57
+ Raises:
58
+ RuntimeError: If the model is not loaded.
59
+ Exception: If reranking fails.
60
+ """
61
+ if not self.loaded:
62
+ raise RuntimeError(f"Model {self.model_id} not loaded")
63
+
64
+ try:
65
+ rankings = self.model.rank(query, documents, convert_to_tensor=True)
66
+
67
+ scores = [0.0] * len(documents)
68
+ for ranking in rankings:
69
+ scores[ranking['corpus_id']] = float(ranking['score'])
70
+
71
+ return scores
72
+
73
+ except Exception as e:
74
+ logger.error(f"Reranking failed with {self.model_id}: {e}")
75
+ raise
76
+
77
+
78
+
79
+ class QwenReranker(RerankerModel):
80
+ """
81
+ Reranker using Qwen3-Reranker model (LLM-based).
82
+
83
+ This class uses a Qwen LLM to judge the relevance of documents to a query and instruction. The model outputs a probability that each document is relevant ("yes") or not ("no").
84
+
85
+ Attributes:
86
+ model_name (str): Name or path of the Qwen model.
87
+ tokenizer (AutoTokenizer): Tokenizer for the Qwen model.
88
+ model (AutoModelForCausalLM): Loaded Qwen model instance.
89
+ loaded (bool): Whether the model has been loaded.
90
+ model_id (str): Unique identifier for the model instance.
91
+ token_false_id (int): Token ID for "no".
92
+ token_true_id (int): Token ID for "yes".
93
+ max_length (int): Maximum input token length.
94
+ prefix (str): Prompt prefix for the system message.
95
+ suffix (str): Prompt suffix for the assistant message.
96
+ prefix_tokens (List[int]): Tokenized prefix.
97
+ suffix_tokens (List[int]): Tokenized suffix.
98
+ """
99
+
100
+ def load(self):
101
+ """
102
+ Load the Qwen reranker model and tokenizer, and initialize prompt templates and special tokens.
103
+
104
+ Raises:
105
+ Exception: If the model or tokenizer fails to load.
106
+ """
107
+ try:
108
+ logger.info(f"Loading Qwen model: {self.model_name}")
109
+
110
+ self.tokenizer = AutoTokenizer.from_pretrained(
111
+ self.model_name,
112
+ padding_side='left'
113
+ )
114
+ self.model = AutoModelForCausalLM.from_pretrained(
115
+ self.model_name
116
+ ).eval()
117
+
118
+ # Set up Qwen-specific tokens
119
+ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
120
+ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
121
+ self.max_length = 8192
122
+
123
+ # Set up prompt templates
124
+ self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
125
+ self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
126
+ self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
127
+ self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
128
+
129
+ self.loaded = True
130
+ logger.success(f"Successfully loaded {self.model_id}")
131
+
132
+ except Exception as e:
133
+ logger.error(f"Failed to load {self.model_id}: {e}")
134
+ raise
135
+
136
+ def _format_instruction(self, instruction: str, query: str, doc: str) -> str:
137
+ """
138
+ Format the instruction string for the Qwen model prompt.
139
+
140
+ Args:
141
+ instruction (str): The instruction for the reranker. If None, a default instruction is used.
142
+ query (str): The search query string.
143
+ doc (str): The document to be evaluated.
144
+
145
+ Returns:
146
+ str: Formatted prompt string for the model.
147
+ """
148
+ if instruction is None:
149
+ instruction = 'Given a web search query, retrieve relevant passages that answer the query'
150
+
151
+ return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
152
+ instruction=instruction, query=query, doc=doc
153
+ )
154
+
155
+ def _process_inputs(self, pairs: List[str]):
156
+ """
157
+ Tokenize and prepare input pairs for the Qwen model.
158
+
159
+ Args:
160
+ pairs (List[str]): List of formatted prompt strings for each document.
161
+
162
+ Returns:
163
+ dict: Tokenized and padded input tensors for the model.
164
+ """
165
+ inputs = self.tokenizer(
166
+ pairs,
167
+ padding=False,
168
+ truncation='longest_first',
169
+ return_attention_mask=False,
170
+ max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
171
+ )
172
+
173
+ for i, ele in enumerate(inputs['input_ids']):
174
+ inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
175
+
176
+ inputs = self.tokenizer.pad(
177
+ inputs,
178
+ padding=True,
179
+ return_tensors="pt",
180
+ max_length=self.max_length
181
+ )
182
+
183
+ for key in inputs:
184
+ inputs[key] = inputs[key].to(self.model.device)
185
+
186
+ return inputs
187
+
188
+ @torch.no_grad()
189
+ def _compute_logits(self, inputs):
190
+ """
191
+ Compute relevance scores from model logits.
192
+
193
+ Args:
194
+ inputs (dict): Tokenized and padded input tensors for the model.
195
+
196
+ Returns:
197
+ List[float]: List of probabilities that each document is relevant ("yes").
198
+ """
199
+ batch_scores = self.model(**inputs).logits[:, -1, :]
200
+ true_vector = batch_scores[:, self.token_true_id]
201
+ false_vector = batch_scores[:, self.token_false_id]
202
+ batch_scores = torch.stack([false_vector, true_vector], dim=1)
203
+ batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
204
+ scores = batch_scores[:, 1].exp().tolist()
205
+ return scores
206
+
207
+ def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
208
+ """
209
+ Rerank documents using the Qwen model.
210
+
211
+ Args:
212
+ query (str): The search query string.
213
+ documents (List[str]): List of documents to be reranked.
214
+ instruction (Optional[str]): Additional instruction for reranking.
215
+
216
+ Returns:
217
+ List[float]: List of relevance scores for each document.
218
+
219
+ Raises:
220
+ RuntimeError: If the model is not loaded.
221
+ Exception: If reranking fails.
222
+ """
223
+ if not self.loaded:
224
+ raise RuntimeError(f"Model {self.model_id} not loaded")
225
+
226
+ try:
227
+ pairs = [
228
+ self._format_instruction(instruction, query, doc)
229
+ for doc in documents
230
+ ]
231
+
232
+ inputs = self._process_inputs(pairs)
233
+ scores = self._compute_logits(inputs)
234
+
235
+ return scores
236
+
237
+ except Exception as e:
238
+ logger.error(f"Reranking failed with {self.model_id}: {e}")
239
+ raise
core/model_manager.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from typing import List, Dict, Any
3
+ from loguru import logger
4
+
5
+ from .base import RerankerModel
6
+ from .cross_encoder import SentenceTransformersReranker, QwenReranker
7
+
8
+
9
+ class ModelManager:
10
+ """
11
+ Manager for reranking models with preloading and configuration.
12
+
13
+ This class loads model configurations from a YAML file (default: config.yaml),
14
+ instantiates and manages multiple reranker models, and provides methods to preload,
15
+ retrieve, and list the available models. Supports a default model if model_id is not provided.
16
+
17
+ Attributes:
18
+ models (Dict[str, RerankerModel]): Dictionary of loaded model instances keyed by model ID.
19
+ model_configs (Dict[str, Dict[str, Any]]): Model configuration loaded from YAML file.
20
+ default_model_id (str): The default model ID to use if none is provided.
21
+ """
22
+
23
+ def __init__(self, config_path: str = 'config.yaml'):
24
+ """
25
+ Initialize the ModelManager and load model configurations from a YAML file.
26
+
27
+ Args:
28
+ config_path (str): Path to the YAML configuration file. Defaults to 'config.yaml'.
29
+
30
+ Side Effects:
31
+ Loads model configuration into self.model_configs.
32
+ Initializes an empty dictionary for loaded models.
33
+ Sets the default model ID from config.
34
+ """
35
+ self.models: Dict[str, RerankerModel] = {}
36
+ try:
37
+ with open(config_path, 'r') as f:
38
+ config_data = yaml.safe_load(f)
39
+ self.model_configs = config_data.get('models', {})
40
+ self.default_model_id = config_data.get('default_model')
41
+ logger.info(f"Loaded model configs from {config_path}")
42
+ except Exception as e:
43
+ logger.error(f"Failed to load config.yaml: {e}")
44
+ self.model_configs = {}
45
+ self.default_model_id = None
46
+
47
+ async def preload_all_models(self):
48
+ """
49
+ Preload all models defined in the configuration file.
50
+
51
+ Iterates through all model configurations, instantiates the appropriate reranker class
52
+ (SentenceTransformersReranker or QwenReranker), loads the model, and stores it in self.models.
53
+ Logs the status of each model load and a summary at the end.
54
+
55
+ Raises:
56
+ Exception: If a model fails to load, logs the error and continues with the next model.
57
+ """
58
+ logger.info(f"Starting preload of {len(self.model_configs)} reranking models...")
59
+
60
+ for model_id, config in self.model_configs.items():
61
+ try:
62
+ logger.info(f"Loading {model_id}...")
63
+
64
+ if config["model_type"] == "sentence_transformers":
65
+ model = SentenceTransformersReranker(
66
+ model_id=model_id,
67
+ model_name=config["model_name"],
68
+ model_type=config["model_type"]
69
+ )
70
+ elif config["model_type"] == "qwen":
71
+ model = QwenReranker(
72
+ model_id=model_id,
73
+ model_name=config["model_name"],
74
+ model_type=config["model_type"]
75
+ )
76
+ else:
77
+ logger.error(f"Unknown model type: {config['model_type']}")
78
+ continue
79
+
80
+ model.load()
81
+ self.models[model_id] = model
82
+ logger.success(f"Successfully preloaded {model_id}")
83
+
84
+ except Exception as e:
85
+ logger.error(f"Failed to preload {model_id}: {e}")
86
+
87
+ loaded_count = len([m for m in self.models.values() if m.loaded])
88
+ logger.success(f"Preloaded {loaded_count}/{len(self.model_configs)} models successfully")
89
+
90
+ def get_model(self, model_id: str = None) -> RerankerModel:
91
+ """
92
+ Retrieve a loaded model instance by its ID, or use the default model if not specified.
93
+
94
+ Args:
95
+ model_id (str, optional): The unique identifier of the model to retrieve. If None, uses the default model.
96
+
97
+ Returns:
98
+ RerankerModel: The loaded reranker model instance.
99
+
100
+ Raises:
101
+ ValueError: If the model is not found or not loaded.
102
+ """
103
+ if model_id is None:
104
+ if not self.default_model_id:
105
+ raise ValueError("No model_id provided and no default_model set in config.yaml")
106
+ model_id = self.default_model_id
107
+
108
+ if model_id not in self.models:
109
+ raise ValueError(f"Model {model_id} not found")
110
+
111
+ model = self.models[model_id]
112
+ if not model.loaded:
113
+ raise ValueError(f"Model {model_id} not loaded")
114
+
115
+ return model
116
+
117
+ def list_models(self) -> List[Dict[str, Any]]:
118
+ """
119
+ List all available models with their configuration and load status.
120
+
121
+ Returns:
122
+ List[Dict[str, Any]]: A list of dictionaries, each containing model ID, name, type, description, and loaded status.
123
+ """
124
+ models_info = []
125
+ for model_id, config in self.model_configs.items():
126
+ model = self.models.get(model_id)
127
+ info = {
128
+ "id": model_id,
129
+ "name": config.get("model_name"),
130
+ "type": config.get("model_type"),
131
+ "language": config.get("language"),
132
+ "description": config.get("description"),
133
+ "repository": config.get("repository"),
134
+ "loaded": model.loaded if model else False
135
+ }
136
+ models_info.append(info)
137
+ return models_info
models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .model import RerankRequest, RerankResponse, RerankResult
2
+
3
+
4
+ __all__ = ["RerankRequest", "RerankResponse", "RerankResult"]
models/model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from pydantic import BaseModel, Field
3
+
4
+
5
+ class RerankRequest(BaseModel):
6
+ """
7
+ Request model for document reranking.
8
+
9
+ Attributes:
10
+ query: The search query
11
+ documents: List of documents to rerank
12
+ model_id: Identifier of the reranking model to use
13
+ instruction: Optional instruction for instruction-based models
14
+ top_k: Maximum number of documents to return (optional)
15
+ """
16
+ query: str = Field(..., description="Search query text")
17
+ documents: List[str] = Field(..., min_items=1, description="List of documents to rerank")
18
+ model_id: Optional[str] = Field(..., description="Model identifier for reranking")
19
+ instruction: Optional[str] = Field(None, description="Optional instruction for reranking task")
20
+ top_k: Optional[int] = Field(None, description="Maximum number of results to return")
21
+
22
+ class RerankResult(BaseModel):
23
+ """
24
+ Single reranking result.
25
+
26
+ Attributes:
27
+ text: The document text
28
+ score: Relevance score from the reranking model
29
+ index: Original index of the document in input list
30
+ """
31
+ text: str
32
+ score: float
33
+ index: int
34
+
35
+ class RerankResponse(BaseModel):
36
+ """
37
+ Response model for document reranking.
38
+
39
+ Attributes:
40
+ results: List of reranked documents with scores
41
+ query: The original search query
42
+ model_id: Identifier of the model used
43
+ processing_time: Time taken to process the request
44
+ total_documents: Total number of input documents
45
+ returned_documents: Number of documents returned
46
+ """
47
+ results: List[RerankResult]
48
+ query: str
49
+ model_id: str
50
+ processing_time: float
51
+ total_documents: int
52
+ returned_documents: int