DocuCite-Agent / app.py
monitkorn's picture
rebuild
412f267
from __future__ import annotations
import tempfile
from pathlib import Path
import gradio as gr
import pdfplumber
import numpy as np
import hashlib, tempfile, pathlib, torch
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langgraph.graph import MessagesState, StateGraph
from langchain.docstore.document import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
import pdfplumber
from langchain_openai import ChatOpenAI
# Trigger rebuild
device = "cuda" if torch.cuda.is_available() else "cpu"
EMBEDDER = HuggingFaceEmbeddings(
model_name="BAAI/bge-m3",
encode_kwargs={"normalize_embeddings": True, "device": device},
)
LLM = ChatOpenAI(
openai_api_key="eyJhbGciOiJIUzI1NiIsImtpZCI6IlV6SXJWd1h0dnprLVRvdzlLZWstc0M1akptWXBvX1VaVkxUZlpnMDRlOFUiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiJnb29nbGUtb2F1dGgyfDExMTYxMjA0MzQ0ODU0NTI5MTczNCIsInNjb3BlIjoib3BlbmlkIG9mZmxpbmVfYWNjZXNzIiwiaXNzIjoiYXBpX2tleV9pc3N1ZXIiLCJhdWQiOlsiaHR0cHM6Ly9uZWJpdXMtaW5mZXJlbmNlLmV1LmF1dGgwLmNvbS9hcGkvdjIvIl0sImV4cCI6MTkwNzA0Mjc0OCwidXVpZCI6ImY4ZWEzOGUyLTllNjktNDM3NS05YjkzLWE3Y2EzMThiMjZjZCIsIm5hbWUiOiJoYWNrYXRob24iLCJleHBpcmVzX2F0IjoiMjAzMC0wNi0wN1QwNjowNTo0OCswMDAwIn0.DH7JrezDuqrl2SPMdWdWWnWgBPrvBbe9yucG29-3YpQ",
openai_api_base="https://api.studio.nebius.com/v1",
model="Qwen/Qwen2.5-72B-Instruct"
)
from pathlib import Path
def get_file_bytes_and_name(pdf_file):
print("DEBUG: pdf_file type:", type(pdf_file))
print("DEBUG: pdf_file dir:", dir(pdf_file))
print("DEBUG: pdf_file repr:", repr(pdf_file))
if hasattr(pdf_file, "read"):
return pdf_file.read(), Path(pdf_file.name).name
if isinstance(pdf_file, str):
file_path = Path(pdf_file)
with open(file_path, "rb") as f:
return f.read(), file_path.name
raise ValueError("Could not extract file bytes from uploaded file.")
VECTOR_ROOT = pathlib.Path.home() / ".rag_vectors"
VECTOR_ROOT.mkdir(exist_ok=True)
# ────────────── 3. PDF-to-vectorstore, clean and tag paragraphs ──────────────
def load_or_create_chroma(pdf_bytes: bytes, filename: str) -> Chroma:
"""
Loads persistent Chroma vectorstore for this PDF, or creates it if not found.
Each chunk carries page and paragraph info.
"""
print(f"\n[INFO] Checking vectorstore for file: {filename}")
h = hashlib.md5(pdf_bytes).hexdigest()
vect_dir = VECTOR_ROOT / h
if (vect_dir / "chroma.sqlite3").exists():
print(f"[INFO] Found existing vectorstore: {vect_dir}")
return Chroma(persist_directory=str(vect_dir), embedding_function=EMBEDDER)
print(f"[INFO] No vectorstore found, embedding file: {filename}")
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
tmp.write(pdf_bytes)
tmp_path = tmp.name
docs = []
BAD_PHRASES = {
"Abstracting with credit is permitted",
"Permission to make digital or hard copies",
"arXiv:",
"Β©",
}
def clean_page(text: str) -> str:
return "\n".join(
line for line in text.splitlines()
if not any(b in line for b in BAD_PHRASES)
)
with pdfplumber.open(tmp_path) as pdf:
for page_num, page in enumerate(pdf.pages, start=1):
text = clean_page(page.extract_text() or "")
if not text.strip():
continue
# Split into small chunks for embedding
splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=1200, chunk_overlap=200
)
para_chunks = splitter.split_text(text)
for para_num, chunk in enumerate(para_chunks, start=1):
docs.append(
Document(
page_content=chunk,
metadata={"page_number": page_num, "paragraph_number": para_num}
)
)
print(f"[INFO] Extracted {len(docs)} chunks from PDF for embedding.")
vectordb = Chroma.from_documents(
docs, EMBEDDER, persist_directory=str(vect_dir)
)
vectordb.persist()
return vectordb
from langchain.tools import Tool
def build_retriever_tool(vectorstore):
# 1) build a retriever (here we ask for top 3 matches)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# 2) wrap it so every result is tagged with page/paragraph
def custom_search(query: str) -> str:
docs = retriever.get_relevant_documents(query)
if not docs:
return "No relevant passages found."
out = []
for d in docs:
page = d.metadata.get("page_number", "?")
para = d.metadata.get("paragraph_number", "?")
txt = d.page_content.replace("\n", " ").strip()
out.append(f"[Page {page}, Paragraph {para}]: {txt}")
# join with blank lines so LLM can see separate chunks
return "\n\n".join(out)
# 3) expose that wrapper as a LangChain Tool
return Tool(
name="document_search",
func=custom_search,
description=(
"Searches the uploaded PDF for a query and returns each matching "
"passage prefixed with its page and paragraph number."
),
)
def make_generate_query_or_respond(retriever_tool):
def generate_query_or_respond(state):
response = (
LLM
.bind_tools([retriever_tool]).invoke(state["messages"])
)
return {"messages": [response]}
return generate_query_or_respond
GENERATE_PROMPT = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer the question with reference and page number."
"attention to the context, and only use it to answer the question. "
"If you don't know the answer, just say that you don't know. "
"Question: {question} \n"
"Context: {context}"
)
def generate_answer(state: MessagesState):
print(f"[DEBUG] Answer node, messages so far: {state['messages']}")
question = state["messages"][0].content
print(f"[DEBUG] Question: {question}")
context = state["messages"][-1].content
print(f"[DEBUG] Context: {context}")
prompt = GENERATE_PROMPT.format(question=question, context=context)
response = LLM.invoke([{"role": "user", "content": prompt}])
print(f"[DEBUG] LLM final answer: {response}")
return {"messages": [response]}
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition
def build_agentic_graph(retriever_tool):
workflow = StateGraph(MessagesState)
workflow.add_node("generate_query_or_respond", make_generate_query_or_respond(retriever_tool))
workflow.add_node("retrieve", ToolNode([retriever_tool]))
workflow.add_node(generate_answer)
workflow.add_edge(START, "generate_query_or_respond")
workflow.add_conditional_edges(
"generate_query_or_respond",
tools_condition,
{
"tools": "retrieve",
END: END,
},
)
workflow.add_edge("retrieve", "generate_answer")
workflow.add_edge("generate_answer", END)
# workflow.add_edge("retrieve", "agent") # cycle back for multiple tool use if needed
return workflow.compile()
def gradio_agentic_rag(pdf_file, question, history=None):
pdf_bytes, filename = get_file_bytes_and_name(pdf_file)
vectordb = load_or_create_chroma(pdf_bytes, filename)
# retriever_tool = build_retriever_tool(vectordb)
retriever_tool = build_retriever_tool(vectordb)
graph = build_agentic_graph(retriever_tool)
state_messages = []
if history:
for turn in history:
if isinstance(turn, list) or isinstance(turn, tuple):
if turn[0]:
state_messages.append({"role": "user", "content": turn[0]})
if len(turn) > 1 and turn[1]:
state_messages.append({"role": "assistant", "content": turn[1]})
state_messages.append({"role": "user", "content": question})
state = {"messages": state_messages}
result = None
for chunk in graph.stream(state):
print(f"Chunk: {chunk}")
for node, update in chunk.items():
print(f"Node: {node}, Update: {update}")
last_msg = update["messages"][-1]
if node == "generate_answer" or (
node == "generate_query_or_respond" and not update["messages"][-1].tool_calls
):
result = last_msg.content
if history is None:
history = []
history.append([question, result])
return result, history
iface = gr.Interface(
fn=gradio_agentic_rag,
inputs=[
gr.File(label="Upload your PDF"),
gr.Textbox(label="Ask a question about your PDF"),
gr.State()
],
outputs=[gr.Textbox(label="Answer from RAG Agent"),
gr.State()],
title="DocuCite Agent",
description="An agentic RAG (Retrieval-Augmented Generation) system that can answer questions about the contents of a PDF document with references to the page and paragraph number.",
examples=[
["paper.pdf", "What is LoRA? please use the tool"],
],
)
if __name__ == "__main__":
iface.launch(
mcp_server=True,
show_error=True,
show_api=True
)