Spaces:
Sleeping
Sleeping
| import os | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.chains import RetrievalQA | |
| from langchain.llms import HuggingFacePipeline | |
| from transformers import pipeline | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| class KnowledgeManager: | |
| def __init__(self, root_dir="."): | |
| self.root_dir = root_dir | |
| self.docsearch = None | |
| self.qa_chain = None | |
| self.llm = None | |
| self.embeddings = None | |
| self._initialize_llm() | |
| self._initialize_embeddings() | |
| self._load_knowledge_base() | |
| def _initialize_llm(self): | |
| # Load local text2text model using HuggingFace pipeline (FLAN-T5 small) | |
| local_pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=1024) | |
| self.llm = HuggingFacePipeline(pipeline=local_pipe) | |
| def _initialize_embeddings(self): | |
| # Use general-purpose sentence transformer | |
| self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| def _load_knowledge_base(self): | |
| # Automatically find all .txt files in the root directory | |
| txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")] | |
| if not txt_files: | |
| raise FileNotFoundError("No .txt files found in root directory.") | |
| all_texts = [] | |
| for filename in txt_files: | |
| path = os.path.join(self.root_dir, filename) | |
| with open(path, "r", encoding="utf-8") as f: | |
| all_texts.append(f.read()) | |
| full_text = "\n\n".join(all_texts) | |
| # Split text into chunks for embedding | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
| docs = text_splitter.create_documents([full_text]) | |
| # Create FAISS vector store | |
| self.docsearch = FAISS.from_documents(docs, self.embeddings) | |
| # Build the QA chain | |
| self.qa_chain = RetrievalQA.from_chain_type( | |
| llm=self.llm, | |
| chain_type="stuff", | |
| retriever=self.docsearch.as_retriever(), | |
| return_source_documents=True, | |
| ) | |
| def ask(self, query): | |
| if not self.qa_chain: | |
| raise ValueError("Knowledge base not initialized.") | |
| result = self.qa_chain(query) | |
| return result['result'] | |