import torch
from typing import List, Optional
from loguru import logger
from sentence_transformers import CrossEncoder
from transformers import AutoTokenizer, AutoModelForCausalLM
from .base import RerankerModel
class SentenceTransformersReranker(RerankerModel):
"""
Reranker using sentence-transformers CrossEncoder.
This class leverages the CrossEncoder model from the sentence-transformers library to score the relevance of documents given a query. It is suitable for reranking tasks in information retrieval pipelines.
Attributes:
model_name (str): Name or path of the model to load.
model (CrossEncoder): The loaded CrossEncoder model instance.
loaded (bool): Whether the model has been loaded.
model_id (str): Unique identifier for the model instance.
"""
def load(self):
"""
Load the sentence-transformers CrossEncoder model.
Loads the CrossEncoder model specified by self.model_name. Sets self.loaded to True if successful.
Raises:
Exception: If the model fails to load.
"""
try:
logger.info(f"Loading SentenceTransformers model: {self.model_name}")
self.model = CrossEncoder(
self.model_name,
model_kwargs={"torch_dtype": "auto"},
trust_remote_code=True
)
self.loaded = True
logger.success(f"Successfully loaded {self.model_id}")
except Exception as e:
logger.error(f"Failed to load {self.model_id}: {e}")
raise
def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
"""
Rerank documents using the CrossEncoder model.
Args:
query (str): The search query string.
documents (List[str]): List of documents to be reranked.
instruction (Optional[str]): Additional instruction for reranking (not used in this implementation).
Returns:
List[float]: List of relevance scores for each document.
Raises:
RuntimeError: If the model is not loaded.
Exception: If reranking fails.
"""
if not self.loaded:
raise RuntimeError(f"Model {self.model_id} not loaded")
try:
rankings = self.model.rank(query, documents, convert_to_tensor=True)
scores = [0.0] * len(documents)
for ranking in rankings:
scores[ranking['corpus_id']] = float(ranking['score'])
return scores
except Exception as e:
logger.error(f"Reranking failed with {self.model_id}: {e}")
raise
class QwenReranker(RerankerModel):
"""
Reranker using Qwen3-Reranker model (LLM-based).
This class uses a Qwen LLM to judge the relevance of documents to a query and instruction. The model outputs a probability that each document is relevant ("yes") or not ("no").
Attributes:
model_name (str): Name or path of the Qwen model.
tokenizer (AutoTokenizer): Tokenizer for the Qwen model.
model (AutoModelForCausalLM): Loaded Qwen model instance.
loaded (bool): Whether the model has been loaded.
model_id (str): Unique identifier for the model instance.
token_false_id (int): Token ID for "no".
token_true_id (int): Token ID for "yes".
max_length (int): Maximum input token length.
prefix (str): Prompt prefix for the system message.
suffix (str): Prompt suffix for the assistant message.
prefix_tokens (List[int]): Tokenized prefix.
suffix_tokens (List[int]): Tokenized suffix.
"""
def load(self):
"""
Load the Qwen reranker model and tokenizer, and initialize prompt templates and special tokens.
Raises:
Exception: If the model or tokenizer fails to load.
"""
try:
logger.info(f"Loading Qwen model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
padding_side='left'
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name
).eval()
# Set up Qwen-specific tokens
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
self.max_length = 8192
# Set up prompt templates
self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
self.suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
self.loaded = True
logger.success(f"Successfully loaded {self.model_id}")
except Exception as e:
logger.error(f"Failed to load {self.model_id}: {e}")
raise
def _format_instruction(self, instruction: str, query: str, doc: str) -> str:
"""
Format the instruction string for the Qwen model prompt.
Args:
instruction (str): The instruction for the reranker. If None, a default instruction is used.
query (str): The search query string.
doc (str): The document to be evaluated.
Returns:
str: Formatted prompt string for the model.
"""
if instruction is None:
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
return ": {instruction}\n: {query}\n: {doc}".format(
instruction=instruction, query=query, doc=doc
)
def _process_inputs(self, pairs: List[str]):
"""
Tokenize and prepare input pairs for the Qwen model.
Args:
pairs (List[str]): List of formatted prompt strings for each document.
Returns:
dict: Tokenized and padded input tensors for the model.
"""
inputs = self.tokenizer(
pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False,
max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
inputs = self.tokenizer.pad(
inputs,
padding=True,
return_tensors="pt",
max_length=self.max_length
)
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs
@torch.no_grad()
def _compute_logits(self, inputs):
"""
Compute relevance scores from model logits.
Args:
inputs (dict): Tokenized and padded input tensors for the model.
Returns:
List[float]: List of probabilities that each document is relevant ("yes").
"""
batch_scores = self.model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, self.token_true_id]
false_vector = batch_scores[:, self.token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores
def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]:
"""
Rerank documents using the Qwen model.
Args:
query (str): The search query string.
documents (List[str]): List of documents to be reranked.
instruction (Optional[str]): Additional instruction for reranking.
Returns:
List[float]: List of relevance scores for each document.
Raises:
RuntimeError: If the model is not loaded.
Exception: If reranking fails.
"""
if not self.loaded:
raise RuntimeError(f"Model {self.model_id} not loaded")
try:
pairs = [
self._format_instruction(instruction, query, doc)
for doc in documents
]
inputs = self._process_inputs(pairs)
scores = self._compute_logits(inputs)
return scores
except Exception as e:
logger.error(f"Reranking failed with {self.model_id}: {e}")
raise