Spaces:
Running
Running
| """ | |
| ViettelPay Knowledge Base with Contextual Retrieval | |
| This updated version: | |
| - Uses ContextualWordProcessor for all document processing | |
| - Integrates OpenAI for contextual enhancement | |
| - Processes all doc/docx files from a parent folder | |
| - Removes CSV processor dependency | |
| """ | |
| import os | |
| import pickle | |
| # import torch | |
| from typing import List, Optional | |
| from pathlib import Path | |
| from openai import OpenAI | |
| from langchain.schema import Document | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain_core.runnables import ConfigurableField | |
| from langchain_cohere.rerank import CohereRerank | |
| # Use newest import paths for langchain | |
| try: | |
| from langchain_chroma import Chroma | |
| except ImportError: | |
| from langchain_community.vectorstores import Chroma | |
| # Use the new HuggingFaceEmbeddings from langchain-huggingface | |
| try: | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| except ImportError: | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from src.processor.contextual_word_processor import ContextualWordProcessor | |
| from src.processor.text_utils import VietnameseTextProcessor | |
| # Import configuration utility | |
| from src.utils.config import get_cohere_api_key, get_openai_api_key, get_embedding_model | |
| class ViettelKnowledgeBase: | |
| """ViettelPay knowledge base with contextual retrieval enhancement""" | |
| def __init__(self, embedding_model: str = None): | |
| """ | |
| Initialize the knowledge base | |
| Args: | |
| embedding_model: Vietnamese embedding model to use | |
| """ | |
| embedding_model = embedding_model or get_embedding_model() | |
| # Initialize Vietnamese text processor | |
| self.text_processor = VietnameseTextProcessor() | |
| # self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.device = "cpu" | |
| print(f"[INFO] Using device: {self.device}") | |
| # Initialize embeddings with GPU support and trust_remote_code | |
| model_kwargs = {"device": self.device, "trust_remote_code": True} | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name=embedding_model, model_kwargs=model_kwargs | |
| ) | |
| # Initialize retrievers as None | |
| self.chroma_retriever = None | |
| self.bm25_retriever = None | |
| self.ensemble_retriever = None | |
| self.reranker = CohereRerank( | |
| model="rerank-v3.5", | |
| cohere_api_key=get_cohere_api_key(), | |
| ) | |
| def build_knowledge_base( | |
| self, | |
| documents_folder: str, | |
| persist_dir: str = "./knowledge_base", | |
| reset: bool = True, | |
| openai_api_key: Optional[str] = None, | |
| ) -> None: | |
| """ | |
| Build knowledge base from all Word documents in a folder | |
| Args: | |
| documents_folder: Path to folder containing doc/docx files | |
| persist_dir: Directory to persist the knowledge base | |
| reset: Whether to reset existing knowledge base | |
| openai_api_key: OpenAI API key for contextual enhancement (optional) | |
| Returns: | |
| None. Use the search() method to perform searches. | |
| """ | |
| print( | |
| "[INFO] Building ViettelPay knowledge base with contextual enhancement..." | |
| ) | |
| # Initialize OpenAI client for contextual enhancement if API key provided | |
| openai_client = None | |
| if openai_api_key: | |
| openai_client = OpenAI(api_key=openai_api_key) | |
| print(f"[INFO] OpenAI client initialized for contextual enhancement") | |
| elif get_openai_api_key(): | |
| api_key = get_openai_api_key() | |
| openai_client = OpenAI(api_key=api_key) | |
| print(f"[INFO] OpenAI client initialized from configuration") | |
| else: | |
| print( | |
| f"[WARNING] No OpenAI API key provided. Contextual enhancement disabled." | |
| ) | |
| # Initialize the contextual word processor with OpenAI client | |
| word_processor = ContextualWordProcessor(llm_client=openai_client) | |
| # Find all Word documents in the folder | |
| word_files = self._find_word_documents(documents_folder) | |
| if not word_files: | |
| raise ValueError(f"No Word documents found in {documents_folder}") | |
| print(f"[INFO] Found {len(word_files)} Word documents to process") | |
| # Process all documents | |
| all_documents = self._process_all_word_files(word_files, word_processor) | |
| print(f"[INFO] Total documents processed: {len(all_documents)}") | |
| # Create directories | |
| os.makedirs(persist_dir, exist_ok=True) | |
| chroma_dir = os.path.join(persist_dir, "chroma") | |
| bm25_path = os.path.join(persist_dir, "bm25_retriever.pkl") | |
| # Build ChromaDB retriever (uses contextualized content) | |
| print("[INFO] Building ChromaDB retriever with contextualized content...") | |
| self.chroma_retriever = self._build_chroma_retriever( | |
| all_documents, chroma_dir, reset | |
| ) | |
| # Build BM25 retriever (uses contextualized content with Vietnamese tokenization) | |
| print("[INFO] Building BM25 retriever with Vietnamese tokenization...") | |
| self.bm25_retriever = self._build_bm25_retriever( | |
| all_documents, bm25_path, reset | |
| ) | |
| # Create ensemble retriever with configurable top-k | |
| print("[INFO] Creating ensemble retriever...") | |
| self.ensemble_retriever = self._build_retriever( | |
| self.bm25_retriever, self.chroma_retriever | |
| ) | |
| print("[SUCCESS] Contextual knowledge base built successfully!") | |
| print("[INFO] Use kb.search(query, top_k) to perform searches.") | |
| def load_knowledge_base(self, persist_dir: str = "./knowledge_base") -> bool: | |
| """ | |
| Load existing knowledge base from disk and rebuild BM25 from ChromaDB documents | |
| Args: | |
| persist_dir: Directory where the knowledge base is stored | |
| Returns: | |
| bool: True if loaded successfully, False otherwise | |
| """ | |
| print("[INFO] Loading knowledge base from disk...") | |
| chroma_dir = os.path.join(persist_dir, "chroma") | |
| try: | |
| # Load ChromaDB | |
| if os.path.exists(chroma_dir): | |
| vectorstore = Chroma( | |
| persist_directory=chroma_dir, embedding_function=self.embeddings | |
| ) | |
| self.chroma_retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) | |
| print("[SUCCESS] ChromaDB loaded") | |
| else: | |
| print("[ERROR] ChromaDB not found") | |
| return False | |
| # Extract all documents from ChromaDB to rebuild BM25 | |
| print("[INFO] Extracting documents from ChromaDB to rebuild BM25...") | |
| try: | |
| # Get all documents and metadata from ChromaDB | |
| all_docs = vectorstore.get(include=["documents", "metadatas"]) | |
| documents = all_docs["documents"] | |
| metadatas = all_docs["metadatas"] | |
| # Reconstruct Document objects | |
| doc_objects = [] | |
| for i, (doc_content, metadata) in enumerate(zip(documents, metadatas)): | |
| # Handle case where metadata might be None | |
| if metadata is None: | |
| metadata = {} | |
| doc_obj = Document(page_content=doc_content, metadata=metadata) | |
| doc_objects.append(doc_obj) | |
| print(f"[INFO] Extracted {len(doc_objects)} documents from ChromaDB") | |
| # Rebuild BM25 retriever using existing method | |
| self.bm25_retriever = self._build_bm25_retriever( | |
| documents=doc_objects, | |
| bm25_path=None, # Not used anymore | |
| reset=False, # Not relevant for rebuilding | |
| ) | |
| except Exception as e: | |
| print(f"[ERROR] Error rebuilding BM25 from ChromaDB: {e}") | |
| return False | |
| # Create ensemble retriever with configurable top-k | |
| self.ensemble_retriever = self._build_retriever( | |
| self.bm25_retriever, self.chroma_retriever | |
| ) | |
| print("[SUCCESS] Knowledge base loaded successfully!") | |
| print("[INFO] Use kb.search(query, top_k) to perform searches.") | |
| return True | |
| except Exception as e: | |
| print(f"[ERROR] Error loading knowledge base: {e}") | |
| return False | |
| def search(self, query: str, top_k: int = 10) -> List[Document]: | |
| """ | |
| Main search method using ensemble retriever with configurable top-k | |
| Args: | |
| query: Search query | |
| top_k: Number of documents to return from each retriever (default: 5) | |
| Returns: | |
| List of retrieved documents | |
| """ | |
| if not self.ensemble_retriever: | |
| raise ValueError( | |
| "Knowledge base not loaded. Call build_knowledge_base() or load_knowledge_base() first." | |
| ) | |
| # Build config based on top_k parameter | |
| config = { | |
| "configurable": { | |
| "bm25_k": top_k * 5, | |
| "chroma_search_kwargs": {"k": top_k * 5}, | |
| } | |
| } | |
| results = self.ensemble_retriever.invoke(query, config=config) | |
| reranked_results = self.reranker.rerank(results, query, top_n=top_k) | |
| final_results = [] | |
| for rerank_item in reranked_results: | |
| # Get the original document using the index | |
| original_doc = results[rerank_item["index"]] | |
| # Create a new document with the relevance score added to metadata | |
| reranked_doc = Document( | |
| page_content=original_doc.page_content, | |
| metadata={ | |
| **original_doc.metadata, | |
| "relevance_score": rerank_item["relevance_score"], | |
| }, | |
| ) | |
| final_results.append(reranked_doc) | |
| return final_results | |
| def get_stats(self) -> dict: | |
| """Get statistics about the knowledge base""" | |
| stats = {} | |
| if self.chroma_retriever: | |
| try: | |
| vectorstore = self.chroma_retriever.vectorstore | |
| collection = vectorstore._collection | |
| stats["chroma_documents"] = collection.count() | |
| except: | |
| stats["chroma_documents"] = "Unknown" | |
| if self.bm25_retriever: | |
| try: | |
| stats["bm25_documents"] = len(self.bm25_retriever.docs) | |
| except: | |
| stats["bm25_documents"] = "Unknown" | |
| stats["ensemble_available"] = self.ensemble_retriever is not None | |
| stats["device"] = self.device | |
| stats["vietnamese_tokenizer"] = "Vietnamese BM25 tokenizer (underthesea)" | |
| return stats | |
| def _find_word_documents(self, folder_path: str) -> List[str]: | |
| """ | |
| Find all Word documents (.doc, .docx) in the given folder | |
| Args: | |
| folder_path: Path to the folder to search | |
| Returns: | |
| List of full paths to Word documents | |
| """ | |
| word_files = [] | |
| folder = Path(folder_path) | |
| if not folder.exists(): | |
| raise FileNotFoundError(f"Folder not found: {folder_path}") | |
| # Search for Word documents | |
| for pattern in ["*.doc", "*.docx"]: | |
| word_files.extend(folder.glob(pattern)) | |
| # Convert to string paths and sort for consistent processing order | |
| word_files = [str(f) for f in word_files] | |
| word_files.sort() | |
| print(f"[INFO] Found Word documents: {[Path(f).name for f in word_files]}") | |
| return word_files | |
| def _process_all_word_files( | |
| self, word_files: List[str], word_processor: ContextualWordProcessor | |
| ) -> List[Document]: | |
| """Process all Word files into unified chunks with contextual enhancement""" | |
| all_documents = [] | |
| for file_path in word_files: | |
| try: | |
| print(f"[INFO] Processing: {Path(file_path).name}") | |
| chunks = word_processor.process_word_document(file_path) | |
| all_documents.extend(chunks) | |
| # Print processing stats for this file | |
| stats = word_processor.get_document_stats(chunks) | |
| print( | |
| f"[SUCCESS] Processed {Path(file_path).name}: {len(chunks)} chunks" | |
| ) | |
| print(f" - Contextualized: {stats.get('contextualized_docs', 0)}") | |
| print( | |
| f" - Non-contextualized: {stats.get('non_contextualized_docs', 0)}" | |
| ) | |
| except Exception as e: | |
| print(f"[ERROR] Error processing {Path(file_path).name}: {e}") | |
| return all_documents | |
| def _build_retriever(self, bm25_retriever, chroma_retriever): | |
| """ | |
| Build ensemble retriever with configurable top-k parameters | |
| Args: | |
| bm25_retriever: BM25 retriever with configurable fields | |
| chroma_retriever: Chroma retriever with configurable fields | |
| Returns: | |
| EnsembleRetriever with configurable retrievers | |
| """ | |
| return EnsembleRetriever( | |
| retrievers=[bm25_retriever, chroma_retriever], | |
| weights=[0.2, 0.8], # Slightly favor semantic search | |
| ) | |
| def _build_chroma_retriever( | |
| self, documents: List[Document], chroma_dir: str, reset: bool | |
| ): | |
| """Build ChromaDB retriever with configurable search parameters""" | |
| if reset and os.path.exists(chroma_dir): | |
| import shutil | |
| shutil.rmtree(chroma_dir) | |
| print("[INFO] Removed existing ChromaDB for rebuild") | |
| # Create Chroma vectorstore (uses contextualized content) | |
| vectorstore = Chroma.from_documents( | |
| documents=documents, embedding=self.embeddings, persist_directory=chroma_dir | |
| ) | |
| # Create retriever with configurable search_kwargs | |
| retriever = vectorstore.as_retriever( | |
| search_kwargs={"k": 5} # default value | |
| ).configurable_fields( | |
| search_kwargs=ConfigurableField( | |
| id="chroma_search_kwargs", | |
| name="Chroma Search Kwargs", | |
| description="Search kwargs for Chroma DB retriever", | |
| ) | |
| ) | |
| print( | |
| f"[SUCCESS] ChromaDB created with {len(documents)} contextualized documents" | |
| ) | |
| return retriever | |
| def _build_bm25_retriever( | |
| self, documents: List[Document], bm25_path: str, reset: bool | |
| ): | |
| """Build BM25 retriever with Vietnamese tokenization and configurable k parameter""" | |
| # Note: We no longer save BM25 to pickle file to avoid Streamlit Cloud compatibility issues | |
| # BM25 will be rebuilt from ChromaDB documents when loading the knowledge base | |
| # Create BM25 retriever with Vietnamese tokenizer as preprocess_func | |
| print("[INFO] Using Vietnamese tokenizer for BM25 on contextualized content...") | |
| retriever = BM25Retriever.from_documents( | |
| documents=documents, | |
| preprocess_func=self.text_processor.bm25_tokenizer, | |
| k=5, # default value | |
| ).configurable_fields( | |
| k=ConfigurableField( | |
| id="bm25_k", | |
| name="BM25 Top K", | |
| description="Number of documents to return from BM25", | |
| ) | |
| ) | |
| print( | |
| f"[SUCCESS] BM25 retriever created with {len(documents)} contextualized documents" | |
| ) | |
| return retriever | |
| def test_contextual_kb(kb: ViettelKnowledgeBase, test_queries: List[str]): | |
| """Test function for the contextual knowledge base""" | |
| print("\n[INFO] Testing Contextual Knowledge Base") | |
| print("=" * 60) | |
| for i, query in enumerate(test_queries, 1): | |
| print(f"\n#{i} Query: '{query}'") | |
| print("-" * 40) | |
| try: | |
| # Test ensemble search with configurable top-k | |
| results = kb.search(query, top_k=3) | |
| if results: | |
| for j, doc in enumerate(results, 1): | |
| content_preview = doc.page_content[:150].replace("\n", " ") | |
| doc_type = doc.metadata.get("doc_type", "unknown") | |
| has_context = doc.metadata.get("has_context", False) | |
| context_indicator = ( | |
| " [CONTEXTUAL]" if has_context else " [ORIGINAL]" | |
| ) | |
| print( | |
| f" {j}. [{doc_type}]{context_indicator} {content_preview}..." | |
| ) | |
| else: | |
| print(" No results found") | |
| except Exception as e: | |
| print(f" [ERROR] Error: {e}") | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Initialize knowledge base | |
| kb = ViettelKnowledgeBase( | |
| embedding_model="dangvantuan/vietnamese-document-embedding" | |
| ) | |
| # Build knowledge base from a folder of Word documents | |
| documents_folder = "./viettelpay_docs" # Folder containing .doc/.docx files | |
| try: | |
| # Build knowledge base (pass OpenAI API key here for contextual enhancement) | |
| kb.build_knowledge_base( | |
| documents_folder, | |
| "./contextual_kb", | |
| reset=True, | |
| openai_api_key="your-openai-api-key-here", # or None to use env variable | |
| ) | |
| # Alternative: Load existing knowledge base | |
| # success = kb.load_knowledge_base("./contextual_kb") | |
| # if not success: | |
| # print("[ERROR] Failed to load knowledge base") | |
| # Test queries | |
| test_queries = [ | |
| "lỗi 606", | |
| "không nạp được tiền", | |
| "hướng dẫn nạp cước", | |
| "quy định hủy giao dịch", | |
| "mệnh giá thẻ cào", | |
| ] | |
| # Test the knowledge base | |
| test_contextual_kb(kb, test_queries) | |
| # Example of runtime configuration for different top-k values | |
| print(f"\n[INFO] Example of runtime configuration:") | |
| print("=" * 50) | |
| # Search with different top-k values | |
| sample_query = "lỗi 606" | |
| # Search with top_k=3 | |
| results1 = kb.search(sample_query, top_k=3) | |
| print(f"Search with top_k=3: {len(results1)} total results") | |
| # Search with top_k=8 | |
| results2 = kb.search(sample_query, top_k=8) | |
| print(f"Search with top_k=8: {len(results2)} total results") | |
| # Show stats | |
| print(f"\n[INFO] Knowledge Base Stats: {kb.get_stats()}") | |
| except Exception as e: | |
| print(f"[ERROR] Error building knowledge base: {e}") | |
| print("[INFO] Make sure you have:") | |
| print(" 1. Valid OpenAI API key") | |
| print(" 2. Word documents in the specified folder") | |
| print(" 3. Required dependencies installed (openai, markitdown, etc.)") | |