DocMind / retriever.py
Tanaybh's picture
Upload 4 files
3a5fdfb verified
raw
history blame
6 kB
"""
DocMind - Retriever Module
Semantic search over arXiv papers using FAISS and sentence-transformers
"""
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Tuple
import pickle
from pathlib import Path
class PaperRetriever:
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
index_path: str = "data/faiss_index"
):
"""
Initialize retriever with embedding model and FAISS index
Args:
model_name: HuggingFace sentence-transformer model
index_path: Directory to save/load FAISS index
"""
print(f"Loading embedding model: {model_name}")
self.model = SentenceTransformer(model_name)
self.index_path = Path(index_path)
self.index_path.mkdir(parents=True, exist_ok=True)
self.index = None
self.papers = []
self.embeddings = None
def build_index(self, papers: List[Dict]):
"""
Build FAISS index from papers
Args:
papers: List of paper dictionaries with 'title' and 'abstract'
"""
print(f"Building index for {len(papers)} papers...")
self.papers = papers
# Create text to embed: title + abstract
texts = [
f"{paper['title']}. {paper['abstract']}"
for paper in papers
]
# Generate embeddings
print("Generating embeddings...")
self.embeddings = self.model.encode(
texts,
show_progress_bar=True,
convert_to_numpy=True
)
# Build FAISS index
dimension = self.embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension) # Inner product (cosine similarity)
# Normalize embeddings for cosine similarity
faiss.normalize_L2(self.embeddings)
self.index.add(self.embeddings)
print(f"Index built with {self.index.ntotal} papers")
def save_index(self, name: str = "papers"):
"""Save FAISS index and metadata"""
faiss.write_index(self.index, str(self.index_path / f"{name}.index"))
with open(self.index_path / f"{name}_papers.pkl", 'wb') as f:
pickle.dump(self.papers, f)
with open(self.index_path / f"{name}_embeddings.npy", 'wb') as f:
np.save(f, self.embeddings)
print(f"Saved index to {self.index_path}/{name}.*")
def load_index(self, name: str = "papers"):
"""Load FAISS index and metadata"""
index_file = self.index_path / f"{name}.index"
if not index_file.exists():
print(f"No index found at {index_file}")
return False
self.index = faiss.read_index(str(index_file))
with open(self.index_path / f"{name}_papers.pkl", 'rb') as f:
self.papers = pickle.load(f)
with open(self.index_path / f"{name}_embeddings.npy", 'rb') as f:
self.embeddings = np.load(f)
print(f"Loaded index with {len(self.papers)} papers")
return True
def search(
self,
query: str,
top_k: int = 5
) -> List[Tuple[Dict, float]]:
"""
Search for relevant papers
Args:
query: Search query string
top_k: Number of results to return
Returns:
List of (paper_dict, score) tuples
"""
if self.index is None:
raise ValueError("Index not built or loaded")
# Embed query
query_embedding = self.model.encode([query], convert_to_numpy=True)
faiss.normalize_L2(query_embedding)
# Search
scores, indices = self.index.search(query_embedding, top_k)
# Return results
results = []
for idx, score in zip(indices[0], scores[0]):
paper = self.papers[idx]
results.append((paper, float(score)))
return results
def get_retrieval_context(
self,
query: str,
top_k: int = 5
) -> str:
"""
Get formatted context string for LLM consumption
Args:
query: Search query
top_k: Number of papers to retrieve
Returns:
Formatted context string with paper summaries
"""
results = self.search(query, top_k)
context = f"Retrieved {len(results)} relevant papers:\n\n"
for i, (paper, score) in enumerate(results, 1):
context += f"[{i}] {paper['title']}\n"
context += f" Authors: {', '.join(paper['authors'][:3])}"
if len(paper['authors']) > 3:
context += f" et al."
context += f"\n arXiv ID: {paper['arxiv_id']}\n"
context += f" Published: {paper['published']}\n"
context += f" Relevance: {score:.3f}\n"
context += f" Abstract: {paper['abstract']}\n\n"
return context
def main():
"""Example: Build and test retriever"""
from fetch_arxiv_data import ArxivFetcher
# Load papers
fetcher = ArxivFetcher()
papers = fetcher.load_papers("arxiv_papers.json")
if not papers:
print("No papers found. Run fetch_arxiv_data.py first")
return
# Build index
retriever = PaperRetriever()
retriever.build_index(papers)
retriever.save_index()
# Test search
test_queries = [
"diffusion models for image generation",
"reinforcement learning from human feedback",
"large language model alignment"
]
for query in test_queries:
print(f"\n{'=' * 60}")
print(f"Query: {query}")
print('=' * 60)
results = retriever.search(query, top_k=3)
for i, (paper, score) in enumerate(results, 1):
print(f"\n[{i}] Score: {score:.3f}")
print(f" {paper['title']}")
print(f" {paper['arxiv_id']}")
if __name__ == "__main__":
main()