Spaces:
Build error
Build error
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import pickle | |
| import os | |
| from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar | |
| from nlp4web_codebase.ir.data_loaders.dm import Document | |
| from collections import Counter | |
| import tqdm | |
| import re | |
| import nltk | |
| nltk.download("stopwords", quiet=True) | |
| from nltk.corpus import stopwords as nltk_stopwords | |
| from nlp4web_codebase.ir.data_loaders.sciq import load_sciq | |
| import gradio as gr | |
| from typing import TypedDict | |
| from dataclasses import asdict, dataclass | |
| import math | |
| import os | |
| from typing import Iterable, List, Optional, Type | |
| import tqdm | |
| from nlp4web_codebase.ir.data_loaders.dm import Document | |
| from nlp4web_codebase.ir.models import BaseRetriever | |
| from typing import Type | |
| from abc import abstractmethod | |
| from nlp4web_codebase.ir.data_loaders import Split | |
| import pytrec_eval | |
| import numpy as np | |
| from matplotlib import pyplot as plt | |
| from scipy.sparse._csc import csc_matrix | |
| # -*- coding: utf-8 -*- | |
| """Kopie von HW1 (more instructed).ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1vFJ6AROcCYNkZRIxpyHs9T1sf-bdB_jW | |
| """ | |
| """## Pre-requisite code | |
| The code within this section will be used in the tasks. Please do not change these code lines. | |
| ### SciQ loading and counting | |
| """ | |
| LANGUAGE = "english" | |
| word_splitter = re.compile(r"(?u)\b\w\w+\b").findall | |
| stopwords = set(nltk_stopwords.words(LANGUAGE)) | |
| def word_splitting(text: str) -> List[str]: | |
| return word_splitter(text.lower()) | |
| def lemmatization(words: List[str]) -> List[str]: | |
| return words # We ignore lemmatization here for simplicity | |
| def simple_tokenize(text: str) -> List[str]: | |
| words = word_splitting(text) | |
| tokenized = list(filter(lambda w: w not in stopwords, words)) | |
| tokenized = lemmatization(tokenized) | |
| return tokenized | |
| T = TypeVar("T", bound="InvertedIndex") | |
| class PostingList: | |
| term: str # The term | |
| docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting | |
| tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting | |
| class InvertedIndex: | |
| posting_lists: List[PostingList] # docid -> posting_list | |
| vocab: Dict[str, int] | |
| cid2docid: Dict[str, int] # collection_id -> docid | |
| collection_ids: List[str] # docid -> collection_id | |
| doc_texts: Optional[List[str]] = None # docid -> document text | |
| def save(self, output_dir: str) -> None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| with open(os.path.join(output_dir, "index.pkl"), "wb") as f: | |
| pickle.dump(self, f) | |
| def from_saved(cls: Type[T], saved_dir: str) -> T: | |
| index = cls( | |
| posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None | |
| ) | |
| with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: | |
| index = pickle.load(f) | |
| return index | |
| # The output of the counting function: | |
| class Counting: | |
| posting_lists: List[PostingList] | |
| vocab: Dict[str, int] | |
| cid2docid: Dict[str, int] | |
| collection_ids: List[str] | |
| dfs: List[int] # tid -> df | |
| dls: List[int] # docid -> doc length | |
| avgdl: float | |
| nterms: int | |
| doc_texts: Optional[List[str]] = None | |
| def run_counting( | |
| documents: Iterable[Document], | |
| tokenize_fn: Callable[[str], List[str]] = simple_tokenize, | |
| store_raw: bool = True, # store the document text in doc_texts | |
| ndocs: Optional[int] = None, | |
| show_progress_bar: bool = True, | |
| ) -> Counting: | |
| """Counting TFs, DFs, doc_lengths, etc.""" | |
| posting_lists: List[PostingList] = [] | |
| vocab: Dict[str, int] = {} | |
| cid2docid: Dict[str, int] = {} | |
| collection_ids: List[str] = [] | |
| dfs: List[int] = [] # tid -> df | |
| dls: List[int] = [] # docid -> doc length | |
| nterms: int = 0 | |
| doc_texts: Optional[List[str]] = [] | |
| for doc in tqdm.tqdm( | |
| documents, | |
| desc="Counting", | |
| total=ndocs, | |
| disable=not show_progress_bar, | |
| ): | |
| if doc.collection_id in cid2docid: | |
| continue | |
| collection_ids.append(doc.collection_id) | |
| docid = cid2docid.setdefault(doc.collection_id, len(cid2docid)) | |
| toks = tokenize_fn(doc.text) | |
| tok2tf = Counter(toks) | |
| dls.append(sum(tok2tf.values())) | |
| for tok, tf in tok2tf.items(): | |
| nterms += tf | |
| tid = vocab.get(tok, None) | |
| if tid is None: | |
| posting_lists.append( | |
| PostingList(term=tok, docid_postings=[], tweight_postings=[]) | |
| ) | |
| tid = vocab.setdefault(tok, len(vocab)) | |
| posting_lists[tid].docid_postings.append(docid) | |
| posting_lists[tid].tweight_postings.append(tf) | |
| if tid < len(dfs): | |
| dfs[tid] += 1 | |
| else: | |
| dfs.append(0) | |
| if store_raw: | |
| doc_texts.append(doc.text) | |
| else: | |
| doc_texts = None | |
| return Counting( | |
| posting_lists=posting_lists, | |
| vocab=vocab, | |
| cid2docid=cid2docid, | |
| collection_ids=collection_ids, | |
| dfs=dfs, | |
| dls=dls, | |
| avgdl=sum(dls) / len(dls), | |
| nterms=nterms, | |
| doc_texts=doc_texts, | |
| ) | |
| sciq = load_sciq() | |
| counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) | |
| """### BM25 Index""" | |
| class BM25Index(InvertedIndex): | |
| def tokenize(text: str) -> List[str]: | |
| return simple_tokenize(text) | |
| def cache_term_weights( | |
| posting_lists: List[PostingList], | |
| total_docs: int, | |
| avgdl: float, | |
| dfs: List[int], | |
| dls: List[int], | |
| k1: float, | |
| b: float, | |
| ) -> None: | |
| """Compute term weights and caching""" | |
| N = total_docs | |
| for tid, posting_list in enumerate( | |
| tqdm.tqdm(posting_lists, desc="Regularizing TFs") | |
| ): | |
| idf = BM25Index.calc_idf(df=dfs[tid], N=N) | |
| for i in range(len(posting_list.docid_postings)): | |
| docid = posting_list.docid_postings[i] | |
| tf = posting_list.tweight_postings[i] | |
| dl = dls[docid] | |
| regularized_tf = BM25Index.calc_regularized_tf( | |
| tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b | |
| ) | |
| posting_list.tweight_postings[i] = regularized_tf * idf | |
| def calc_regularized_tf( | |
| tf: int, dl: float, avgdl: float, k1: float, b: float | |
| ) -> float: | |
| return tf / (tf + k1 * (1 - b + b * dl / avgdl)) | |
| def calc_idf(df: int, N: int): | |
| return math.log(1 + (N - df + 0.5) / (df + 0.5)) | |
| def build_from_documents( | |
| cls: Type[BM25Index], | |
| documents: Iterable[Document], | |
| store_raw: bool = True, | |
| output_dir: Optional[str] = None, | |
| ndocs: Optional[int] = None, | |
| show_progress_bar: bool = True, | |
| k1: float = 0.9, | |
| b: float = 0.4, | |
| ) -> BM25Index: | |
| # Counting TFs, DFs, doc_lengths, etc.: | |
| counting = run_counting( | |
| documents=documents, | |
| tokenize_fn=BM25Index.tokenize, | |
| store_raw=store_raw, | |
| ndocs=ndocs, | |
| show_progress_bar=show_progress_bar, | |
| ) | |
| # Compute term weights and caching: | |
| posting_lists = counting.posting_lists | |
| total_docs = len(counting.cid2docid) | |
| BM25Index.cache_term_weights( | |
| posting_lists=posting_lists, | |
| total_docs=total_docs, | |
| avgdl=counting.avgdl, | |
| dfs=counting.dfs, | |
| dls=counting.dls, | |
| k1=k1, | |
| b=b, | |
| ) | |
| # Assembly and save: | |
| index = BM25Index( | |
| posting_lists=posting_lists, | |
| vocab=counting.vocab, | |
| cid2docid=counting.cid2docid, | |
| collection_ids=counting.collection_ids, | |
| doc_texts=counting.doc_texts, | |
| ) | |
| return index | |
| bm25_index = BM25Index.build_from_documents( | |
| documents=iter(sciq.corpus), | |
| ndocs=12160, | |
| show_progress_bar=True, | |
| ) | |
| bm25_index.save("output/bm25_index") | |
| """### BM25 Retriever""" | |
| class BaseInvertedIndexRetriever(BaseRetriever): | |
| def index_class(self) -> Type[InvertedIndex]: | |
| pass | |
| def __init__(self, index_dir: str) -> None: | |
| self.index = self.index_class.from_saved(index_dir) | |
| def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: | |
| toks = self.index.tokenize(query) | |
| target_docid = self.index.cid2docid[cid] | |
| term_weights = {} | |
| for tok in toks: | |
| if tok not in self.index.vocab: | |
| continue | |
| tid = self.index.vocab[tok] | |
| posting_list = self.index.posting_lists[tid] | |
| for docid, tweight in zip( | |
| posting_list.docid_postings, posting_list.tweight_postings | |
| ): | |
| if docid == target_docid: | |
| term_weights[tok] = tweight | |
| break | |
| return term_weights | |
| def score(self, query: str, cid: str) -> float: | |
| return sum(self.get_term_weights(query=query, cid=cid).values()) | |
| def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: | |
| toks = self.index.tokenize(query) | |
| docid2score: Dict[int, float] = {} | |
| for tok in toks: | |
| if tok not in self.index.vocab: | |
| continue | |
| tid = self.index.vocab[tok] | |
| posting_list = self.index.posting_lists[tid] | |
| for docid, tweight in zip( | |
| posting_list.docid_postings, posting_list.tweight_postings | |
| ): | |
| docid2score.setdefault(docid, 0) | |
| docid2score[docid] += tweight | |
| docid2score = dict( | |
| sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] | |
| ) | |
| return { | |
| self.index.collection_ids[docid]: score | |
| for docid, score in docid2score.items() | |
| } | |
| class BM25Retriever(BaseInvertedIndexRetriever): | |
| def index_class(self) -> Type[BM25Index]: | |
| return BM25Index | |
| bm25_retriever = BM25Retriever(index_dir="output/bm25_index") | |
| bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?") | |
| """# TASK1: tune b and k1 (4 points) | |
| Tune b and k1 on the **dev** split of SciQ using the metric MAP@10. The evaluation function (`evalaute_map`) is provided. Record the values in `plots_k1` and `plots_b`. Do it in a greedy manner: as the influence from b is larger, please first tune b (with k1 fixed to the default value 0.9) and use the best value of b to further tune k1. | |
| $${\displaystyle {\text{score}}(D,Q)=\sum _{i=1}^{n}{\text{IDF}}(q_{i})\cdot {\frac {f(q_{i},D)\cdot (k_{1}+1)}{f(q_{i},D)+k_{1}\cdot \left(1-b+b\cdot {\frac {|D|}{\text{avgdl}}}\right)}}}$$ | |
| """ | |
| def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float: | |
| metric = "map_cut_10" | |
| qrels = sciq.get_qrels_dict(split) | |
| evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,)) | |
| qps = evaluator.evaluate(rankings) | |
| return float(np.mean([qp[metric] for qp in qps.values()])) | |
| """Example of using the pre-requisite code:""" | |
| # Loading dataset: | |
| sciq = load_sciq() | |
| counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) | |
| # Building BM25 index and save: | |
| bm25_index = BM25Index.build_from_documents( | |
| documents=iter(sciq.corpus), | |
| ndocs=12160, | |
| show_progress_bar=True | |
| ) | |
| bm25_index.save("output/bm25_index") | |
| # Loading index and use BM25 retriever to retrieve: | |
| bm25_retriever = BM25Retriever(index_dir="output/bm25_index") | |
| print(bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")) # the ranking | |
| plots_b: Dict[str, List[float]] = { | |
| "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], | |
| "Y": [] | |
| } | |
| plots_k1: Dict[str, List[float]] = { | |
| "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], | |
| "Y": [] | |
| } | |
| ## YOUR_CODE_STARTS_HERE | |
| # Two steps should be involved: | |
| # Step 1. Fix k1 value to the default one 0.9, | |
| # go through all the candidate b values (0, 0.1, ..., 1.0), | |
| # and record in plots_b["Y"] the corresponding performances obtained via evaluate_map; | |
| # Step 2. Fix b to the best one in step 1. and do the same for k1. | |
| # Hint (on using the pre-requisite code): | |
| # - One can use the loaded sciq dataset directly (loaded in the pre-requisite code); | |
| # - One can build bm25_index with `BM25Index.build_from_documents`; | |
| # - One can use BM25Retriever to load the index and perform retrieval on the dev queries | |
| # (dev queries can be obtained via sciq.get_split_queries(Split.dev)) | |
| sciq = load_sciq() | |
| counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) | |
| #Tuning b | |
| fixed_k1 = 0.9 | |
| for b in plots_b["X"]: | |
| print(b) | |
| bm25_index = BM25Index.build_from_documents( | |
| documents=iter(sciq.corpus), | |
| ndocs=len(sciq.corpus), | |
| k1=fixed_k1, | |
| b=b, | |
| show_progress_bar=True, | |
| ) | |
| bm25_index.save("output/bm25_index") | |
| bm25_retriever = BM25Retriever(index_dir="output/bm25_index") | |
| dev_queries = {query.query_id: query.text for query in sciq.get_split_queries(Split.dev)} | |
| rankings : Dict[str, Dict[str, float]] = {qid: bm25_retriever.retrieve(query) for qid, query in dev_queries.items()} | |
| score = evaluate_map(rankings, split=Split.dev) | |
| plots_b["Y"].append(score) | |
| print(f"appended {score} to the plots_b list") | |
| tuned_b = plots_b["X"][np.argmax(plots_b["Y"])] | |
| print(f"The best value for b is: {tuned_b}") | |
| for k1 in plots_k1["X"]: | |
| print(k1) | |
| bm25_index = BM25Index.build_from_documents( | |
| documents=iter(sciq.corpus), | |
| ndocs=len(sciq.corpus), | |
| k1=k1, | |
| b=tuned_b, | |
| show_progress_bar=True, | |
| ) | |
| bm25_index.save("output/bm25_index") | |
| bm25_retriever = BM25Retriever(index_dir="output/bm25_index") | |
| dev_queries = {query.query_id: query.text for query in sciq.get_split_queries(Split.dev)} | |
| rankings : Dict[str, Dict[str, float]] = {qid: bm25_retriever.retrieve(query) for qid, query in dev_queries.items()} | |
| score = evaluate_map(rankings, split=Split.dev) | |
| plots_k1["Y"].append(score) | |
| print(f"appended {score} to the plots_k1 list") | |
| tuned_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])] | |
| print(f"The best value for k1 is: {tuned_k1}") | |
| ## YOU_CODE_ENDS_HERE | |
| ## TEST_CASES (should be close to 0.8135637188208616 and 0.7512916099773244) | |
| print(plots_k1["Y"][9]) | |
| print(plots_b["Y"][1]) | |
| ## RESULT_CHECKING_POINT | |
| print(plots_k1) | |
| print(plots_b) | |
| plt.plot(plots_b["X"], plots_b["Y"], label="b") | |
| plt.plot(plots_k1["X"], plots_k1["Y"], label="k1") | |
| plt.ylabel("MAP") | |
| plt.legend() | |
| plt.grid() | |
| plt.show() | |
| """Let's check the effectiveness gain on test after this tuning on dev""" | |
| default_map = 0.7849 | |
| best_b = plots_b["X"][np.argmax(plots_b["Y"])] | |
| best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])] | |
| bm25_index = BM25Index.build_from_documents( | |
| documents=iter(sciq.corpus), | |
| ndocs=12160, | |
| show_progress_bar=True, | |
| k1=best_k1, | |
| b=best_b | |
| ) | |
| bm25_index.save("output/bm25_index") | |
| bm25_retriever = BM25Retriever(index_dir="output/bm25_index") | |
| rankings = {} | |
| for query in sciq.get_split_queries(Split.test): # note this is now on test | |
| ranking = bm25_retriever.retrieve(query=query.text) | |
| rankings[query.query_id] = ranking | |
| optimized_map = evaluate_map(rankings, split=Split.test) # note this is now on test | |
| print(default_map, optimized_map) | |
| """# TASK2: CSC matrix and `CSCBM25Index` (12 points) | |
| Recall that we use Python lists to implement posting lists, mapping term IDs to the documents in which they appear. This is inefficient due to its naive design. Actually [Compressed Sparse Column matrix](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html) is very suitable for storing the posting lists and can boost the efficiency. | |
| ## TASK2.1: learn about `scipy.sparse.csc_matrix` (2 point) | |
| Convert the matrix \begin{bmatrix} | |
| 0 & 1 & 0 & 3 \\ | |
| 10 & 2 & 1 & 0 \\ | |
| 0 & 0 & 0 & 9 | |
| \end{bmatrix} to a `csc_matrix` by specifying `data`, `indices`, `indptr` and `shape`. | |
| """ | |
| input_matrix = [[0, 1, 0, 3], [10, 2, 1, 0], [0, 0, 0, 9]] | |
| data = None | |
| indices = None | |
| indptr = None | |
| shape = None | |
| ## YOUR_CODE_STARTS_HERE | |
| # Please assign the values to data, indices, indptr and shape | |
| # One can just do it in a hard-coded manner | |
| data = [10,1,2,1,3,9] | |
| indices = [1,0,1,1,0,2] | |
| indptr = [0,1,3,4,6] | |
| shape = (3,4) | |
| ## YOUR_CODE_ENDS_HERE | |
| output_matrix = csc_matrix((data, indices, indptr), shape=shape) | |
| ## TEST_CASES (should be 3 and 11) | |
| print((output_matrix.indices + output_matrix.data).tolist()[2]) | |
| print((output_matrix.indices + output_matrix.data).tolist()[-1]) | |
| ## RESULT_CHECKING_POINT | |
| print((output_matrix.indices + output_matrix.data).tolist()) | |
| """## TASK2.2: implement `CSCBM25Index` (4 points) | |
| Implement `CSCBM25Index` by completing the missing code. Note that `CSCInvertedIndex` is similar to `InvertedIndex` which we talked about during the class. The main difference is posting lists are represented by a CSC sparse matrix. | |
| """ | |
| class CSCInvertedIndex: | |
| posting_lists_matrix: csc_matrix # docid -> posting_list | |
| vocab: Dict[str, int] | |
| cid2docid: Dict[str, int] # collection_id -> docid | |
| collection_ids: List[str] # docid -> collection_id | |
| doc_texts: Optional[List[str]] = None # docid -> document text | |
| def save(self, output_dir: str) -> None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| with open(os.path.join(output_dir, "index.pkl"), "wb") as f: | |
| pickle.dump(self, f) | |
| def from_saved(cls: Type[T], saved_dir: str) -> T: | |
| index = cls( | |
| posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None | |
| ) | |
| with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: | |
| index = pickle.load(f) | |
| return index | |
| class CSCBM25Index(CSCInvertedIndex): | |
| def tokenize(text: str) -> List[str]: | |
| return simple_tokenize(text) | |
| def cache_term_weights( | |
| posting_lists: List[PostingList], | |
| total_docs: int, | |
| avgdl: float, | |
| dfs: List[int], | |
| dls: List[int], | |
| k1: float, | |
| b: float, | |
| ) -> csc_matrix: | |
| """Compute term weights and caching""" | |
| ## YOUR_CODE_STARTS_HERE | |
| data = [] | |
| indices = [] | |
| indptr = [0] | |
| N = total_docs | |
| for tid, posting_list in enumerate( | |
| tqdm.tqdm(posting_lists, desc="Regularizing TFs") | |
| ): | |
| idf = BM25Index.calc_idf(df=dfs[tid], N=N) | |
| counter = 0 | |
| for i in range(len(posting_list.docid_postings)): | |
| docid = posting_list.docid_postings[i] | |
| tf = posting_list.tweight_postings[i] | |
| dl = dls[docid] | |
| regularized_tf = CSCBM25Index.calc_regularized_tf( | |
| tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b | |
| ) | |
| weight = regularized_tf * idf | |
| data.append(np.float32(weight)) | |
| indices.append(np.int32(docid)) | |
| counter += 1 | |
| indptr.append(indptr[-1] + counter) | |
| posting_lists_matrix = csc_matrix((np.array(data, dtype=np.float32), np.array(indices, dtype=np.int32), np.array(indptr))) | |
| return posting_lists_matrix | |
| ## YOUR_CODE_ENDS_HERE | |
| def calc_regularized_tf( | |
| tf: int, dl: float, avgdl: float, k1: float, b: float | |
| ) -> float: | |
| return tf / (tf + k1 * (1 - b + b * dl / avgdl)) | |
| def calc_idf(df: int, N: int): | |
| return math.log(1 + (N - df + 0.5) / (df + 0.5)) | |
| def build_from_documents( | |
| cls: Type[CSCBM25Index], | |
| documents: Iterable[Document], | |
| store_raw: bool = True, | |
| output_dir: Optional[str] = None, | |
| ndocs: Optional[int] = None, | |
| show_progress_bar: bool = True, | |
| k1: float = 0.9, | |
| b: float = 0.4, | |
| ) -> CSCBM25Index: | |
| # Counting TFs, DFs, doc_lengths, etc.: | |
| counting = run_counting( | |
| documents=documents, | |
| tokenize_fn=CSCBM25Index.tokenize, | |
| store_raw=store_raw, | |
| ndocs=ndocs, | |
| show_progress_bar=show_progress_bar, | |
| ) | |
| # Compute term weights and caching: | |
| posting_lists = counting.posting_lists | |
| total_docs = len(counting.cid2docid) | |
| posting_lists_matrix = CSCBM25Index.cache_term_weights( | |
| posting_lists=posting_lists, | |
| total_docs=total_docs, | |
| avgdl=counting.avgdl, | |
| dfs=counting.dfs, | |
| dls=counting.dls, | |
| k1=k1, | |
| b=b, | |
| ) | |
| # Assembly and save: | |
| index = CSCBM25Index( | |
| posting_lists_matrix=posting_lists_matrix, | |
| vocab=counting.vocab, | |
| cid2docid=counting.cid2docid, | |
| collection_ids=counting.collection_ids, | |
| doc_texts=counting.doc_texts, | |
| ) | |
| return index | |
| csc_bm25_index = CSCBM25Index.build_from_documents( | |
| documents=iter(sciq.corpus), | |
| ndocs=12160, | |
| show_progress_bar=True, | |
| k1=best_k1, | |
| b=best_b | |
| ) | |
| csc_bm25_index.save("output/csc_bm25_index") | |
| ## TEST_CASES (should be 7 and 95) | |
| print(len(str(os.path.getsize("output/csc_bm25_index/index.pkl")))) | |
| print(os.path.getsize("output/csc_bm25_index/index.pkl") // int(1e5)) | |
| ## RESULT_CHECKING_POINT | |
| print(os.path.getsize("output/csc_bm25_index/index.pkl")) | |
| """We can compare the size of the CSC-based index with the Python-list-based index:""" | |
| print(os.path.getsize("output/bm25_index/index.pkl")) | |
| """## TASK2.3: implement `CSCInvertedIndexRetriever` (6 points) | |
| Implement `CSCInvertedIndexRetriever` by completing the missing code. | |
| """ | |
| class BaseCSCInvertedIndexRetriever(BaseRetriever): | |
| def index_class(self) -> Type[CSCInvertedIndex]: | |
| pass | |
| def __init__(self, index_dir: str) -> None: | |
| self.index = self.index_class.from_saved(index_dir) | |
| def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: | |
| ## YOUR_CODE_STARTS_HERE | |
| toks = self.index.tokenize(query) | |
| target_docid = self.index.cid2docid[cid] | |
| term_weights = {} | |
| for tok in toks: | |
| if tok not in self.index.vocab: | |
| continue | |
| tid = self.index.vocab[tok] | |
| posting_list = self.index.posting_lists_matrix[:,tid].toarray().flatten() | |
| if posting_list[target_docid] > 0: | |
| term_weights[tok] = posting_list[target_docid] | |
| return term_weights | |
| ## YOUR_CODE_ENDS_HERE | |
| def score(self, query: str, cid: str) -> float: | |
| return sum(self.get_term_weights(query=query, cid=cid).values()) | |
| def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: | |
| ## YOUR_CODE_STARTS_HERE | |
| toks = self.index.tokenize(query) | |
| docid2score: Dict[int, float] = {} | |
| for tok in toks: | |
| if tok not in self.index.vocab: | |
| continue | |
| tid = self.index.vocab[tok] | |
| posting_list = self.index.posting_lists_matrix[:, tid].toarray().flatten() | |
| for docid, tweight in enumerate(posting_list): | |
| if tweight > 0: | |
| docid2score.setdefault(docid, 0) | |
| docid2score[docid] += tweight | |
| docid2score = dict( | |
| sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] | |
| ) | |
| return { | |
| self.index.collection_ids[docid]: score | |
| for docid, score in docid2score.items() | |
| } | |
| ## YOUR_CODE_ENDS_HERE | |
| class CSCBM25Retriever(BaseCSCInvertedIndexRetriever): | |
| def index_class(self) -> Type[CSCBM25Index]: | |
| return CSCBM25Index | |
| ## TEST_CASES (should be close to | |
| # {'theory': 3.1838157176971436, 'evolution': 3.488086223602295, 'natural': 2.629807710647583, 'selection': 3.552377462387085} | |
| # {'train-11632': 16.241527557373047, 'train-10931': 13.352127075195312, 'train-2006': 12.854086875915527, 'train-7040': 12.690572738647461, 'train-1719': 11.01913833618164, 'train-9875': 10.886155128479004, 'train-1971': 10.796306610107422, 'train-9882': 10.535819053649902, 'train-2018': 10.481085777282715, 'test-586': 10.478515625} | |
| #) | |
| csc_bm25_retriever = CSCBM25Retriever(index_dir="output/csc_bm25_index") | |
| query = "Who proposed the theory of evolution by natural selection?" | |
| print(csc_bm25_retriever.get_term_weights(query=query, cid="train-2006")) | |
| print(csc_bm25_retriever.retrieve(query)) | |
| ## RESULT_CHECKING_POINT | |
| csc_bm25_retriever = CSCBM25Retriever(index_dir="output/csc_bm25_index") | |
| query = "What are the differences between immunodeficiency and autoimmune diseases?" | |
| print(csc_bm25_retriever.get_term_weights(query=query, cid="train-1691")) | |
| print(csc_bm25_retriever.retrieve("What are the differences between immunodeficiency and autoimmune diseases?")) | |
| """# TASK3: a search-engine demo based on Huggingface space (4 points) | |
| ## TASK3.1: create the gradio app (2 point) | |
| Create a gradio app to demo the BM25 search engine index on SciQ. The app should have a single input variable for the query (of type `str`) and a single output variable for the returned ranking (of type `List[Hit]` in the code below). Please use the BM25 system with default k1 and b values. | |
| Hint: it should use a "search" function of signature: | |
| ```python | |
| def search(query: str) -> List[Hit]: | |
| ... | |
| ``` | |
| """ | |
| class Hit(TypedDict): | |
| cid: str | |
| score: float | |
| text: str | |
| demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable | |
| return_type = List[Hit] | |
| ## YOUR_CODE_STARTS_HERE | |
| def search(query: str) -> List[Hit]: | |
| bm25_index = BM25Index.build_from_documents( | |
| documents=iter(sciq.corpus), | |
| ndocs=len(sciq.corpus), | |
| show_progress_bar=True,) | |
| bm25_index.save("output/bm25_index") | |
| bm25_retriever = BM25Retriever(index_dir="output/bm25_index") | |
| ranking = bm25_retriever.retrieve(query) | |
| hits: List[Hit] = [ | |
| {"cid": cid, "score": score, "text": bm25_index.doc_texts[bm25_index.cid2docid[cid]]} | |
| for cid, score in ranking.items() | |
| ] | |
| return hits | |
| demo: Optional[gr.Interface] = gr.Interface( | |
| fn=search, | |
| inputs="text", | |
| outputs=gr.Textbox(), | |
| title="Search-engine demo", | |
| description="Please enter your search query", | |
| ) | |
| ## YOUR_CODE_ENDS_HERE | |
| demo.launch() | |