Spaces:
Running
Running
Jatin Mehra
Refactor retrieval and agent functions for improved chunk handling and error management
63ed7c1
| import os | |
| from langchain_community.document_loaders import PyMuPDFLoader | |
| import faiss | |
| from langchain_groq import ChatGroq | |
| from langchain.agents import AgentExecutor, create_tool_calling_agent | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain.memory import ConversationBufferMemory | |
| from sentence_transformers import SentenceTransformer | |
| import dotenv | |
| from langchain.tools import tool | |
| import traceback | |
| dotenv.load_dotenv() | |
| # Initialize LLM and tools globally | |
| def model_selection(model_name): | |
| llm = ChatGroq(model=model_name, api_key=os.getenv("GROQ_API_KEY")) | |
| return llm | |
| tools = [TavilySearchResults(max_results=5)] | |
| # Initialize memory for conversation history | |
| memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
| def estimate_tokens(text): | |
| """Estimate the number of tokens in a text (rough approximation).""" | |
| return len(text) // 4 | |
| def process_pdf_file(file_path): | |
| """Load a PDF file and extract its text with metadata.""" | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError(f"The file {file_path} does not exist.") | |
| loader = PyMuPDFLoader(file_path) | |
| documents = loader.load() | |
| return documents # Return list of Document objects with metadata | |
| def chunk_text(documents, max_length=1000): | |
| """Split documents into chunks with metadata.""" | |
| chunks = [] | |
| for doc in documents: | |
| text = doc.page_content | |
| metadata = doc.metadata | |
| paragraphs = text.split("\n\n") | |
| current_chunk = "" | |
| current_metadata = metadata.copy() | |
| for paragraph in paragraphs: | |
| if estimate_tokens(current_chunk + paragraph) <= max_length // 4: | |
| current_chunk += paragraph + "\n\n" | |
| else: | |
| chunks.append({"text": current_chunk.strip(), "metadata": current_metadata}) | |
| current_chunk = paragraph + "\n\n" | |
| if current_chunk: | |
| chunks.append({"text": current_chunk.strip(), "metadata": current_metadata}) | |
| return chunks | |
| def create_embeddings(chunks, model): | |
| """Create embeddings for a list of chunk texts.""" | |
| texts = [chunk["text"] for chunk in chunks] | |
| embeddings = model.encode(texts, show_progress_bar=True, convert_to_tensor=True) | |
| return embeddings.cpu().numpy(), chunks | |
| def build_faiss_index(embeddings): | |
| """Build a FAISS HNSW index from embeddings for similarity search.""" | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexHNSWFlat(dim, 32) # 32 = number of neighbors in HNSW graph | |
| index.hnsw.efConstruction = 200 # Higher = better quality, slower build | |
| index.hnsw.efSearch = 50 # Higher = better accuracy, slower search | |
| index.add(embeddings) | |
| return index | |
| def retrieve_similar_chunks(query, index, chunks_with_metadata, embedding_model, k=10, max_chunk_length=1000): | |
| """Retrieve top k similar chunks to the query from the FAISS index.""" | |
| query_embedding = embedding_model.encode([query], convert_to_tensor=True).cpu().numpy() | |
| distances, indices = index.search(query_embedding, k) | |
| # Ensure indices are within bounds of chunks_with_metadata | |
| valid_indices = [i for i in indices[0] if 0 <= i < len(chunks_with_metadata)] | |
| return [ | |
| (chunks_with_metadata[i]["text"][:max_chunk_length], distances[0][j], chunks_with_metadata[i]["metadata"]) | |
| for j, i in enumerate(valid_indices) # Use valid_indices | |
| ] | |
| def create_vector_search_tool(faiss_index, document_chunks_with_metadata, embedding_model, k=3, max_chunk_length=1000): | |
| def vector_database_search(query: str) -> str: | |
| """ | |
| Searches the currently uploaded PDF document for information semantically similar to the query. | |
| Use this tool when the user's question is likely answerable from the content of the specific document they provided. | |
| Input should be the search query. | |
| """ | |
| # Retrieve similar chunks using the provided session-specific components | |
| similar_chunks_data = retrieve_similar_chunks( | |
| query, | |
| faiss_index, | |
| document_chunks_with_metadata, # This is the list of dicts {text: ..., metadata: ...} | |
| embedding_model, | |
| k=k, | |
| max_chunk_length=max_chunk_length | |
| ) | |
| # Format the response | |
| if not similar_chunks_data: | |
| return "No relevant information found in the document for that query." | |
| context = "\n\n---\n\n".join([chunk_text for chunk_text, _, _ in similar_chunks_data]) | |
| return f"The following information was found in the document regarding '{query}':\n{context}" | |
| return vector_database_search | |
| def agentic_rag(llm, agent_specific_tools, query, context_chunks, memory, Use_Tavily=False): # Renamed 'tools' to 'agent_specific_tools' | |
| # Sort chunks by relevance (lower distance = more relevant) | |
| context_chunks = sorted(context_chunks, key=lambda x: x[1]) if context_chunks else [] | |
| context = "" | |
| total_tokens = 0 | |
| max_tokens = 7000 # Leave room for prompt and response | |
| for chunk, _, _ in context_chunks: | |
| chunk_tokens = estimate_tokens(chunk) | |
| if total_tokens + chunk_tokens <= max_tokens: | |
| context += chunk + "\n\n" | |
| total_tokens += chunk_tokens | |
| else: | |
| break | |
| context = context.strip() if context else "No initial context provided from preliminary search." | |
| # Dynamically build the tool guidance for the prompt | |
| # Tool names: 'vector_database_search', 'tavily_search_results_json' | |
| has_document_search = any(t.name == "vector_database_search" for t in agent_specific_tools) | |
| has_web_search = any(t.name == "tavily_search_results_json" for t in agent_specific_tools) | |
| guidance_parts = [] | |
| if has_document_search: | |
| guidance_parts.append( | |
| "If the direct context (if any from preliminary search) is insufficient and the question seems answerable from the uploaded document, " | |
| "use the 'vector_database_search' tool to find relevant information within the document." | |
| ) | |
| if has_web_search: # Tavily tool would only be in agent_specific_tools if Use_Tavily was true | |
| guidance_parts.append( | |
| "If the information is not found in the document (after using 'vector_database_search' if appropriate) " | |
| "or the question is of a general nature not specific to the document, " | |
| "use the 'tavily_search_results_json' tool for web searches." | |
| ) | |
| if not guidance_parts: | |
| search_behavior_instructions = "If the context is insufficient, you *must* state that you don't know." | |
| else: | |
| search_behavior_instructions = " ".join(guidance_parts) | |
| search_behavior_instructions += ("\n * If, after all steps and tool use (if any), you cannot find an answer, " | |
| "respond with: \"Based on the available information, I don't know the answer.\"") | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", f""" | |
| You are an expert Q&A system. Your primary function is to answer questions using a given set of documents (Context) and available tools. | |
| **Your Process:** | |
| 1. **Analyze the Question:** Understand exactly what the user is asking. | |
| 2. **Scan the Context:** Thoroughly review the 'Context' provided (if any) to find relevant information. This context is derived from a preliminary similarity search in the document. | |
| 3. **Formulate the Answer:** | |
| * If the initially provided context contains a clear answer, synthesize it into a concise response. Start your answer with "Based on the Document, ...". | |
| * {search_behavior_instructions} | |
| * When using the 'vector_database_search' tool, the information comes from the document. Prepend your answer with "Based on the Document, ...". | |
| * When using the 'tavily_search_results_json' tool, the information comes from the web. Prepend your answer with "According to a web search, ...". If no useful information is found, state that. | |
| 4. **Clarity:** Ensure your final answer is clear, direct, and avoids jargon if possible. | |
| **Important Rules:** | |
| * **Stick to Sources:** Do *not* use any information outside of the provided 'Context', document search results ('vector_database_search'), or web search results ('tavily_search_results_json'). | |
| * **No Speculation:** Do not make assumptions or infer information not explicitly present. | |
| * **Cite Sources (If Web Searching):** If you use the 'tavily_search_results_json' tool and it provides source links, you MUST include them in your response. | |
| """), | |
| ("human", "Context: {{context}}\n\nQuestion: {{input}}"), # Double braces for f-string in f-string | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| MessagesPlaceholder(variable_name="agent_scratchpad"), | |
| ]) | |
| try: | |
| agent = create_tool_calling_agent(llm, agent_specific_tools, prompt) | |
| agent_executor = AgentExecutor(agent=agent, tools=agent_specific_tools, memory=memory, verbose=True) | |
| response_payload = agent_executor.invoke({ | |
| "input": query, | |
| "context": context, | |
| }) | |
| return response_payload # Expecting dict like {'output': '...'} | |
| except Exception as e: | |
| print(f"Error during agent execution: {str(e)} \nTraceback: {traceback.format_exc()}") | |
| fallback_prompt_template = ChatPromptTemplate.from_messages([ | |
| ("system", "You are a helpful assistant. Use the provided context to answer the user's question. If the context is insufficient, say you don't know."), | |
| ("human", "Context: {context}\n\nQuestion: {input}") | |
| ]) | |
| # Format the prompt with the actual context and query | |
| formatted_fallback_prompt = fallback_prompt_template.format_prompt(context=context, input=query).to_messages() | |
| response = llm.invoke(formatted_fallback_prompt) | |
| return {"output": response.content if hasattr(response, 'content') else str(response)} | |
| """if __name__ == "__main__": | |
| # Process PDF and prepare index | |
| dotenv.load_dotenv() | |
| pdf_file = "JatinCV.pdf" | |
| llm = model_selection("meta-llama/llama-4-scout-17b-16e-instruct") | |
| texts = process_pdf_file(pdf_file) | |
| chunks = chunk_text(texts, max_length=1500) | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| embeddings = create_embeddings(chunks, model) | |
| index = build_faiss_index(embeddings) | |
| # Chat loop | |
| print("Chat with the assistant (type 'exit' or 'quit' to stop):") | |
| while True: | |
| query = input("User: ") | |
| if query.lower() in ["exit", "quit"]: | |
| break | |
| # Retrieve similar chunks | |
| similar_chunks = retrieve_similar_chunks(query, index, chunks, model, k=3) | |
| # context = "\n".join([chunk for chunk, _ in similar_chunks]) | |
| # Generate response | |
| response = agentic_rag(llm, tools, query=query, context=similar_chunks, Use_Tavily=True, memory=memory) | |
| print("Assistant:", response["output"])""" |