| import torch | |
| from transformers import AutoModelForMaskedLM, AutoTokenizer | |
| import numpy as np | |
| from pyserini.encode import QueryEncoder | |
| class SpladeQueryEncoder(QueryEncoder): | |
| def __init__(self, model_name_or_path, tokenizer_name=None, device='cpu'): | |
| self.device = device | |
| self.model = AutoModelForMaskedLM.from_pretrained(model_name_or_path) | |
| self.model.to(self.device) | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path) | |
| self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()} | |
| def encode(self, text, max_length=256, **kwargs): | |
| inputs = self.tokenizer([text], max_length=max_length, padding='longest', | |
| truncation=True, add_special_tokens=True, | |
| return_tensors='pt').to(self.device) | |
| input_ids = inputs['input_ids'] | |
| input_attention = inputs['attention_mask'] | |
| batch_logits = self.model(input_ids)['logits'] | |
| batch_aggregated_logits, _ = torch.max(torch.log(1 + torch.relu(batch_logits)) | |
| * input_attention.unsqueeze(-1), dim=1) | |
| batch_aggregated_logits = batch_aggregated_logits.cpu().detach().numpy() | |
| return self._output_to_weight_dicts(batch_aggregated_logits)[0] | |
| def _output_to_weight_dicts(self, batch_aggregated_logits): | |
| to_return = [] | |
| for aggregated_logits in batch_aggregated_logits: | |
| col = np.nonzero(aggregated_logits)[0] | |
| weights = aggregated_logits[col] | |
| d = {self.reverse_voc[k]: float(v) for k, v in zip(list(col), list(weights))} | |
| to_return.append(d) | |
| return to_return | |