File size: 5,997 Bytes
3a5fdfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""
DocMind - Retriever Module
Semantic search over arXiv papers using FAISS and sentence-transformers
"""

import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Tuple
import pickle
from pathlib import Path


class PaperRetriever:
    def __init__(
            self,
            model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
            index_path: str = "data/faiss_index"
    ):
        """
        Initialize retriever with embedding model and FAISS index

        Args:
            model_name: HuggingFace sentence-transformer model
            index_path: Directory to save/load FAISS index
        """
        print(f"Loading embedding model: {model_name}")
        self.model = SentenceTransformer(model_name)
        self.index_path = Path(index_path)
        self.index_path.mkdir(parents=True, exist_ok=True)

        self.index = None
        self.papers = []
        self.embeddings = None

    def build_index(self, papers: List[Dict]):
        """
        Build FAISS index from papers

        Args:
            papers: List of paper dictionaries with 'title' and 'abstract'
        """
        print(f"Building index for {len(papers)} papers...")
        self.papers = papers

        # Create text to embed: title + abstract
        texts = [
            f"{paper['title']}. {paper['abstract']}"
            for paper in papers
        ]

        # Generate embeddings
        print("Generating embeddings...")
        self.embeddings = self.model.encode(
            texts,
            show_progress_bar=True,
            convert_to_numpy=True
        )

        # Build FAISS index
        dimension = self.embeddings.shape[1]
        self.index = faiss.IndexFlatIP(dimension)  # Inner product (cosine similarity)

        # Normalize embeddings for cosine similarity
        faiss.normalize_L2(self.embeddings)
        self.index.add(self.embeddings)

        print(f"Index built with {self.index.ntotal} papers")

    def save_index(self, name: str = "papers"):
        """Save FAISS index and metadata"""
        faiss.write_index(self.index, str(self.index_path / f"{name}.index"))

        with open(self.index_path / f"{name}_papers.pkl", 'wb') as f:
            pickle.dump(self.papers, f)

        with open(self.index_path / f"{name}_embeddings.npy", 'wb') as f:
            np.save(f, self.embeddings)

        print(f"Saved index to {self.index_path}/{name}.*")

    def load_index(self, name: str = "papers"):
        """Load FAISS index and metadata"""
        index_file = self.index_path / f"{name}.index"
        if not index_file.exists():
            print(f"No index found at {index_file}")
            return False

        self.index = faiss.read_index(str(index_file))

        with open(self.index_path / f"{name}_papers.pkl", 'rb') as f:
            self.papers = pickle.load(f)

        with open(self.index_path / f"{name}_embeddings.npy", 'rb') as f:
            self.embeddings = np.load(f)

        print(f"Loaded index with {len(self.papers)} papers")
        return True

    def search(
            self,
            query: str,
            top_k: int = 5
    ) -> List[Tuple[Dict, float]]:
        """
        Search for relevant papers

        Args:
            query: Search query string
            top_k: Number of results to return

        Returns:
            List of (paper_dict, score) tuples
        """
        if self.index is None:
            raise ValueError("Index not built or loaded")

        # Embed query
        query_embedding = self.model.encode([query], convert_to_numpy=True)
        faiss.normalize_L2(query_embedding)

        # Search
        scores, indices = self.index.search(query_embedding, top_k)

        # Return results
        results = []
        for idx, score in zip(indices[0], scores[0]):
            paper = self.papers[idx]
            results.append((paper, float(score)))

        return results

    def get_retrieval_context(
            self,
            query: str,
            top_k: int = 5
    ) -> str:
        """
        Get formatted context string for LLM consumption

        Args:
            query: Search query
            top_k: Number of papers to retrieve

        Returns:
            Formatted context string with paper summaries
        """
        results = self.search(query, top_k)

        context = f"Retrieved {len(results)} relevant papers:\n\n"
        for i, (paper, score) in enumerate(results, 1):
            context += f"[{i}] {paper['title']}\n"
            context += f"    Authors: {', '.join(paper['authors'][:3])}"
            if len(paper['authors']) > 3:
                context += f" et al."
            context += f"\n    arXiv ID: {paper['arxiv_id']}\n"
            context += f"    Published: {paper['published']}\n"
            context += f"    Relevance: {score:.3f}\n"
            context += f"    Abstract: {paper['abstract']}\n\n"

        return context


def main():
    """Example: Build and test retriever"""
    from fetch_arxiv_data import ArxivFetcher

    # Load papers
    fetcher = ArxivFetcher()
    papers = fetcher.load_papers("arxiv_papers.json")

    if not papers:
        print("No papers found. Run fetch_arxiv_data.py first")
        return

    # Build index
    retriever = PaperRetriever()
    retriever.build_index(papers)
    retriever.save_index()

    # Test search
    test_queries = [
        "diffusion models for image generation",
        "reinforcement learning from human feedback",
        "large language model alignment"
    ]

    for query in test_queries:
        print(f"\n{'=' * 60}")
        print(f"Query: {query}")
        print('=' * 60)

        results = retriever.search(query, top_k=3)
        for i, (paper, score) in enumerate(results, 1):
            print(f"\n[{i}] Score: {score:.3f}")
            print(f"    {paper['title']}")
            print(f"    {paper['arxiv_id']}")


if __name__ == "__main__":
    main()