File size: 5,997 Bytes
3a5fdfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
"""
DocMind - Retriever Module
Semantic search over arXiv papers using FAISS and sentence-transformers
"""
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Tuple
import pickle
from pathlib import Path
class PaperRetriever:
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
index_path: str = "data/faiss_index"
):
"""
Initialize retriever with embedding model and FAISS index
Args:
model_name: HuggingFace sentence-transformer model
index_path: Directory to save/load FAISS index
"""
print(f"Loading embedding model: {model_name}")
self.model = SentenceTransformer(model_name)
self.index_path = Path(index_path)
self.index_path.mkdir(parents=True, exist_ok=True)
self.index = None
self.papers = []
self.embeddings = None
def build_index(self, papers: List[Dict]):
"""
Build FAISS index from papers
Args:
papers: List of paper dictionaries with 'title' and 'abstract'
"""
print(f"Building index for {len(papers)} papers...")
self.papers = papers
# Create text to embed: title + abstract
texts = [
f"{paper['title']}. {paper['abstract']}"
for paper in papers
]
# Generate embeddings
print("Generating embeddings...")
self.embeddings = self.model.encode(
texts,
show_progress_bar=True,
convert_to_numpy=True
)
# Build FAISS index
dimension = self.embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension) # Inner product (cosine similarity)
# Normalize embeddings for cosine similarity
faiss.normalize_L2(self.embeddings)
self.index.add(self.embeddings)
print(f"Index built with {self.index.ntotal} papers")
def save_index(self, name: str = "papers"):
"""Save FAISS index and metadata"""
faiss.write_index(self.index, str(self.index_path / f"{name}.index"))
with open(self.index_path / f"{name}_papers.pkl", 'wb') as f:
pickle.dump(self.papers, f)
with open(self.index_path / f"{name}_embeddings.npy", 'wb') as f:
np.save(f, self.embeddings)
print(f"Saved index to {self.index_path}/{name}.*")
def load_index(self, name: str = "papers"):
"""Load FAISS index and metadata"""
index_file = self.index_path / f"{name}.index"
if not index_file.exists():
print(f"No index found at {index_file}")
return False
self.index = faiss.read_index(str(index_file))
with open(self.index_path / f"{name}_papers.pkl", 'rb') as f:
self.papers = pickle.load(f)
with open(self.index_path / f"{name}_embeddings.npy", 'rb') as f:
self.embeddings = np.load(f)
print(f"Loaded index with {len(self.papers)} papers")
return True
def search(
self,
query: str,
top_k: int = 5
) -> List[Tuple[Dict, float]]:
"""
Search for relevant papers
Args:
query: Search query string
top_k: Number of results to return
Returns:
List of (paper_dict, score) tuples
"""
if self.index is None:
raise ValueError("Index not built or loaded")
# Embed query
query_embedding = self.model.encode([query], convert_to_numpy=True)
faiss.normalize_L2(query_embedding)
# Search
scores, indices = self.index.search(query_embedding, top_k)
# Return results
results = []
for idx, score in zip(indices[0], scores[0]):
paper = self.papers[idx]
results.append((paper, float(score)))
return results
def get_retrieval_context(
self,
query: str,
top_k: int = 5
) -> str:
"""
Get formatted context string for LLM consumption
Args:
query: Search query
top_k: Number of papers to retrieve
Returns:
Formatted context string with paper summaries
"""
results = self.search(query, top_k)
context = f"Retrieved {len(results)} relevant papers:\n\n"
for i, (paper, score) in enumerate(results, 1):
context += f"[{i}] {paper['title']}\n"
context += f" Authors: {', '.join(paper['authors'][:3])}"
if len(paper['authors']) > 3:
context += f" et al."
context += f"\n arXiv ID: {paper['arxiv_id']}\n"
context += f" Published: {paper['published']}\n"
context += f" Relevance: {score:.3f}\n"
context += f" Abstract: {paper['abstract']}\n\n"
return context
def main():
"""Example: Build and test retriever"""
from fetch_arxiv_data import ArxivFetcher
# Load papers
fetcher = ArxivFetcher()
papers = fetcher.load_papers("arxiv_papers.json")
if not papers:
print("No papers found. Run fetch_arxiv_data.py first")
return
# Build index
retriever = PaperRetriever()
retriever.build_index(papers)
retriever.save_index()
# Test search
test_queries = [
"diffusion models for image generation",
"reinforcement learning from human feedback",
"large language model alignment"
]
for query in test_queries:
print(f"\n{'=' * 60}")
print(f"Query: {query}")
print('=' * 60)
results = retriever.search(query, top_k=3)
for i, (paper, score) in enumerate(results, 1):
print(f"\n[{i}] Score: {score:.3f}")
print(f" {paper['title']}")
print(f" {paper['arxiv_id']}")
if __name__ == "__main__":
main() |