DocMind / agents.py
Tanaybh's picture
Upload 4 files
3a5fdfb verified
raw
history blame
10.6 kB
"""
DocMind - Multi-Agent System
Implements Retriever, Reader, Critic, and Synthesizer agents
"""
from typing import List, Dict, Tuple
from retriever import PaperRetriever
import os
class RetrieverAgent:
"""Agent responsible for finding relevant papers"""
def __init__(self, retriever: PaperRetriever):
self.retriever = retriever
def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[Dict, float]]:
"""
Retrieve relevant papers for the query
Returns:
List of (paper, relevance_score) tuples
"""
print(f"πŸ” Retriever Agent: Searching for '{query}'...")
results = self.retriever.search(query, top_k)
print(f" Found {len(results)} relevant papers")
return results
class ReaderAgent:
"""Agent responsible for reading and summarizing papers"""
def __init__(self, llm_client=None):
"""
Args:
llm_client: Optional LLM client (OpenAI, Anthropic, etc.)
If None, uses rule-based summarization
"""
self.llm_client = llm_client
def summarize_paper(self, paper: Dict) -> str:
"""
Generate a summary of a single paper
Args:
paper: Paper dictionary with title, abstract, etc.
Returns:
Summary string
"""
if self.llm_client:
return self._llm_summarize(paper)
else:
return self._rule_based_summarize(paper)
def _rule_based_summarize(self, paper: Dict) -> str:
"""Simple extractive summary (first 3 sentences)"""
abstract = paper['abstract']
sentences = abstract.split('. ')
summary = '. '.join(sentences[:3]) + '.'
return {
'title': paper['title'],
'arxiv_id': paper['arxiv_id'],
'authors': paper['authors'][:3],
'summary': summary,
'year': paper['published'][:4]
}
def _llm_summarize(self, paper: Dict) -> str:
"""Use LLM to generate intelligent summary"""
prompt = f"""Summarize this research paper in 2-3 sentences, focusing on:
1. The main contribution/idea
2. The key methodology or approach
3. Important results or implications
Title: {paper['title']}
Abstract: {paper['abstract']}
Summary:"""
# Call LLM (implementation depends on client)
# This is a placeholder - replace with actual LLM call
response = "LLM summary would go here"
return {
'title': paper['title'],
'arxiv_id': paper['arxiv_id'],
'authors': paper['authors'][:3],
'summary': response,
'year': paper['published'][:4]
}
def read_papers(self, papers: List[Tuple[Dict, float]]) -> List[Dict]:
"""
Read and summarize multiple papers
Args:
papers: List of (paper, score) tuples from retriever
Returns:
List of summaries
"""
print(f"πŸ“– Reader Agent: Reading {len(papers)} papers...")
summaries = []
for paper, score in papers:
summary = self.summarize_paper(paper)
summary['relevance_score'] = score
summaries.append(summary)
print(f" Generated {len(summaries)} summaries")
return summaries
class CriticAgent:
"""Agent responsible for evaluating and filtering summaries"""
def __init__(self, llm_client=None):
self.llm_client = llm_client
def critique(self, summaries: List[Dict], query: str) -> List[Dict]:
"""
Evaluate summaries for quality and relevance
Args:
summaries: List of paper summaries
query: Original user query
Returns:
Filtered and scored summaries
"""
print(f"πŸ”Ž Critic Agent: Evaluating {len(summaries)} summaries...")
# Simple rule-based filtering
filtered = []
for summary in summaries:
# Check relevance score threshold
if summary['relevance_score'] > 0.3:
# Add quality score (can be enhanced with LLM)
summary['quality_score'] = self._assess_quality(summary, query)
filtered.append(summary)
# Sort by combined score
filtered.sort(
key=lambda x: x['relevance_score'] * 0.7 + x['quality_score'] * 0.3,
reverse=True
)
print(f" Retained {len(filtered)} high-quality summaries")
return filtered
def _assess_quality(self, summary: Dict, query: str) -> float:
"""
Simple quality assessment (can be enhanced with LLM)
Returns:
Quality score 0-1
"""
score = 0.5 # Base score
# Longer summaries might be more informative
if len(summary['summary']) > 100:
score += 0.2
# Recent papers get bonus
if int(summary['year']) >= 2024:
score += 0.3
return min(score, 1.0)
class SynthesizerAgent:
"""Agent responsible for synthesizing final answer"""
def __init__(self, llm_client=None):
self.llm_client = llm_client
def synthesize(
self,
summaries: List[Dict],
query: str,
max_papers: int = 10
) -> str:
"""
Synthesize final answer from summaries
Args:
summaries: List of filtered, quality summaries
query: Original user query
max_papers: Maximum papers to include in response
Returns:
Final synthesized response with citations
"""
print(f"✨ Synthesizer Agent: Creating final response...")
if not summaries:
return "No relevant papers found for your query."
# Limit to top papers
top_summaries = summaries[:max_papers]
if self.llm_client:
return self._llm_synthesize(top_summaries, query)
else:
return self._rule_based_synthesize(top_summaries, query)
def _rule_based_synthesize(self, summaries: List[Dict], query: str) -> str:
"""Create structured response without LLM"""
response = f"# Research Summary: {query}\n\n"
response += f"Based on {len(summaries)} relevant papers from arXiv:\n\n"
for i, summary in enumerate(summaries, 1):
response += f"## [{i}] {summary['title']}\n"
response += f"**Authors:** {', '.join(summary['authors'])}"
if len(summary['authors']) >= 3:
response += " et al."
response += f"\n**Year:** {summary['year']}\n"
response += f"**arXiv ID:** {summary['arxiv_id']}\n"
response += f"**Relevance:** {summary['relevance_score']:.2f}\n\n"
response += f"{summary['summary']}\n\n"
response += "---\n\n"
return response
def _llm_synthesize(self, summaries: List[Dict], query: str) -> str:
"""Use LLM to create coherent synthesis"""
# Build context from summaries
context = ""
for i, summary in enumerate(summaries, 1):
context += f"[{i}] {summary['title']} ({summary['arxiv_id']})\n"
context += f" {summary['summary']}\n\n"
prompt = f"""You are a research assistant. Based on the following papers, answer this question:
Question: {query}
Papers:
{context}
Provide a comprehensive answer that:
1. Directly addresses the question
2. Synthesizes information across papers
3. Cites papers by number [1], [2], etc.
4. Highlights key findings and consensus/disagreements
5. Is concise but thorough (3-5 paragraphs)
Answer:"""
# Placeholder for LLM call
response = "LLM-generated synthesis would go here with citations"
# Append paper references
response += "\n\n## References\n"
for i, summary in enumerate(summaries, 1):
response += f"[{i}] {summary['title']} "
response += f"({summary['arxiv_id']}, {summary['year']})\n"
return response
class DocMindOrchestrator:
"""Main orchestrator that coordinates all agents"""
def __init__(
self,
retriever: PaperRetriever,
llm_client=None
):
self.retriever_agent = RetrieverAgent(retriever)
self.reader_agent = ReaderAgent(llm_client)
self.critic_agent = CriticAgent(llm_client)
self.synthesizer_agent = SynthesizerAgent(llm_client)
def process_query(
self,
query: str,
top_k: int = 10,
max_papers_in_response: int = 5
) -> str:
"""
Process user query through full agent pipeline
Args:
query: User question
top_k: Number of papers to retrieve
max_papers_in_response: Max papers in final response
Returns:
Final synthesized answer
"""
print(f"\n{'=' * 60}")
print(f"Processing query: {query}")
print('=' * 60)
# Step 1: Retrieve
papers = self.retriever_agent.retrieve(query, top_k)
if not papers:
return "No relevant papers found for your query."
# Step 2: Read & Summarize
summaries = self.reader_agent.read_papers(papers)
# Step 3: Critique & Filter
quality_summaries = self.critic_agent.critique(summaries, query)
# Step 4: Synthesize
final_response = self.synthesizer_agent.synthesize(
quality_summaries,
query,
max_papers_in_response
)
print(f"{'=' * 60}\n")
return final_response
def main():
"""Example usage of multi-agent system"""
from fetch_arxiv_data import ArxivFetcher
# Setup
fetcher = ArxivFetcher()
retriever = PaperRetriever()
# Load or build index
if not retriever.load_index():
papers = fetcher.load_papers("arxiv_papers.json")
retriever.build_index(papers)
retriever.save_index()
# Create orchestrator
orchestrator = DocMindOrchestrator(retriever)
# Test queries
test_queries = [
"What are the latest improvements in diffusion models?",
"How does RLHF compare to DPO for language model alignment?",
"What are the main challenges in scaling transformers?"
]
for query in test_queries:
response = orchestrator.process_query(query, top_k=8, max_papers_in_response=3)
print(response)
print("\n" + "=" * 80 + "\n")
if __name__ == "__main__":
main()