|
|
""" |
|
|
DocMind - Retriever Module |
|
|
Semantic search over arXiv papers using FAISS and sentence-transformers |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import faiss |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from typing import List, Dict, Tuple |
|
|
import pickle |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
class PaperRetriever: |
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "sentence-transformers/all-MiniLM-L6-v2", |
|
|
index_path: str = "data/faiss_index" |
|
|
): |
|
|
""" |
|
|
Initialize retriever with embedding model and FAISS index |
|
|
|
|
|
Args: |
|
|
model_name: HuggingFace sentence-transformer model |
|
|
index_path: Directory to save/load FAISS index |
|
|
""" |
|
|
print(f"Loading embedding model: {model_name}") |
|
|
self.model = SentenceTransformer(model_name) |
|
|
self.index_path = Path(index_path) |
|
|
self.index_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.index = None |
|
|
self.papers = [] |
|
|
self.embeddings = None |
|
|
|
|
|
def build_index(self, papers: List[Dict]): |
|
|
""" |
|
|
Build FAISS index from papers |
|
|
|
|
|
Args: |
|
|
papers: List of paper dictionaries with 'title' and 'abstract' |
|
|
""" |
|
|
print(f"Building index for {len(papers)} papers...") |
|
|
self.papers = papers |
|
|
|
|
|
|
|
|
texts = [ |
|
|
f"{paper['title']}. {paper['abstract']}" |
|
|
for paper in papers |
|
|
] |
|
|
|
|
|
|
|
|
print("Generating embeddings...") |
|
|
self.embeddings = self.model.encode( |
|
|
texts, |
|
|
show_progress_bar=True, |
|
|
convert_to_numpy=True |
|
|
) |
|
|
|
|
|
|
|
|
dimension = self.embeddings.shape[1] |
|
|
self.index = faiss.IndexFlatIP(dimension) |
|
|
|
|
|
|
|
|
faiss.normalize_L2(self.embeddings) |
|
|
self.index.add(self.embeddings) |
|
|
|
|
|
print(f"Index built with {self.index.ntotal} papers") |
|
|
|
|
|
def save_index(self, name: str = "papers"): |
|
|
"""Save FAISS index and metadata""" |
|
|
faiss.write_index(self.index, str(self.index_path / f"{name}.index")) |
|
|
|
|
|
with open(self.index_path / f"{name}_papers.pkl", 'wb') as f: |
|
|
pickle.dump(self.papers, f) |
|
|
|
|
|
with open(self.index_path / f"{name}_embeddings.npy", 'wb') as f: |
|
|
np.save(f, self.embeddings) |
|
|
|
|
|
print(f"Saved index to {self.index_path}/{name}.*") |
|
|
|
|
|
def load_index(self, name: str = "papers"): |
|
|
"""Load FAISS index and metadata""" |
|
|
index_file = self.index_path / f"{name}.index" |
|
|
if not index_file.exists(): |
|
|
print(f"No index found at {index_file}") |
|
|
return False |
|
|
|
|
|
self.index = faiss.read_index(str(index_file)) |
|
|
|
|
|
with open(self.index_path / f"{name}_papers.pkl", 'rb') as f: |
|
|
self.papers = pickle.load(f) |
|
|
|
|
|
with open(self.index_path / f"{name}_embeddings.npy", 'rb') as f: |
|
|
self.embeddings = np.load(f) |
|
|
|
|
|
print(f"Loaded index with {len(self.papers)} papers") |
|
|
return True |
|
|
|
|
|
def search( |
|
|
self, |
|
|
query: str, |
|
|
top_k: int = 5 |
|
|
) -> List[Tuple[Dict, float]]: |
|
|
""" |
|
|
Search for relevant papers |
|
|
|
|
|
Args: |
|
|
query: Search query string |
|
|
top_k: Number of results to return |
|
|
|
|
|
Returns: |
|
|
List of (paper_dict, score) tuples |
|
|
""" |
|
|
if self.index is None: |
|
|
raise ValueError("Index not built or loaded") |
|
|
|
|
|
|
|
|
query_embedding = self.model.encode([query], convert_to_numpy=True) |
|
|
faiss.normalize_L2(query_embedding) |
|
|
|
|
|
|
|
|
scores, indices = self.index.search(query_embedding, top_k) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for idx, score in zip(indices[0], scores[0]): |
|
|
paper = self.papers[idx] |
|
|
results.append((paper, float(score))) |
|
|
|
|
|
return results |
|
|
|
|
|
def get_retrieval_context( |
|
|
self, |
|
|
query: str, |
|
|
top_k: int = 5 |
|
|
) -> str: |
|
|
""" |
|
|
Get formatted context string for LLM consumption |
|
|
|
|
|
Args: |
|
|
query: Search query |
|
|
top_k: Number of papers to retrieve |
|
|
|
|
|
Returns: |
|
|
Formatted context string with paper summaries |
|
|
""" |
|
|
results = self.search(query, top_k) |
|
|
|
|
|
context = f"Retrieved {len(results)} relevant papers:\n\n" |
|
|
for i, (paper, score) in enumerate(results, 1): |
|
|
context += f"[{i}] {paper['title']}\n" |
|
|
context += f" Authors: {', '.join(paper['authors'][:3])}" |
|
|
if len(paper['authors']) > 3: |
|
|
context += f" et al." |
|
|
context += f"\n arXiv ID: {paper['arxiv_id']}\n" |
|
|
context += f" Published: {paper['published']}\n" |
|
|
context += f" Relevance: {score:.3f}\n" |
|
|
context += f" Abstract: {paper['abstract']}\n\n" |
|
|
|
|
|
return context |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example: Build and test retriever""" |
|
|
from fetch_arxiv_data import ArxivFetcher |
|
|
|
|
|
|
|
|
fetcher = ArxivFetcher() |
|
|
papers = fetcher.load_papers("arxiv_papers.json") |
|
|
|
|
|
if not papers: |
|
|
print("No papers found. Run fetch_arxiv_data.py first") |
|
|
return |
|
|
|
|
|
|
|
|
retriever = PaperRetriever() |
|
|
retriever.build_index(papers) |
|
|
retriever.save_index() |
|
|
|
|
|
|
|
|
test_queries = [ |
|
|
"diffusion models for image generation", |
|
|
"reinforcement learning from human feedback", |
|
|
"large language model alignment" |
|
|
] |
|
|
|
|
|
for query in test_queries: |
|
|
print(f"\n{'=' * 60}") |
|
|
print(f"Query: {query}") |
|
|
print('=' * 60) |
|
|
|
|
|
results = retriever.search(query, top_k=3) |
|
|
for i, (paper, score) in enumerate(results, 1): |
|
|
print(f"\n[{i}] Score: {score:.3f}") |
|
|
print(f" {paper['title']}") |
|
|
print(f" {paper['arxiv_id']}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |