diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..da15360707a7132317001f9c9f7ac950a39b05b4 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/1.png filter=lfs diff=lfs merge=lfs -text
+assets/2.png filter=lfs diff=lfs merge=lfs -text
+data/chinese_document.pdf filter=lfs diff=lfs merge=lfs -text
+data/english_document.pdf filter=lfs diff=lfs merge=lfs -text
+Mini_RAG/assets/1.png filter=lfs diff=lfs merge=lfs -text
+Mini_RAG/assets/2.png filter=lfs diff=lfs merge=lfs -text
+Mini_RAG/data/chinese_document.pdf filter=lfs diff=lfs merge=lfs -text
+Mini_RAG/data/english_document.pdf filter=lfs diff=lfs merge=lfs -text
diff --git a/Mini_RAG/.gitattributes b/Mini_RAG/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..dab9a4e17afd2ef39d90ccb0b40ef2786fe77422
--- /dev/null
+++ b/Mini_RAG/.gitattributes
@@ -0,0 +1,35 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/Mini_RAG/.idea/.gitignore b/Mini_RAG/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..1c2fda565b94d0f2b94cb65ba7cca866e7a25478
--- /dev/null
+++ b/Mini_RAG/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/Mini_RAG/.idea/RAG_Min.iml b/Mini_RAG/.idea/RAG_Min.iml
new file mode 100644
index 0000000000000000000000000000000000000000..527b363ce16e0ead03139f7e7ccee15f64a86ba5
--- /dev/null
+++ b/Mini_RAG/.idea/RAG_Min.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Mini_RAG/.idea/inspectionProfiles/Project_Default.xml b/Mini_RAG/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000000000000000000000000000000000000..9680032c8f34359f630a6d6d887cc6f1c4d5d9c3
--- /dev/null
+++ b/Mini_RAG/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,19 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Mini_RAG/.idea/inspectionProfiles/profiles_settings.xml b/Mini_RAG/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/Mini_RAG/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Mini_RAG/.idea/misc.xml b/Mini_RAG/.idea/misc.xml
new file mode 100644
index 0000000000000000000000000000000000000000..437b754c07ba5e4ee060d3231c0520ff6c3c8ebe
--- /dev/null
+++ b/Mini_RAG/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Mini_RAG/.idea/modules.xml b/Mini_RAG/.idea/modules.xml
new file mode 100644
index 0000000000000000000000000000000000000000..d0ece5d0e26d05bc29e01f2286873d2e9c7a6707
--- /dev/null
+++ b/Mini_RAG/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Mini_RAG/.idea/vcs.xml b/Mini_RAG/.idea/vcs.xml
new file mode 100644
index 0000000000000000000000000000000000000000..9661ac713428efbad557d3ba3a62216b5bb7d226
--- /dev/null
+++ b/Mini_RAG/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Mini_RAG/.idea/workspace.xml b/Mini_RAG/.idea/workspace.xml
new file mode 100644
index 0000000000000000000000000000000000000000..18c0482fe1342a0a6752f8fd66652308d62ab26f
--- /dev/null
+++ b/Mini_RAG/.idea/workspace.xml
@@ -0,0 +1,236 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ "customColor": "",
+ "associatedIndex": 3
+}
+
+
+
+
+
+ {
+ "keyToString": {
+ "ModuleVcsDetector.initialDetectionPerformed": "true",
+ "Python.llm_stearm.executor": "Run",
+ "Python.main.executor": "Run",
+ "Python.rag_mini (1).executor": "Run",
+ "Python.test_qw.executor": "Run",
+ "RunOnceActivity.OpenProjectViewOnStart": "true",
+ "RunOnceActivity.ShowReadmeOnStart": "true",
+ "RunOnceActivity.TerminalTabsStorage.copyFrom.TerminalArrangementManager.252": "true",
+ "RunOnceActivity.git.unshallow": "true",
+ "WebServerToolWindowFactoryState": "false",
+ "git-widget-placeholder": "main",
+ "ignore.virus.scanning.warn.message": "true",
+ "last_opened_file_path": "D:/LLms/Tiny-RAG",
+ "node.js.detected.package.eslint": "true",
+ "node.js.detected.package.tslint": "true",
+ "node.js.selected.package.eslint": "(autodetect)",
+ "node.js.selected.package.tslint": "(autodetect)",
+ "nodejs_package_manager_path": "npm",
+ "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable",
+ "vue.rearranger.settings.migration": "true"
+ }
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1756273884601
+
+
+ 1756273884601
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1756863588308
+
+
+
+ 1756863588308
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Mini_RAG/LICENSE b/Mini_RAG/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..5eea95b8154bf1a59f0bf164075a1ab20d726899
--- /dev/null
+++ b/Mini_RAG/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 TuNan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/Mini_RAG/README.md b/Mini_RAG/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f47dbce21502cc0ac57a7b25aef7c87b1b7146c1
--- /dev/null
+++ b/Mini_RAG/README.md
@@ -0,0 +1,185 @@
+# RAG_Mini
+---
+
+# Enterprise-Ready RAG System with Gradio Interface
+
+This is a powerful, enterprise-grade Retrieval-Augmented Generation (RAG) system designed to transform your documents into an interactive and intelligent knowledge base. Users can upload their own documents (PDFs, TXT files), build a searchable vector index, and ask complex questions in natural language to receive accurate, context-aware answers sourced directly from the provided materials.
+
+The entire application is wrapped in a clean, user-friendly web interface powered by Gradio.
+
+
+
+
+## ✨ Features
+
+- **Intuitive Web UI**: Simple, clean interface built with Gradio for uploading documents and chatting.
+- **Multi-Document Support**: Natively handles PDF and TXT files.
+- **Advanced Text Splitting**: Uses a `HierarchicalSemanticSplitter` that first splits documents into large parent chunks (for context) and then into smaller child chunks (for precise search), respecting semantic boundaries.
+- **Hybrid Search**: Combines the strengths of dense vector search (FAISS) and sparse keyword search (BM25) for robust and accurate retrieval.
+- **Reranking for Accuracy**: Employs a Cross-Encoder model to rerank the retrieved documents, ensuring the most relevant context is passed to the language model.
+- **Persistent Knowledge Base**: Automatically saves the built vector index and metadata, allowing you to load an existing knowledge base instantly on startup.
+- **Modular & Extensible Codebase**: The project is logically structured into services for loading, splitting, embedding, and generation, making it easy to maintain and extend.
+
+## 🏛️ System Architecture
+
+The RAG pipeline follows a logical, multi-step process to ensure high-quality answers:
+
+1. **Load**: Documents are loaded from various formats and parsed into a standardized `Document` object, preserving metadata like source and page number.
+2. **Split**: The raw text is processed by the `HierarchicalSemanticSplitter`, creating parent and child text chunks. This provides both broad context and fine-grained detail.
+3. **Embed & Index**: The child chunks are converted into vector embeddings using a `SentenceTransformer` model and indexed in a FAISS vector store. A parallel BM25 index is also built for keyword search.
+4. **Retrieve**: When a user asks a question, a hybrid search query is performed against the FAISS and BM25 indices to retrieve the most relevant child chunks.
+5. **Fetch Context**: The parent chunks corresponding to the retrieved child chunks are fetched. This ensures the LLM receives a wider, more complete context.
+6. **Rerank**: A powerful Cross-Encoder model re-evaluates the relevance of the parent chunks against the query, pushing the best matches to the top.
+7. **Generate**: The top-ranked, reranked documents are combined with the user's query into a final prompt. This prompt is sent to a Large Language Model (LLM) to generate a final, coherent answer.
+
+```
+[User Uploads Docs] -> [Loader] -> [Splitter] -> [Embedder & Vector Store] -> [Knowledge Base Saved]
+
+[User Asks Question] -> [Hybrid Search] -> [Get Parent Docs] -> [Reranker] -> [LLM] -> [Answer & Sources]
+```
+
+## 🛠️ Tech Stack
+
+- **Backend**: Python 3.9+
+- **UI**: Gradio
+- **LLM & Embedding Framework**: Hugging Face Transformers, Sentence-Transformers
+- **Vector Search**: Faiss (from Facebook AI)
+- **Keyword Search**: rank-bm25
+- **PDF Parsing**: PyMuPDF (fitz)
+- **Configuration**: PyYAML
+
+## 🚀 Getting Started
+
+Follow these steps to set up and run the project on your local machine.
+
+### 1. Prerequisites
+
+- Python 3.9 or higher
+- `pip` for package management
+
+### 2. Create a `requirements.txt` file
+
+Before proceeding, it's crucial to have a `requirements.txt` file so others can easily install the necessary dependencies. In your activated terminal, run:
+
+```bash
+pip freeze > requirements.txt
+```
+This will save all the packages from your environment into the file. Make sure this file is committed to your GitHub repository. The key packages it should contain are: `gradio`, `torch`, `transformers`, `sentence-transformers`, `faiss-cpu`, `rank_bm25`, `PyMuPDF`, `pyyaml`, `numpy`.
+
+### 3. Installation & Setup
+
+**1. Clone the repository:**
+```bash
+git clone https://github.com/YOUR_USERNAME/YOUR_REPOSITORY_NAME.git
+cd YOUR_REPOSITORY_NAME
+```
+
+**2. Create and activate a virtual environment (recommended):**
+```bash
+# For Windows
+python -m venv venv
+.\venv\Scripts\activate
+
+# For macOS/Linux
+python3 -m venv venv
+source venv/bin/activate
+```
+
+**3. Install the required packages:**
+```bash
+pip install -r requirements.txt
+```
+
+**4. Configure the system:**
+Review the `configs/config.yaml` file. You can change the models, chunk sizes, and other parameters here. The default settings are a good starting point.
+
+> **Note:** The first time you run the application, the models specified in the config file will be downloaded from Hugging Face. This may take some time depending on your internet connection.
+
+### 4. Running the Application
+
+To start the Gradio web server, run the `main.py` script:
+
+```bash
+python main.py
+```
+
+The application will be available at **`http://localhost:7860`**.
+
+## 📖 How to Use
+
+The application has two primary workflows:
+
+**1. Build a New Knowledge Base:**
+ - Drag and drop one or more `.pdf` or `.txt` files into the "Upload New Docs to Build" area.
+ - Click the **"Build New KB"** button.
+ - The system status will show the progress (Loading -> Splitting -> Indexing).
+ - Once complete, the status will confirm that the knowledge base is ready, and the chat window will appear.
+
+**2. Load an Existing Knowledge Base:**
+ - If you have previously built a knowledge base, simply click the **"Load Existing KB"** button.
+ - The system will load the saved FAISS index and metadata from the `storage` directory.
+ - The chat window will appear, and you can start asking questions immediately.
+
+**Chatting with Your Documents:**
+ - Once the knowledge base is ready, type your question into the chat box at the bottom and press Enter or click "Submit".
+ - The model will generate an answer based on the documents you provided.
+ - The sources used to generate the answer will be displayed below the chat window.
+
+## 📂 Project Structure
+
+```
+.
+├── configs/
+│ └── config.yaml # Main configuration file for models, paths, etc.
+├── core/
+│ ├── embedder.py # Handles text embedding.
+│ ├── llm_interface.py # Handles reranking and answer generation.
+│ ├── loader.py # Loads and parses documents.
+│ ├── schema.py # Defines data structures (Document, Chunk).
+│ ├── splitter.py # Splits documents into chunks.
+│ └── vector_store.py # Manages FAISS & BM25 indices.
+├── service/
+│ └── rag_service.py # Orchestrates the entire RAG pipeline.
+├── storage/ # Default location for saved indices (auto-generated).
+│ └── ...
+├── ui/
+│ └── app.py # Contains the Gradio UI logic.
+├── utils/
+│ └── logger.py # Logging configuration.
+├── assets/
+│ └── 1.png # Screenshot of the application.
+├── main.py # Entry point to run the application.
+└── requirements.txt # Python package dependencies.
+```
+
+## 🔧 Configuration Details (`config.yaml`)
+
+You can customize the RAG pipeline by modifying `configs/config.yaml`:
+
+- **`models`**: Specify the Hugging Face models for embedding, reranking, and generation.
+- **`vector_store`**: Define the paths where the FAISS index and metadata will be saved.
+- **`splitter`**: Control the `HierarchicalSemanticSplitter` behavior.
+ - `parent_chunk_size`: The target size for larger context chunks.
+ - `parent_chunk_overlap`: The overlap between parent chunks.
+ - `child_chunk_size`: The target size for smaller, searchable chunks.
+- **`retrieval`**: Tune the retrieval and reranking process.
+ - `retrieval_top_k`: How many initial candidates to retrieve with hybrid search.
+ - `rerank_top_k`: How many final documents to pass to the LLM after reranking.
+ - `hybrid_search_alpha`: The weighting between vector search (`alpha`) and BM25 search (`1 - alpha`). `1.0` is pure vector search, `0.0` is pure keyword search.
+- **`generation`**: Set parameters for the final answer generation, like `max_new_tokens`.
+
+## 🛣️ Future Roadmap
+
+- [ ] Support for more document types (e.g., `.docx`, `.pptx`, `.html`).
+- [ ] Implement response streaming for a more interactive chat experience.
+- [ ] Integrate with other vector databases like ChromaDB or Pinecone.
+- [ ] Create API endpoints for programmatic access to the RAG service.
+- [ ] Add more advanced logging and monitoring for enterprise use.
+
+## 🤝 Contributing
+
+Contributions are welcome! If you have ideas for improvements or find a bug, please feel free to open an issue or submit a pull request.
+
+## 📄 License
+
+This project is licensed under the MIT License. See the `LICENSE` file for details.
diff --git a/Mini_RAG/app.py b/Mini_RAG/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..4720bf8d24ffaf91dda540a08e105848a3b24a27
--- /dev/null
+++ b/Mini_RAG/app.py
@@ -0,0 +1,19 @@
+import yaml
+from service.rag_service import RAGService
+from ui.app import GradioApp
+from utils.logger import setup_logger
+
+
+def main():
+ setup_logger()
+
+ with open('configs/config.yaml', 'r', encoding='utf-8') as f:
+ config = yaml.safe_load(f)
+
+ rag_service = RAGService(config)
+ app = GradioApp(rag_service)
+ app.launch()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/Mini_RAG/assets/1.png b/Mini_RAG/assets/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..def6f1b1f4f00ea6b493b7cdece6a228480c5934
--- /dev/null
+++ b/Mini_RAG/assets/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3cfd5a1e1f78574b7ad19ba82bc14aefedc30e198293a36b34fddbe58ae9572a
+size 117169
diff --git a/Mini_RAG/assets/2.png b/Mini_RAG/assets/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..04960ccb2184a5a930b0bd4be061c294a142b68e
--- /dev/null
+++ b/Mini_RAG/assets/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:488e38ddb8e075942a427f604b290abc27121023b44ceab74b2447ccd5d39ecd
+size 144857
diff --git a/Mini_RAG/configs/config.yaml b/Mini_RAG/configs/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..249028c3740d4f6e40b7e12d273319203640579d
--- /dev/null
+++ b/Mini_RAG/configs/config.yaml
@@ -0,0 +1,26 @@
+# 路径配置
+storage_path: "./storage"
+vector_store:
+ index_path: "./storage/faiss_index/index.faiss"
+ metadata_path: "./storage/faiss_index/chunks.pkl"
+
+# 切分器配置
+splitter:
+ parent_chunk_size: 800
+ parent_chunk_overlap: 100
+ child_chunk_size: 250
+
+# 模型配置
+models:
+ embedding: "moka-ai/m3e-base"
+ reranker: "BAAI/bge-reranker-base"
+ llm_generator: "Qwen/Qwen3-0.6B" # Qwen/Qwen1.5-1.8B-Chat or Qwen/Qwen3-0.6B
+
+
+# 检索与生成参数
+retrieval:
+ hybrid_search_alpha: 0.5 # 混合检索中向量权(0-1), 关键词权为1-alpha
+ retrieval_top_k: 20
+ rerank_top_k: 5
+generation:
+ max_new_tokens: 512
diff --git a/Mini_RAG/core/__init__.py b/Mini_RAG/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3f5a12faa99758192ecc4ed3fc22c9249232e86
--- /dev/null
+++ b/Mini_RAG/core/__init__.py
@@ -0,0 +1 @@
+
diff --git a/Mini_RAG/core/__pycache__/__init__.cpython-39.pyc b/Mini_RAG/core/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..490fc55d5de07f75e7b79a8526715e75207b5e33
Binary files /dev/null and b/Mini_RAG/core/__pycache__/__init__.cpython-39.pyc differ
diff --git a/Mini_RAG/core/__pycache__/embedder.cpython-39.pyc b/Mini_RAG/core/__pycache__/embedder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b33e5057272a15b09b1bf751a6082bfd8a72e68
Binary files /dev/null and b/Mini_RAG/core/__pycache__/embedder.cpython-39.pyc differ
diff --git a/Mini_RAG/core/__pycache__/llm_interface.cpython-39.pyc b/Mini_RAG/core/__pycache__/llm_interface.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17317a98ba48afdddb28d1cdcf421f61843a88ed
Binary files /dev/null and b/Mini_RAG/core/__pycache__/llm_interface.cpython-39.pyc differ
diff --git a/Mini_RAG/core/__pycache__/loader.cpython-39.pyc b/Mini_RAG/core/__pycache__/loader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e832df0a9e3abbeb0443ef56884f79577401fee
Binary files /dev/null and b/Mini_RAG/core/__pycache__/loader.cpython-39.pyc differ
diff --git a/Mini_RAG/core/__pycache__/schema.cpython-39.pyc b/Mini_RAG/core/__pycache__/schema.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d379d281390f817f93027d9084d701b0040cdd5a
Binary files /dev/null and b/Mini_RAG/core/__pycache__/schema.cpython-39.pyc differ
diff --git a/Mini_RAG/core/__pycache__/splitter.cpython-39.pyc b/Mini_RAG/core/__pycache__/splitter.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7bf2fd07fd733b59cd01dc4b4785fd01d31c20f
Binary files /dev/null and b/Mini_RAG/core/__pycache__/splitter.cpython-39.pyc differ
diff --git a/Mini_RAG/core/__pycache__/vector_store.cpython-39.pyc b/Mini_RAG/core/__pycache__/vector_store.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01876546ce639ba8b2e37b0bf6b5a77b1bf79bf7
Binary files /dev/null and b/Mini_RAG/core/__pycache__/vector_store.cpython-39.pyc differ
diff --git a/Mini_RAG/core/embedder.py b/Mini_RAG/core/embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d85c5cef0adb20c2de3013eb0290f62252e02f0
--- /dev/null
+++ b/Mini_RAG/core/embedder.py
@@ -0,0 +1,19 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/24 11:50
+# @Author : hukangzhe
+# @File : embedder.py
+# @Description :
+
+from sentence_transformers import SentenceTransformer
+from typing import List
+import numpy as np
+
+
+class EmbeddingModel:
+ def __init__(self, model_name: str):
+ self.embedding_model = SentenceTransformer(model_name)
+
+ def embed(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
+ return self.embedding_model.encode(texts, batch_size=batch_size, convert_to_numpy=True)
+
diff --git a/Mini_RAG/core/llm_interface.py b/Mini_RAG/core/llm_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..bff2ea38acbc6524c9a8e2844464abaea3ba23a5
--- /dev/null
+++ b/Mini_RAG/core/llm_interface.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/29 19:54
+# @Author : hukangzhe
+# @File : generator.py
+# @Description : 负责生成答案模块
+import os
+import queue
+import logging
+import threading
+
+import torch
+from typing import Dict, List, Tuple, Generator
+from sentence_transformers import CrossEncoder
+from .schema import Document, Chunk
+from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer
+
+class ThinkStreamer(TextStreamer):
+ def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool =True, **decode_kwargs):
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
+ self.is_thinking = True
+ self.think_end_token_id = self.tokenizer.encode("", add_special_tokens=False)[0]
+ self.output_queue = queue.Queue()
+
+ def on_finalized_text(self, text: str, stream_end: bool = False):
+ self.output_queue.put(text)
+ if stream_end:
+ self.output_queue.put(None) # 发送结束信号
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ value = self.output_queue.get()
+ if value is None:
+ raise StopIteration()
+ return value
+
+ def generate_output(self) -> Generator[Tuple[str, str], None, None]:
+ """
+ 分离Think和回答
+ :return:
+ """
+ full_decode_text = ""
+ already_yielded_len = 0
+ for text_chunk in self:
+ if not self.is_thinking:
+ yield "answer", text_chunk
+ continue
+
+ full_decode_text += text_chunk
+ tokens = self.tokenizer.encode(full_decode_text, add_special_tokens=False)
+
+ if self.think_end_token_id in tokens:
+ spilt_point = tokens.index(self.think_end_token_id)
+ think_part_tokens = tokens[:spilt_point]
+ thinking_text = self.tokenizer.decode(think_part_tokens)
+
+ answer_part_tokens = tokens[spilt_point:]
+ answer_text = self.tokenizer.decode(answer_part_tokens)
+ remaining_thinking = thinking_text[already_yielded_len:]
+ if remaining_thinking:
+ yield "thinking", remaining_thinking
+
+ if answer_text:
+ yield "answer", answer_text
+
+ self.is_thinking = False
+ already_yielded_len = len(thinking_text) + len(self.tokenizer.decode(self.think_end_token_id))
+ else:
+ yield "thinking", text_chunk
+ already_yielded_len += len(text_chunk)
+
+
+class QueueTextStreamer(TextStreamer):
+ def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool = True, **decode_kwargs):
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
+ self.output_queue = queue.Queue()
+
+ def on_finalized_text(self, text: str, stream_end: bool = False):
+ """Puts text into the queue; sends None as a sentinel value to signal the end."""
+ self.output_queue.put(text)
+ if stream_end:
+ self.output_queue.put(None)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ value = self.output_queue.get()
+ if value is None:
+ raise StopIteration()
+ return value
+
+
+class LLMInterface:
+ def __init__(self, config: dict):
+ self.config = config
+ self.reranker = CrossEncoder(config['models']['reranker'])
+ self.generator_new_tokens = config['generation']['max_new_tokens']
+ self.device =torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ generator_name = config['models']['llm_generator']
+ logging.info(f"Initializing generator {generator_name}")
+ self.generator_tokenizer = AutoTokenizer.from_pretrained(generator_name)
+ self.generator_model = AutoModelForCausalLM.from_pretrained(
+ generator_name,
+ torch_dtype="auto",
+ device_map="auto")
+
+ def rerank(self, query: str, docs: List[Document]) -> List[Document]:
+ pairs = [[query, doc.text] for doc in docs]
+ scores = self.reranker.predict(pairs)
+ ranked_docs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
+ return [doc for doc, score in ranked_docs]
+
+ def _threaded_generate(self, streamer: QueueTextStreamer, generation_kwargs: dict):
+ """
+ 一个包装函数,将 model.generate 放入 try...finally 块中。
+ """
+ try:
+ self.generator_model.generate(**generation_kwargs)
+ finally:
+ # 无论成功还是失败,都确保在最后发送结束信号
+ streamer.output_queue.put(None)
+
+ def generate_answer(self, query: str, context_docs: List[Document]) -> str:
+ context_str = ""
+ for doc in context_docs:
+ context_str += f"Source: {os.path.basename(doc.metadata.get('source', ''))}, Page: {doc.metadata.get('page', 'N/A')}\n"
+ context_str += f"Content: {doc.text}\n\n"
+ # content设置为英文,回答则为英文
+ messages = [
+ {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
+ {"role": "user", "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题:{query}"}
+ ]
+ prompt = self.generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(self.device)
+ output = self.generator_model.generate(**inputs,
+ max_new_tokens=self.generator_new_tokens, num_return_sequences=1,
+ eos_token_id=self.generator_tokenizer.eos_token_id)
+ generated_ids = output[0][inputs["input_ids"].shape[1]:]
+ answer = self.generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
+ return answer
+
+ def generate_answer_stream(self, query: str, context_docs: List[Document]) -> Generator[str, None, None]:
+ context_str = ""
+ for doc in context_docs:
+ context_str += f"Content: {doc.text}\n\n"
+
+ messages = [
+ {"role": "system",
+ "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
+ {"role": "user",
+ "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题: {query}"}
+ ]
+
+ prompt = self.generator_tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device)
+
+ streamer = QueueTextStreamer(self.generator_tokenizer, skip_prompt=True)
+
+ generation_kwargs = dict(
+ **model_inputs,
+ max_new_tokens=self.generator_new_tokens,
+ streamer=streamer,
+ pad_token_id=self.generator_tokenizer.eos_token_id,
+ )
+
+ thread = threading.Thread(target=self._threaded_generate, args=(streamer,generation_kwargs,))
+ thread.start()
+ for new_text in streamer:
+ if new_text is not None:
+ yield new_text
+
+ def generate_answer_stream_split(self, query: str, context_docs: List[Document]) -> Generator[Tuple[str, str], None, None]:
+ """分离思考和回答的流式输出"""
+ context_str = ""
+ for doc in context_docs:
+ context_str += f"Content: {doc.text}\n\n"
+
+ messages = [
+ {"role": "system",
+ "content": "You are a helpful assistant. Please answer the question based on the provided context. First, think through the process in tags, then provide the final answer."},
+ {"role": "user",
+ "content": f"Context:\n---\n{context_str}\n---\nBased on the context above, please answer the question: {query}"}
+ ]
+
+ prompt = self.generator_tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True
+ )
+ model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device)
+
+ streamer = ThinkStreamer(self.generator_tokenizer, skip_prompt=True)
+
+ generation_kwargs = dict(
+ **model_inputs,
+ max_new_tokens=self.generator_new_tokens,
+ streamer=streamer
+ )
+
+ thread = threading.Thread(target=self.generator_model.generate, kwargs=generation_kwargs)
+ thread.start()
+
+ yield from streamer.generate_output()
+
+
+
+
diff --git a/Mini_RAG/core/loader.py b/Mini_RAG/core/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..169895436c9119c68856abfc9cafbc8055a6c078
--- /dev/null
+++ b/Mini_RAG/core/loader.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/24 19:51
+# @Author : hukangzhe
+# @File : loader.py
+# @Description : 文档加载模块,不在像reg_mini返回长字符串,而是返回一个document对象列表,每个document都带有源文件信息
+import logging
+from .schema import Document
+from typing import List
+import fitz
+import os
+
+
+class DocumentLoader:
+
+ @staticmethod
+ def _load_pdf(file_path):
+ logging.info(f"Loading PDF file from {file_path}")
+ try:
+ with fitz.open(file_path) as f:
+ text = "".join(page.get_text() for page in f)
+ logging.info(f"Successfully loaded {len(f)} pages.")
+ return text
+ except Exception as e:
+ logging.error(f"Failed to load PDF {file_path}: {e}")
+ return None
+
+
+class MultiDocumentLoader:
+ def __init__(self, paths: List[str]):
+ self.paths = paths
+
+ def load(self) -> List[Document]:
+ docs = []
+ for file_path in self.paths:
+ file_extension = os.path.splitext(file_path)[1].lower()
+ if file_extension == '.pdf':
+ docs.extend(self._load_pdf(file_path))
+ elif file_extension == '.txt':
+ docs.extend(self._load_txt(file_path))
+ else:
+ logging.warning(f"Unsupported file type:{file_extension}. Skipping {file_path}")
+
+ return docs
+
+ def _load_pdf(self, file_path: str) -> List[Document]:
+ logging.info(f"Loading PDF file from {file_path}")
+ try:
+ pdf_docs = []
+ with fitz.open(file_path) as doc:
+ for i, page in enumerate(doc):
+ pdf_docs.append(Document(
+ text=page.get_text(),
+ metadata={'source': file_path, 'page': i + 1}
+ ))
+ return pdf_docs
+ except Exception as e:
+ logging.error(f"Failed to load PDF {file_path}: {e}")
+ return []
+
+ def _load_txt(self, file_path: str) -> List[Document]:
+ logging.info(f"Loading txt file from {file_path}")
+ try:
+ txt_docs = []
+ with open(file_path, 'r', encoding="utf-8") as f:
+ txt_docs.append(Document(
+ text=f.read(),
+ metadata={'source': file_path}
+ ))
+ return txt_docs
+ except Exception as e:
+ logging.error(f"Failed to load txt {file_path}:{e}")
+ return []
diff --git a/Mini_RAG/core/schema.py b/Mini_RAG/core/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..647924b047ac0f8f8225a5d85ecb3191f499ef7d
--- /dev/null
+++ b/Mini_RAG/core/schema.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/24 21:11
+# @Author : hukangzhe
+# @File : schema.py
+# @Description : 不直接处理纯文本字符串. 引入一个标准化的数据结构来承载文本块,既有内容,也有元数据。
+
+from dataclasses import dataclass, field
+from typing import Dict, Any
+
+
+@dataclass
+class Document:
+ text: str
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class Chunk(Document):
+ parent_id: int = None
\ No newline at end of file
diff --git a/Mini_RAG/core/splitter.py b/Mini_RAG/core/splitter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8df02e6bd6fddbb1a8157211ba9983a10dfe63
--- /dev/null
+++ b/Mini_RAG/core/splitter.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/25 19:52
+# @Author : hukangzhe
+# @File : splitter.py
+# @Description : 负责切分文本的模块
+
+import logging
+from typing import List, Dict, Tuple
+from .schema import Document, Chunk
+
+
+class SemanticRecursiveSplitter:
+ def __init__(self, chunk_size: int=500, chunk_overlap:int = 50, separators: List[str] = None):
+ """
+ 一个真正实现递归的语义文本切分器。
+ :param chunk_size: 每个文本块的目标大小。
+ :param chunk_overlap: 文本块之间的重叠大小。
+ :param separators: 用于切分的语义分隔符列表,按优先级从高到低排列。
+ """
+ self.chunk_size = chunk_size
+ self.chunk_overlap = chunk_overlap
+ if self.chunk_size <= self.chunk_overlap:
+ raise ValueError("Chunk overlap must be smaller than chunk size.")
+
+ self.separators = separators if separators else ['\n\n', '\n', " ", ""] # 默认分割符
+
+ def text_split(self, text: str) -> List[str]:
+ """
+ 切分入口
+ :param text:
+ :return:
+ """
+ logging.info("Starting semantic recursive splitting...")
+ final_chunks = self._split(text, self.separators)
+ logging.info(f"Text successfully split into {len(final_chunks)} chunks.")
+ return final_chunks
+
+ def _split(self, text: str, separators: List[str]) -> List[str]:
+ final_chunks = []
+ # 1. 如果文本足够小,直接返回
+ if len(text) < self.chunk_size:
+ return [text]
+ # 2. 先尝试最高优先的分割符
+ cur_separator = separators[0]
+
+ # 3. 如果可以分割
+ if cur_separator in text:
+ # 分割成多个小部分
+ parts = text.split(cur_separator)
+
+ buffer="" # 用来合并小部分
+ for i, part in enumerate(parts):
+ # 如果小于chunk_size,就再加一小部分,使buffer接近chunk_size
+ if len(buffer) + len(part) + len(cur_separator) <= self.chunk_size:
+ buffer += part+cur_separator
+ else:
+ # 如果buffer 不为空
+ if buffer:
+ final_chunks.append(buffer)
+ # 如果当前part就已经超过chunk_size
+ if len(part) > self.chunk_size:
+ # 递归调用下一级
+ sub_chunks = self._split(part, separators = separators[1:])
+ final_chunks.extend(sub_chunks)
+ else: # 成为新的缓冲区
+ buffer = part + cur_separator
+
+ if buffer: # 最后一部分的缓冲区
+ final_chunks.append(buffer.strip())
+
+ else:
+ # 4. 使用下一级分隔符
+ final_chunks = self._split(text, separators[1:])
+
+ # 处理重叠
+ if self.chunk_overlap > 0:
+ return self._handle_overlap(final_chunks)
+ else:
+ return final_chunks
+
+ def _handle_overlap(self, final_chunks: List[str]) -> List[str]:
+ overlap_chunks = []
+ if not final_chunks:
+ return []
+ overlap_chunks.append(final_chunks[0])
+ for i in range(1, len(final_chunks)):
+ pre_chunk = overlap_chunks[-1]
+ cur_chunk = final_chunks[i]
+ # 从前一个chunk取出重叠部分与当前chunk合并
+ overlap_part = pre_chunk[-self.chunk_overlap:]
+ overlap_chunks.append(overlap_part+cur_chunk)
+
+ return overlap_chunks
+
+
+class HierarchicalSemanticSplitter:
+ """
+ 结合了层次化(父/子)和递归语义分割策略。确保在创建父块和子块时,遵循文本的自然语义边界。
+ """
+ def __init__(self,
+ parent_chunk_size: int = 800,
+ parent_chunk_overlap: int = 100,
+ child_chunk_size: int = 250,
+ separators: List[str] = None):
+ if parent_chunk_overlap >= parent_chunk_size:
+ raise ValueError("Parent chunk overlap must be smaller than parent chunk size.")
+ if child_chunk_size >= parent_chunk_size:
+ raise ValueError("Child chunk size must be smaller than parent chunk size.")
+
+ self.parent_chunk_size = parent_chunk_size
+ self.parent_chunk_overlap = parent_chunk_overlap
+ self.child_chunk_size = child_chunk_size
+ self.separators = separators or ["\n\n", "\n", "。", ". ", "!", "!", "?", "?", " ", ""]
+
+ def _recursive_semantic_split(self, text: str, chunk_size: int) -> List[str]:
+ """
+ 优先考虑语义边界
+ """
+ if len(text) <= chunk_size:
+ return [text]
+
+ for sep in self.separators:
+ split_point = text.rfind(sep, 0, chunk_size)
+ if split_point != -1:
+ break
+ else:
+ split_point = chunk_size
+
+ chunk1 = text[:split_point]
+ remaining_text = text[split_point:].lstrip() # 删除剩余部分的前空格
+
+ # 递归拆分剩余文本
+ # 分隔符将添加回第一个块以保持上下文
+ if remaining_text:
+ return [chunk1 + (sep if sep in " \n" else "")] + self._recursive_semantic_split(remaining_text, chunk_size)
+ else:
+ return [chunk1]
+
+ def _apply_overlap(self, chunks: List[str], overlap: int) -> List[str]:
+ """处理重叠部分chunk"""
+ if not overlap or len(chunks) <= 1:
+ return chunks
+
+ overlapped_chunks = [chunks[0]]
+ for i in range(1, len(chunks)):
+ # 从前一个chunk中获取最后的“重叠”字符
+ overlap_content = chunks[i - 1][-overlap:]
+ overlapped_chunks.append(overlap_content + chunks[i])
+
+ return overlapped_chunks
+
+ def split_documents(self, documents: List[Document]) -> Tuple[Dict[int, Document], List[Chunk]]:
+ """
+ 两次切分
+ :param documents:
+ :return:
+ - parent documents: {parent_id: Document}
+ - child chunks: [Chunk, Chunk, ...]
+ """
+ parent_docs_dict: Dict[int, Document] = {}
+ child_chunks_list: List[Chunk] = []
+ parent_id_counter = 0
+
+ logging.info("Starting robust hierarchical semantic splitting...")
+
+ for doc in documents:
+ # === PASS 1: 创建父chunks ===
+ # 1. 将整个文档text分割成大的语义chunks
+ initial_parent_chunks = self._recursive_semantic_split(doc.text, self.parent_chunk_size)
+
+ # 2. 父chunks进行重叠处理
+ overlapped_parent_texts = self._apply_overlap(initial_parent_chunks, self.parent_chunk_overlap)
+
+ for p_text in overlapped_parent_texts:
+ parent_doc = Document(text=p_text, metadata=doc.metadata.copy())
+ parent_docs_dict[parent_id_counter] = parent_doc
+
+ # === PASS 2: Create Child Chunks from each Parent ===
+ child_texts = self._recursive_semantic_split(p_text, self.child_chunk_size)
+
+ for c_text in child_texts:
+ child_metadata = doc.metadata.copy()
+ child_metadata['parent_id'] = parent_id_counter
+ child_chunk = Chunk(text=c_text, metadata=child_metadata, parent_id=parent_id_counter)
+ child_chunks_list.append(child_chunk)
+
+ parent_id_counter += 1
+
+ logging.info(
+ f"Splitting complete. Created {len(parent_docs_dict)} parent chunks and {len(child_chunks_list)} child chunks.")
+ return parent_docs_dict, child_chunks_list
diff --git a/Mini_RAG/core/test_qw.py b/Mini_RAG/core/test_qw.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac3f83f1b3c1366dee367520cf6c1129b9be3535
--- /dev/null
+++ b/Mini_RAG/core/test_qw.py
@@ -0,0 +1,133 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/29 20:03
+# @Author : hukangzhe
+# @File : test_qw.py
+# @Description : 测试两个模型(qwen1.5 qwen3)的两种输出方式(full or stream)是否正确
+import os
+import queue
+import logging
+import threading
+import torch
+from typing import Tuple, Generator
+from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer
+
+class ThinkStreamer(TextStreamer):
+ def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool =True, **decode_kwargs):
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
+ self.is_thinking = True
+ self.think_end_token_id = self.tokenizer.encode("", add_special_tokens=False)[0]
+ self.output_queue = queue.Queue()
+
+ def on_finalized_text(self, text: str, stream_end: bool = False):
+ self.output_queue.put(text)
+ if stream_end:
+ self.output_queue.put(None) # 发送结束信号
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ value = self.output_queue.get()
+ if value is None:
+ raise StopIteration()
+ return value
+
+ def generate_output(self) -> Generator[Tuple[str, str], None, None]:
+ full_decode_text = ""
+ already_yielded_len = 0
+ for text_chunk in self:
+ if not self.is_thinking:
+ yield "answer", text_chunk
+ continue
+
+ full_decode_text += text_chunk
+ tokens = self.tokenizer.encode(full_decode_text, add_special_tokens=False)
+
+ if self.think_end_token_id in tokens:
+ spilt_point = tokens.index(self.think_end_token_id)
+ think_part_tokens = tokens[:spilt_point]
+ thinking_text = self.tokenizer.decode(think_part_tokens)
+
+ answer_part_tokens = tokens[spilt_point:]
+ answer_text = self.tokenizer.decode(answer_part_tokens)
+ remaining_thinking = thinking_text[already_yielded_len:]
+ if remaining_thinking:
+ yield "thinking", remaining_thinking
+
+ if answer_text:
+ yield "answer", answer_text
+
+ self.is_thinking = False
+ already_yielded_len = len(thinking_text) + len(self.tokenizer.decode(self.think_end_token_id))
+ else:
+ yield "thinking", text_chunk
+ already_yielded_len += len(text_chunk)
+
+
+
+class LLMInterface:
+ def __init__(self, model_name: str= "Qwen/Qwen3-0.6B"):
+ logging.info(f"Initializing generator {model_name}")
+ self.generator_tokenizer = AutoTokenizer.from_pretrained(model_name)
+ self.generator_model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype="auto",
+ device_map="auto")
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ def generate_answer(self, query: str, context_str: str) -> str:
+ messages = [
+ {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
+ {"role": "user", "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题:{query}"}
+ ]
+ prompt = self.generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(self.device)
+ output = self.generator_model.generate(**inputs,
+ max_new_tokens=256, num_return_sequences=1,
+ eos_token_id=self.generator_tokenizer.eos_token_id)
+ generated_ids = output[0][inputs["input_ids"].shape[1]:]
+ answer = self.generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
+ return answer
+
+ def generate_answer_stream(self, query: str, context_str: str) -> Generator[Tuple[str, str], None, None]:
+ """Generates an answer as a stream of (state, content) tuples."""
+ messages = [
+ {"role": "system",
+ "content": "You are a helpful assistant. Please answer the question based on the provided context. First, think through the process in tags, then provide the final answer."},
+ {"role": "user",
+ "content": f"Context:\n---\n{context_str}\n---\nBased on the context above, please answer the question: {query}"}
+ ]
+
+ # Use the template that enables thinking for Qwen models
+ prompt = self.generator_tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True
+ )
+ model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device)
+
+ streamer = ThinkStreamer(self.generator_tokenizer, skip_prompt=True)
+
+ generation_kwargs = dict(
+ **model_inputs,
+ max_new_tokens=512,
+ streamer=streamer
+ )
+
+ thread = threading.Thread(target=self.generator_model.generate, kwargs=generation_kwargs)
+ thread.start()
+
+ yield from streamer.generate_output()
+
+
+# if __name__ == "__main__":
+# qwen = LLMInterface("Qwen/Qwen3-0.6B")
+# answer = qwen.generate_answer("儒家思想的创始人是谁?", "中国传统哲学以儒家、道家和法家为主要流派。儒家思想由孔子创立,强调“仁”、“义”、“礼”、“智”、“信”,主张修身齐家治国平天下,对中国社会产生了深远的影响。其核心价值观如“己所不欲,勿施于人”至今仍具有普世意义。"+
+#
+# "道家思想以老子和庄子为代表,主张“道法自然”,追求人与自然的和谐统一,强调无为而治、清静无为。道家思想对中国人的审美情趣、艺术创作以及养生之道都有着重要的影响。"+
+#
+# "法家思想以韩非子为集大成者,主张以法治国,强调君主的权威和法律的至高无上。尽管法家思想在历史上曾被用于强化中央集权,但其对建立健全的法律体系也提供了重要的理论基础。")
+#
+# print(answer)
diff --git a/Mini_RAG/core/vector_store.py b/Mini_RAG/core/vector_store.py
new file mode 100644
index 0000000000000000000000000000000000000000..6792e7151477341db43f19e148c5734d94e95bab
--- /dev/null
+++ b/Mini_RAG/core/vector_store.py
@@ -0,0 +1,158 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/27 19:52
+# @Author : hukangzhe
+# @File : retriever.py
+# @Description : 负责向量化、存储、检索的模块
+import os
+import faiss
+import numpy as np
+import pickle
+import logging
+from rank_bm25 import BM25Okapi
+from typing import List, Dict, Tuple
+from .schema import Document, Chunk
+
+
+class HybridVectorStore:
+ def __init__(self, config: dict, embedder):
+ self.config = config["vector_store"]
+ self.embedder = embedder
+ self.faiss_index = None
+ self.bm25_index = None
+ self.parent_docs: Dict[int, Document] = {}
+ self.child_chunks: List[Chunk] = []
+
+ def build(self, parent_docs: Dict[int, Document], child_chunks: List[Chunk]):
+ self.parent_docs = parent_docs
+ self.child_chunks = child_chunks
+
+ # Build Faiss index
+ child_text = [child.text for child in child_chunks]
+ embeddings = self.embedder.embed(child_text)
+ dim = embeddings.shape[1]
+ self.faiss_index = faiss.IndexFlatL2(dim)
+ self.faiss_index.add(embeddings)
+ logging.info(f"FAISS index built with {len(child_chunks)} vectors.")
+
+ # Build BM25 index
+ tokenize_chunks = [doc.text.split(" ") for doc in child_chunks]
+ self.bm25_index = BM25Okapi(tokenize_chunks)
+ logging.info(f"BM25 index built for {len(child_chunks)} documents.")
+
+ self.save()
+
+ def search(self, query: str, top_k: int , alpha: float) -> List[Tuple[int, float]]:
+ # Vector Search
+ query_embedding = self.embedder.embed([query])
+ distances, indices = self.faiss_index.search(query_embedding, k=top_k)
+ vector_scores = {idx : 1.0/(1.0 + dist) for idx, dist in zip(indices[0], distances[0])}
+
+ # BM25 Search
+ tokenize_query = query.split(" ")
+ bm25_scores = self.bm25_index.get_scores(tokenize_query)
+ bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
+ bm25_scores = {idx: bm25_scores[idx] for idx in bm25_top_indices}
+
+ # Hybrid Search
+ all_indices = set(vector_scores.keys()) | set(bm25_scores.keys()) # 求并集
+ hybrid_scors = {}
+
+ # Normalization
+ max_v_score = max(vector_scores.values()) if vector_scores else 1.0
+ max_b_score = max(bm25_scores.values()) if bm25_scores else 1.0
+ for idx in all_indices:
+ v_score = (vector_scores.get(idx, 0))/max_v_score
+ b_score = (bm25_scores.get(idx, 0))/max_b_score
+ hybrid_scors[idx] = alpha * v_score + (1 - alpha) * b_score
+
+ sorted_indices = sorted(hybrid_scors.items(), key=lambda item: item[1], reverse=True)[:top_k]
+ return sorted_indices
+
+ def get_chunks(self, indices: List[int]) -> List[Chunk]:
+ return [self.child_chunks[i] for i in indices]
+
+ def get_parent_docs(self, chunks: List[Chunk]) -> List[Document]:
+ parent_ids = sorted(list(set(chunk.parent_id for chunk in chunks)))
+ return [self.parent_docs[pid] for pid in parent_ids]
+
+ def save(self):
+ index_path = self.config['index_path']
+ metadata_path = self.config['metadata_path']
+
+ os.makedirs(os.path.dirname(index_path), exist_ok=True)
+ os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
+ logging.info(f"Saving FAISS index to: {index_path}")
+ try:
+ faiss.write_index(self.faiss_index, index_path)
+ except Exception as e:
+ logging.error(f"Failed to save FAISS index: {e}")
+ raise
+
+ logging.info(f"Saving metadata data to: {metadata_path}")
+ try:
+ with open(metadata_path, 'wb') as f:
+ metadata = {
+ 'parent_docs': self.parent_docs,
+ 'child_chunks': self.child_chunks,
+ 'bm25_index': self.bm25_index
+ }
+ pickle.dump(metadata, f)
+ except Exception as e:
+ logging.error(f"Failed to save metadata: {e}")
+ raise
+
+ logging.info("Vector store saved successfully.")
+
+ def load(self) -> bool:
+ """
+ 从磁盘加载整个向量存储状态,成功时返回 True,失败时返回 False。
+ """
+ index_path = self.config['index_path']
+ metadata_path = self.config['metadata_path']
+
+ if not os.path.exists(index_path) or not os.path.exists(metadata_path):
+ logging.warning("Index files not found. Cannot load vector store.")
+ return False
+
+ logging.info(f"Loading vector store from disk...")
+ try:
+ # Load FAISS index
+ logging.info(f"Loading FAISS index from: {index_path}")
+ self.faiss_index = faiss.read_index(index_path)
+
+ # Load metadata
+ logging.info(f"Loading metadata from: {metadata_path}")
+ with open(metadata_path, 'rb') as f:
+ metadata = pickle.load(f)
+ self.parent_docs = metadata['parent_docs']
+ self.child_chunks = metadata['child_chunks']
+ self.bm25_index = metadata['bm25_index']
+
+ logging.info("Vector store loaded successfully.")
+ return True
+
+ except Exception as e:
+ logging.error(f"Failed to load vector store from disk: {e}")
+ self.faiss_index = None
+ self.bm25_index = None
+ self.parent_docs = {}
+ self.child_chunks = []
+ return False
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/Mini_RAG/data/chinese_document.pdf b/Mini_RAG/data/chinese_document.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..cf3c783f36f435c8a29f5d77e52efceb04852205
--- /dev/null
+++ b/Mini_RAG/data/chinese_document.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bde4aba4245bb013438240cabf4c149d8aec76536f7a1459a165db02ec3695c
+size 317035
diff --git a/Mini_RAG/data/english_document.pdf b/Mini_RAG/data/english_document.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..214d2ea71708ac3cd25d9dc02c8a9213fdefe17c
--- /dev/null
+++ b/Mini_RAG/data/english_document.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5aa48e1b7d30d8cf2ac87cae6e2fee73e2de4d87178d0878d0d34b499fc7b26d
+size 166331
diff --git a/Mini_RAG/main.py b/Mini_RAG/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..032ece81fc60f6a0415af85f2359e0b1511ad874
--- /dev/null
+++ b/Mini_RAG/main.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/5/1 14:00
+# @Author : hukangzhe
+# @File : main.py
+# @Description :
+import yaml
+from service.rag_service import RAGService
+from ui.app import GradioApp
+from utils.logger import setup_logger
+
+
+def main():
+ setup_logger()
+
+ with open('configs/config.yaml', 'r', encoding='utf-8') as f:
+ config = yaml.safe_load(f)
+
+ rag_service = RAGService(config)
+ app = GradioApp(rag_service)
+ app.launch()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/Mini_RAG/rag_mini.py b/Mini_RAG/rag_mini.py
new file mode 100644
index 0000000000000000000000000000000000000000..813c892b4ba93a3bfd33089fc2baea497b6fab59
--- /dev/null
+++ b/Mini_RAG/rag_mini.py
@@ -0,0 +1,148 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/24 14:04
+# @Author : hukangzhe
+# @File : rag_core.py
+# @Description :非常简单的RAG系统
+
+import PyPDF2
+import fitz
+from sentence_transformers import SentenceTransformer, CrossEncoder
+from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
+import numpy as np
+import faiss
+import torch
+
+
+class RAGSystem:
+ def __init__(self, pdf_path):
+ self.pdf_path = pdf_path
+ self.texts = self._load_and_spilt_pdf()
+ self.embedder = SentenceTransformer('moka-ai/m3e-base')
+ self.reranker = CrossEncoder('BAAI/bge-reranker-base') # 加载一个reranker模型
+ self.vector_store = self._create_vector_store()
+ print("3. Initializing Generator Model...")
+ model_name = "Qwen/Qwen1.5-1.8B-Chat"
+
+ # 检查是否有可用的GPU
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ print(f" - Using device: {device}")
+
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+ # 注意:对于像Qwen这样的模型,我们通常使用 AutoModelForCausalLM
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
+
+ self.generator = pipeline(
+ 'text-generation',
+ model=model,
+ tokenizer=self.tokenizer,
+ device=device
+ )
+
+ # 1. 文档加载 & 2.文本切分 (为了简化,合在一起)
+ def _load_and_spilt_pdf(self):
+ print("1. Loading and splitting PDF...")
+ full_text = ""
+ with fitz.open(self.pdf_path) as doc:
+ for page in doc:
+ full_text += page.get_text()
+
+ # 非常基础的切分:根据固定大小
+ chunk_size = 500
+ overlap = 50
+ chunks = [full_text[i: i+chunk_size] for i in range(0, len(full_text), chunk_size-overlap)]
+ print(f" - Splitted into {len(chunks)} chunks.")
+ return chunks
+
+ # 3. 文本向量化 & 向量存储
+ def _create_vector_store(self):
+ print("2. Creating vector store...")
+ # embedding
+ embeddings = self.embedder.encode(self.texts)
+
+ # Storing with faiss
+ dim = embeddings.shape[1]
+ index = faiss.IndexFlatL2(dim) # 使用L2距离进行相似度计算
+ index.add(np.array(embeddings))
+ print(" - Created vector store")
+ return index
+
+ # 4.检索
+ def retrieve(self, query, k=3):
+ print(f"3. Retrieving top {k} relevant chunks for query: '{query}' ")
+ query_embeddings = self.embedder.encode([query])
+
+ distances, indices = self.vector_store.search(np.array(query_embeddings), k=k)
+ retrieved_chunks = [self.texts[i] for i in indices[0]]
+ print(" - Retrieval complete.")
+ return retrieved_chunks
+
+ # 5.生成
+ def generate(self, query, context_chunks):
+ print("4. Generate answer...")
+ context = "\n".join(context_chunks)
+
+ messages = [
+ {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
+ {"role": "user", "content": f"上下文:\n---\n{context}\n---\n请根据以上上下文回答这个问题:{query}"}
+ ]
+ prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
+ # print("Final Prompt:\n", prompt)
+ # print("Prompt token length:", len(self.tokenizer.encode(prompt)))
+
+ result = self.generator(prompt, max_new_tokens=200, num_return_sequences=1,
+ eos_token_id=self.tokenizer.eos_token_id)
+
+ print(" - Generation complete.")
+ # print("Raw results:", result)
+ # 提取生成的文本
+ # 注意:Qwen模型返回的文本包含了prompt,我们需要从中提取出答案部分
+ full_response = result[0]["generated_text"]
+ answer = full_response[len(prompt):].strip() # 从prompt之后开始截取
+
+ # print("Final Answer:", repr(answer))
+ return answer
+
+ # 优化1
+ def rerank(self, query, chunks):
+ print(" - Reranking retrieved chunks...")
+ pairs = [[query, chunk] for chunk in chunks]
+ scores = self.reranker.predict(pairs)
+
+ # 将chunks和scores打包,并按score降序排序
+ ranked_chunks = sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
+ return [chunk for chunk, score in ranked_chunks]
+
+ def query(self, query_text):
+ # 1. 检索(可以检索更多结果,如top 10)
+ retrieved_chunks = self.retrieve(query_text, k=10)
+
+ # 2. 重排(从10个中选出最相关的3个)
+ reranked_chunks = self.rerank(query_text, retrieved_chunks)
+ top_k_reranked = reranked_chunks[:3]
+
+ answer = self.generate(query_text, top_k_reranked)
+ return answer
+
+
+def main():
+ # 确保你的data文件夹里有一个叫做sample.pdf的文件
+ pdf_path = 'data/chinese_document.pdf'
+
+ print("Initializing RAG System...")
+ rag_system = RAGSystem(pdf_path)
+ print("\nRAG System is ready. You can start asking questions.")
+ print("Type 'q' to quit.")
+
+ while True:
+ user_query = input("\nYour Question: ")
+ if user_query.lower() == 'q':
+ break
+
+ answer = rag_system.query(user_query)
+ print("\nAnswer:", answer)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/Mini_RAG/requirements.txt b/Mini_RAG/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c5a1afd4a08714741949343784a3e878e0e69ec8
--- /dev/null
+++ b/Mini_RAG/requirements.txt
@@ -0,0 +1,9 @@
+pyyaml
+pypdf2
+PyMuPDF
+sentence-transformers
+faiss-cpu
+rank_bm25
+transformers
+torch
+gradio
\ No newline at end of file
diff --git a/Mini_RAG/service/__init__.py b/Mini_RAG/service/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Mini_RAG/service/__pycache__/__init__.cpython-39.pyc b/Mini_RAG/service/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e4b381495c7815b79f4444a2a3d2f5cad4027ae
Binary files /dev/null and b/Mini_RAG/service/__pycache__/__init__.cpython-39.pyc differ
diff --git a/Mini_RAG/service/__pycache__/rag_service.cpython-39.pyc b/Mini_RAG/service/__pycache__/rag_service.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f04ec204633814e2a126918295c729eb3e500c98
Binary files /dev/null and b/Mini_RAG/service/__pycache__/rag_service.cpython-39.pyc differ
diff --git a/Mini_RAG/service/rag_service.py b/Mini_RAG/service/rag_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7fb58099fc1372b269ae32c312e083f330adfe0
--- /dev/null
+++ b/Mini_RAG/service/rag_service.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/30 11:50
+# @Author : hukangzhe
+# @File : rag_service.py
+# @Description :
+import logging
+import os
+from typing import List, Generator, Tuple
+from core.schema import Document
+from core.embedder import EmbeddingModel
+from core.loader import MultiDocumentLoader
+from core.splitter import HierarchicalSemanticSplitter
+from core.vector_store import HybridVectorStore
+from core.llm_interface import LLMInterface
+
+
+class RAGService:
+ def __init__(self, config: dict):
+ self.config = config
+ logging.info("Initializing RAG Service...")
+ self.embedder = EmbeddingModel(config['models']['embedding'])
+ self.vector_store = HybridVectorStore(config, self.embedder)
+ self.llm = LLMInterface(config)
+ self.is_ready = False # 是否准备好进行查询
+ logging.info("RAG Service initialized. Knowledge base is not loaded.")
+
+ def load_knowledge_base(self) -> Tuple[bool, str]:
+ """
+ 尝试从磁盘加载
+ Returns:
+ A tuple (success: bool, message: str)
+ """
+ if self.is_ready:
+ return True, "Knowledge base is already loaded."
+
+ logging.info("Attempting to load knowledge base from disk...")
+ success = self.vector_store.load()
+ if success:
+ self.is_ready = True
+ message = "Knowledge base loaded successfully from disk."
+ logging.info(message)
+ return True, message
+ else:
+ self.is_ready = False
+ message = "No existing knowledge base found or failed to load. Please build a new one."
+ logging.warning(message)
+ return False, message
+
+ def build_knowledge_base(self, file_paths: List[str]) -> Generator[str, None, None]:
+ self.is_ready = False
+ yield "Step 1/3: Loading documents..."
+ loader = MultiDocumentLoader(file_paths)
+ docs = loader.load()
+
+ yield "Step 2/3: Splitting documents into hierarchical chunks..."
+ splitter = HierarchicalSemanticSplitter(
+ parent_chunk_size=self.config['splitter']['parent_chunk_size'],
+ parent_chunk_overlap=self.config['splitter']['parent_chunk_overlap'],
+ child_chunk_size=self.config['splitter']['child_chunk_size']
+ )
+ parent_docs, child_chunks = splitter.split_documents(docs)
+
+ yield "Step 3/3: Building and saving vector index..."
+ self.vector_store.build(parent_docs, child_chunks)
+ self.is_ready = True
+ yield "Knowledge base built and ready!"
+
+ def _get_context_and_sources(self, query: str) -> List[Document]:
+ if not self.is_ready:
+ raise Exception("Knowledge base is not ready. Please build it first.")
+
+ # Hybrid Search to get child chunks
+ retrieved_child_indices_scores = self.vector_store.search(
+ query,
+ top_k=self.config['retrieval']['retrieval_top_k'],
+ alpha=self.config['retrieval']['hybrid_search_alpha']
+ )
+ retrieved_child_indices = [idx for idx, score in retrieved_child_indices_scores]
+ retrieved_child_chunks = self.vector_store.get_chunks(retrieved_child_indices)
+
+ # Get Parent Documents
+ retrieved_parent_docs = self.vector_store.get_parent_docs(retrieved_child_chunks)
+
+ # Rerank Parent Documents
+ reranked_docs = self.llm.rerank(query, retrieved_parent_docs)
+ final_context_docs = reranked_docs[:self.config['retrieval']['rerank_top_k']]
+
+ return final_context_docs
+
+ def get_response_full(self, query: str) ->(str, List[Document]):
+ final_context_docs = self._get_context_and_sources(query)
+ answer = self.llm.generate_answer(query, final_context_docs)
+ return answer, final_context_docs
+
+ def get_response_stream(self, query: str) ->(Generator[str, None, None], List[Document]):
+ final_context_docs = self._get_context_and_sources(query)
+ answer_generator = self.llm.generate_answer_stream(query, final_context_docs)
+ return answer_generator, final_context_docs
+
+ def get_context_string(self, context_docs: List[Document]) -> str:
+ context_str = "引用上下文 (Context Sources):\n\n"
+ for doc in context_docs:
+ source_info = f"--- (来源: {os.path.basename(doc.metadata.get('source', ''))}, 页码: {doc.metadata.get('page', 'N/A')}) ---\n"
+ content = doc.text[:200]+"..." if len(doc.text) > 200 else doc.text
+ context_str += source_info + content + "\n\n"
+ return context_str.strip()
+
+
+
diff --git a/Mini_RAG/storage/faiss_index/chunks.pkl b/Mini_RAG/storage/faiss_index/chunks.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..a084a4cd7d4834a6e8929987bce62976f5b9101b
--- /dev/null
+++ b/Mini_RAG/storage/faiss_index/chunks.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:da7fae509a97a93f69b8048e375fca25dd45f0e26a40bff83c30759c5853e473
+size 27663
diff --git a/Mini_RAG/storage/faiss_index/index.faiss b/Mini_RAG/storage/faiss_index/index.faiss
new file mode 100644
index 0000000000000000000000000000000000000000..3c2e5a8e2cd09faee8ced8ee789d661f50e1a434
Binary files /dev/null and b/Mini_RAG/storage/faiss_index/index.faiss differ
diff --git a/Mini_RAG/ui/__init__.py b/Mini_RAG/ui/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Mini_RAG/ui/__pycache__/__init__.cpython-39.pyc b/Mini_RAG/ui/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15bb473a52d09cfd63c269b912e6a93bbf57995a
Binary files /dev/null and b/Mini_RAG/ui/__pycache__/__init__.cpython-39.pyc differ
diff --git a/Mini_RAG/ui/__pycache__/app.cpython-39.pyc b/Mini_RAG/ui/__pycache__/app.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ceb60199ef7e2d2327109894ee63fa5ae7a69b8f
Binary files /dev/null and b/Mini_RAG/ui/__pycache__/app.cpython-39.pyc differ
diff --git a/Mini_RAG/ui/app.py b/Mini_RAG/ui/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..5990d22ce4ec0ec2019beb57c8fe5a20555225bd
--- /dev/null
+++ b/Mini_RAG/ui/app.py
@@ -0,0 +1,202 @@
+import gradio as gr
+import os
+from typing import List, Tuple
+
+from service.rag_service import RAGService
+
+
+class GradioApp:
+ def __init__(self, rag_service: RAGService):
+ self.rag_service = rag_service
+ self._build_ui()
+
+ def _build_ui(self):
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"),
+ title="Enterprise RAG System") as self.demo:
+ gr.Markdown("# 企业级RAG智能问答系统 (Enterprise RAG System)")
+ gr.Markdown("您可以**加载现有知识库**快速开始,或**上传新文档**构建一个全新的知识库。")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown("### 控制面板 (Control Panel)")
+
+ self.load_kb_button = gr.Button("加载已有知识库 (Load Existing KB)")
+
+ gr.Markdown("
")
+
+ self.file_uploader = gr.File(
+ label="上传新文档以构建 (Upload New Docs to Build)",
+ file_count="multiple",
+ file_types=[".pdf", ".txt"],
+ interactive=True
+ )
+ self.build_kb_button = gr.Button("构建新知识库 (Build New KB)", variant="primary")
+
+ self.status_box = gr.Textbox(
+ label="系统状态 (System Status)",
+ value="系统已初始化,等待加载或构建知识库。",
+ interactive=False,
+ lines=4
+ )
+
+ # --- 刚开始隐藏,构建了数据库再显示 ---
+ with gr.Column(scale=2, visible=False) as self.chat_area:
+ gr.Markdown("### 对话窗口 (Chat Window)")
+ self.chatbot = gr.Chatbot(label="RAG Chatbot", bubble_full_width=False, height=500)
+ self.mode_selector = gr.Radio(
+ ["流式输出(Streaming)","一次性输出(Full)"],
+ label="输出模式:(Output Mode)",
+ value="流式输出(Streaming)"
+ )
+ self.question_box = gr.Textbox(label="您的问题", placeholder="请在此处输入您的问题...",
+ show_label=False)
+ with gr.Row():
+ self.submit_btn = gr.Button("提交 (Submit)", variant="primary")
+ self.clear_btn = gr.Button("清空历史 (Clear History)")
+
+ gr.Markdown("---")
+ self.source_display = gr.Markdown("### 引用来源 (Sources)")
+
+ # --- Event Listeners ---
+ self.load_kb_button.click(
+ fn=self._handle_load_kb,
+ inputs=None,
+ outputs=[self.status_box, self.chat_area]
+ )
+
+ self.build_kb_button.click(
+ fn=self._handle_build_kb,
+ inputs=[self.file_uploader],
+ outputs=[self.status_box, self.chat_area]
+ )
+
+ self.submit_btn.click(
+ fn=self._handle_chat_submission,
+ inputs=[self.question_box, self.chatbot, self.mode_selector],
+ outputs=[self.chatbot, self.question_box, self.source_display]
+ )
+
+ self.question_box.submit(
+ fn=self._handle_chat_submission,
+ inputs=[self.question_box, self.chatbot, self.mode_selector],
+ outputs=[self.chatbot, self.question_box, self.source_display]
+ )
+
+ self.clear_btn.click(
+ fn=self._clear_chat,
+ inputs=None,
+ outputs=[self.chatbot, self.question_box, self.source_display]
+ )
+
+ def _handle_load_kb(self):
+ """处理现有知识库的加载。返回更新字典。"""
+ success, message = self.rag_service.load_knowledge_base()
+ if success:
+ return {
+ self.status_box: gr.update(value=message),
+ self.chat_area: gr.update(visible=True)
+ }
+ else:
+ return {
+ self.status_box: gr.update(value=message),
+ self.chat_area: gr.update(visible=False)
+ }
+
+ def _handle_build_kb(self, files: List[str], progress=gr.Progress(track_tqdm=True)):
+ """构建新知识库,返回更新的字典."""
+ if not files:
+ # --- MODIFIED LINE ---
+ return {
+ self.status_box: gr.update(value="错误:请至少上传一个文档。"),
+ self.chat_area: gr.update(visible=False)
+ }
+
+ file_paths = [file.name for file in files]
+
+ try:
+ for status in self.rag_service.build_knowledge_base(file_paths):
+ progress(0.5, desc=status)
+
+ final_status = "知识库构建完成并已就绪!√"
+ # --- MODIFIED LINE ---
+ return {
+ self.status_box: gr.update(value=final_status),
+ self.chat_area: gr.update(visible=True)
+ }
+ except Exception as e:
+ error_message = f"构建失败: {e}"
+ # --- MODIFIED LINE ---
+ return {
+ self.status_box: gr.update(value=error_message),
+ self.chat_area: gr.update(visible=False)
+ }
+
+ def _handle_chat_submission(self, question: str, history: List[Tuple[str, str]], mode: str):
+ if not question or not question.strip():
+ yield history, "", "### 引用来源 (Sources)\n"
+ return
+
+ history.append((question, ""))
+
+ try:
+ # 一次全部输出
+ if "Full" in mode:
+ yield history, "", "### 引用来源 (Sources)\n"
+
+ answer, sources = self.rag_service.get_response_full(question)
+ # 获取引用内容
+ context_string_for_display = self.rag_service.get_context_string(sources)
+ # 修改格式
+ source_text_for_panel = self._format_sources(sources)
+ #完整内容:引用+回答
+ full_response = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n{answer}"
+ history[-1] = (question, full_response)
+
+ yield history, "", source_text_for_panel
+
+ # 流式输出
+ else:
+ answer_generator, sources = self.rag_service.get_response_stream(question)
+
+ context_string_for_display = self.rag_service.get_context_string(sources)
+
+ source_text_for_panel = self._format_sources(sources)
+
+ yield history, "", source_text_for_panel
+
+ response_prefix = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n"
+ history[-1] = (question, response_prefix)
+ yield history, "", source_text_for_panel
+
+ answer_log = ""
+ for text_chunk in answer_generator:
+ answer_log += text_chunk
+ history[-1] = (question, response_prefix + answer_log)
+ yield history, "", source_text_for_panel
+
+ except Exception as e:
+ error_response = f"处理请求时出错: {e}"
+ history[-1] = (question, error_response)
+ yield history, "", "### 引用来源 (Sources)\n"
+
+ def _format_sources(self, sources: List) -> str:
+ source_text = "### 引用来源 (sources)\n)"
+ if not sources:
+ return source_text
+
+ unique_sources = set()
+ for doc in sources:
+ source_name = os.path.basename(doc.metadata.get('source', 'Unknown'))
+ page_num = doc.metadata.get('page', 'N/A')
+ unique_sources.add(f"- **{source_name}** (Page: {page_num})")
+
+ source_text += "\n".join(sorted(list(unique_sources)))
+ return source_text
+
+ def _clear_chat(self):
+ """清理聊天内容"""
+ return None, "", "### 引用来源 (Sources)\n"
+
+ def launch(self):
+ self.demo.queue().launch()
+
diff --git a/Mini_RAG/utils/__init__.py b/Mini_RAG/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Mini_RAG/utils/__pycache__/__init__.cpython-39.pyc b/Mini_RAG/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ddccd3328327de77cb1351661e3593a951a5768
Binary files /dev/null and b/Mini_RAG/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/Mini_RAG/utils/__pycache__/logger.cpython-39.pyc b/Mini_RAG/utils/__pycache__/logger.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a39a45115fb2ce92fd3a19ffa7ec666683f1e0b
Binary files /dev/null and b/Mini_RAG/utils/__pycache__/logger.cpython-39.pyc differ
diff --git a/Mini_RAG/utils/logger.py b/Mini_RAG/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..975c1081b7df6e96b5cc94194dfbeddeffe15567
--- /dev/null
+++ b/Mini_RAG/utils/logger.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/25 19:51
+# @Author : hukangzhe
+# @File : logger.py.py
+# @Description : 日志配置模块
+import logging
+import sys
+
+
+def setup_logger():
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s",
+ handlers=[
+ logging.FileHandler("storage/logs/app.log"),
+ logging.StreamHandler(sys.stdout)
+ ]
+ )
+
diff --git a/README.md b/README.md
index fe64a7ec2784660c2eee38447c44c3060e4f4caf..f47dbce21502cc0ac57a7b25aef7c87b1b7146c1 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,185 @@
----
-title: Mini RAG
-emoji: 💻
-colorFrom: blue
-colorTo: yellow
-sdk: gradio
-sdk_version: 5.44.1
-app_file: app.py
-pinned: false
-short_description: 一个小的RAG
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# RAG_Mini
+---
+
+# Enterprise-Ready RAG System with Gradio Interface
+
+This is a powerful, enterprise-grade Retrieval-Augmented Generation (RAG) system designed to transform your documents into an interactive and intelligent knowledge base. Users can upload their own documents (PDFs, TXT files), build a searchable vector index, and ask complex questions in natural language to receive accurate, context-aware answers sourced directly from the provided materials.
+
+The entire application is wrapped in a clean, user-friendly web interface powered by Gradio.
+
+
+
+
+## ✨ Features
+
+- **Intuitive Web UI**: Simple, clean interface built with Gradio for uploading documents and chatting.
+- **Multi-Document Support**: Natively handles PDF and TXT files.
+- **Advanced Text Splitting**: Uses a `HierarchicalSemanticSplitter` that first splits documents into large parent chunks (for context) and then into smaller child chunks (for precise search), respecting semantic boundaries.
+- **Hybrid Search**: Combines the strengths of dense vector search (FAISS) and sparse keyword search (BM25) for robust and accurate retrieval.
+- **Reranking for Accuracy**: Employs a Cross-Encoder model to rerank the retrieved documents, ensuring the most relevant context is passed to the language model.
+- **Persistent Knowledge Base**: Automatically saves the built vector index and metadata, allowing you to load an existing knowledge base instantly on startup.
+- **Modular & Extensible Codebase**: The project is logically structured into services for loading, splitting, embedding, and generation, making it easy to maintain and extend.
+
+## 🏛️ System Architecture
+
+The RAG pipeline follows a logical, multi-step process to ensure high-quality answers:
+
+1. **Load**: Documents are loaded from various formats and parsed into a standardized `Document` object, preserving metadata like source and page number.
+2. **Split**: The raw text is processed by the `HierarchicalSemanticSplitter`, creating parent and child text chunks. This provides both broad context and fine-grained detail.
+3. **Embed & Index**: The child chunks are converted into vector embeddings using a `SentenceTransformer` model and indexed in a FAISS vector store. A parallel BM25 index is also built for keyword search.
+4. **Retrieve**: When a user asks a question, a hybrid search query is performed against the FAISS and BM25 indices to retrieve the most relevant child chunks.
+5. **Fetch Context**: The parent chunks corresponding to the retrieved child chunks are fetched. This ensures the LLM receives a wider, more complete context.
+6. **Rerank**: A powerful Cross-Encoder model re-evaluates the relevance of the parent chunks against the query, pushing the best matches to the top.
+7. **Generate**: The top-ranked, reranked documents are combined with the user's query into a final prompt. This prompt is sent to a Large Language Model (LLM) to generate a final, coherent answer.
+
+```
+[User Uploads Docs] -> [Loader] -> [Splitter] -> [Embedder & Vector Store] -> [Knowledge Base Saved]
+
+[User Asks Question] -> [Hybrid Search] -> [Get Parent Docs] -> [Reranker] -> [LLM] -> [Answer & Sources]
+```
+
+## 🛠️ Tech Stack
+
+- **Backend**: Python 3.9+
+- **UI**: Gradio
+- **LLM & Embedding Framework**: Hugging Face Transformers, Sentence-Transformers
+- **Vector Search**: Faiss (from Facebook AI)
+- **Keyword Search**: rank-bm25
+- **PDF Parsing**: PyMuPDF (fitz)
+- **Configuration**: PyYAML
+
+## 🚀 Getting Started
+
+Follow these steps to set up and run the project on your local machine.
+
+### 1. Prerequisites
+
+- Python 3.9 or higher
+- `pip` for package management
+
+### 2. Create a `requirements.txt` file
+
+Before proceeding, it's crucial to have a `requirements.txt` file so others can easily install the necessary dependencies. In your activated terminal, run:
+
+```bash
+pip freeze > requirements.txt
+```
+This will save all the packages from your environment into the file. Make sure this file is committed to your GitHub repository. The key packages it should contain are: `gradio`, `torch`, `transformers`, `sentence-transformers`, `faiss-cpu`, `rank_bm25`, `PyMuPDF`, `pyyaml`, `numpy`.
+
+### 3. Installation & Setup
+
+**1. Clone the repository:**
+```bash
+git clone https://github.com/YOUR_USERNAME/YOUR_REPOSITORY_NAME.git
+cd YOUR_REPOSITORY_NAME
+```
+
+**2. Create and activate a virtual environment (recommended):**
+```bash
+# For Windows
+python -m venv venv
+.\venv\Scripts\activate
+
+# For macOS/Linux
+python3 -m venv venv
+source venv/bin/activate
+```
+
+**3. Install the required packages:**
+```bash
+pip install -r requirements.txt
+```
+
+**4. Configure the system:**
+Review the `configs/config.yaml` file. You can change the models, chunk sizes, and other parameters here. The default settings are a good starting point.
+
+> **Note:** The first time you run the application, the models specified in the config file will be downloaded from Hugging Face. This may take some time depending on your internet connection.
+
+### 4. Running the Application
+
+To start the Gradio web server, run the `main.py` script:
+
+```bash
+python main.py
+```
+
+The application will be available at **`http://localhost:7860`**.
+
+## 📖 How to Use
+
+The application has two primary workflows:
+
+**1. Build a New Knowledge Base:**
+ - Drag and drop one or more `.pdf` or `.txt` files into the "Upload New Docs to Build" area.
+ - Click the **"Build New KB"** button.
+ - The system status will show the progress (Loading -> Splitting -> Indexing).
+ - Once complete, the status will confirm that the knowledge base is ready, and the chat window will appear.
+
+**2. Load an Existing Knowledge Base:**
+ - If you have previously built a knowledge base, simply click the **"Load Existing KB"** button.
+ - The system will load the saved FAISS index and metadata from the `storage` directory.
+ - The chat window will appear, and you can start asking questions immediately.
+
+**Chatting with Your Documents:**
+ - Once the knowledge base is ready, type your question into the chat box at the bottom and press Enter or click "Submit".
+ - The model will generate an answer based on the documents you provided.
+ - The sources used to generate the answer will be displayed below the chat window.
+
+## 📂 Project Structure
+
+```
+.
+├── configs/
+│ └── config.yaml # Main configuration file for models, paths, etc.
+├── core/
+│ ├── embedder.py # Handles text embedding.
+│ ├── llm_interface.py # Handles reranking and answer generation.
+│ ├── loader.py # Loads and parses documents.
+│ ├── schema.py # Defines data structures (Document, Chunk).
+│ ├── splitter.py # Splits documents into chunks.
+│ └── vector_store.py # Manages FAISS & BM25 indices.
+├── service/
+│ └── rag_service.py # Orchestrates the entire RAG pipeline.
+├── storage/ # Default location for saved indices (auto-generated).
+│ └── ...
+├── ui/
+│ └── app.py # Contains the Gradio UI logic.
+├── utils/
+│ └── logger.py # Logging configuration.
+├── assets/
+│ └── 1.png # Screenshot of the application.
+├── main.py # Entry point to run the application.
+└── requirements.txt # Python package dependencies.
+```
+
+## 🔧 Configuration Details (`config.yaml`)
+
+You can customize the RAG pipeline by modifying `configs/config.yaml`:
+
+- **`models`**: Specify the Hugging Face models for embedding, reranking, and generation.
+- **`vector_store`**: Define the paths where the FAISS index and metadata will be saved.
+- **`splitter`**: Control the `HierarchicalSemanticSplitter` behavior.
+ - `parent_chunk_size`: The target size for larger context chunks.
+ - `parent_chunk_overlap`: The overlap between parent chunks.
+ - `child_chunk_size`: The target size for smaller, searchable chunks.
+- **`retrieval`**: Tune the retrieval and reranking process.
+ - `retrieval_top_k`: How many initial candidates to retrieve with hybrid search.
+ - `rerank_top_k`: How many final documents to pass to the LLM after reranking.
+ - `hybrid_search_alpha`: The weighting between vector search (`alpha`) and BM25 search (`1 - alpha`). `1.0` is pure vector search, `0.0` is pure keyword search.
+- **`generation`**: Set parameters for the final answer generation, like `max_new_tokens`.
+
+## 🛣️ Future Roadmap
+
+- [ ] Support for more document types (e.g., `.docx`, `.pptx`, `.html`).
+- [ ] Implement response streaming for a more interactive chat experience.
+- [ ] Integrate with other vector databases like ChromaDB or Pinecone.
+- [ ] Create API endpoints for programmatic access to the RAG service.
+- [ ] Add more advanced logging and monitoring for enterprise use.
+
+## 🤝 Contributing
+
+Contributions are welcome! If you have ideas for improvements or find a bug, please feel free to open an issue or submit a pull request.
+
+## 📄 License
+
+This project is licensed under the MIT License. See the `LICENSE` file for details.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..1864745105bf97476bc002569275138b8e910375
--- /dev/null
+++ b/app.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/9/3 09:38
+# @Author : hukangzhe
+# @File : app.py
+# @Description :
+import yaml
+from service.rag_service import RAGService
+from ui.app import GradioApp
+from utils.logger import setup_logger
+
+
+def main():
+ setup_logger()
+
+ with open('configs/config.yaml', 'r', encoding='utf-8') as f:
+ config = yaml.safe_load(f)
+
+ rag_service = RAGService(config)
+ app = GradioApp(rag_service)
+ app.launch()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/assets/1.png b/assets/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..def6f1b1f4f00ea6b493b7cdece6a228480c5934
--- /dev/null
+++ b/assets/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3cfd5a1e1f78574b7ad19ba82bc14aefedc30e198293a36b34fddbe58ae9572a
+size 117169
diff --git a/assets/2.png b/assets/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..04960ccb2184a5a930b0bd4be061c294a142b68e
--- /dev/null
+++ b/assets/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:488e38ddb8e075942a427f604b290abc27121023b44ceab74b2447ccd5d39ecd
+size 144857
diff --git a/configs/config.yaml b/configs/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..249028c3740d4f6e40b7e12d273319203640579d
--- /dev/null
+++ b/configs/config.yaml
@@ -0,0 +1,26 @@
+# 路径配置
+storage_path: "./storage"
+vector_store:
+ index_path: "./storage/faiss_index/index.faiss"
+ metadata_path: "./storage/faiss_index/chunks.pkl"
+
+# 切分器配置
+splitter:
+ parent_chunk_size: 800
+ parent_chunk_overlap: 100
+ child_chunk_size: 250
+
+# 模型配置
+models:
+ embedding: "moka-ai/m3e-base"
+ reranker: "BAAI/bge-reranker-base"
+ llm_generator: "Qwen/Qwen3-0.6B" # Qwen/Qwen1.5-1.8B-Chat or Qwen/Qwen3-0.6B
+
+
+# 检索与生成参数
+retrieval:
+ hybrid_search_alpha: 0.5 # 混合检索中向量权(0-1), 关键词权为1-alpha
+ retrieval_top_k: 20
+ rerank_top_k: 5
+generation:
+ max_new_tokens: 512
diff --git a/core/__init__.py b/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3f5a12faa99758192ecc4ed3fc22c9249232e86
--- /dev/null
+++ b/core/__init__.py
@@ -0,0 +1 @@
+
diff --git a/core/__pycache__/__init__.cpython-39.pyc b/core/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..490fc55d5de07f75e7b79a8526715e75207b5e33
Binary files /dev/null and b/core/__pycache__/__init__.cpython-39.pyc differ
diff --git a/core/__pycache__/embedder.cpython-39.pyc b/core/__pycache__/embedder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b33e5057272a15b09b1bf751a6082bfd8a72e68
Binary files /dev/null and b/core/__pycache__/embedder.cpython-39.pyc differ
diff --git a/core/__pycache__/llm_interface.cpython-39.pyc b/core/__pycache__/llm_interface.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17317a98ba48afdddb28d1cdcf421f61843a88ed
Binary files /dev/null and b/core/__pycache__/llm_interface.cpython-39.pyc differ
diff --git a/core/__pycache__/loader.cpython-39.pyc b/core/__pycache__/loader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e832df0a9e3abbeb0443ef56884f79577401fee
Binary files /dev/null and b/core/__pycache__/loader.cpython-39.pyc differ
diff --git a/core/__pycache__/schema.cpython-39.pyc b/core/__pycache__/schema.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d379d281390f817f93027d9084d701b0040cdd5a
Binary files /dev/null and b/core/__pycache__/schema.cpython-39.pyc differ
diff --git a/core/__pycache__/splitter.cpython-39.pyc b/core/__pycache__/splitter.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7bf2fd07fd733b59cd01dc4b4785fd01d31c20f
Binary files /dev/null and b/core/__pycache__/splitter.cpython-39.pyc differ
diff --git a/core/__pycache__/vector_store.cpython-39.pyc b/core/__pycache__/vector_store.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01876546ce639ba8b2e37b0bf6b5a77b1bf79bf7
Binary files /dev/null and b/core/__pycache__/vector_store.cpython-39.pyc differ
diff --git a/core/embedder.py b/core/embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d85c5cef0adb20c2de3013eb0290f62252e02f0
--- /dev/null
+++ b/core/embedder.py
@@ -0,0 +1,19 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/24 11:50
+# @Author : hukangzhe
+# @File : embedder.py
+# @Description :
+
+from sentence_transformers import SentenceTransformer
+from typing import List
+import numpy as np
+
+
+class EmbeddingModel:
+ def __init__(self, model_name: str):
+ self.embedding_model = SentenceTransformer(model_name)
+
+ def embed(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
+ return self.embedding_model.encode(texts, batch_size=batch_size, convert_to_numpy=True)
+
diff --git a/core/llm_interface.py b/core/llm_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..bff2ea38acbc6524c9a8e2844464abaea3ba23a5
--- /dev/null
+++ b/core/llm_interface.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/29 19:54
+# @Author : hukangzhe
+# @File : generator.py
+# @Description : 负责生成答案模块
+import os
+import queue
+import logging
+import threading
+
+import torch
+from typing import Dict, List, Tuple, Generator
+from sentence_transformers import CrossEncoder
+from .schema import Document, Chunk
+from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer
+
+class ThinkStreamer(TextStreamer):
+ def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool =True, **decode_kwargs):
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
+ self.is_thinking = True
+ self.think_end_token_id = self.tokenizer.encode("", add_special_tokens=False)[0]
+ self.output_queue = queue.Queue()
+
+ def on_finalized_text(self, text: str, stream_end: bool = False):
+ self.output_queue.put(text)
+ if stream_end:
+ self.output_queue.put(None) # 发送结束信号
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ value = self.output_queue.get()
+ if value is None:
+ raise StopIteration()
+ return value
+
+ def generate_output(self) -> Generator[Tuple[str, str], None, None]:
+ """
+ 分离Think和回答
+ :return:
+ """
+ full_decode_text = ""
+ already_yielded_len = 0
+ for text_chunk in self:
+ if not self.is_thinking:
+ yield "answer", text_chunk
+ continue
+
+ full_decode_text += text_chunk
+ tokens = self.tokenizer.encode(full_decode_text, add_special_tokens=False)
+
+ if self.think_end_token_id in tokens:
+ spilt_point = tokens.index(self.think_end_token_id)
+ think_part_tokens = tokens[:spilt_point]
+ thinking_text = self.tokenizer.decode(think_part_tokens)
+
+ answer_part_tokens = tokens[spilt_point:]
+ answer_text = self.tokenizer.decode(answer_part_tokens)
+ remaining_thinking = thinking_text[already_yielded_len:]
+ if remaining_thinking:
+ yield "thinking", remaining_thinking
+
+ if answer_text:
+ yield "answer", answer_text
+
+ self.is_thinking = False
+ already_yielded_len = len(thinking_text) + len(self.tokenizer.decode(self.think_end_token_id))
+ else:
+ yield "thinking", text_chunk
+ already_yielded_len += len(text_chunk)
+
+
+class QueueTextStreamer(TextStreamer):
+ def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool = True, **decode_kwargs):
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
+ self.output_queue = queue.Queue()
+
+ def on_finalized_text(self, text: str, stream_end: bool = False):
+ """Puts text into the queue; sends None as a sentinel value to signal the end."""
+ self.output_queue.put(text)
+ if stream_end:
+ self.output_queue.put(None)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ value = self.output_queue.get()
+ if value is None:
+ raise StopIteration()
+ return value
+
+
+class LLMInterface:
+ def __init__(self, config: dict):
+ self.config = config
+ self.reranker = CrossEncoder(config['models']['reranker'])
+ self.generator_new_tokens = config['generation']['max_new_tokens']
+ self.device =torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ generator_name = config['models']['llm_generator']
+ logging.info(f"Initializing generator {generator_name}")
+ self.generator_tokenizer = AutoTokenizer.from_pretrained(generator_name)
+ self.generator_model = AutoModelForCausalLM.from_pretrained(
+ generator_name,
+ torch_dtype="auto",
+ device_map="auto")
+
+ def rerank(self, query: str, docs: List[Document]) -> List[Document]:
+ pairs = [[query, doc.text] for doc in docs]
+ scores = self.reranker.predict(pairs)
+ ranked_docs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
+ return [doc for doc, score in ranked_docs]
+
+ def _threaded_generate(self, streamer: QueueTextStreamer, generation_kwargs: dict):
+ """
+ 一个包装函数,将 model.generate 放入 try...finally 块中。
+ """
+ try:
+ self.generator_model.generate(**generation_kwargs)
+ finally:
+ # 无论成功还是失败,都确保在最后发送结束信号
+ streamer.output_queue.put(None)
+
+ def generate_answer(self, query: str, context_docs: List[Document]) -> str:
+ context_str = ""
+ for doc in context_docs:
+ context_str += f"Source: {os.path.basename(doc.metadata.get('source', ''))}, Page: {doc.metadata.get('page', 'N/A')}\n"
+ context_str += f"Content: {doc.text}\n\n"
+ # content设置为英文,回答则为英文
+ messages = [
+ {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
+ {"role": "user", "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题:{query}"}
+ ]
+ prompt = self.generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(self.device)
+ output = self.generator_model.generate(**inputs,
+ max_new_tokens=self.generator_new_tokens, num_return_sequences=1,
+ eos_token_id=self.generator_tokenizer.eos_token_id)
+ generated_ids = output[0][inputs["input_ids"].shape[1]:]
+ answer = self.generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
+ return answer
+
+ def generate_answer_stream(self, query: str, context_docs: List[Document]) -> Generator[str, None, None]:
+ context_str = ""
+ for doc in context_docs:
+ context_str += f"Content: {doc.text}\n\n"
+
+ messages = [
+ {"role": "system",
+ "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
+ {"role": "user",
+ "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题: {query}"}
+ ]
+
+ prompt = self.generator_tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device)
+
+ streamer = QueueTextStreamer(self.generator_tokenizer, skip_prompt=True)
+
+ generation_kwargs = dict(
+ **model_inputs,
+ max_new_tokens=self.generator_new_tokens,
+ streamer=streamer,
+ pad_token_id=self.generator_tokenizer.eos_token_id,
+ )
+
+ thread = threading.Thread(target=self._threaded_generate, args=(streamer,generation_kwargs,))
+ thread.start()
+ for new_text in streamer:
+ if new_text is not None:
+ yield new_text
+
+ def generate_answer_stream_split(self, query: str, context_docs: List[Document]) -> Generator[Tuple[str, str], None, None]:
+ """分离思考和回答的流式输出"""
+ context_str = ""
+ for doc in context_docs:
+ context_str += f"Content: {doc.text}\n\n"
+
+ messages = [
+ {"role": "system",
+ "content": "You are a helpful assistant. Please answer the question based on the provided context. First, think through the process in tags, then provide the final answer."},
+ {"role": "user",
+ "content": f"Context:\n---\n{context_str}\n---\nBased on the context above, please answer the question: {query}"}
+ ]
+
+ prompt = self.generator_tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True
+ )
+ model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device)
+
+ streamer = ThinkStreamer(self.generator_tokenizer, skip_prompt=True)
+
+ generation_kwargs = dict(
+ **model_inputs,
+ max_new_tokens=self.generator_new_tokens,
+ streamer=streamer
+ )
+
+ thread = threading.Thread(target=self.generator_model.generate, kwargs=generation_kwargs)
+ thread.start()
+
+ yield from streamer.generate_output()
+
+
+
+
diff --git a/core/loader.py b/core/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..169895436c9119c68856abfc9cafbc8055a6c078
--- /dev/null
+++ b/core/loader.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/24 19:51
+# @Author : hukangzhe
+# @File : loader.py
+# @Description : 文档加载模块,不在像reg_mini返回长字符串,而是返回一个document对象列表,每个document都带有源文件信息
+import logging
+from .schema import Document
+from typing import List
+import fitz
+import os
+
+
+class DocumentLoader:
+
+ @staticmethod
+ def _load_pdf(file_path):
+ logging.info(f"Loading PDF file from {file_path}")
+ try:
+ with fitz.open(file_path) as f:
+ text = "".join(page.get_text() for page in f)
+ logging.info(f"Successfully loaded {len(f)} pages.")
+ return text
+ except Exception as e:
+ logging.error(f"Failed to load PDF {file_path}: {e}")
+ return None
+
+
+class MultiDocumentLoader:
+ def __init__(self, paths: List[str]):
+ self.paths = paths
+
+ def load(self) -> List[Document]:
+ docs = []
+ for file_path in self.paths:
+ file_extension = os.path.splitext(file_path)[1].lower()
+ if file_extension == '.pdf':
+ docs.extend(self._load_pdf(file_path))
+ elif file_extension == '.txt':
+ docs.extend(self._load_txt(file_path))
+ else:
+ logging.warning(f"Unsupported file type:{file_extension}. Skipping {file_path}")
+
+ return docs
+
+ def _load_pdf(self, file_path: str) -> List[Document]:
+ logging.info(f"Loading PDF file from {file_path}")
+ try:
+ pdf_docs = []
+ with fitz.open(file_path) as doc:
+ for i, page in enumerate(doc):
+ pdf_docs.append(Document(
+ text=page.get_text(),
+ metadata={'source': file_path, 'page': i + 1}
+ ))
+ return pdf_docs
+ except Exception as e:
+ logging.error(f"Failed to load PDF {file_path}: {e}")
+ return []
+
+ def _load_txt(self, file_path: str) -> List[Document]:
+ logging.info(f"Loading txt file from {file_path}")
+ try:
+ txt_docs = []
+ with open(file_path, 'r', encoding="utf-8") as f:
+ txt_docs.append(Document(
+ text=f.read(),
+ metadata={'source': file_path}
+ ))
+ return txt_docs
+ except Exception as e:
+ logging.error(f"Failed to load txt {file_path}:{e}")
+ return []
diff --git a/core/schema.py b/core/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..647924b047ac0f8f8225a5d85ecb3191f499ef7d
--- /dev/null
+++ b/core/schema.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/24 21:11
+# @Author : hukangzhe
+# @File : schema.py
+# @Description : 不直接处理纯文本字符串. 引入一个标准化的数据结构来承载文本块,既有内容,也有元数据。
+
+from dataclasses import dataclass, field
+from typing import Dict, Any
+
+
+@dataclass
+class Document:
+ text: str
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class Chunk(Document):
+ parent_id: int = None
\ No newline at end of file
diff --git a/core/splitter.py b/core/splitter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8df02e6bd6fddbb1a8157211ba9983a10dfe63
--- /dev/null
+++ b/core/splitter.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/25 19:52
+# @Author : hukangzhe
+# @File : splitter.py
+# @Description : 负责切分文本的模块
+
+import logging
+from typing import List, Dict, Tuple
+from .schema import Document, Chunk
+
+
+class SemanticRecursiveSplitter:
+ def __init__(self, chunk_size: int=500, chunk_overlap:int = 50, separators: List[str] = None):
+ """
+ 一个真正实现递归的语义文本切分器。
+ :param chunk_size: 每个文本块的目标大小。
+ :param chunk_overlap: 文本块之间的重叠大小。
+ :param separators: 用于切分的语义分隔符列表,按优先级从高到低排列。
+ """
+ self.chunk_size = chunk_size
+ self.chunk_overlap = chunk_overlap
+ if self.chunk_size <= self.chunk_overlap:
+ raise ValueError("Chunk overlap must be smaller than chunk size.")
+
+ self.separators = separators if separators else ['\n\n', '\n', " ", ""] # 默认分割符
+
+ def text_split(self, text: str) -> List[str]:
+ """
+ 切分入口
+ :param text:
+ :return:
+ """
+ logging.info("Starting semantic recursive splitting...")
+ final_chunks = self._split(text, self.separators)
+ logging.info(f"Text successfully split into {len(final_chunks)} chunks.")
+ return final_chunks
+
+ def _split(self, text: str, separators: List[str]) -> List[str]:
+ final_chunks = []
+ # 1. 如果文本足够小,直接返回
+ if len(text) < self.chunk_size:
+ return [text]
+ # 2. 先尝试最高优先的分割符
+ cur_separator = separators[0]
+
+ # 3. 如果可以分割
+ if cur_separator in text:
+ # 分割成多个小部分
+ parts = text.split(cur_separator)
+
+ buffer="" # 用来合并小部分
+ for i, part in enumerate(parts):
+ # 如果小于chunk_size,就再加一小部分,使buffer接近chunk_size
+ if len(buffer) + len(part) + len(cur_separator) <= self.chunk_size:
+ buffer += part+cur_separator
+ else:
+ # 如果buffer 不为空
+ if buffer:
+ final_chunks.append(buffer)
+ # 如果当前part就已经超过chunk_size
+ if len(part) > self.chunk_size:
+ # 递归调用下一级
+ sub_chunks = self._split(part, separators = separators[1:])
+ final_chunks.extend(sub_chunks)
+ else: # 成为新的缓冲区
+ buffer = part + cur_separator
+
+ if buffer: # 最后一部分的缓冲区
+ final_chunks.append(buffer.strip())
+
+ else:
+ # 4. 使用下一级分隔符
+ final_chunks = self._split(text, separators[1:])
+
+ # 处理重叠
+ if self.chunk_overlap > 0:
+ return self._handle_overlap(final_chunks)
+ else:
+ return final_chunks
+
+ def _handle_overlap(self, final_chunks: List[str]) -> List[str]:
+ overlap_chunks = []
+ if not final_chunks:
+ return []
+ overlap_chunks.append(final_chunks[0])
+ for i in range(1, len(final_chunks)):
+ pre_chunk = overlap_chunks[-1]
+ cur_chunk = final_chunks[i]
+ # 从前一个chunk取出重叠部分与当前chunk合并
+ overlap_part = pre_chunk[-self.chunk_overlap:]
+ overlap_chunks.append(overlap_part+cur_chunk)
+
+ return overlap_chunks
+
+
+class HierarchicalSemanticSplitter:
+ """
+ 结合了层次化(父/子)和递归语义分割策略。确保在创建父块和子块时,遵循文本的自然语义边界。
+ """
+ def __init__(self,
+ parent_chunk_size: int = 800,
+ parent_chunk_overlap: int = 100,
+ child_chunk_size: int = 250,
+ separators: List[str] = None):
+ if parent_chunk_overlap >= parent_chunk_size:
+ raise ValueError("Parent chunk overlap must be smaller than parent chunk size.")
+ if child_chunk_size >= parent_chunk_size:
+ raise ValueError("Child chunk size must be smaller than parent chunk size.")
+
+ self.parent_chunk_size = parent_chunk_size
+ self.parent_chunk_overlap = parent_chunk_overlap
+ self.child_chunk_size = child_chunk_size
+ self.separators = separators or ["\n\n", "\n", "。", ". ", "!", "!", "?", "?", " ", ""]
+
+ def _recursive_semantic_split(self, text: str, chunk_size: int) -> List[str]:
+ """
+ 优先考虑语义边界
+ """
+ if len(text) <= chunk_size:
+ return [text]
+
+ for sep in self.separators:
+ split_point = text.rfind(sep, 0, chunk_size)
+ if split_point != -1:
+ break
+ else:
+ split_point = chunk_size
+
+ chunk1 = text[:split_point]
+ remaining_text = text[split_point:].lstrip() # 删除剩余部分的前空格
+
+ # 递归拆分剩余文本
+ # 分隔符将添加回第一个块以保持上下文
+ if remaining_text:
+ return [chunk1 + (sep if sep in " \n" else "")] + self._recursive_semantic_split(remaining_text, chunk_size)
+ else:
+ return [chunk1]
+
+ def _apply_overlap(self, chunks: List[str], overlap: int) -> List[str]:
+ """处理重叠部分chunk"""
+ if not overlap or len(chunks) <= 1:
+ return chunks
+
+ overlapped_chunks = [chunks[0]]
+ for i in range(1, len(chunks)):
+ # 从前一个chunk中获取最后的“重叠”字符
+ overlap_content = chunks[i - 1][-overlap:]
+ overlapped_chunks.append(overlap_content + chunks[i])
+
+ return overlapped_chunks
+
+ def split_documents(self, documents: List[Document]) -> Tuple[Dict[int, Document], List[Chunk]]:
+ """
+ 两次切分
+ :param documents:
+ :return:
+ - parent documents: {parent_id: Document}
+ - child chunks: [Chunk, Chunk, ...]
+ """
+ parent_docs_dict: Dict[int, Document] = {}
+ child_chunks_list: List[Chunk] = []
+ parent_id_counter = 0
+
+ logging.info("Starting robust hierarchical semantic splitting...")
+
+ for doc in documents:
+ # === PASS 1: 创建父chunks ===
+ # 1. 将整个文档text分割成大的语义chunks
+ initial_parent_chunks = self._recursive_semantic_split(doc.text, self.parent_chunk_size)
+
+ # 2. 父chunks进行重叠处理
+ overlapped_parent_texts = self._apply_overlap(initial_parent_chunks, self.parent_chunk_overlap)
+
+ for p_text in overlapped_parent_texts:
+ parent_doc = Document(text=p_text, metadata=doc.metadata.copy())
+ parent_docs_dict[parent_id_counter] = parent_doc
+
+ # === PASS 2: Create Child Chunks from each Parent ===
+ child_texts = self._recursive_semantic_split(p_text, self.child_chunk_size)
+
+ for c_text in child_texts:
+ child_metadata = doc.metadata.copy()
+ child_metadata['parent_id'] = parent_id_counter
+ child_chunk = Chunk(text=c_text, metadata=child_metadata, parent_id=parent_id_counter)
+ child_chunks_list.append(child_chunk)
+
+ parent_id_counter += 1
+
+ logging.info(
+ f"Splitting complete. Created {len(parent_docs_dict)} parent chunks and {len(child_chunks_list)} child chunks.")
+ return parent_docs_dict, child_chunks_list
diff --git a/core/test_qw.py b/core/test_qw.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac3f83f1b3c1366dee367520cf6c1129b9be3535
--- /dev/null
+++ b/core/test_qw.py
@@ -0,0 +1,133 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/29 20:03
+# @Author : hukangzhe
+# @File : test_qw.py
+# @Description : 测试两个模型(qwen1.5 qwen3)的两种输出方式(full or stream)是否正确
+import os
+import queue
+import logging
+import threading
+import torch
+from typing import Tuple, Generator
+from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer
+
+class ThinkStreamer(TextStreamer):
+ def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool =True, **decode_kwargs):
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
+ self.is_thinking = True
+ self.think_end_token_id = self.tokenizer.encode("", add_special_tokens=False)[0]
+ self.output_queue = queue.Queue()
+
+ def on_finalized_text(self, text: str, stream_end: bool = False):
+ self.output_queue.put(text)
+ if stream_end:
+ self.output_queue.put(None) # 发送结束信号
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ value = self.output_queue.get()
+ if value is None:
+ raise StopIteration()
+ return value
+
+ def generate_output(self) -> Generator[Tuple[str, str], None, None]:
+ full_decode_text = ""
+ already_yielded_len = 0
+ for text_chunk in self:
+ if not self.is_thinking:
+ yield "answer", text_chunk
+ continue
+
+ full_decode_text += text_chunk
+ tokens = self.tokenizer.encode(full_decode_text, add_special_tokens=False)
+
+ if self.think_end_token_id in tokens:
+ spilt_point = tokens.index(self.think_end_token_id)
+ think_part_tokens = tokens[:spilt_point]
+ thinking_text = self.tokenizer.decode(think_part_tokens)
+
+ answer_part_tokens = tokens[spilt_point:]
+ answer_text = self.tokenizer.decode(answer_part_tokens)
+ remaining_thinking = thinking_text[already_yielded_len:]
+ if remaining_thinking:
+ yield "thinking", remaining_thinking
+
+ if answer_text:
+ yield "answer", answer_text
+
+ self.is_thinking = False
+ already_yielded_len = len(thinking_text) + len(self.tokenizer.decode(self.think_end_token_id))
+ else:
+ yield "thinking", text_chunk
+ already_yielded_len += len(text_chunk)
+
+
+
+class LLMInterface:
+ def __init__(self, model_name: str= "Qwen/Qwen3-0.6B"):
+ logging.info(f"Initializing generator {model_name}")
+ self.generator_tokenizer = AutoTokenizer.from_pretrained(model_name)
+ self.generator_model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype="auto",
+ device_map="auto")
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ def generate_answer(self, query: str, context_str: str) -> str:
+ messages = [
+ {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
+ {"role": "user", "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题:{query}"}
+ ]
+ prompt = self.generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(self.device)
+ output = self.generator_model.generate(**inputs,
+ max_new_tokens=256, num_return_sequences=1,
+ eos_token_id=self.generator_tokenizer.eos_token_id)
+ generated_ids = output[0][inputs["input_ids"].shape[1]:]
+ answer = self.generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
+ return answer
+
+ def generate_answer_stream(self, query: str, context_str: str) -> Generator[Tuple[str, str], None, None]:
+ """Generates an answer as a stream of (state, content) tuples."""
+ messages = [
+ {"role": "system",
+ "content": "You are a helpful assistant. Please answer the question based on the provided context. First, think through the process in tags, then provide the final answer."},
+ {"role": "user",
+ "content": f"Context:\n---\n{context_str}\n---\nBased on the context above, please answer the question: {query}"}
+ ]
+
+ # Use the template that enables thinking for Qwen models
+ prompt = self.generator_tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True
+ )
+ model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device)
+
+ streamer = ThinkStreamer(self.generator_tokenizer, skip_prompt=True)
+
+ generation_kwargs = dict(
+ **model_inputs,
+ max_new_tokens=512,
+ streamer=streamer
+ )
+
+ thread = threading.Thread(target=self.generator_model.generate, kwargs=generation_kwargs)
+ thread.start()
+
+ yield from streamer.generate_output()
+
+
+# if __name__ == "__main__":
+# qwen = LLMInterface("Qwen/Qwen3-0.6B")
+# answer = qwen.generate_answer("儒家思想的创始人是谁?", "中国传统哲学以儒家、道家和法家为主要流派。儒家思想由孔子创立,强调“仁”、“义”、“礼”、“智”、“信”,主张修身齐家治国平天下,对中国社会产生了深远的影响。其核心价值观如“己所不欲,勿施于人”至今仍具有普世意义。"+
+#
+# "道家思想以老子和庄子为代表,主张“道法自然”,追求人与自然的和谐统一,强调无为而治、清静无为。道家思想对中国人的审美情趣、艺术创作以及养生之道都有着重要的影响。"+
+#
+# "法家思想以韩非子为集大成者,主张以法治国,强调君主的权威和法律的至高无上。尽管法家思想在历史上曾被用于强化中央集权,但其对建立健全的法律体系也提供了重要的理论基础。")
+#
+# print(answer)
diff --git a/core/vector_store.py b/core/vector_store.py
new file mode 100644
index 0000000000000000000000000000000000000000..6792e7151477341db43f19e148c5734d94e95bab
--- /dev/null
+++ b/core/vector_store.py
@@ -0,0 +1,158 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/27 19:52
+# @Author : hukangzhe
+# @File : retriever.py
+# @Description : 负责向量化、存储、检索的模块
+import os
+import faiss
+import numpy as np
+import pickle
+import logging
+from rank_bm25 import BM25Okapi
+from typing import List, Dict, Tuple
+from .schema import Document, Chunk
+
+
+class HybridVectorStore:
+ def __init__(self, config: dict, embedder):
+ self.config = config["vector_store"]
+ self.embedder = embedder
+ self.faiss_index = None
+ self.bm25_index = None
+ self.parent_docs: Dict[int, Document] = {}
+ self.child_chunks: List[Chunk] = []
+
+ def build(self, parent_docs: Dict[int, Document], child_chunks: List[Chunk]):
+ self.parent_docs = parent_docs
+ self.child_chunks = child_chunks
+
+ # Build Faiss index
+ child_text = [child.text for child in child_chunks]
+ embeddings = self.embedder.embed(child_text)
+ dim = embeddings.shape[1]
+ self.faiss_index = faiss.IndexFlatL2(dim)
+ self.faiss_index.add(embeddings)
+ logging.info(f"FAISS index built with {len(child_chunks)} vectors.")
+
+ # Build BM25 index
+ tokenize_chunks = [doc.text.split(" ") for doc in child_chunks]
+ self.bm25_index = BM25Okapi(tokenize_chunks)
+ logging.info(f"BM25 index built for {len(child_chunks)} documents.")
+
+ self.save()
+
+ def search(self, query: str, top_k: int , alpha: float) -> List[Tuple[int, float]]:
+ # Vector Search
+ query_embedding = self.embedder.embed([query])
+ distances, indices = self.faiss_index.search(query_embedding, k=top_k)
+ vector_scores = {idx : 1.0/(1.0 + dist) for idx, dist in zip(indices[0], distances[0])}
+
+ # BM25 Search
+ tokenize_query = query.split(" ")
+ bm25_scores = self.bm25_index.get_scores(tokenize_query)
+ bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
+ bm25_scores = {idx: bm25_scores[idx] for idx in bm25_top_indices}
+
+ # Hybrid Search
+ all_indices = set(vector_scores.keys()) | set(bm25_scores.keys()) # 求并集
+ hybrid_scors = {}
+
+ # Normalization
+ max_v_score = max(vector_scores.values()) if vector_scores else 1.0
+ max_b_score = max(bm25_scores.values()) if bm25_scores else 1.0
+ for idx in all_indices:
+ v_score = (vector_scores.get(idx, 0))/max_v_score
+ b_score = (bm25_scores.get(idx, 0))/max_b_score
+ hybrid_scors[idx] = alpha * v_score + (1 - alpha) * b_score
+
+ sorted_indices = sorted(hybrid_scors.items(), key=lambda item: item[1], reverse=True)[:top_k]
+ return sorted_indices
+
+ def get_chunks(self, indices: List[int]) -> List[Chunk]:
+ return [self.child_chunks[i] for i in indices]
+
+ def get_parent_docs(self, chunks: List[Chunk]) -> List[Document]:
+ parent_ids = sorted(list(set(chunk.parent_id for chunk in chunks)))
+ return [self.parent_docs[pid] for pid in parent_ids]
+
+ def save(self):
+ index_path = self.config['index_path']
+ metadata_path = self.config['metadata_path']
+
+ os.makedirs(os.path.dirname(index_path), exist_ok=True)
+ os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
+ logging.info(f"Saving FAISS index to: {index_path}")
+ try:
+ faiss.write_index(self.faiss_index, index_path)
+ except Exception as e:
+ logging.error(f"Failed to save FAISS index: {e}")
+ raise
+
+ logging.info(f"Saving metadata data to: {metadata_path}")
+ try:
+ with open(metadata_path, 'wb') as f:
+ metadata = {
+ 'parent_docs': self.parent_docs,
+ 'child_chunks': self.child_chunks,
+ 'bm25_index': self.bm25_index
+ }
+ pickle.dump(metadata, f)
+ except Exception as e:
+ logging.error(f"Failed to save metadata: {e}")
+ raise
+
+ logging.info("Vector store saved successfully.")
+
+ def load(self) -> bool:
+ """
+ 从磁盘加载整个向量存储状态,成功时返回 True,失败时返回 False。
+ """
+ index_path = self.config['index_path']
+ metadata_path = self.config['metadata_path']
+
+ if not os.path.exists(index_path) or not os.path.exists(metadata_path):
+ logging.warning("Index files not found. Cannot load vector store.")
+ return False
+
+ logging.info(f"Loading vector store from disk...")
+ try:
+ # Load FAISS index
+ logging.info(f"Loading FAISS index from: {index_path}")
+ self.faiss_index = faiss.read_index(index_path)
+
+ # Load metadata
+ logging.info(f"Loading metadata from: {metadata_path}")
+ with open(metadata_path, 'rb') as f:
+ metadata = pickle.load(f)
+ self.parent_docs = metadata['parent_docs']
+ self.child_chunks = metadata['child_chunks']
+ self.bm25_index = metadata['bm25_index']
+
+ logging.info("Vector store loaded successfully.")
+ return True
+
+ except Exception as e:
+ logging.error(f"Failed to load vector store from disk: {e}")
+ self.faiss_index = None
+ self.bm25_index = None
+ self.parent_docs = {}
+ self.child_chunks = []
+ return False
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/data/chinese_document.pdf b/data/chinese_document.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..cf3c783f36f435c8a29f5d77e52efceb04852205
--- /dev/null
+++ b/data/chinese_document.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bde4aba4245bb013438240cabf4c149d8aec76536f7a1459a165db02ec3695c
+size 317035
diff --git a/data/english_document.pdf b/data/english_document.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..214d2ea71708ac3cd25d9dc02c8a9213fdefe17c
--- /dev/null
+++ b/data/english_document.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5aa48e1b7d30d8cf2ac87cae6e2fee73e2de4d87178d0878d0d34b499fc7b26d
+size 166331
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..032ece81fc60f6a0415af85f2359e0b1511ad874
--- /dev/null
+++ b/main.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/5/1 14:00
+# @Author : hukangzhe
+# @File : main.py
+# @Description :
+import yaml
+from service.rag_service import RAGService
+from ui.app import GradioApp
+from utils.logger import setup_logger
+
+
+def main():
+ setup_logger()
+
+ with open('configs/config.yaml', 'r', encoding='utf-8') as f:
+ config = yaml.safe_load(f)
+
+ rag_service = RAGService(config)
+ app = GradioApp(rag_service)
+ app.launch()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/rag_mini.py b/rag_mini.py
new file mode 100644
index 0000000000000000000000000000000000000000..813c892b4ba93a3bfd33089fc2baea497b6fab59
--- /dev/null
+++ b/rag_mini.py
@@ -0,0 +1,148 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/24 14:04
+# @Author : hukangzhe
+# @File : rag_core.py
+# @Description :非常简单的RAG系统
+
+import PyPDF2
+import fitz
+from sentence_transformers import SentenceTransformer, CrossEncoder
+from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
+import numpy as np
+import faiss
+import torch
+
+
+class RAGSystem:
+ def __init__(self, pdf_path):
+ self.pdf_path = pdf_path
+ self.texts = self._load_and_spilt_pdf()
+ self.embedder = SentenceTransformer('moka-ai/m3e-base')
+ self.reranker = CrossEncoder('BAAI/bge-reranker-base') # 加载一个reranker模型
+ self.vector_store = self._create_vector_store()
+ print("3. Initializing Generator Model...")
+ model_name = "Qwen/Qwen1.5-1.8B-Chat"
+
+ # 检查是否有可用的GPU
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ print(f" - Using device: {device}")
+
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+ # 注意:对于像Qwen这样的模型,我们通常使用 AutoModelForCausalLM
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
+
+ self.generator = pipeline(
+ 'text-generation',
+ model=model,
+ tokenizer=self.tokenizer,
+ device=device
+ )
+
+ # 1. 文档加载 & 2.文本切分 (为了简化,合在一起)
+ def _load_and_spilt_pdf(self):
+ print("1. Loading and splitting PDF...")
+ full_text = ""
+ with fitz.open(self.pdf_path) as doc:
+ for page in doc:
+ full_text += page.get_text()
+
+ # 非常基础的切分:根据固定大小
+ chunk_size = 500
+ overlap = 50
+ chunks = [full_text[i: i+chunk_size] for i in range(0, len(full_text), chunk_size-overlap)]
+ print(f" - Splitted into {len(chunks)} chunks.")
+ return chunks
+
+ # 3. 文本向量化 & 向量存储
+ def _create_vector_store(self):
+ print("2. Creating vector store...")
+ # embedding
+ embeddings = self.embedder.encode(self.texts)
+
+ # Storing with faiss
+ dim = embeddings.shape[1]
+ index = faiss.IndexFlatL2(dim) # 使用L2距离进行相似度计算
+ index.add(np.array(embeddings))
+ print(" - Created vector store")
+ return index
+
+ # 4.检索
+ def retrieve(self, query, k=3):
+ print(f"3. Retrieving top {k} relevant chunks for query: '{query}' ")
+ query_embeddings = self.embedder.encode([query])
+
+ distances, indices = self.vector_store.search(np.array(query_embeddings), k=k)
+ retrieved_chunks = [self.texts[i] for i in indices[0]]
+ print(" - Retrieval complete.")
+ return retrieved_chunks
+
+ # 5.生成
+ def generate(self, query, context_chunks):
+ print("4. Generate answer...")
+ context = "\n".join(context_chunks)
+
+ messages = [
+ {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
+ {"role": "user", "content": f"上下文:\n---\n{context}\n---\n请根据以上上下文回答这个问题:{query}"}
+ ]
+ prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
+ # print("Final Prompt:\n", prompt)
+ # print("Prompt token length:", len(self.tokenizer.encode(prompt)))
+
+ result = self.generator(prompt, max_new_tokens=200, num_return_sequences=1,
+ eos_token_id=self.tokenizer.eos_token_id)
+
+ print(" - Generation complete.")
+ # print("Raw results:", result)
+ # 提取生成的文本
+ # 注意:Qwen模型返回的文本包含了prompt,我们需要从中提取出答案部分
+ full_response = result[0]["generated_text"]
+ answer = full_response[len(prompt):].strip() # 从prompt之后开始截取
+
+ # print("Final Answer:", repr(answer))
+ return answer
+
+ # 优化1
+ def rerank(self, query, chunks):
+ print(" - Reranking retrieved chunks...")
+ pairs = [[query, chunk] for chunk in chunks]
+ scores = self.reranker.predict(pairs)
+
+ # 将chunks和scores打包,并按score降序排序
+ ranked_chunks = sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
+ return [chunk for chunk, score in ranked_chunks]
+
+ def query(self, query_text):
+ # 1. 检索(可以检索更多结果,如top 10)
+ retrieved_chunks = self.retrieve(query_text, k=10)
+
+ # 2. 重排(从10个中选出最相关的3个)
+ reranked_chunks = self.rerank(query_text, retrieved_chunks)
+ top_k_reranked = reranked_chunks[:3]
+
+ answer = self.generate(query_text, top_k_reranked)
+ return answer
+
+
+def main():
+ # 确保你的data文件夹里有一个叫做sample.pdf的文件
+ pdf_path = 'data/chinese_document.pdf'
+
+ print("Initializing RAG System...")
+ rag_system = RAGSystem(pdf_path)
+ print("\nRAG System is ready. You can start asking questions.")
+ print("Type 'q' to quit.")
+
+ while True:
+ user_query = input("\nYour Question: ")
+ if user_query.lower() == 'q':
+ break
+
+ answer = rag_system.query(user_query)
+ print("\nAnswer:", answer)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c5a1afd4a08714741949343784a3e878e0e69ec8
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,9 @@
+pyyaml
+pypdf2
+PyMuPDF
+sentence-transformers
+faiss-cpu
+rank_bm25
+transformers
+torch
+gradio
\ No newline at end of file
diff --git a/service/__init__.py b/service/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/service/__pycache__/__init__.cpython-39.pyc b/service/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e4b381495c7815b79f4444a2a3d2f5cad4027ae
Binary files /dev/null and b/service/__pycache__/__init__.cpython-39.pyc differ
diff --git a/service/__pycache__/rag_service.cpython-39.pyc b/service/__pycache__/rag_service.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f04ec204633814e2a126918295c729eb3e500c98
Binary files /dev/null and b/service/__pycache__/rag_service.cpython-39.pyc differ
diff --git a/service/rag_service.py b/service/rag_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7fb58099fc1372b269ae32c312e083f330adfe0
--- /dev/null
+++ b/service/rag_service.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/30 11:50
+# @Author : hukangzhe
+# @File : rag_service.py
+# @Description :
+import logging
+import os
+from typing import List, Generator, Tuple
+from core.schema import Document
+from core.embedder import EmbeddingModel
+from core.loader import MultiDocumentLoader
+from core.splitter import HierarchicalSemanticSplitter
+from core.vector_store import HybridVectorStore
+from core.llm_interface import LLMInterface
+
+
+class RAGService:
+ def __init__(self, config: dict):
+ self.config = config
+ logging.info("Initializing RAG Service...")
+ self.embedder = EmbeddingModel(config['models']['embedding'])
+ self.vector_store = HybridVectorStore(config, self.embedder)
+ self.llm = LLMInterface(config)
+ self.is_ready = False # 是否准备好进行查询
+ logging.info("RAG Service initialized. Knowledge base is not loaded.")
+
+ def load_knowledge_base(self) -> Tuple[bool, str]:
+ """
+ 尝试从磁盘加载
+ Returns:
+ A tuple (success: bool, message: str)
+ """
+ if self.is_ready:
+ return True, "Knowledge base is already loaded."
+
+ logging.info("Attempting to load knowledge base from disk...")
+ success = self.vector_store.load()
+ if success:
+ self.is_ready = True
+ message = "Knowledge base loaded successfully from disk."
+ logging.info(message)
+ return True, message
+ else:
+ self.is_ready = False
+ message = "No existing knowledge base found or failed to load. Please build a new one."
+ logging.warning(message)
+ return False, message
+
+ def build_knowledge_base(self, file_paths: List[str]) -> Generator[str, None, None]:
+ self.is_ready = False
+ yield "Step 1/3: Loading documents..."
+ loader = MultiDocumentLoader(file_paths)
+ docs = loader.load()
+
+ yield "Step 2/3: Splitting documents into hierarchical chunks..."
+ splitter = HierarchicalSemanticSplitter(
+ parent_chunk_size=self.config['splitter']['parent_chunk_size'],
+ parent_chunk_overlap=self.config['splitter']['parent_chunk_overlap'],
+ child_chunk_size=self.config['splitter']['child_chunk_size']
+ )
+ parent_docs, child_chunks = splitter.split_documents(docs)
+
+ yield "Step 3/3: Building and saving vector index..."
+ self.vector_store.build(parent_docs, child_chunks)
+ self.is_ready = True
+ yield "Knowledge base built and ready!"
+
+ def _get_context_and_sources(self, query: str) -> List[Document]:
+ if not self.is_ready:
+ raise Exception("Knowledge base is not ready. Please build it first.")
+
+ # Hybrid Search to get child chunks
+ retrieved_child_indices_scores = self.vector_store.search(
+ query,
+ top_k=self.config['retrieval']['retrieval_top_k'],
+ alpha=self.config['retrieval']['hybrid_search_alpha']
+ )
+ retrieved_child_indices = [idx for idx, score in retrieved_child_indices_scores]
+ retrieved_child_chunks = self.vector_store.get_chunks(retrieved_child_indices)
+
+ # Get Parent Documents
+ retrieved_parent_docs = self.vector_store.get_parent_docs(retrieved_child_chunks)
+
+ # Rerank Parent Documents
+ reranked_docs = self.llm.rerank(query, retrieved_parent_docs)
+ final_context_docs = reranked_docs[:self.config['retrieval']['rerank_top_k']]
+
+ return final_context_docs
+
+ def get_response_full(self, query: str) ->(str, List[Document]):
+ final_context_docs = self._get_context_and_sources(query)
+ answer = self.llm.generate_answer(query, final_context_docs)
+ return answer, final_context_docs
+
+ def get_response_stream(self, query: str) ->(Generator[str, None, None], List[Document]):
+ final_context_docs = self._get_context_and_sources(query)
+ answer_generator = self.llm.generate_answer_stream(query, final_context_docs)
+ return answer_generator, final_context_docs
+
+ def get_context_string(self, context_docs: List[Document]) -> str:
+ context_str = "引用上下文 (Context Sources):\n\n"
+ for doc in context_docs:
+ source_info = f"--- (来源: {os.path.basename(doc.metadata.get('source', ''))}, 页码: {doc.metadata.get('page', 'N/A')}) ---\n"
+ content = doc.text[:200]+"..." if len(doc.text) > 200 else doc.text
+ context_str += source_info + content + "\n\n"
+ return context_str.strip()
+
+
+
diff --git a/storage/faiss_index/chunks.pkl b/storage/faiss_index/chunks.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..a084a4cd7d4834a6e8929987bce62976f5b9101b
--- /dev/null
+++ b/storage/faiss_index/chunks.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:da7fae509a97a93f69b8048e375fca25dd45f0e26a40bff83c30759c5853e473
+size 27663
diff --git a/storage/faiss_index/index.faiss b/storage/faiss_index/index.faiss
new file mode 100644
index 0000000000000000000000000000000000000000..3c2e5a8e2cd09faee8ced8ee789d661f50e1a434
Binary files /dev/null and b/storage/faiss_index/index.faiss differ
diff --git a/ui/__init__.py b/ui/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ui/__pycache__/__init__.cpython-39.pyc b/ui/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15bb473a52d09cfd63c269b912e6a93bbf57995a
Binary files /dev/null and b/ui/__pycache__/__init__.cpython-39.pyc differ
diff --git a/ui/__pycache__/app.cpython-39.pyc b/ui/__pycache__/app.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ceb60199ef7e2d2327109894ee63fa5ae7a69b8f
Binary files /dev/null and b/ui/__pycache__/app.cpython-39.pyc differ
diff --git a/ui/app.py b/ui/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..5990d22ce4ec0ec2019beb57c8fe5a20555225bd
--- /dev/null
+++ b/ui/app.py
@@ -0,0 +1,202 @@
+import gradio as gr
+import os
+from typing import List, Tuple
+
+from service.rag_service import RAGService
+
+
+class GradioApp:
+ def __init__(self, rag_service: RAGService):
+ self.rag_service = rag_service
+ self._build_ui()
+
+ def _build_ui(self):
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"),
+ title="Enterprise RAG System") as self.demo:
+ gr.Markdown("# 企业级RAG智能问答系统 (Enterprise RAG System)")
+ gr.Markdown("您可以**加载现有知识库**快速开始,或**上传新文档**构建一个全新的知识库。")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown("### 控制面板 (Control Panel)")
+
+ self.load_kb_button = gr.Button("加载已有知识库 (Load Existing KB)")
+
+ gr.Markdown("
")
+
+ self.file_uploader = gr.File(
+ label="上传新文档以构建 (Upload New Docs to Build)",
+ file_count="multiple",
+ file_types=[".pdf", ".txt"],
+ interactive=True
+ )
+ self.build_kb_button = gr.Button("构建新知识库 (Build New KB)", variant="primary")
+
+ self.status_box = gr.Textbox(
+ label="系统状态 (System Status)",
+ value="系统已初始化,等待加载或构建知识库。",
+ interactive=False,
+ lines=4
+ )
+
+ # --- 刚开始隐藏,构建了数据库再显示 ---
+ with gr.Column(scale=2, visible=False) as self.chat_area:
+ gr.Markdown("### 对话窗口 (Chat Window)")
+ self.chatbot = gr.Chatbot(label="RAG Chatbot", bubble_full_width=False, height=500)
+ self.mode_selector = gr.Radio(
+ ["流式输出(Streaming)","一次性输出(Full)"],
+ label="输出模式:(Output Mode)",
+ value="流式输出(Streaming)"
+ )
+ self.question_box = gr.Textbox(label="您的问题", placeholder="请在此处输入您的问题...",
+ show_label=False)
+ with gr.Row():
+ self.submit_btn = gr.Button("提交 (Submit)", variant="primary")
+ self.clear_btn = gr.Button("清空历史 (Clear History)")
+
+ gr.Markdown("---")
+ self.source_display = gr.Markdown("### 引用来源 (Sources)")
+
+ # --- Event Listeners ---
+ self.load_kb_button.click(
+ fn=self._handle_load_kb,
+ inputs=None,
+ outputs=[self.status_box, self.chat_area]
+ )
+
+ self.build_kb_button.click(
+ fn=self._handle_build_kb,
+ inputs=[self.file_uploader],
+ outputs=[self.status_box, self.chat_area]
+ )
+
+ self.submit_btn.click(
+ fn=self._handle_chat_submission,
+ inputs=[self.question_box, self.chatbot, self.mode_selector],
+ outputs=[self.chatbot, self.question_box, self.source_display]
+ )
+
+ self.question_box.submit(
+ fn=self._handle_chat_submission,
+ inputs=[self.question_box, self.chatbot, self.mode_selector],
+ outputs=[self.chatbot, self.question_box, self.source_display]
+ )
+
+ self.clear_btn.click(
+ fn=self._clear_chat,
+ inputs=None,
+ outputs=[self.chatbot, self.question_box, self.source_display]
+ )
+
+ def _handle_load_kb(self):
+ """处理现有知识库的加载。返回更新字典。"""
+ success, message = self.rag_service.load_knowledge_base()
+ if success:
+ return {
+ self.status_box: gr.update(value=message),
+ self.chat_area: gr.update(visible=True)
+ }
+ else:
+ return {
+ self.status_box: gr.update(value=message),
+ self.chat_area: gr.update(visible=False)
+ }
+
+ def _handle_build_kb(self, files: List[str], progress=gr.Progress(track_tqdm=True)):
+ """构建新知识库,返回更新的字典."""
+ if not files:
+ # --- MODIFIED LINE ---
+ return {
+ self.status_box: gr.update(value="错误:请至少上传一个文档。"),
+ self.chat_area: gr.update(visible=False)
+ }
+
+ file_paths = [file.name for file in files]
+
+ try:
+ for status in self.rag_service.build_knowledge_base(file_paths):
+ progress(0.5, desc=status)
+
+ final_status = "知识库构建完成并已就绪!√"
+ # --- MODIFIED LINE ---
+ return {
+ self.status_box: gr.update(value=final_status),
+ self.chat_area: gr.update(visible=True)
+ }
+ except Exception as e:
+ error_message = f"构建失败: {e}"
+ # --- MODIFIED LINE ---
+ return {
+ self.status_box: gr.update(value=error_message),
+ self.chat_area: gr.update(visible=False)
+ }
+
+ def _handle_chat_submission(self, question: str, history: List[Tuple[str, str]], mode: str):
+ if not question or not question.strip():
+ yield history, "", "### 引用来源 (Sources)\n"
+ return
+
+ history.append((question, ""))
+
+ try:
+ # 一次全部输出
+ if "Full" in mode:
+ yield history, "", "### 引用来源 (Sources)\n"
+
+ answer, sources = self.rag_service.get_response_full(question)
+ # 获取引用内容
+ context_string_for_display = self.rag_service.get_context_string(sources)
+ # 修改格式
+ source_text_for_panel = self._format_sources(sources)
+ #完整内容:引用+回答
+ full_response = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n{answer}"
+ history[-1] = (question, full_response)
+
+ yield history, "", source_text_for_panel
+
+ # 流式输出
+ else:
+ answer_generator, sources = self.rag_service.get_response_stream(question)
+
+ context_string_for_display = self.rag_service.get_context_string(sources)
+
+ source_text_for_panel = self._format_sources(sources)
+
+ yield history, "", source_text_for_panel
+
+ response_prefix = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n"
+ history[-1] = (question, response_prefix)
+ yield history, "", source_text_for_panel
+
+ answer_log = ""
+ for text_chunk in answer_generator:
+ answer_log += text_chunk
+ history[-1] = (question, response_prefix + answer_log)
+ yield history, "", source_text_for_panel
+
+ except Exception as e:
+ error_response = f"处理请求时出错: {e}"
+ history[-1] = (question, error_response)
+ yield history, "", "### 引用来源 (Sources)\n"
+
+ def _format_sources(self, sources: List) -> str:
+ source_text = "### 引用来源 (sources)\n)"
+ if not sources:
+ return source_text
+
+ unique_sources = set()
+ for doc in sources:
+ source_name = os.path.basename(doc.metadata.get('source', 'Unknown'))
+ page_num = doc.metadata.get('page', 'N/A')
+ unique_sources.add(f"- **{source_name}** (Page: {page_num})")
+
+ source_text += "\n".join(sorted(list(unique_sources)))
+ return source_text
+
+ def _clear_chat(self):
+ """清理聊天内容"""
+ return None, "", "### 引用来源 (Sources)\n"
+
+ def launch(self):
+ self.demo.queue().launch()
+
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ddccd3328327de77cb1351661e3593a951a5768
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/utils/__pycache__/logger.cpython-39.pyc b/utils/__pycache__/logger.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a39a45115fb2ce92fd3a19ffa7ec666683f1e0b
Binary files /dev/null and b/utils/__pycache__/logger.cpython-39.pyc differ
diff --git a/utils/logger.py b/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..975c1081b7df6e96b5cc94194dfbeddeffe15567
--- /dev/null
+++ b/utils/logger.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2025/4/25 19:51
+# @Author : hukangzhe
+# @File : logger.py.py
+# @Description : 日志配置模块
+import logging
+import sys
+
+
+def setup_logger():
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s",
+ handlers=[
+ logging.FileHandler("storage/logs/app.log"),
+ logging.StreamHandler(sys.stdout)
+ ]
+ )
+