SHEN1017's picture
Upload 97 files
96b6673 verified
raw
history blame
3.41 kB
import argparse
import csv
import json
import os
import time
import pickle
import numpy as np
import torch
from tqdm import tqdm
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
def gtr_build_index(encoder, docs):
with torch.inference_mode():
embs = encoder.encode(docs, show_progress_bar=True, normalize_embeddings=True)
embs = embs.astype("float16")
GTR_EMB = os.environ.get("GTR_EMB")
with open(GTR_EMB, "wb") as f:
pickle.dump(embs, f)
return embs
class DPRRetriever:
def __init__(self, DPR_WIKI_TSV, GTR_EMB = None, emb_model = "sentence-transformers/gtr-t5-xxl") -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.encoder = SentenceTransformer(emb_model, device = device)
self.docs = []
print("loading wikipedia file...")
with open(DPR_WIKI_TSV) as f:
reader = csv.reader(f, delimiter="\t")
for i, row in enumerate(reader):
if i == 0:
continue
self.docs.append(row[2] + "\n" + row[1])
if not GTR_EMB:
print("gtr embeddings not found, building...")
embs = gtr_build_index(self.encoder, self.docs)
else:
print("gtr embeddings found, loading...")
with open(GTR_EMB, "rb") as f:
embs = pickle.load(f)
self.gtr_emb = torch.tensor(embs, dtype=torch.float16, device=device)
def retrieve(self, question, topk):
with torch.inference_mode():
query = self.encoder.encode(question, batch_size=4, show_progress_bar=True, normalize_embeddings=True)
query = torch.tensor(query, dtype=torch.float16, device=self.device)
query = query.to(self.device)
scores = torch.matmul(self.gtr_emb, query)
score, idx = torch.topk(scores, topk)
ret = []
for i in range(idx.size(0)):
title, text = self.docs[idx[i].item()].split("\n")
ret.append({"id": str(idx[i].item() + 1), "title": title, "text": text, "score": score[i].item()})
return ret
def __repr__(self) -> str:
return 'DPR Retriever'
def __str__(self) -> str:
return repr(self)
class BM25Retriever:
def __init__(self, DPR_WIKI_TSV):
self.docs = []
self.tokenized_docs = []
print("loading wikipedia file...")
with open(DPR_WIKI_TSV) as f:
reader = csv.reader(f, delimiter="\t")
for i, row in enumerate(reader):
if i == 0:
continue
self.docs.append(row[2] + "\n" + row[1])
self.tokenized_docs.append((row[2] + " " + row[1]).split())
print("BM25 index building...")
self.bm25 = BM25Okapi(self.tokenized_docs)
def retrieve(self, question, topk):
query = question.split()
scores = self.bm25.get_scores(query)
topk_indices = scores.argsort()[-topk:][::-1]
ret = []
for idx in topk_indices:
title, text = self.docs[idx].split("\n", 1)
ret.append({"id": str(idx + 1), "title": title, "text": text, "score": scores[idx]})
return ret
def __repr__(self) -> str:
return 'BM25 Retriever'
def __str__(self) -> str:
return repr(self)