import torch from typing import List, Optional from loguru import logger from sentence_transformers import CrossEncoder from transformers import AutoTokenizer, AutoModelForCausalLM from .base import RerankerModel class SentenceTransformersReranker(RerankerModel): """ Reranker using sentence-transformers CrossEncoder. This class leverages the CrossEncoder model from the sentence-transformers library to score the relevance of documents given a query. It is suitable for reranking tasks in information retrieval pipelines. Attributes: model_name (str): Name or path of the model to load. model (CrossEncoder): The loaded CrossEncoder model instance. loaded (bool): Whether the model has been loaded. model_id (str): Unique identifier for the model instance. """ def load(self): """ Load the sentence-transformers CrossEncoder model. Loads the CrossEncoder model specified by self.model_name. Sets self.loaded to True if successful. Raises: Exception: If the model fails to load. """ try: logger.info(f"Loading SentenceTransformers model: {self.model_name}") self.model = CrossEncoder( self.model_name, model_kwargs={"torch_dtype": "auto"}, trust_remote_code=True ) self.loaded = True logger.success(f"Successfully loaded {self.model_id}") except Exception as e: logger.error(f"Failed to load {self.model_id}: {e}") raise def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]: """ Rerank documents using the CrossEncoder model. Args: query (str): The search query string. documents (List[str]): List of documents to be reranked. instruction (Optional[str]): Additional instruction for reranking (not used in this implementation). Returns: List[float]: List of relevance scores for each document. Raises: RuntimeError: If the model is not loaded. Exception: If reranking fails. """ if not self.loaded: raise RuntimeError(f"Model {self.model_id} not loaded") try: rankings = self.model.rank(query, documents, convert_to_tensor=True) scores = [0.0] * len(documents) for ranking in rankings: scores[ranking['corpus_id']] = float(ranking['score']) return scores except Exception as e: logger.error(f"Reranking failed with {self.model_id}: {e}") raise class QwenReranker(RerankerModel): """ Reranker using Qwen3-Reranker model (LLM-based). This class uses a Qwen LLM to judge the relevance of documents to a query and instruction. The model outputs a probability that each document is relevant ("yes") or not ("no"). Attributes: model_name (str): Name or path of the Qwen model. tokenizer (AutoTokenizer): Tokenizer for the Qwen model. model (AutoModelForCausalLM): Loaded Qwen model instance. loaded (bool): Whether the model has been loaded. model_id (str): Unique identifier for the model instance. token_false_id (int): Token ID for "no". token_true_id (int): Token ID for "yes". max_length (int): Maximum input token length. prefix (str): Prompt prefix for the system message. suffix (str): Prompt suffix for the assistant message. prefix_tokens (List[int]): Tokenized prefix. suffix_tokens (List[int]): Tokenized suffix. """ def load(self): """ Load the Qwen reranker model and tokenizer, and initialize prompt templates and special tokens. Raises: Exception: If the model or tokenizer fails to load. """ try: logger.info(f"Loading Qwen model: {self.model_name}") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, padding_side='left' ) self.model = AutoModelForCausalLM.from_pretrained( self.model_name ).eval() # Set up Qwen-specific tokens self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") self.max_length = 8192 # Set up prompt templates self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" self.suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False) self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False) self.loaded = True logger.success(f"Successfully loaded {self.model_id}") except Exception as e: logger.error(f"Failed to load {self.model_id}: {e}") raise def _format_instruction(self, instruction: str, query: str, doc: str) -> str: """ Format the instruction string for the Qwen model prompt. Args: instruction (str): The instruction for the reranker. If None, a default instruction is used. query (str): The search query string. doc (str): The document to be evaluated. Returns: str: Formatted prompt string for the model. """ if instruction is None: instruction = 'Given a web search query, retrieve relevant passages that answer the query' return ": {instruction}\n: {query}\n: {doc}".format( instruction=instruction, query=query, doc=doc ) def _process_inputs(self, pairs: List[str]): """ Tokenize and prepare input pairs for the Qwen model. Args: pairs (List[str]): List of formatted prompt strings for each document. Returns: dict: Tokenized and padded input tensors for the model. """ inputs = self.tokenizer( pairs, padding=False, truncation='longest_first', return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens) ) for i, ele in enumerate(inputs['input_ids']): inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens inputs = self.tokenizer.pad( inputs, padding=True, return_tensors="pt", max_length=self.max_length ) for key in inputs: inputs[key] = inputs[key].to(self.model.device) return inputs @torch.no_grad() def _compute_logits(self, inputs): """ Compute relevance scores from model logits. Args: inputs (dict): Tokenized and padded input tensors for the model. Returns: List[float]: List of probabilities that each document is relevant ("yes"). """ batch_scores = self.model(**inputs).logits[:, -1, :] true_vector = batch_scores[:, self.token_true_id] false_vector = batch_scores[:, self.token_false_id] batch_scores = torch.stack([false_vector, true_vector], dim=1) batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) scores = batch_scores[:, 1].exp().tolist() return scores def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]: """ Rerank documents using the Qwen model. Args: query (str): The search query string. documents (List[str]): List of documents to be reranked. instruction (Optional[str]): Additional instruction for reranking. Returns: List[float]: List of relevance scores for each document. Raises: RuntimeError: If the model is not loaded. Exception: If reranking fails. """ if not self.loaded: raise RuntimeError(f"Model {self.model_id} not loaded") try: pairs = [ self._format_instruction(instruction, query, doc) for doc in documents ] inputs = self._process_inputs(pairs) scores = self._compute_logits(inputs) return scores except Exception as e: logger.error(f"Reranking failed with {self.model_id}: {e}") raise