fahmiaziz98 commited on
Commit
073edba
Β·
1 Parent(s): 76d149a
Files changed (2) hide show
  1. app.py +514 -5
  2. requirements.txt +6 -2
app.py CHANGED
@@ -1,7 +1,516 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
2
 
3
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel, Field
3
+ from typing import List, Optional, Dict, Any
4
+ from loguru import logger
5
+ import time
6
+ import torch
7
+ from contextlib import asynccontextmanager
8
 
9
+ from sentence_transformers import CrossEncoder
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+
12
+
13
+ # -------------------------
14
+ # Request/Response Models
15
+ # -------------------------
16
+
17
+ class RerankRequest(BaseModel):
18
+ """
19
+ Request model for document reranking.
20
+
21
+ Attributes:
22
+ query: The search query
23
+ documents: List of documents to rerank
24
+ model_id: Identifier of the reranking model to use
25
+ instruction: Optional instruction for instruction-based models
26
+ top_k: Maximum number of documents to return (optional)
27
+ """
28
+ query: str = Field(..., description="Search query text")
29
+ documents: List[str] = Field(..., min_items=1, description="List of documents to rerank")
30
+ model_id: str = Field(..., description="Model identifier for reranking")
31
+ instruction: Optional[str] = Field(None, description="Optional instruction for reranking task")
32
+ top_k: Optional[int] = Field(None, description="Maximum number of results to return")
33
+
34
+
35
+ class RerankResult(BaseModel):
36
+ """
37
+ Single reranking result.
38
+
39
+ Attributes:
40
+ text: The document text
41
+ score: Relevance score from the reranking model
42
+ index: Original index of the document in input list
43
+ """
44
+ text: str
45
+ score: float
46
+ index: int
47
+
48
+
49
+ class RerankResponse(BaseModel):
50
+ """
51
+ Response model for document reranking.
52
+
53
+ Attributes:
54
+ results: List of reranked documents with scores
55
+ query: The original search query
56
+ model_id: Identifier of the model used
57
+ processing_time: Time taken to process the request
58
+ total_documents: Total number of input documents
59
+ returned_documents: Number of documents returned
60
+ """
61
+ results: List[RerankResult]
62
+ query: str
63
+ model_id: str
64
+ processing_time: float
65
+ total_documents: int
66
+ returned_documents: int
67
+
68
+
69
+ # -------------------------
70
+ # Model Management
71
+ # -------------------------
72
+
73
+ class RerankerModel:
74
+ """Base class for reranking models."""
75
+
76
+ def __init__(self, model_id: str, model_name: str, model_type: str):
77
+ self.model_id = model_id
78
+ self.model_name = model_name
79
+ self.model_type = model_type
80
+ self.model = None
81
+ self.tokenizer = None
82
+ self.loaded = False
83
+
84
+ def load(self):
85
+ """Load the model. To be implemented by subclasses."""
86
+ raise NotImplementedError
87
+
88
+ def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
89
+ """Rerank documents. To be implemented by subclasses."""
90
+ raise NotImplementedError
91
+
92
+
93
+ class SentenceTransformersReranker(RerankerModel):
94
+ """Reranker using sentence-transformers CrossEncoder."""
95
+
96
+ def load(self):
97
+ """Load sentence-transformers CrossEncoder model."""
98
+ try:
99
+ logger.info(f"Loading SentenceTransformers model: {self.model_name}")
100
+ self.model = CrossEncoder(
101
+ self.model_name,
102
+ model_kwargs={"torch_dtype": "auto"},
103
+ trust_remote_code=True
104
+ )
105
+ self.loaded = True
106
+ logger.success(f"Successfully loaded {self.model_id}")
107
+ except Exception as e:
108
+ logger.error(f"Failed to load {self.model_id}: {e}")
109
+ raise
110
+
111
+ def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
112
+ """Rerank documents using CrossEncoder."""
113
+ if not self.loaded:
114
+ raise RuntimeError(f"Model {self.model_id} not loaded")
115
+
116
+ try:
117
+ # For sentence-transformers, we can use the rank method directly
118
+ rankings = self.model.rank(query, documents, convert_to_tensor=True)
119
+
120
+ # Extract scores and maintain original order
121
+ scores = [0.0] * len(documents)
122
+ for ranking in rankings:
123
+ scores[ranking['corpus_id']] = float(ranking['score'])
124
+
125
+ return scores
126
+
127
+ except Exception as e:
128
+ logger.error(f"Reranking failed with {self.model_id}: {e}")
129
+ raise
130
+
131
+
132
+ class QwenReranker(RerankerModel):
133
+ """Reranker using Qwen3-Reranker model."""
134
+
135
+ def load(self):
136
+ """Load Qwen reranker model."""
137
+ try:
138
+ logger.info(f"Loading Qwen model: {self.model_name}")
139
+
140
+ self.tokenizer = AutoTokenizer.from_pretrained(
141
+ self.model_name,
142
+ padding_side='left'
143
+ )
144
+ self.model = AutoModelForCausalLM.from_pretrained(
145
+ self.model_name
146
+ ).eval()
147
+
148
+ # Set up Qwen-specific tokens
149
+ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
150
+ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
151
+ self.max_length = 8192
152
+
153
+ # Set up prompt templates
154
+ self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
155
+ self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
156
+ self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
157
+ self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
158
+
159
+ self.loaded = True
160
+ logger.success(f"Successfully loaded {self.model_id}")
161
+
162
+ except Exception as e:
163
+ logger.error(f"Failed to load {self.model_id}: {e}")
164
+ raise
165
+
166
+ def _format_instruction(self, instruction: str, query: str, doc: str) -> str:
167
+ """Format instruction for Qwen model."""
168
+ if instruction is None:
169
+ instruction = 'Given a web search query, retrieve relevant passages that answer the query'
170
+
171
+ return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
172
+ instruction=instruction, query=query, doc=doc
173
+ )
174
+
175
+ def _process_inputs(self, pairs: List[str]):
176
+ """Process input pairs for Qwen model."""
177
+ inputs = self.tokenizer(
178
+ pairs,
179
+ padding=False,
180
+ truncation='longest_first',
181
+ return_attention_mask=False,
182
+ max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
183
+ )
184
+
185
+ for i, ele in enumerate(inputs['input_ids']):
186
+ inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
187
+
188
+ inputs = self.tokenizer.pad(
189
+ inputs,
190
+ padding=True,
191
+ return_tensors="pt",
192
+ max_length=self.max_length
193
+ )
194
+
195
+ for key in inputs:
196
+ inputs[key] = inputs[key].to(self.model.device)
197
+
198
+ return inputs
199
+
200
+ @torch.no_grad()
201
+ def _compute_logits(self, inputs):
202
+ """Compute relevance scores from model logits."""
203
+ batch_scores = self.model(**inputs).logits[:, -1, :]
204
+ true_vector = batch_scores[:, self.token_true_id]
205
+ false_vector = batch_scores[:, self.token_false_id]
206
+ batch_scores = torch.stack([false_vector, true_vector], dim=1)
207
+ batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
208
+ scores = batch_scores[:, 1].exp().tolist()
209
+ return scores
210
+
211
+ def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
212
+ """Rerank documents using Qwen model."""
213
+ if not self.loaded:
214
+ raise RuntimeError(f"Model {self.model_id} not loaded")
215
+
216
+ try:
217
+ # Format instruction pairs
218
+ pairs = [
219
+ self._format_instruction(instruction, query, doc)
220
+ for doc in documents
221
+ ]
222
+
223
+ # Process inputs
224
+ inputs = self._process_inputs(pairs)
225
+
226
+ # Compute scores
227
+ scores = self._compute_logits(inputs)
228
+
229
+ return scores
230
+
231
+ except Exception as e:
232
+ logger.error(f"Reranking failed with {self.model_id}: {e}")
233
+ raise
234
+
235
+
236
+ class ModelManager:
237
+ """Manager for reranking models with preloading."""
238
+
239
+ def __init__(self):
240
+ self.models: Dict[str, RerankerModel] = {}
241
+ self.model_configs = {
242
+ "jina-reranker-v2": {
243
+ "model_name": "jinaai/jina-reranker-v2-base-multilingual",
244
+ "model_type": "sentence_transformers",
245
+ "description": "Multilingual reranker from Jina AI"
246
+ },
247
+ "bge-reranker-v2": {
248
+ "model_name": "BAAI/bge-reranker-v2-m3",
249
+ "model_type": "sentence_transformers",
250
+ "description": "BGE multilingual reranker"
251
+ },
252
+ "qwen3-reranker": {
253
+ "model_name": "Qwen/Qwen3-Reranker-0.6B",
254
+ "model_type": "qwen",
255
+ "description": "Qwen3 instruction-based reranker"
256
+ }
257
+ }
258
+
259
+ async def preload_all_models(self):
260
+ """Preload all configured models."""
261
+ logger.info(f"Starting preload of {len(self.model_configs)} reranking models...")
262
+
263
+ for model_id, config in self.model_configs.items():
264
+ try:
265
+ logger.info(f"Loading {model_id}...")
266
+
267
+ if config["model_type"] == "sentence_transformers":
268
+ model = SentenceTransformersReranker(
269
+ model_id=model_id,
270
+ model_name=config["model_name"],
271
+ model_type=config["model_type"]
272
+ )
273
+ elif config["model_type"] == "qwen":
274
+ model = QwenReranker(
275
+ model_id=model_id,
276
+ model_name=config["model_name"],
277
+ model_type=config["model_type"]
278
+ )
279
+ else:
280
+ logger.error(f"Unknown model type: {config['model_type']}")
281
+ continue
282
+
283
+ model.load()
284
+ self.models[model_id] = model
285
+ logger.success(f"Successfully preloaded {model_id}")
286
+
287
+ except Exception as e:
288
+ logger.error(f"Failed to preload {model_id}: {e}")
289
+
290
+ loaded_count = len([m for m in self.models.values() if m.loaded])
291
+ logger.success(f"Preloaded {loaded_count}/{len(self.model_configs)} models successfully")
292
+
293
+ def get_model(self, model_id: str) -> RerankerModel:
294
+ """Get a loaded model by ID."""
295
+ if model_id not in self.models:
296
+ raise ValueError(f"Model {model_id} not found")
297
+
298
+ model = self.models[model_id]
299
+ if not model.loaded:
300
+ raise ValueError(f"Model {model_id} not loaded")
301
+
302
+ return model
303
+
304
+ def list_models(self) -> List[Dict[str, Any]]:
305
+ """List all available models with their status."""
306
+ models_info = []
307
+ for model_id, config in self.model_configs.items():
308
+ model = self.models.get(model_id)
309
+ info = {
310
+ "id": model_id,
311
+ "name": config["model_name"],
312
+ "type": config["model_type"],
313
+ "description": config["description"],
314
+ "loaded": model.loaded if model else False
315
+ }
316
+ models_info.append(info)
317
+ return models_info
318
+
319
+
320
+ # -------------------------
321
+ # Application Setup
322
+ # -------------------------
323
+
324
+ model_manager = None
325
+
326
+ @asynccontextmanager
327
+ async def lifespan(app: FastAPI):
328
+ """Application lifespan manager with model preloading."""
329
+ global model_manager
330
+
331
+ # Startup
332
+ logger.info("Starting reranking API...")
333
+ try:
334
+ model_manager = ModelManager()
335
+ await model_manager.preload_all_models()
336
+ logger.success("Reranking API startup complete!")
337
+ except Exception as e:
338
+ logger.error(f"Failed to initialize models: {e}")
339
+ raise
340
+
341
+ yield
342
+
343
+ # Shutdown
344
+ logger.info("Shutting down reranking API...")
345
+
346
+
347
+ app = FastAPI(
348
+ title="Multi-Model Reranking API",
349
+ description="""
350
+ High-performance API for document reranking using multiple state-of-the-art models.
351
+
352
+ βœ… **Supported Models:**
353
+ - **Jina Reranker V2**: Multilingual reranker optimized for search
354
+ - **BGE Reranker V2**: High-performance multilingual reranking
355
+ - **Qwen3 Reranker**: Instruction-based reranking with reasoning
356
+
357
+ πŸš€ **Features:**
358
+ - Multiple reranking models preloaded at startup
359
+ - Batch document reranking with relevance scoring
360
+ - Optional instruction-based reranking (Qwen3)
361
+ - Comprehensive performance metrics
362
+ - Zero cold start delay
363
+
364
+ πŸ“Š **Input/Output:**
365
+ - Input: Query + documents + optional instruction
366
+ - Output: Ranked documents with relevance scores
367
+ """,
368
+ version="1.0.0",
369
+ lifespan=lifespan
370
+ )
371
+
372
+
373
+ # -------------------------
374
+ # API Endpoints
375
+ # -------------------------
376
+
377
+ @app.post("/rerank", response_model=RerankResponse, tags=["Reranking"])
378
+ async def rerank_documents(request: RerankRequest):
379
+ """
380
+ Rerank documents based on relevance to query.
381
+
382
+ This endpoint takes a query and list of documents, then returns them
383
+ ranked by relevance using the specified reranking model.
384
+
385
+ Args:
386
+ request: RerankRequest containing query, documents, and model info
387
+
388
+ Returns:
389
+ RerankResponse with ranked documents, scores, and metadata
390
+
391
+ Example:
392
+ ```json
393
+ {
394
+ "query": "machine learning algorithms",
395
+ "documents": [
396
+ "Deep learning uses neural networks",
397
+ "Weather forecast for tomorrow",
398
+ "Supervised learning with labeled data"
399
+ ],
400
+ "model_id": "jina-reranker-v2"
401
+ }
402
+ ```
403
+ """
404
+ if not request.query.strip():
405
+ raise HTTPException(400, "Query cannot be empty")
406
+
407
+ if not request.documents:
408
+ raise HTTPException(400, "Documents list cannot be empty")
409
+
410
+ # Filter out empty documents
411
+ valid_docs = [(i, doc.strip()) for i, doc in enumerate(request.documents) if doc.strip()]
412
+ if not valid_docs:
413
+ raise HTTPException(400, "No valid documents found after filtering empty strings")
414
+
415
+ try:
416
+ start_time = time.time()
417
+
418
+ # Get model
419
+ model = model_manager.get_model(request.model_id)
420
+
421
+ # Extract valid documents and their indices
422
+ original_indices, documents = zip(*valid_docs)
423
+
424
+ # Perform reranking
425
+ scores = model.rerank(
426
+ query=request.query.strip(),
427
+ documents=list(documents),
428
+ instruction=request.instruction
429
+ )
430
+
431
+ # Create results with original indices
432
+ results = []
433
+ for i, (orig_idx, doc, score) in enumerate(zip(original_indices, documents, scores)):
434
+ results.append(RerankResult(
435
+ text=doc,
436
+ score=score,
437
+ index=orig_idx
438
+ ))
439
+
440
+ # Sort by score (descending)
441
+ results.sort(key=lambda x: x.score, reverse=True)
442
+
443
+ # Apply top_k limit if specified
444
+ if request.top_k:
445
+ results = results[:request.top_k]
446
+
447
+ processing_time = time.time() - start_time
448
+
449
+ logger.info(
450
+ f"Reranked {len(documents)} documents in {processing_time:.3f}s "
451
+ f"using {request.model_id}"
452
+ )
453
+
454
+ return RerankResponse(
455
+ results=results,
456
+ query=request.query.strip(),
457
+ model_id=request.model_id,
458
+ processing_time=processing_time,
459
+ total_documents=len(request.documents),
460
+ returned_documents=len(results)
461
+ )
462
+
463
+ except ValueError as e:
464
+ raise HTTPException(400, str(e))
465
+ except Exception as e:
466
+ logger.error(f"Reranking failed: {e}")
467
+ raise HTTPException(500, f"Reranking failed: {str(e)}")
468
+
469
+
470
+ @app.get("/models", tags=["Models"])
471
+ async def list_models():
472
+ """
473
+ List all available reranking models.
474
+
475
+ Returns information about all configured models including their
476
+ loading status and capabilities.
477
+
478
+ Returns:
479
+ List of model information dictionaries
480
+ """
481
+ try:
482
+ return model_manager.list_models()
483
+ except Exception as e:
484
+ logger.error(f"Failed to list models: {e}")
485
+ raise HTTPException(500, str(e))
486
+
487
+
488
+ @app.get("/health", tags=["Monitoring"])
489
+ async def health_check():
490
+ """
491
+ Check API health and model status.
492
+
493
+ Returns comprehensive health information including model loading
494
+ status and system metrics.
495
+
496
+ Returns:
497
+ Health status dictionary
498
+ """
499
+ try:
500
+ models = model_manager.list_models()
501
+ loaded_models = [m for m in models if m['loaded']]
502
+
503
+ return {
504
+ "status": "ok",
505
+ "total_models": len(models),
506
+ "loaded_models": len(loaded_models),
507
+ "available_models": [m['id'] for m in loaded_models],
508
+ "models_info": models
509
+ }
510
+ except Exception as e:
511
+ logger.error(f"Health check failed: {e}")
512
+ return {
513
+ "status": "error",
514
+ "error": str(e)
515
+ }
516
 
 
 
 
requirements.txt CHANGED
@@ -1,2 +1,6 @@
1
- fastapi
2
- uvicorn[standard]
 
 
 
 
 
1
+ fastapi==0.116.2
2
+ uvicorn[standard]==0.35.0
3
+ torch==2.8.0
4
+ sentence-transformers==5.1.1
5
+ loguru==0.7.3
6
+ einops==0.8.1