edgellm / rag_system.py
wu981526092's picture
� IMPROVEMENT: 大幅简化Documents界面用户体验
62309aa
raw
history blame
8.24 kB
"""
Simple RAG (Retrieval-Augmented Generation) System using LangChain
"""
import os
import tempfile
from typing import List, Dict, Any, Optional
from pathlib import Path
import uuid
try:
from langchain_community.document_loaders import (
PyPDFLoader,
TextLoader,
UnstructuredWordDocumentLoader
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
LANGCHAIN_AVAILABLE = True
except ImportError:
print("LangChain not installed. Install with: pip install langchain langchain-community langchain-huggingface pypdf python-docx faiss-cpu sentence-transformers")
LANGCHAIN_AVAILABLE = False
# Fallback Document class for type hints
class Document:
def __init__(self, page_content: str, metadata: dict = None):
self.page_content = page_content
self.metadata = metadata or {}
class SimpleRAGSystem:
def __init__(self):
"""Initialize the RAG system with embeddings and vector store."""
if not LANGCHAIN_AVAILABLE:
print("LangChain not available. RAG functionality disabled.")
self.embeddings = None
self.vector_store = None
self.documents_metadata = {}
self.text_splitter = None
return
# Use a lightweight embedding model
self.embeddings = None
self.vector_store = None
self.documents_metadata = {}
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len
)
# Initialize embeddings
try:
self.embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}
)
except Exception as e:
print(f"Failed to initialize embeddings: {e}")
self.embeddings = None
def _load_document(self, file_path: str, file_type: str) -> List[Document]:
"""Load a document based on its type."""
if not LANGCHAIN_AVAILABLE:
return [Document(
page_content="LangChain not available",
metadata={"source": file_path, "error": True}
)]
try:
if file_type == 'application/pdf' or file_path.endswith('.pdf'):
loader = PyPDFLoader(file_path)
elif file_type == 'text/plain' or file_path.endswith('.txt'):
loader = TextLoader(file_path, encoding='utf-8')
elif file_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' or file_path.endswith('.docx'):
loader = UnstructuredWordDocumentLoader(file_path)
elif file_path.endswith('.md'):
loader = TextLoader(file_path, encoding='utf-8')
else:
# Fallback to text loader
loader = TextLoader(file_path, encoding='utf-8')
return loader.load()
except Exception as e:
print(f"Error loading document {file_path}: {e}")
# Return empty document with error info
return [Document(
page_content=f"Error loading document: {str(e)}",
metadata={"source": file_path, "error": True}
)]
def add_document(self, file_content: bytes, filename: str, file_type: str) -> Dict[str, Any]:
"""Add a document to the RAG system."""
if not LANGCHAIN_AVAILABLE:
return {"success": False, "error": "LangChain not available"}
if not self.embeddings:
return {"success": False, "error": "Embeddings not initialized"}
try:
# Create temporary file
doc_id = str(uuid.uuid4())
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(filename).suffix) as tmp_file:
tmp_file.write(file_content)
tmp_path = tmp_file.name
# Load and process document
documents = self._load_document(tmp_path, file_type)
# Split documents into chunks
texts = self.text_splitter.split_documents(documents)
# Add metadata
for text in texts:
text.metadata.update({
"doc_id": doc_id,
"filename": filename,
"file_type": file_type
})
# Create or update vector store
if self.vector_store is None:
self.vector_store = FAISS.from_documents(texts, self.embeddings)
else:
self.vector_store.add_documents(texts)
# Store document metadata
self.documents_metadata[doc_id] = {
"filename": filename,
"file_type": file_type,
"chunks": len(texts),
"status": "processed"
}
# Clean up temporary file
os.unlink(tmp_path)
return {
"success": True,
"doc_id": doc_id,
"chunks": len(texts),
"message": f"Document '{filename}' processed successfully"
}
except Exception as e:
print(f"Error processing document {filename}: {e}")
return {"success": False, "error": str(e)}
def remove_document(self, doc_id: str) -> Dict[str, Any]:
"""Remove a document from the RAG system."""
try:
if doc_id in self.documents_metadata:
# Note: FAISS doesn't support removing specific documents easily
# In a production system, you'd rebuild the vector store
del self.documents_metadata[doc_id]
return {"success": True, "message": "Document removed"}
else:
return {"success": False, "error": "Document not found"}
except Exception as e:
return {"success": False, "error": str(e)}
def search_similar(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
"""Search for similar documents."""
if not LANGCHAIN_AVAILABLE:
return []
if not self.vector_store:
return []
try:
docs = self.vector_store.similarity_search(query, k=k)
results = []
for doc in docs:
results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"filename": doc.metadata.get("filename", "Unknown")
})
return results
except Exception as e:
print(f"Error searching documents: {e}")
return []
def get_context_for_query(self, query: str, max_chunks: int = 3) -> str:
"""Get relevant context for a query."""
if not LANGCHAIN_AVAILABLE:
return ""
if not self.vector_store:
return ""
try:
similar_docs = self.search_similar(query, k=max_chunks)
context_parts = []
for doc in similar_docs:
context_parts.append(f"From '{doc['filename']}':\n{doc['content']}")
return "\n\n---\n\n".join(context_parts)
except Exception as e:
print(f"Error getting context: {e}")
return ""
def get_documents_info(self) -> Dict[str, Any]:
"""Get information about stored documents."""
return {
"total_documents": len(self.documents_metadata),
"documents": self.documents_metadata,
"vector_store_ready": self.vector_store is not None
}
# Global RAG system instance
rag_system = SimpleRAGSystem()
def get_rag_system() -> SimpleRAGSystem:
"""Get the global RAG system instance."""
return rag_system