| import argparse | |
| import csv | |
| import json | |
| import os | |
| import time | |
| import pickle | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer | |
| def gtr_build_index(encoder, docs): | |
| with torch.inference_mode(): | |
| embs = encoder.encode(docs, show_progress_bar=True, normalize_embeddings=True) | |
| embs = embs.astype("float16") | |
| GTR_EMB = os.environ.get("GTR_EMB") | |
| with open(GTR_EMB, "wb") as f: | |
| pickle.dump(embs, f) | |
| return embs | |
| class DPRRetriever: | |
| def __init__(self, DPR_WIKI_TSV, GTR_EMB = None, emb_model = "sentence-transformers/gtr-t5-xxl") -> None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.device = device | |
| self.encoder = SentenceTransformer(emb_model, device = device) | |
| self.docs = [] | |
| print("loading wikipedia file...") | |
| with open(DPR_WIKI_TSV) as f: | |
| reader = csv.reader(f, delimiter="\t") | |
| for i, row in enumerate(reader): | |
| if i == 0: | |
| continue | |
| self.docs.append(row[2] + "\n" + row[1]) | |
| if not GTR_EMB: | |
| print("gtr embeddings not found, building...") | |
| embs = gtr_build_index(self.encoder, self.docs) | |
| else: | |
| print("gtr embeddings found, loading...") | |
| with open(GTR_EMB, "rb") as f: | |
| embs = pickle.load(f) | |
| self.gtr_emb = torch.tensor(embs, dtype=torch.float16, device=device) | |
| def retrieve(self, question, topk): | |
| with torch.inference_mode(): | |
| query = self.encoder.encode(question, batch_size=4, show_progress_bar=True, normalize_embeddings=True) | |
| query = torch.tensor(query, dtype=torch.float16, device=self.device) | |
| query = query.to(self.device) | |
| scores = torch.matmul(self.gtr_emb, query) | |
| score, idx = torch.topk(scores, topk) | |
| ret = [] | |
| for i in range(idx.size(0)): | |
| title, text = self.docs[idx[i].item()].split("\n") | |
| ret.append({"id": str(idx[i].item() + 1), "title": title, "text": text, "score": score[i].item()}) | |
| return ret | |
| def __repr__(self) -> str: | |
| return 'DPR Retriever' | |
| def __str__(self) -> str: | |
| return repr(self) | |
| class BM25Retriever: | |
| def __init__(self, DPR_WIKI_TSV): | |
| self.docs = [] | |
| self.tokenized_docs = [] | |
| print("loading wikipedia file...") | |
| with open(DPR_WIKI_TSV) as f: | |
| reader = csv.reader(f, delimiter="\t") | |
| for i, row in enumerate(reader): | |
| if i == 0: | |
| continue | |
| self.docs.append(row[2] + "\n" + row[1]) | |
| self.tokenized_docs.append((row[2] + " " + row[1]).split()) | |
| print("BM25 index building...") | |
| self.bm25 = BM25Okapi(self.tokenized_docs) | |
| def retrieve(self, question, topk): | |
| query = question.split() | |
| scores = self.bm25.get_scores(query) | |
| topk_indices = scores.argsort()[-topk:][::-1] | |
| ret = [] | |
| for idx in topk_indices: | |
| title, text = self.docs[idx].split("\n", 1) | |
| ret.append({"id": str(idx + 1), "title": title, "text": text, "score": scores[idx]}) | |
| return ret | |
| def __repr__(self) -> str: | |
| return 'BM25 Retriever' | |
| def __str__(self) -> str: | |
| return repr(self) | |