diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..da15360707a7132317001f9c9f7ac950a39b05b4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/1.png filter=lfs diff=lfs merge=lfs -text +assets/2.png filter=lfs diff=lfs merge=lfs -text +data/chinese_document.pdf filter=lfs diff=lfs merge=lfs -text +data/english_document.pdf filter=lfs diff=lfs merge=lfs -text +Mini_RAG/assets/1.png filter=lfs diff=lfs merge=lfs -text +Mini_RAG/assets/2.png filter=lfs diff=lfs merge=lfs -text +Mini_RAG/data/chinese_document.pdf filter=lfs diff=lfs merge=lfs -text +Mini_RAG/data/english_document.pdf filter=lfs diff=lfs merge=lfs -text diff --git a/Mini_RAG/.gitattributes b/Mini_RAG/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..dab9a4e17afd2ef39d90ccb0b40ef2786fe77422 --- /dev/null +++ b/Mini_RAG/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/Mini_RAG/.idea/.gitignore b/Mini_RAG/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1c2fda565b94d0f2b94cb65ba7cca866e7a25478 --- /dev/null +++ b/Mini_RAG/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/Mini_RAG/.idea/RAG_Min.iml b/Mini_RAG/.idea/RAG_Min.iml new file mode 100644 index 0000000000000000000000000000000000000000..527b363ce16e0ead03139f7e7ccee15f64a86ba5 --- /dev/null +++ b/Mini_RAG/.idea/RAG_Min.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/Mini_RAG/.idea/inspectionProfiles/Project_Default.xml b/Mini_RAG/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000000000000000000000000000000000000..9680032c8f34359f630a6d6d887cc6f1c4d5d9c3 --- /dev/null +++ b/Mini_RAG/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,19 @@ + + + + \ No newline at end of file diff --git a/Mini_RAG/.idea/inspectionProfiles/profiles_settings.xml b/Mini_RAG/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99 --- /dev/null +++ b/Mini_RAG/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/Mini_RAG/.idea/misc.xml b/Mini_RAG/.idea/misc.xml new file mode 100644 index 0000000000000000000000000000000000000000..437b754c07ba5e4ee060d3231c0520ff6c3c8ebe --- /dev/null +++ b/Mini_RAG/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/Mini_RAG/.idea/modules.xml b/Mini_RAG/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..d0ece5d0e26d05bc29e01f2286873d2e9c7a6707 --- /dev/null +++ b/Mini_RAG/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/Mini_RAG/.idea/vcs.xml b/Mini_RAG/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..9661ac713428efbad557d3ba3a62216b5bb7d226 --- /dev/null +++ b/Mini_RAG/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Mini_RAG/.idea/workspace.xml b/Mini_RAG/.idea/workspace.xml new file mode 100644 index 0000000000000000000000000000000000000000..18c0482fe1342a0a6752f8fd66652308d62ab26f --- /dev/null +++ b/Mini_RAG/.idea/workspace.xml @@ -0,0 +1,236 @@ + + + + + + + + + + + + + + + + + + + + + + + { + "customColor": "", + "associatedIndex": 3 +} + + + + { + "keyToString": { + "ModuleVcsDetector.initialDetectionPerformed": "true", + "Python.llm_stearm.executor": "Run", + "Python.main.executor": "Run", + "Python.rag_mini (1).executor": "Run", + "Python.test_qw.executor": "Run", + "RunOnceActivity.OpenProjectViewOnStart": "true", + "RunOnceActivity.ShowReadmeOnStart": "true", + "RunOnceActivity.TerminalTabsStorage.copyFrom.TerminalArrangementManager.252": "true", + "RunOnceActivity.git.unshallow": "true", + "WebServerToolWindowFactoryState": "false", + "git-widget-placeholder": "main", + "ignore.virus.scanning.warn.message": "true", + "last_opened_file_path": "D:/LLms/Tiny-RAG", + "node.js.detected.package.eslint": "true", + "node.js.detected.package.tslint": "true", + "node.js.selected.package.eslint": "(autodetect)", + "node.js.selected.package.tslint": "(autodetect)", + "nodejs_package_manager_path": "npm", + "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable", + "vue.rearranger.settings.migration": "true" + } +} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1756273884601 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/Mini_RAG/LICENSE b/Mini_RAG/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..5eea95b8154bf1a59f0bf164075a1ab20d726899 --- /dev/null +++ b/Mini_RAG/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 TuNan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Mini_RAG/README.md b/Mini_RAG/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f47dbce21502cc0ac57a7b25aef7c87b1b7146c1 --- /dev/null +++ b/Mini_RAG/README.md @@ -0,0 +1,185 @@ +# RAG_Mini +--- + +# Enterprise-Ready RAG System with Gradio Interface + +This is a powerful, enterprise-grade Retrieval-Augmented Generation (RAG) system designed to transform your documents into an interactive and intelligent knowledge base. Users can upload their own documents (PDFs, TXT files), build a searchable vector index, and ask complex questions in natural language to receive accurate, context-aware answers sourced directly from the provided materials. + +The entire application is wrapped in a clean, user-friendly web interface powered by Gradio. + +![App Screenshot](assets/1.png) +![App Screenshot](assets/2.png) + +## ✨ Features + +- **Intuitive Web UI**: Simple, clean interface built with Gradio for uploading documents and chatting. +- **Multi-Document Support**: Natively handles PDF and TXT files. +- **Advanced Text Splitting**: Uses a `HierarchicalSemanticSplitter` that first splits documents into large parent chunks (for context) and then into smaller child chunks (for precise search), respecting semantic boundaries. +- **Hybrid Search**: Combines the strengths of dense vector search (FAISS) and sparse keyword search (BM25) for robust and accurate retrieval. +- **Reranking for Accuracy**: Employs a Cross-Encoder model to rerank the retrieved documents, ensuring the most relevant context is passed to the language model. +- **Persistent Knowledge Base**: Automatically saves the built vector index and metadata, allowing you to load an existing knowledge base instantly on startup. +- **Modular & Extensible Codebase**: The project is logically structured into services for loading, splitting, embedding, and generation, making it easy to maintain and extend. + +## 🏛️ System Architecture + +The RAG pipeline follows a logical, multi-step process to ensure high-quality answers: + +1. **Load**: Documents are loaded from various formats and parsed into a standardized `Document` object, preserving metadata like source and page number. +2. **Split**: The raw text is processed by the `HierarchicalSemanticSplitter`, creating parent and child text chunks. This provides both broad context and fine-grained detail. +3. **Embed & Index**: The child chunks are converted into vector embeddings using a `SentenceTransformer` model and indexed in a FAISS vector store. A parallel BM25 index is also built for keyword search. +4. **Retrieve**: When a user asks a question, a hybrid search query is performed against the FAISS and BM25 indices to retrieve the most relevant child chunks. +5. **Fetch Context**: The parent chunks corresponding to the retrieved child chunks are fetched. This ensures the LLM receives a wider, more complete context. +6. **Rerank**: A powerful Cross-Encoder model re-evaluates the relevance of the parent chunks against the query, pushing the best matches to the top. +7. **Generate**: The top-ranked, reranked documents are combined with the user's query into a final prompt. This prompt is sent to a Large Language Model (LLM) to generate a final, coherent answer. + +``` +[User Uploads Docs] -> [Loader] -> [Splitter] -> [Embedder & Vector Store] -> [Knowledge Base Saved] + +[User Asks Question] -> [Hybrid Search] -> [Get Parent Docs] -> [Reranker] -> [LLM] -> [Answer & Sources] +``` + +## 🛠️ Tech Stack + +- **Backend**: Python 3.9+ +- **UI**: Gradio +- **LLM & Embedding Framework**: Hugging Face Transformers, Sentence-Transformers +- **Vector Search**: Faiss (from Facebook AI) +- **Keyword Search**: rank-bm25 +- **PDF Parsing**: PyMuPDF (fitz) +- **Configuration**: PyYAML + +## 🚀 Getting Started + +Follow these steps to set up and run the project on your local machine. + +### 1. Prerequisites + +- Python 3.9 or higher +- `pip` for package management + +### 2. Create a `requirements.txt` file + +Before proceeding, it's crucial to have a `requirements.txt` file so others can easily install the necessary dependencies. In your activated terminal, run: + +```bash +pip freeze > requirements.txt +``` +This will save all the packages from your environment into the file. Make sure this file is committed to your GitHub repository. The key packages it should contain are: `gradio`, `torch`, `transformers`, `sentence-transformers`, `faiss-cpu`, `rank_bm25`, `PyMuPDF`, `pyyaml`, `numpy`. + +### 3. Installation & Setup + +**1. Clone the repository:** +```bash +git clone https://github.com/YOUR_USERNAME/YOUR_REPOSITORY_NAME.git +cd YOUR_REPOSITORY_NAME +``` + +**2. Create and activate a virtual environment (recommended):** +```bash +# For Windows +python -m venv venv +.\venv\Scripts\activate + +# For macOS/Linux +python3 -m venv venv +source venv/bin/activate +``` + +**3. Install the required packages:** +```bash +pip install -r requirements.txt +``` + +**4. Configure the system:** +Review the `configs/config.yaml` file. You can change the models, chunk sizes, and other parameters here. The default settings are a good starting point. + +> **Note:** The first time you run the application, the models specified in the config file will be downloaded from Hugging Face. This may take some time depending on your internet connection. + +### 4. Running the Application + +To start the Gradio web server, run the `main.py` script: + +```bash +python main.py +``` + +The application will be available at **`http://localhost:7860`**. + +## 📖 How to Use + +The application has two primary workflows: + +**1. Build a New Knowledge Base:** + - Drag and drop one or more `.pdf` or `.txt` files into the "Upload New Docs to Build" area. + - Click the **"Build New KB"** button. + - The system status will show the progress (Loading -> Splitting -> Indexing). + - Once complete, the status will confirm that the knowledge base is ready, and the chat window will appear. + +**2. Load an Existing Knowledge Base:** + - If you have previously built a knowledge base, simply click the **"Load Existing KB"** button. + - The system will load the saved FAISS index and metadata from the `storage` directory. + - The chat window will appear, and you can start asking questions immediately. + +**Chatting with Your Documents:** + - Once the knowledge base is ready, type your question into the chat box at the bottom and press Enter or click "Submit". + - The model will generate an answer based on the documents you provided. + - The sources used to generate the answer will be displayed below the chat window. + +## 📂 Project Structure + +``` +. +├── configs/ +│ └── config.yaml # Main configuration file for models, paths, etc. +├── core/ +│ ├── embedder.py # Handles text embedding. +│ ├── llm_interface.py # Handles reranking and answer generation. +│ ├── loader.py # Loads and parses documents. +│ ├── schema.py # Defines data structures (Document, Chunk). +│ ├── splitter.py # Splits documents into chunks. +│ └── vector_store.py # Manages FAISS & BM25 indices. +├── service/ +│ └── rag_service.py # Orchestrates the entire RAG pipeline. +├── storage/ # Default location for saved indices (auto-generated). +│ └── ... +├── ui/ +│ └── app.py # Contains the Gradio UI logic. +├── utils/ +│ └── logger.py # Logging configuration. +├── assets/ +│ └── 1.png # Screenshot of the application. +├── main.py # Entry point to run the application. +└── requirements.txt # Python package dependencies. +``` + +## 🔧 Configuration Details (`config.yaml`) + +You can customize the RAG pipeline by modifying `configs/config.yaml`: + +- **`models`**: Specify the Hugging Face models for embedding, reranking, and generation. +- **`vector_store`**: Define the paths where the FAISS index and metadata will be saved. +- **`splitter`**: Control the `HierarchicalSemanticSplitter` behavior. + - `parent_chunk_size`: The target size for larger context chunks. + - `parent_chunk_overlap`: The overlap between parent chunks. + - `child_chunk_size`: The target size for smaller, searchable chunks. +- **`retrieval`**: Tune the retrieval and reranking process. + - `retrieval_top_k`: How many initial candidates to retrieve with hybrid search. + - `rerank_top_k`: How many final documents to pass to the LLM after reranking. + - `hybrid_search_alpha`: The weighting between vector search (`alpha`) and BM25 search (`1 - alpha`). `1.0` is pure vector search, `0.0` is pure keyword search. +- **`generation`**: Set parameters for the final answer generation, like `max_new_tokens`. + +## 🛣️ Future Roadmap + +- [ ] Support for more document types (e.g., `.docx`, `.pptx`, `.html`). +- [ ] Implement response streaming for a more interactive chat experience. +- [ ] Integrate with other vector databases like ChromaDB or Pinecone. +- [ ] Create API endpoints for programmatic access to the RAG service. +- [ ] Add more advanced logging and monitoring for enterprise use. + +## 🤝 Contributing + +Contributions are welcome! If you have ideas for improvements or find a bug, please feel free to open an issue or submit a pull request. + +## 📄 License + +This project is licensed under the MIT License. See the `LICENSE` file for details. diff --git a/Mini_RAG/app.py b/Mini_RAG/app.py new file mode 100644 index 0000000000000000000000000000000000000000..4720bf8d24ffaf91dda540a08e105848a3b24a27 --- /dev/null +++ b/Mini_RAG/app.py @@ -0,0 +1,19 @@ +import yaml +from service.rag_service import RAGService +from ui.app import GradioApp +from utils.logger import setup_logger + + +def main(): + setup_logger() + + with open('configs/config.yaml', 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + rag_service = RAGService(config) + app = GradioApp(rag_service) + app.launch() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Mini_RAG/assets/1.png b/Mini_RAG/assets/1.png new file mode 100644 index 0000000000000000000000000000000000000000..def6f1b1f4f00ea6b493b7cdece6a228480c5934 --- /dev/null +++ b/Mini_RAG/assets/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3cfd5a1e1f78574b7ad19ba82bc14aefedc30e198293a36b34fddbe58ae9572a +size 117169 diff --git a/Mini_RAG/assets/2.png b/Mini_RAG/assets/2.png new file mode 100644 index 0000000000000000000000000000000000000000..04960ccb2184a5a930b0bd4be061c294a142b68e --- /dev/null +++ b/Mini_RAG/assets/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:488e38ddb8e075942a427f604b290abc27121023b44ceab74b2447ccd5d39ecd +size 144857 diff --git a/Mini_RAG/configs/config.yaml b/Mini_RAG/configs/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..249028c3740d4f6e40b7e12d273319203640579d --- /dev/null +++ b/Mini_RAG/configs/config.yaml @@ -0,0 +1,26 @@ +# 路径配置 +storage_path: "./storage" +vector_store: + index_path: "./storage/faiss_index/index.faiss" + metadata_path: "./storage/faiss_index/chunks.pkl" + +# 切分器配置 +splitter: + parent_chunk_size: 800 + parent_chunk_overlap: 100 + child_chunk_size: 250 + +# 模型配置 +models: + embedding: "moka-ai/m3e-base" + reranker: "BAAI/bge-reranker-base" + llm_generator: "Qwen/Qwen3-0.6B" # Qwen/Qwen1.5-1.8B-Chat or Qwen/Qwen3-0.6B + + +# 检索与生成参数 +retrieval: + hybrid_search_alpha: 0.5 # 混合检索中向量权(0-1), 关键词权为1-alpha + retrieval_top_k: 20 + rerank_top_k: 5 +generation: + max_new_tokens: 512 diff --git a/Mini_RAG/core/__init__.py b/Mini_RAG/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f5a12faa99758192ecc4ed3fc22c9249232e86 --- /dev/null +++ b/Mini_RAG/core/__init__.py @@ -0,0 +1 @@ + diff --git a/Mini_RAG/core/__pycache__/__init__.cpython-39.pyc b/Mini_RAG/core/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..490fc55d5de07f75e7b79a8526715e75207b5e33 Binary files /dev/null and b/Mini_RAG/core/__pycache__/__init__.cpython-39.pyc differ diff --git a/Mini_RAG/core/__pycache__/embedder.cpython-39.pyc b/Mini_RAG/core/__pycache__/embedder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b33e5057272a15b09b1bf751a6082bfd8a72e68 Binary files /dev/null and b/Mini_RAG/core/__pycache__/embedder.cpython-39.pyc differ diff --git a/Mini_RAG/core/__pycache__/llm_interface.cpython-39.pyc b/Mini_RAG/core/__pycache__/llm_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17317a98ba48afdddb28d1cdcf421f61843a88ed Binary files /dev/null and b/Mini_RAG/core/__pycache__/llm_interface.cpython-39.pyc differ diff --git a/Mini_RAG/core/__pycache__/loader.cpython-39.pyc b/Mini_RAG/core/__pycache__/loader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e832df0a9e3abbeb0443ef56884f79577401fee Binary files /dev/null and b/Mini_RAG/core/__pycache__/loader.cpython-39.pyc differ diff --git a/Mini_RAG/core/__pycache__/schema.cpython-39.pyc b/Mini_RAG/core/__pycache__/schema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d379d281390f817f93027d9084d701b0040cdd5a Binary files /dev/null and b/Mini_RAG/core/__pycache__/schema.cpython-39.pyc differ diff --git a/Mini_RAG/core/__pycache__/splitter.cpython-39.pyc b/Mini_RAG/core/__pycache__/splitter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7bf2fd07fd733b59cd01dc4b4785fd01d31c20f Binary files /dev/null and b/Mini_RAG/core/__pycache__/splitter.cpython-39.pyc differ diff --git a/Mini_RAG/core/__pycache__/vector_store.cpython-39.pyc b/Mini_RAG/core/__pycache__/vector_store.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01876546ce639ba8b2e37b0bf6b5a77b1bf79bf7 Binary files /dev/null and b/Mini_RAG/core/__pycache__/vector_store.cpython-39.pyc differ diff --git a/Mini_RAG/core/embedder.py b/Mini_RAG/core/embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..8d85c5cef0adb20c2de3013eb0290f62252e02f0 --- /dev/null +++ b/Mini_RAG/core/embedder.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/24 11:50 +# @Author : hukangzhe +# @File : embedder.py +# @Description : + +from sentence_transformers import SentenceTransformer +from typing import List +import numpy as np + + +class EmbeddingModel: + def __init__(self, model_name: str): + self.embedding_model = SentenceTransformer(model_name) + + def embed(self, texts: List[str], batch_size: int = 32) -> np.ndarray: + return self.embedding_model.encode(texts, batch_size=batch_size, convert_to_numpy=True) + diff --git a/Mini_RAG/core/llm_interface.py b/Mini_RAG/core/llm_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..bff2ea38acbc6524c9a8e2844464abaea3ba23a5 --- /dev/null +++ b/Mini_RAG/core/llm_interface.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/29 19:54 +# @Author : hukangzhe +# @File : generator.py +# @Description : 负责生成答案模块 +import os +import queue +import logging +import threading + +import torch +from typing import Dict, List, Tuple, Generator +from sentence_transformers import CrossEncoder +from .schema import Document, Chunk +from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer + +class ThinkStreamer(TextStreamer): + def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool =True, **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.is_thinking = True + self.think_end_token_id = self.tokenizer.encode("", add_special_tokens=False)[0] + self.output_queue = queue.Queue() + + def on_finalized_text(self, text: str, stream_end: bool = False): + self.output_queue.put(text) + if stream_end: + self.output_queue.put(None) # 发送结束信号 + + def __iter__(self): + return self + + def __next__(self): + value = self.output_queue.get() + if value is None: + raise StopIteration() + return value + + def generate_output(self) -> Generator[Tuple[str, str], None, None]: + """ + 分离Think和回答 + :return: + """ + full_decode_text = "" + already_yielded_len = 0 + for text_chunk in self: + if not self.is_thinking: + yield "answer", text_chunk + continue + + full_decode_text += text_chunk + tokens = self.tokenizer.encode(full_decode_text, add_special_tokens=False) + + if self.think_end_token_id in tokens: + spilt_point = tokens.index(self.think_end_token_id) + think_part_tokens = tokens[:spilt_point] + thinking_text = self.tokenizer.decode(think_part_tokens) + + answer_part_tokens = tokens[spilt_point:] + answer_text = self.tokenizer.decode(answer_part_tokens) + remaining_thinking = thinking_text[already_yielded_len:] + if remaining_thinking: + yield "thinking", remaining_thinking + + if answer_text: + yield "answer", answer_text + + self.is_thinking = False + already_yielded_len = len(thinking_text) + len(self.tokenizer.decode(self.think_end_token_id)) + else: + yield "thinking", text_chunk + already_yielded_len += len(text_chunk) + + +class QueueTextStreamer(TextStreamer): + def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool = True, **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.output_queue = queue.Queue() + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Puts text into the queue; sends None as a sentinel value to signal the end.""" + self.output_queue.put(text) + if stream_end: + self.output_queue.put(None) + + def __iter__(self): + return self + + def __next__(self): + value = self.output_queue.get() + if value is None: + raise StopIteration() + return value + + +class LLMInterface: + def __init__(self, config: dict): + self.config = config + self.reranker = CrossEncoder(config['models']['reranker']) + self.generator_new_tokens = config['generation']['max_new_tokens'] + self.device =torch.device("cuda" if torch.cuda.is_available() else "cpu") + + generator_name = config['models']['llm_generator'] + logging.info(f"Initializing generator {generator_name}") + self.generator_tokenizer = AutoTokenizer.from_pretrained(generator_name) + self.generator_model = AutoModelForCausalLM.from_pretrained( + generator_name, + torch_dtype="auto", + device_map="auto") + + def rerank(self, query: str, docs: List[Document]) -> List[Document]: + pairs = [[query, doc.text] for doc in docs] + scores = self.reranker.predict(pairs) + ranked_docs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) + return [doc for doc, score in ranked_docs] + + def _threaded_generate(self, streamer: QueueTextStreamer, generation_kwargs: dict): + """ + 一个包装函数,将 model.generate 放入 try...finally 块中。 + """ + try: + self.generator_model.generate(**generation_kwargs) + finally: + # 无论成功还是失败,都确保在最后发送结束信号 + streamer.output_queue.put(None) + + def generate_answer(self, query: str, context_docs: List[Document]) -> str: + context_str = "" + for doc in context_docs: + context_str += f"Source: {os.path.basename(doc.metadata.get('source', ''))}, Page: {doc.metadata.get('page', 'N/A')}\n" + context_str += f"Content: {doc.text}\n\n" + # content设置为英文,回答则为英文 + messages = [ + {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"}, + {"role": "user", "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题:{query}"} + ] + prompt = self.generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(self.device) + output = self.generator_model.generate(**inputs, + max_new_tokens=self.generator_new_tokens, num_return_sequences=1, + eos_token_id=self.generator_tokenizer.eos_token_id) + generated_ids = output[0][inputs["input_ids"].shape[1]:] + answer = self.generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + return answer + + def generate_answer_stream(self, query: str, context_docs: List[Document]) -> Generator[str, None, None]: + context_str = "" + for doc in context_docs: + context_str += f"Content: {doc.text}\n\n" + + messages = [ + {"role": "system", + "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"}, + {"role": "user", + "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题: {query}"} + ] + + prompt = self.generator_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device) + + streamer = QueueTextStreamer(self.generator_tokenizer, skip_prompt=True) + + generation_kwargs = dict( + **model_inputs, + max_new_tokens=self.generator_new_tokens, + streamer=streamer, + pad_token_id=self.generator_tokenizer.eos_token_id, + ) + + thread = threading.Thread(target=self._threaded_generate, args=(streamer,generation_kwargs,)) + thread.start() + for new_text in streamer: + if new_text is not None: + yield new_text + + def generate_answer_stream_split(self, query: str, context_docs: List[Document]) -> Generator[Tuple[str, str], None, None]: + """分离思考和回答的流式输出""" + context_str = "" + for doc in context_docs: + context_str += f"Content: {doc.text}\n\n" + + messages = [ + {"role": "system", + "content": "You are a helpful assistant. Please answer the question based on the provided context. First, think through the process in tags, then provide the final answer."}, + {"role": "user", + "content": f"Context:\n---\n{context_str}\n---\nBased on the context above, please answer the question: {query}"} + ] + + prompt = self.generator_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True + ) + model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device) + + streamer = ThinkStreamer(self.generator_tokenizer, skip_prompt=True) + + generation_kwargs = dict( + **model_inputs, + max_new_tokens=self.generator_new_tokens, + streamer=streamer + ) + + thread = threading.Thread(target=self.generator_model.generate, kwargs=generation_kwargs) + thread.start() + + yield from streamer.generate_output() + + + + diff --git a/Mini_RAG/core/loader.py b/Mini_RAG/core/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..169895436c9119c68856abfc9cafbc8055a6c078 --- /dev/null +++ b/Mini_RAG/core/loader.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/24 19:51 +# @Author : hukangzhe +# @File : loader.py +# @Description : 文档加载模块,不在像reg_mini返回长字符串,而是返回一个document对象列表,每个document都带有源文件信息 +import logging +from .schema import Document +from typing import List +import fitz +import os + + +class DocumentLoader: + + @staticmethod + def _load_pdf(file_path): + logging.info(f"Loading PDF file from {file_path}") + try: + with fitz.open(file_path) as f: + text = "".join(page.get_text() for page in f) + logging.info(f"Successfully loaded {len(f)} pages.") + return text + except Exception as e: + logging.error(f"Failed to load PDF {file_path}: {e}") + return None + + +class MultiDocumentLoader: + def __init__(self, paths: List[str]): + self.paths = paths + + def load(self) -> List[Document]: + docs = [] + for file_path in self.paths: + file_extension = os.path.splitext(file_path)[1].lower() + if file_extension == '.pdf': + docs.extend(self._load_pdf(file_path)) + elif file_extension == '.txt': + docs.extend(self._load_txt(file_path)) + else: + logging.warning(f"Unsupported file type:{file_extension}. Skipping {file_path}") + + return docs + + def _load_pdf(self, file_path: str) -> List[Document]: + logging.info(f"Loading PDF file from {file_path}") + try: + pdf_docs = [] + with fitz.open(file_path) as doc: + for i, page in enumerate(doc): + pdf_docs.append(Document( + text=page.get_text(), + metadata={'source': file_path, 'page': i + 1} + )) + return pdf_docs + except Exception as e: + logging.error(f"Failed to load PDF {file_path}: {e}") + return [] + + def _load_txt(self, file_path: str) -> List[Document]: + logging.info(f"Loading txt file from {file_path}") + try: + txt_docs = [] + with open(file_path, 'r', encoding="utf-8") as f: + txt_docs.append(Document( + text=f.read(), + metadata={'source': file_path} + )) + return txt_docs + except Exception as e: + logging.error(f"Failed to load txt {file_path}:{e}") + return [] diff --git a/Mini_RAG/core/schema.py b/Mini_RAG/core/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..647924b047ac0f8f8225a5d85ecb3191f499ef7d --- /dev/null +++ b/Mini_RAG/core/schema.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/24 21:11 +# @Author : hukangzhe +# @File : schema.py +# @Description : 不直接处理纯文本字符串. 引入一个标准化的数据结构来承载文本块,既有内容,也有元数据。 + +from dataclasses import dataclass, field +from typing import Dict, Any + + +@dataclass +class Document: + text: str + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class Chunk(Document): + parent_id: int = None \ No newline at end of file diff --git a/Mini_RAG/core/splitter.py b/Mini_RAG/core/splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8df02e6bd6fddbb1a8157211ba9983a10dfe63 --- /dev/null +++ b/Mini_RAG/core/splitter.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/25 19:52 +# @Author : hukangzhe +# @File : splitter.py +# @Description : 负责切分文本的模块 + +import logging +from typing import List, Dict, Tuple +from .schema import Document, Chunk + + +class SemanticRecursiveSplitter: + def __init__(self, chunk_size: int=500, chunk_overlap:int = 50, separators: List[str] = None): + """ + 一个真正实现递归的语义文本切分器。 + :param chunk_size: 每个文本块的目标大小。 + :param chunk_overlap: 文本块之间的重叠大小。 + :param separators: 用于切分的语义分隔符列表,按优先级从高到低排列。 + """ + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + if self.chunk_size <= self.chunk_overlap: + raise ValueError("Chunk overlap must be smaller than chunk size.") + + self.separators = separators if separators else ['\n\n', '\n', " ", ""] # 默认分割符 + + def text_split(self, text: str) -> List[str]: + """ + 切分入口 + :param text: + :return: + """ + logging.info("Starting semantic recursive splitting...") + final_chunks = self._split(text, self.separators) + logging.info(f"Text successfully split into {len(final_chunks)} chunks.") + return final_chunks + + def _split(self, text: str, separators: List[str]) -> List[str]: + final_chunks = [] + # 1. 如果文本足够小,直接返回 + if len(text) < self.chunk_size: + return [text] + # 2. 先尝试最高优先的分割符 + cur_separator = separators[0] + + # 3. 如果可以分割 + if cur_separator in text: + # 分割成多个小部分 + parts = text.split(cur_separator) + + buffer="" # 用来合并小部分 + for i, part in enumerate(parts): + # 如果小于chunk_size,就再加一小部分,使buffer接近chunk_size + if len(buffer) + len(part) + len(cur_separator) <= self.chunk_size: + buffer += part+cur_separator + else: + # 如果buffer 不为空 + if buffer: + final_chunks.append(buffer) + # 如果当前part就已经超过chunk_size + if len(part) > self.chunk_size: + # 递归调用下一级 + sub_chunks = self._split(part, separators = separators[1:]) + final_chunks.extend(sub_chunks) + else: # 成为新的缓冲区 + buffer = part + cur_separator + + if buffer: # 最后一部分的缓冲区 + final_chunks.append(buffer.strip()) + + else: + # 4. 使用下一级分隔符 + final_chunks = self._split(text, separators[1:]) + + # 处理重叠 + if self.chunk_overlap > 0: + return self._handle_overlap(final_chunks) + else: + return final_chunks + + def _handle_overlap(self, final_chunks: List[str]) -> List[str]: + overlap_chunks = [] + if not final_chunks: + return [] + overlap_chunks.append(final_chunks[0]) + for i in range(1, len(final_chunks)): + pre_chunk = overlap_chunks[-1] + cur_chunk = final_chunks[i] + # 从前一个chunk取出重叠部分与当前chunk合并 + overlap_part = pre_chunk[-self.chunk_overlap:] + overlap_chunks.append(overlap_part+cur_chunk) + + return overlap_chunks + + +class HierarchicalSemanticSplitter: + """ + 结合了层次化(父/子)和递归语义分割策略。确保在创建父块和子块时,遵循文本的自然语义边界。 + """ + def __init__(self, + parent_chunk_size: int = 800, + parent_chunk_overlap: int = 100, + child_chunk_size: int = 250, + separators: List[str] = None): + if parent_chunk_overlap >= parent_chunk_size: + raise ValueError("Parent chunk overlap must be smaller than parent chunk size.") + if child_chunk_size >= parent_chunk_size: + raise ValueError("Child chunk size must be smaller than parent chunk size.") + + self.parent_chunk_size = parent_chunk_size + self.parent_chunk_overlap = parent_chunk_overlap + self.child_chunk_size = child_chunk_size + self.separators = separators or ["\n\n", "\n", "。", ". ", "!", "!", "?", "?", " ", ""] + + def _recursive_semantic_split(self, text: str, chunk_size: int) -> List[str]: + """ + 优先考虑语义边界 + """ + if len(text) <= chunk_size: + return [text] + + for sep in self.separators: + split_point = text.rfind(sep, 0, chunk_size) + if split_point != -1: + break + else: + split_point = chunk_size + + chunk1 = text[:split_point] + remaining_text = text[split_point:].lstrip() # 删除剩余部分的前空格 + + # 递归拆分剩余文本 + # 分隔符将添加回第一个块以保持上下文 + if remaining_text: + return [chunk1 + (sep if sep in " \n" else "")] + self._recursive_semantic_split(remaining_text, chunk_size) + else: + return [chunk1] + + def _apply_overlap(self, chunks: List[str], overlap: int) -> List[str]: + """处理重叠部分chunk""" + if not overlap or len(chunks) <= 1: + return chunks + + overlapped_chunks = [chunks[0]] + for i in range(1, len(chunks)): + # 从前一个chunk中获取最后的“重叠”字符 + overlap_content = chunks[i - 1][-overlap:] + overlapped_chunks.append(overlap_content + chunks[i]) + + return overlapped_chunks + + def split_documents(self, documents: List[Document]) -> Tuple[Dict[int, Document], List[Chunk]]: + """ + 两次切分 + :param documents: + :return: + - parent documents: {parent_id: Document} + - child chunks: [Chunk, Chunk, ...] + """ + parent_docs_dict: Dict[int, Document] = {} + child_chunks_list: List[Chunk] = [] + parent_id_counter = 0 + + logging.info("Starting robust hierarchical semantic splitting...") + + for doc in documents: + # === PASS 1: 创建父chunks === + # 1. 将整个文档text分割成大的语义chunks + initial_parent_chunks = self._recursive_semantic_split(doc.text, self.parent_chunk_size) + + # 2. 父chunks进行重叠处理 + overlapped_parent_texts = self._apply_overlap(initial_parent_chunks, self.parent_chunk_overlap) + + for p_text in overlapped_parent_texts: + parent_doc = Document(text=p_text, metadata=doc.metadata.copy()) + parent_docs_dict[parent_id_counter] = parent_doc + + # === PASS 2: Create Child Chunks from each Parent === + child_texts = self._recursive_semantic_split(p_text, self.child_chunk_size) + + for c_text in child_texts: + child_metadata = doc.metadata.copy() + child_metadata['parent_id'] = parent_id_counter + child_chunk = Chunk(text=c_text, metadata=child_metadata, parent_id=parent_id_counter) + child_chunks_list.append(child_chunk) + + parent_id_counter += 1 + + logging.info( + f"Splitting complete. Created {len(parent_docs_dict)} parent chunks and {len(child_chunks_list)} child chunks.") + return parent_docs_dict, child_chunks_list diff --git a/Mini_RAG/core/test_qw.py b/Mini_RAG/core/test_qw.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3f83f1b3c1366dee367520cf6c1129b9be3535 --- /dev/null +++ b/Mini_RAG/core/test_qw.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/29 20:03 +# @Author : hukangzhe +# @File : test_qw.py +# @Description : 测试两个模型(qwen1.5 qwen3)的两种输出方式(full or stream)是否正确 +import os +import queue +import logging +import threading +import torch +from typing import Tuple, Generator +from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer + +class ThinkStreamer(TextStreamer): + def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool =True, **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.is_thinking = True + self.think_end_token_id = self.tokenizer.encode("", add_special_tokens=False)[0] + self.output_queue = queue.Queue() + + def on_finalized_text(self, text: str, stream_end: bool = False): + self.output_queue.put(text) + if stream_end: + self.output_queue.put(None) # 发送结束信号 + + def __iter__(self): + return self + + def __next__(self): + value = self.output_queue.get() + if value is None: + raise StopIteration() + return value + + def generate_output(self) -> Generator[Tuple[str, str], None, None]: + full_decode_text = "" + already_yielded_len = 0 + for text_chunk in self: + if not self.is_thinking: + yield "answer", text_chunk + continue + + full_decode_text += text_chunk + tokens = self.tokenizer.encode(full_decode_text, add_special_tokens=False) + + if self.think_end_token_id in tokens: + spilt_point = tokens.index(self.think_end_token_id) + think_part_tokens = tokens[:spilt_point] + thinking_text = self.tokenizer.decode(think_part_tokens) + + answer_part_tokens = tokens[spilt_point:] + answer_text = self.tokenizer.decode(answer_part_tokens) + remaining_thinking = thinking_text[already_yielded_len:] + if remaining_thinking: + yield "thinking", remaining_thinking + + if answer_text: + yield "answer", answer_text + + self.is_thinking = False + already_yielded_len = len(thinking_text) + len(self.tokenizer.decode(self.think_end_token_id)) + else: + yield "thinking", text_chunk + already_yielded_len += len(text_chunk) + + + +class LLMInterface: + def __init__(self, model_name: str= "Qwen/Qwen3-0.6B"): + logging.info(f"Initializing generator {model_name}") + self.generator_tokenizer = AutoTokenizer.from_pretrained(model_name) + self.generator_model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="auto") + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + def generate_answer(self, query: str, context_str: str) -> str: + messages = [ + {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"}, + {"role": "user", "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题:{query}"} + ] + prompt = self.generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(self.device) + output = self.generator_model.generate(**inputs, + max_new_tokens=256, num_return_sequences=1, + eos_token_id=self.generator_tokenizer.eos_token_id) + generated_ids = output[0][inputs["input_ids"].shape[1]:] + answer = self.generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + return answer + + def generate_answer_stream(self, query: str, context_str: str) -> Generator[Tuple[str, str], None, None]: + """Generates an answer as a stream of (state, content) tuples.""" + messages = [ + {"role": "system", + "content": "You are a helpful assistant. Please answer the question based on the provided context. First, think through the process in tags, then provide the final answer."}, + {"role": "user", + "content": f"Context:\n---\n{context_str}\n---\nBased on the context above, please answer the question: {query}"} + ] + + # Use the template that enables thinking for Qwen models + prompt = self.generator_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True + ) + model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device) + + streamer = ThinkStreamer(self.generator_tokenizer, skip_prompt=True) + + generation_kwargs = dict( + **model_inputs, + max_new_tokens=512, + streamer=streamer + ) + + thread = threading.Thread(target=self.generator_model.generate, kwargs=generation_kwargs) + thread.start() + + yield from streamer.generate_output() + + +# if __name__ == "__main__": +# qwen = LLMInterface("Qwen/Qwen3-0.6B") +# answer = qwen.generate_answer("儒家思想的创始人是谁?", "中国传统哲学以儒家、道家和法家为主要流派。儒家思想由孔子创立,强调“仁”、“义”、“礼”、“智”、“信”,主张修身齐家治国平天下,对中国社会产生了深远的影响。其核心价值观如“己所不欲,勿施于人”至今仍具有普世意义。"+ +# +# "道家思想以老子和庄子为代表,主张“道法自然”,追求人与自然的和谐统一,强调无为而治、清静无为。道家思想对中国人的审美情趣、艺术创作以及养生之道都有着重要的影响。"+ +# +# "法家思想以韩非子为集大成者,主张以法治国,强调君主的权威和法律的至高无上。尽管法家思想在历史上曾被用于强化中央集权,但其对建立健全的法律体系也提供了重要的理论基础。") +# +# print(answer) diff --git a/Mini_RAG/core/vector_store.py b/Mini_RAG/core/vector_store.py new file mode 100644 index 0000000000000000000000000000000000000000..6792e7151477341db43f19e148c5734d94e95bab --- /dev/null +++ b/Mini_RAG/core/vector_store.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/27 19:52 +# @Author : hukangzhe +# @File : retriever.py +# @Description : 负责向量化、存储、检索的模块 +import os +import faiss +import numpy as np +import pickle +import logging +from rank_bm25 import BM25Okapi +from typing import List, Dict, Tuple +from .schema import Document, Chunk + + +class HybridVectorStore: + def __init__(self, config: dict, embedder): + self.config = config["vector_store"] + self.embedder = embedder + self.faiss_index = None + self.bm25_index = None + self.parent_docs: Dict[int, Document] = {} + self.child_chunks: List[Chunk] = [] + + def build(self, parent_docs: Dict[int, Document], child_chunks: List[Chunk]): + self.parent_docs = parent_docs + self.child_chunks = child_chunks + + # Build Faiss index + child_text = [child.text for child in child_chunks] + embeddings = self.embedder.embed(child_text) + dim = embeddings.shape[1] + self.faiss_index = faiss.IndexFlatL2(dim) + self.faiss_index.add(embeddings) + logging.info(f"FAISS index built with {len(child_chunks)} vectors.") + + # Build BM25 index + tokenize_chunks = [doc.text.split(" ") for doc in child_chunks] + self.bm25_index = BM25Okapi(tokenize_chunks) + logging.info(f"BM25 index built for {len(child_chunks)} documents.") + + self.save() + + def search(self, query: str, top_k: int , alpha: float) -> List[Tuple[int, float]]: + # Vector Search + query_embedding = self.embedder.embed([query]) + distances, indices = self.faiss_index.search(query_embedding, k=top_k) + vector_scores = {idx : 1.0/(1.0 + dist) for idx, dist in zip(indices[0], distances[0])} + + # BM25 Search + tokenize_query = query.split(" ") + bm25_scores = self.bm25_index.get_scores(tokenize_query) + bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k] + bm25_scores = {idx: bm25_scores[idx] for idx in bm25_top_indices} + + # Hybrid Search + all_indices = set(vector_scores.keys()) | set(bm25_scores.keys()) # 求并集 + hybrid_scors = {} + + # Normalization + max_v_score = max(vector_scores.values()) if vector_scores else 1.0 + max_b_score = max(bm25_scores.values()) if bm25_scores else 1.0 + for idx in all_indices: + v_score = (vector_scores.get(idx, 0))/max_v_score + b_score = (bm25_scores.get(idx, 0))/max_b_score + hybrid_scors[idx] = alpha * v_score + (1 - alpha) * b_score + + sorted_indices = sorted(hybrid_scors.items(), key=lambda item: item[1], reverse=True)[:top_k] + return sorted_indices + + def get_chunks(self, indices: List[int]) -> List[Chunk]: + return [self.child_chunks[i] for i in indices] + + def get_parent_docs(self, chunks: List[Chunk]) -> List[Document]: + parent_ids = sorted(list(set(chunk.parent_id for chunk in chunks))) + return [self.parent_docs[pid] for pid in parent_ids] + + def save(self): + index_path = self.config['index_path'] + metadata_path = self.config['metadata_path'] + + os.makedirs(os.path.dirname(index_path), exist_ok=True) + os.makedirs(os.path.dirname(metadata_path), exist_ok=True) + logging.info(f"Saving FAISS index to: {index_path}") + try: + faiss.write_index(self.faiss_index, index_path) + except Exception as e: + logging.error(f"Failed to save FAISS index: {e}") + raise + + logging.info(f"Saving metadata data to: {metadata_path}") + try: + with open(metadata_path, 'wb') as f: + metadata = { + 'parent_docs': self.parent_docs, + 'child_chunks': self.child_chunks, + 'bm25_index': self.bm25_index + } + pickle.dump(metadata, f) + except Exception as e: + logging.error(f"Failed to save metadata: {e}") + raise + + logging.info("Vector store saved successfully.") + + def load(self) -> bool: + """ + 从磁盘加载整个向量存储状态,成功时返回 True,失败时返回 False。 + """ + index_path = self.config['index_path'] + metadata_path = self.config['metadata_path'] + + if not os.path.exists(index_path) or not os.path.exists(metadata_path): + logging.warning("Index files not found. Cannot load vector store.") + return False + + logging.info(f"Loading vector store from disk...") + try: + # Load FAISS index + logging.info(f"Loading FAISS index from: {index_path}") + self.faiss_index = faiss.read_index(index_path) + + # Load metadata + logging.info(f"Loading metadata from: {metadata_path}") + with open(metadata_path, 'rb') as f: + metadata = pickle.load(f) + self.parent_docs = metadata['parent_docs'] + self.child_chunks = metadata['child_chunks'] + self.bm25_index = metadata['bm25_index'] + + logging.info("Vector store loaded successfully.") + return True + + except Exception as e: + logging.error(f"Failed to load vector store from disk: {e}") + self.faiss_index = None + self.bm25_index = None + self.parent_docs = {} + self.child_chunks = [] + return False + + + + + + + + + + + + + + + + + diff --git a/Mini_RAG/data/chinese_document.pdf b/Mini_RAG/data/chinese_document.pdf new file mode 100644 index 0000000000000000000000000000000000000000..cf3c783f36f435c8a29f5d77e52efceb04852205 --- /dev/null +++ b/Mini_RAG/data/chinese_document.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bde4aba4245bb013438240cabf4c149d8aec76536f7a1459a165db02ec3695c +size 317035 diff --git a/Mini_RAG/data/english_document.pdf b/Mini_RAG/data/english_document.pdf new file mode 100644 index 0000000000000000000000000000000000000000..214d2ea71708ac3cd25d9dc02c8a9213fdefe17c --- /dev/null +++ b/Mini_RAG/data/english_document.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5aa48e1b7d30d8cf2ac87cae6e2fee73e2de4d87178d0878d0d34b499fc7b26d +size 166331 diff --git a/Mini_RAG/main.py b/Mini_RAG/main.py new file mode 100644 index 0000000000000000000000000000000000000000..032ece81fc60f6a0415af85f2359e0b1511ad874 --- /dev/null +++ b/Mini_RAG/main.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/5/1 14:00 +# @Author : hukangzhe +# @File : main.py +# @Description : +import yaml +from service.rag_service import RAGService +from ui.app import GradioApp +from utils.logger import setup_logger + + +def main(): + setup_logger() + + with open('configs/config.yaml', 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + rag_service = RAGService(config) + app = GradioApp(rag_service) + app.launch() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Mini_RAG/rag_mini.py b/Mini_RAG/rag_mini.py new file mode 100644 index 0000000000000000000000000000000000000000..813c892b4ba93a3bfd33089fc2baea497b6fab59 --- /dev/null +++ b/Mini_RAG/rag_mini.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/24 14:04 +# @Author : hukangzhe +# @File : rag_core.py +# @Description :非常简单的RAG系统 + +import PyPDF2 +import fitz +from sentence_transformers import SentenceTransformer, CrossEncoder +from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM +import numpy as np +import faiss +import torch + + +class RAGSystem: + def __init__(self, pdf_path): + self.pdf_path = pdf_path + self.texts = self._load_and_spilt_pdf() + self.embedder = SentenceTransformer('moka-ai/m3e-base') + self.reranker = CrossEncoder('BAAI/bge-reranker-base') # 加载一个reranker模型 + self.vector_store = self._create_vector_store() + print("3. Initializing Generator Model...") + model_name = "Qwen/Qwen1.5-1.8B-Chat" + + # 检查是否有可用的GPU + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f" - Using device: {device}") + + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + # 注意:对于像Qwen这样的模型,我们通常使用 AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(model_name).to(device) + + self.generator = pipeline( + 'text-generation', + model=model, + tokenizer=self.tokenizer, + device=device + ) + + # 1. 文档加载 & 2.文本切分 (为了简化,合在一起) + def _load_and_spilt_pdf(self): + print("1. Loading and splitting PDF...") + full_text = "" + with fitz.open(self.pdf_path) as doc: + for page in doc: + full_text += page.get_text() + + # 非常基础的切分:根据固定大小 + chunk_size = 500 + overlap = 50 + chunks = [full_text[i: i+chunk_size] for i in range(0, len(full_text), chunk_size-overlap)] + print(f" - Splitted into {len(chunks)} chunks.") + return chunks + + # 3. 文本向量化 & 向量存储 + def _create_vector_store(self): + print("2. Creating vector store...") + # embedding + embeddings = self.embedder.encode(self.texts) + + # Storing with faiss + dim = embeddings.shape[1] + index = faiss.IndexFlatL2(dim) # 使用L2距离进行相似度计算 + index.add(np.array(embeddings)) + print(" - Created vector store") + return index + + # 4.检索 + def retrieve(self, query, k=3): + print(f"3. Retrieving top {k} relevant chunks for query: '{query}' ") + query_embeddings = self.embedder.encode([query]) + + distances, indices = self.vector_store.search(np.array(query_embeddings), k=k) + retrieved_chunks = [self.texts[i] for i in indices[0]] + print(" - Retrieval complete.") + return retrieved_chunks + + # 5.生成 + def generate(self, query, context_chunks): + print("4. Generate answer...") + context = "\n".join(context_chunks) + + messages = [ + {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"}, + {"role": "user", "content": f"上下文:\n---\n{context}\n---\n请根据以上上下文回答这个问题:{query}"} + ] + prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + # print("Final Prompt:\n", prompt) + # print("Prompt token length:", len(self.tokenizer.encode(prompt))) + + result = self.generator(prompt, max_new_tokens=200, num_return_sequences=1, + eos_token_id=self.tokenizer.eos_token_id) + + print(" - Generation complete.") + # print("Raw results:", result) + # 提取生成的文本 + # 注意:Qwen模型返回的文本包含了prompt,我们需要从中提取出答案部分 + full_response = result[0]["generated_text"] + answer = full_response[len(prompt):].strip() # 从prompt之后开始截取 + + # print("Final Answer:", repr(answer)) + return answer + + # 优化1 + def rerank(self, query, chunks): + print(" - Reranking retrieved chunks...") + pairs = [[query, chunk] for chunk in chunks] + scores = self.reranker.predict(pairs) + + # 将chunks和scores打包,并按score降序排序 + ranked_chunks = sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True) + return [chunk for chunk, score in ranked_chunks] + + def query(self, query_text): + # 1. 检索(可以检索更多结果,如top 10) + retrieved_chunks = self.retrieve(query_text, k=10) + + # 2. 重排(从10个中选出最相关的3个) + reranked_chunks = self.rerank(query_text, retrieved_chunks) + top_k_reranked = reranked_chunks[:3] + + answer = self.generate(query_text, top_k_reranked) + return answer + + +def main(): + # 确保你的data文件夹里有一个叫做sample.pdf的文件 + pdf_path = 'data/chinese_document.pdf' + + print("Initializing RAG System...") + rag_system = RAGSystem(pdf_path) + print("\nRAG System is ready. You can start asking questions.") + print("Type 'q' to quit.") + + while True: + user_query = input("\nYour Question: ") + if user_query.lower() == 'q': + break + + answer = rag_system.query(user_query) + print("\nAnswer:", answer) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Mini_RAG/requirements.txt b/Mini_RAG/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c5a1afd4a08714741949343784a3e878e0e69ec8 --- /dev/null +++ b/Mini_RAG/requirements.txt @@ -0,0 +1,9 @@ +pyyaml +pypdf2 +PyMuPDF +sentence-transformers +faiss-cpu +rank_bm25 +transformers +torch +gradio \ No newline at end of file diff --git a/Mini_RAG/service/__init__.py b/Mini_RAG/service/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Mini_RAG/service/__pycache__/__init__.cpython-39.pyc b/Mini_RAG/service/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e4b381495c7815b79f4444a2a3d2f5cad4027ae Binary files /dev/null and b/Mini_RAG/service/__pycache__/__init__.cpython-39.pyc differ diff --git a/Mini_RAG/service/__pycache__/rag_service.cpython-39.pyc b/Mini_RAG/service/__pycache__/rag_service.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f04ec204633814e2a126918295c729eb3e500c98 Binary files /dev/null and b/Mini_RAG/service/__pycache__/rag_service.cpython-39.pyc differ diff --git a/Mini_RAG/service/rag_service.py b/Mini_RAG/service/rag_service.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fb58099fc1372b269ae32c312e083f330adfe0 --- /dev/null +++ b/Mini_RAG/service/rag_service.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/30 11:50 +# @Author : hukangzhe +# @File : rag_service.py +# @Description : +import logging +import os +from typing import List, Generator, Tuple +from core.schema import Document +from core.embedder import EmbeddingModel +from core.loader import MultiDocumentLoader +from core.splitter import HierarchicalSemanticSplitter +from core.vector_store import HybridVectorStore +from core.llm_interface import LLMInterface + + +class RAGService: + def __init__(self, config: dict): + self.config = config + logging.info("Initializing RAG Service...") + self.embedder = EmbeddingModel(config['models']['embedding']) + self.vector_store = HybridVectorStore(config, self.embedder) + self.llm = LLMInterface(config) + self.is_ready = False # 是否准备好进行查询 + logging.info("RAG Service initialized. Knowledge base is not loaded.") + + def load_knowledge_base(self) -> Tuple[bool, str]: + """ + 尝试从磁盘加载 + Returns: + A tuple (success: bool, message: str) + """ + if self.is_ready: + return True, "Knowledge base is already loaded." + + logging.info("Attempting to load knowledge base from disk...") + success = self.vector_store.load() + if success: + self.is_ready = True + message = "Knowledge base loaded successfully from disk." + logging.info(message) + return True, message + else: + self.is_ready = False + message = "No existing knowledge base found or failed to load. Please build a new one." + logging.warning(message) + return False, message + + def build_knowledge_base(self, file_paths: List[str]) -> Generator[str, None, None]: + self.is_ready = False + yield "Step 1/3: Loading documents..." + loader = MultiDocumentLoader(file_paths) + docs = loader.load() + + yield "Step 2/3: Splitting documents into hierarchical chunks..." + splitter = HierarchicalSemanticSplitter( + parent_chunk_size=self.config['splitter']['parent_chunk_size'], + parent_chunk_overlap=self.config['splitter']['parent_chunk_overlap'], + child_chunk_size=self.config['splitter']['child_chunk_size'] + ) + parent_docs, child_chunks = splitter.split_documents(docs) + + yield "Step 3/3: Building and saving vector index..." + self.vector_store.build(parent_docs, child_chunks) + self.is_ready = True + yield "Knowledge base built and ready!" + + def _get_context_and_sources(self, query: str) -> List[Document]: + if not self.is_ready: + raise Exception("Knowledge base is not ready. Please build it first.") + + # Hybrid Search to get child chunks + retrieved_child_indices_scores = self.vector_store.search( + query, + top_k=self.config['retrieval']['retrieval_top_k'], + alpha=self.config['retrieval']['hybrid_search_alpha'] + ) + retrieved_child_indices = [idx for idx, score in retrieved_child_indices_scores] + retrieved_child_chunks = self.vector_store.get_chunks(retrieved_child_indices) + + # Get Parent Documents + retrieved_parent_docs = self.vector_store.get_parent_docs(retrieved_child_chunks) + + # Rerank Parent Documents + reranked_docs = self.llm.rerank(query, retrieved_parent_docs) + final_context_docs = reranked_docs[:self.config['retrieval']['rerank_top_k']] + + return final_context_docs + + def get_response_full(self, query: str) ->(str, List[Document]): + final_context_docs = self._get_context_and_sources(query) + answer = self.llm.generate_answer(query, final_context_docs) + return answer, final_context_docs + + def get_response_stream(self, query: str) ->(Generator[str, None, None], List[Document]): + final_context_docs = self._get_context_and_sources(query) + answer_generator = self.llm.generate_answer_stream(query, final_context_docs) + return answer_generator, final_context_docs + + def get_context_string(self, context_docs: List[Document]) -> str: + context_str = "引用上下文 (Context Sources):\n\n" + for doc in context_docs: + source_info = f"--- (来源: {os.path.basename(doc.metadata.get('source', ''))}, 页码: {doc.metadata.get('page', 'N/A')}) ---\n" + content = doc.text[:200]+"..." if len(doc.text) > 200 else doc.text + context_str += source_info + content + "\n\n" + return context_str.strip() + + + diff --git a/Mini_RAG/storage/faiss_index/chunks.pkl b/Mini_RAG/storage/faiss_index/chunks.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a084a4cd7d4834a6e8929987bce62976f5b9101b --- /dev/null +++ b/Mini_RAG/storage/faiss_index/chunks.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da7fae509a97a93f69b8048e375fca25dd45f0e26a40bff83c30759c5853e473 +size 27663 diff --git a/Mini_RAG/storage/faiss_index/index.faiss b/Mini_RAG/storage/faiss_index/index.faiss new file mode 100644 index 0000000000000000000000000000000000000000..3c2e5a8e2cd09faee8ced8ee789d661f50e1a434 Binary files /dev/null and b/Mini_RAG/storage/faiss_index/index.faiss differ diff --git a/Mini_RAG/ui/__init__.py b/Mini_RAG/ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Mini_RAG/ui/__pycache__/__init__.cpython-39.pyc b/Mini_RAG/ui/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15bb473a52d09cfd63c269b912e6a93bbf57995a Binary files /dev/null and b/Mini_RAG/ui/__pycache__/__init__.cpython-39.pyc differ diff --git a/Mini_RAG/ui/__pycache__/app.cpython-39.pyc b/Mini_RAG/ui/__pycache__/app.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ceb60199ef7e2d2327109894ee63fa5ae7a69b8f Binary files /dev/null and b/Mini_RAG/ui/__pycache__/app.cpython-39.pyc differ diff --git a/Mini_RAG/ui/app.py b/Mini_RAG/ui/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5990d22ce4ec0ec2019beb57c8fe5a20555225bd --- /dev/null +++ b/Mini_RAG/ui/app.py @@ -0,0 +1,202 @@ +import gradio as gr +import os +from typing import List, Tuple + +from service.rag_service import RAGService + + +class GradioApp: + def __init__(self, rag_service: RAGService): + self.rag_service = rag_service + self._build_ui() + + def _build_ui(self): + with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), + title="Enterprise RAG System") as self.demo: + gr.Markdown("# 企业级RAG智能问答系统 (Enterprise RAG System)") + gr.Markdown("您可以**加载现有知识库**快速开始,或**上传新文档**构建一个全新的知识库。") + + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("### 控制面板 (Control Panel)") + + self.load_kb_button = gr.Button("加载已有知识库 (Load Existing KB)") + + gr.Markdown("
") + + self.file_uploader = gr.File( + label="上传新文档以构建 (Upload New Docs to Build)", + file_count="multiple", + file_types=[".pdf", ".txt"], + interactive=True + ) + self.build_kb_button = gr.Button("构建新知识库 (Build New KB)", variant="primary") + + self.status_box = gr.Textbox( + label="系统状态 (System Status)", + value="系统已初始化,等待加载或构建知识库。", + interactive=False, + lines=4 + ) + + # --- 刚开始隐藏,构建了数据库再显示 --- + with gr.Column(scale=2, visible=False) as self.chat_area: + gr.Markdown("### 对话窗口 (Chat Window)") + self.chatbot = gr.Chatbot(label="RAG Chatbot", bubble_full_width=False, height=500) + self.mode_selector = gr.Radio( + ["流式输出(Streaming)","一次性输出(Full)"], + label="输出模式:(Output Mode)", + value="流式输出(Streaming)" + ) + self.question_box = gr.Textbox(label="您的问题", placeholder="请在此处输入您的问题...", + show_label=False) + with gr.Row(): + self.submit_btn = gr.Button("提交 (Submit)", variant="primary") + self.clear_btn = gr.Button("清空历史 (Clear History)") + + gr.Markdown("---") + self.source_display = gr.Markdown("### 引用来源 (Sources)") + + # --- Event Listeners --- + self.load_kb_button.click( + fn=self._handle_load_kb, + inputs=None, + outputs=[self.status_box, self.chat_area] + ) + + self.build_kb_button.click( + fn=self._handle_build_kb, + inputs=[self.file_uploader], + outputs=[self.status_box, self.chat_area] + ) + + self.submit_btn.click( + fn=self._handle_chat_submission, + inputs=[self.question_box, self.chatbot, self.mode_selector], + outputs=[self.chatbot, self.question_box, self.source_display] + ) + + self.question_box.submit( + fn=self._handle_chat_submission, + inputs=[self.question_box, self.chatbot, self.mode_selector], + outputs=[self.chatbot, self.question_box, self.source_display] + ) + + self.clear_btn.click( + fn=self._clear_chat, + inputs=None, + outputs=[self.chatbot, self.question_box, self.source_display] + ) + + def _handle_load_kb(self): + """处理现有知识库的加载。返回更新字典。""" + success, message = self.rag_service.load_knowledge_base() + if success: + return { + self.status_box: gr.update(value=message), + self.chat_area: gr.update(visible=True) + } + else: + return { + self.status_box: gr.update(value=message), + self.chat_area: gr.update(visible=False) + } + + def _handle_build_kb(self, files: List[str], progress=gr.Progress(track_tqdm=True)): + """构建新知识库,返回更新的字典.""" + if not files: + # --- MODIFIED LINE --- + return { + self.status_box: gr.update(value="错误:请至少上传一个文档。"), + self.chat_area: gr.update(visible=False) + } + + file_paths = [file.name for file in files] + + try: + for status in self.rag_service.build_knowledge_base(file_paths): + progress(0.5, desc=status) + + final_status = "知识库构建完成并已就绪!√" + # --- MODIFIED LINE --- + return { + self.status_box: gr.update(value=final_status), + self.chat_area: gr.update(visible=True) + } + except Exception as e: + error_message = f"构建失败: {e}" + # --- MODIFIED LINE --- + return { + self.status_box: gr.update(value=error_message), + self.chat_area: gr.update(visible=False) + } + + def _handle_chat_submission(self, question: str, history: List[Tuple[str, str]], mode: str): + if not question or not question.strip(): + yield history, "", "### 引用来源 (Sources)\n" + return + + history.append((question, "")) + + try: + # 一次全部输出 + if "Full" in mode: + yield history, "", "### 引用来源 (Sources)\n" + + answer, sources = self.rag_service.get_response_full(question) + # 获取引用内容 + context_string_for_display = self.rag_service.get_context_string(sources) + # 修改格式 + source_text_for_panel = self._format_sources(sources) + #完整内容:引用+回答 + full_response = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n{answer}" + history[-1] = (question, full_response) + + yield history, "", source_text_for_panel + + # 流式输出 + else: + answer_generator, sources = self.rag_service.get_response_stream(question) + + context_string_for_display = self.rag_service.get_context_string(sources) + + source_text_for_panel = self._format_sources(sources) + + yield history, "", source_text_for_panel + + response_prefix = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n" + history[-1] = (question, response_prefix) + yield history, "", source_text_for_panel + + answer_log = "" + for text_chunk in answer_generator: + answer_log += text_chunk + history[-1] = (question, response_prefix + answer_log) + yield history, "", source_text_for_panel + + except Exception as e: + error_response = f"处理请求时出错: {e}" + history[-1] = (question, error_response) + yield history, "", "### 引用来源 (Sources)\n" + + def _format_sources(self, sources: List) -> str: + source_text = "### 引用来源 (sources)\n)" + if not sources: + return source_text + + unique_sources = set() + for doc in sources: + source_name = os.path.basename(doc.metadata.get('source', 'Unknown')) + page_num = doc.metadata.get('page', 'N/A') + unique_sources.add(f"- **{source_name}** (Page: {page_num})") + + source_text += "\n".join(sorted(list(unique_sources))) + return source_text + + def _clear_chat(self): + """清理聊天内容""" + return None, "", "### 引用来源 (Sources)\n" + + def launch(self): + self.demo.queue().launch() + diff --git a/Mini_RAG/utils/__init__.py b/Mini_RAG/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Mini_RAG/utils/__pycache__/__init__.cpython-39.pyc b/Mini_RAG/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ddccd3328327de77cb1351661e3593a951a5768 Binary files /dev/null and b/Mini_RAG/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/Mini_RAG/utils/__pycache__/logger.cpython-39.pyc b/Mini_RAG/utils/__pycache__/logger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a39a45115fb2ce92fd3a19ffa7ec666683f1e0b Binary files /dev/null and b/Mini_RAG/utils/__pycache__/logger.cpython-39.pyc differ diff --git a/Mini_RAG/utils/logger.py b/Mini_RAG/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..975c1081b7df6e96b5cc94194dfbeddeffe15567 --- /dev/null +++ b/Mini_RAG/utils/logger.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/25 19:51 +# @Author : hukangzhe +# @File : logger.py.py +# @Description : 日志配置模块 +import logging +import sys + + +def setup_logger(): + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s", + handlers=[ + logging.FileHandler("storage/logs/app.log"), + logging.StreamHandler(sys.stdout) + ] + ) + diff --git a/README.md b/README.md index fe64a7ec2784660c2eee38447c44c3060e4f4caf..f47dbce21502cc0ac57a7b25aef7c87b1b7146c1 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,185 @@ ---- -title: Mini RAG -emoji: 💻 -colorFrom: blue -colorTo: yellow -sdk: gradio -sdk_version: 5.44.1 -app_file: app.py -pinned: false -short_description: 一个小的RAG ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# RAG_Mini +--- + +# Enterprise-Ready RAG System with Gradio Interface + +This is a powerful, enterprise-grade Retrieval-Augmented Generation (RAG) system designed to transform your documents into an interactive and intelligent knowledge base. Users can upload their own documents (PDFs, TXT files), build a searchable vector index, and ask complex questions in natural language to receive accurate, context-aware answers sourced directly from the provided materials. + +The entire application is wrapped in a clean, user-friendly web interface powered by Gradio. + +![App Screenshot](assets/1.png) +![App Screenshot](assets/2.png) + +## ✨ Features + +- **Intuitive Web UI**: Simple, clean interface built with Gradio for uploading documents and chatting. +- **Multi-Document Support**: Natively handles PDF and TXT files. +- **Advanced Text Splitting**: Uses a `HierarchicalSemanticSplitter` that first splits documents into large parent chunks (for context) and then into smaller child chunks (for precise search), respecting semantic boundaries. +- **Hybrid Search**: Combines the strengths of dense vector search (FAISS) and sparse keyword search (BM25) for robust and accurate retrieval. +- **Reranking for Accuracy**: Employs a Cross-Encoder model to rerank the retrieved documents, ensuring the most relevant context is passed to the language model. +- **Persistent Knowledge Base**: Automatically saves the built vector index and metadata, allowing you to load an existing knowledge base instantly on startup. +- **Modular & Extensible Codebase**: The project is logically structured into services for loading, splitting, embedding, and generation, making it easy to maintain and extend. + +## 🏛️ System Architecture + +The RAG pipeline follows a logical, multi-step process to ensure high-quality answers: + +1. **Load**: Documents are loaded from various formats and parsed into a standardized `Document` object, preserving metadata like source and page number. +2. **Split**: The raw text is processed by the `HierarchicalSemanticSplitter`, creating parent and child text chunks. This provides both broad context and fine-grained detail. +3. **Embed & Index**: The child chunks are converted into vector embeddings using a `SentenceTransformer` model and indexed in a FAISS vector store. A parallel BM25 index is also built for keyword search. +4. **Retrieve**: When a user asks a question, a hybrid search query is performed against the FAISS and BM25 indices to retrieve the most relevant child chunks. +5. **Fetch Context**: The parent chunks corresponding to the retrieved child chunks are fetched. This ensures the LLM receives a wider, more complete context. +6. **Rerank**: A powerful Cross-Encoder model re-evaluates the relevance of the parent chunks against the query, pushing the best matches to the top. +7. **Generate**: The top-ranked, reranked documents are combined with the user's query into a final prompt. This prompt is sent to a Large Language Model (LLM) to generate a final, coherent answer. + +``` +[User Uploads Docs] -> [Loader] -> [Splitter] -> [Embedder & Vector Store] -> [Knowledge Base Saved] + +[User Asks Question] -> [Hybrid Search] -> [Get Parent Docs] -> [Reranker] -> [LLM] -> [Answer & Sources] +``` + +## 🛠️ Tech Stack + +- **Backend**: Python 3.9+ +- **UI**: Gradio +- **LLM & Embedding Framework**: Hugging Face Transformers, Sentence-Transformers +- **Vector Search**: Faiss (from Facebook AI) +- **Keyword Search**: rank-bm25 +- **PDF Parsing**: PyMuPDF (fitz) +- **Configuration**: PyYAML + +## 🚀 Getting Started + +Follow these steps to set up and run the project on your local machine. + +### 1. Prerequisites + +- Python 3.9 or higher +- `pip` for package management + +### 2. Create a `requirements.txt` file + +Before proceeding, it's crucial to have a `requirements.txt` file so others can easily install the necessary dependencies. In your activated terminal, run: + +```bash +pip freeze > requirements.txt +``` +This will save all the packages from your environment into the file. Make sure this file is committed to your GitHub repository. The key packages it should contain are: `gradio`, `torch`, `transformers`, `sentence-transformers`, `faiss-cpu`, `rank_bm25`, `PyMuPDF`, `pyyaml`, `numpy`. + +### 3. Installation & Setup + +**1. Clone the repository:** +```bash +git clone https://github.com/YOUR_USERNAME/YOUR_REPOSITORY_NAME.git +cd YOUR_REPOSITORY_NAME +``` + +**2. Create and activate a virtual environment (recommended):** +```bash +# For Windows +python -m venv venv +.\venv\Scripts\activate + +# For macOS/Linux +python3 -m venv venv +source venv/bin/activate +``` + +**3. Install the required packages:** +```bash +pip install -r requirements.txt +``` + +**4. Configure the system:** +Review the `configs/config.yaml` file. You can change the models, chunk sizes, and other parameters here. The default settings are a good starting point. + +> **Note:** The first time you run the application, the models specified in the config file will be downloaded from Hugging Face. This may take some time depending on your internet connection. + +### 4. Running the Application + +To start the Gradio web server, run the `main.py` script: + +```bash +python main.py +``` + +The application will be available at **`http://localhost:7860`**. + +## 📖 How to Use + +The application has two primary workflows: + +**1. Build a New Knowledge Base:** + - Drag and drop one or more `.pdf` or `.txt` files into the "Upload New Docs to Build" area. + - Click the **"Build New KB"** button. + - The system status will show the progress (Loading -> Splitting -> Indexing). + - Once complete, the status will confirm that the knowledge base is ready, and the chat window will appear. + +**2. Load an Existing Knowledge Base:** + - If you have previously built a knowledge base, simply click the **"Load Existing KB"** button. + - The system will load the saved FAISS index and metadata from the `storage` directory. + - The chat window will appear, and you can start asking questions immediately. + +**Chatting with Your Documents:** + - Once the knowledge base is ready, type your question into the chat box at the bottom and press Enter or click "Submit". + - The model will generate an answer based on the documents you provided. + - The sources used to generate the answer will be displayed below the chat window. + +## 📂 Project Structure + +``` +. +├── configs/ +│ └── config.yaml # Main configuration file for models, paths, etc. +├── core/ +│ ├── embedder.py # Handles text embedding. +│ ├── llm_interface.py # Handles reranking and answer generation. +│ ├── loader.py # Loads and parses documents. +│ ├── schema.py # Defines data structures (Document, Chunk). +│ ├── splitter.py # Splits documents into chunks. +│ └── vector_store.py # Manages FAISS & BM25 indices. +├── service/ +│ └── rag_service.py # Orchestrates the entire RAG pipeline. +├── storage/ # Default location for saved indices (auto-generated). +│ └── ... +├── ui/ +│ └── app.py # Contains the Gradio UI logic. +├── utils/ +│ └── logger.py # Logging configuration. +├── assets/ +│ └── 1.png # Screenshot of the application. +├── main.py # Entry point to run the application. +└── requirements.txt # Python package dependencies. +``` + +## 🔧 Configuration Details (`config.yaml`) + +You can customize the RAG pipeline by modifying `configs/config.yaml`: + +- **`models`**: Specify the Hugging Face models for embedding, reranking, and generation. +- **`vector_store`**: Define the paths where the FAISS index and metadata will be saved. +- **`splitter`**: Control the `HierarchicalSemanticSplitter` behavior. + - `parent_chunk_size`: The target size for larger context chunks. + - `parent_chunk_overlap`: The overlap between parent chunks. + - `child_chunk_size`: The target size for smaller, searchable chunks. +- **`retrieval`**: Tune the retrieval and reranking process. + - `retrieval_top_k`: How many initial candidates to retrieve with hybrid search. + - `rerank_top_k`: How many final documents to pass to the LLM after reranking. + - `hybrid_search_alpha`: The weighting between vector search (`alpha`) and BM25 search (`1 - alpha`). `1.0` is pure vector search, `0.0` is pure keyword search. +- **`generation`**: Set parameters for the final answer generation, like `max_new_tokens`. + +## 🛣️ Future Roadmap + +- [ ] Support for more document types (e.g., `.docx`, `.pptx`, `.html`). +- [ ] Implement response streaming for a more interactive chat experience. +- [ ] Integrate with other vector databases like ChromaDB or Pinecone. +- [ ] Create API endpoints for programmatic access to the RAG service. +- [ ] Add more advanced logging and monitoring for enterprise use. + +## 🤝 Contributing + +Contributions are welcome! If you have ideas for improvements or find a bug, please feel free to open an issue or submit a pull request. + +## 📄 License + +This project is licensed under the MIT License. See the `LICENSE` file for details. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..1864745105bf97476bc002569275138b8e910375 --- /dev/null +++ b/app.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/9/3 09:38 +# @Author : hukangzhe +# @File : app.py +# @Description : +import yaml +from service.rag_service import RAGService +from ui.app import GradioApp +from utils.logger import setup_logger + + +def main(): + setup_logger() + + with open('configs/config.yaml', 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + rag_service = RAGService(config) + app = GradioApp(rag_service) + app.launch() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/assets/1.png b/assets/1.png new file mode 100644 index 0000000000000000000000000000000000000000..def6f1b1f4f00ea6b493b7cdece6a228480c5934 --- /dev/null +++ b/assets/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3cfd5a1e1f78574b7ad19ba82bc14aefedc30e198293a36b34fddbe58ae9572a +size 117169 diff --git a/assets/2.png b/assets/2.png new file mode 100644 index 0000000000000000000000000000000000000000..04960ccb2184a5a930b0bd4be061c294a142b68e --- /dev/null +++ b/assets/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:488e38ddb8e075942a427f604b290abc27121023b44ceab74b2447ccd5d39ecd +size 144857 diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..249028c3740d4f6e40b7e12d273319203640579d --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,26 @@ +# 路径配置 +storage_path: "./storage" +vector_store: + index_path: "./storage/faiss_index/index.faiss" + metadata_path: "./storage/faiss_index/chunks.pkl" + +# 切分器配置 +splitter: + parent_chunk_size: 800 + parent_chunk_overlap: 100 + child_chunk_size: 250 + +# 模型配置 +models: + embedding: "moka-ai/m3e-base" + reranker: "BAAI/bge-reranker-base" + llm_generator: "Qwen/Qwen3-0.6B" # Qwen/Qwen1.5-1.8B-Chat or Qwen/Qwen3-0.6B + + +# 检索与生成参数 +retrieval: + hybrid_search_alpha: 0.5 # 混合检索中向量权(0-1), 关键词权为1-alpha + retrieval_top_k: 20 + rerank_top_k: 5 +generation: + max_new_tokens: 512 diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f5a12faa99758192ecc4ed3fc22c9249232e86 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1 @@ + diff --git a/core/__pycache__/__init__.cpython-39.pyc b/core/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..490fc55d5de07f75e7b79a8526715e75207b5e33 Binary files /dev/null and b/core/__pycache__/__init__.cpython-39.pyc differ diff --git a/core/__pycache__/embedder.cpython-39.pyc b/core/__pycache__/embedder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b33e5057272a15b09b1bf751a6082bfd8a72e68 Binary files /dev/null and b/core/__pycache__/embedder.cpython-39.pyc differ diff --git a/core/__pycache__/llm_interface.cpython-39.pyc b/core/__pycache__/llm_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17317a98ba48afdddb28d1cdcf421f61843a88ed Binary files /dev/null and b/core/__pycache__/llm_interface.cpython-39.pyc differ diff --git a/core/__pycache__/loader.cpython-39.pyc b/core/__pycache__/loader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e832df0a9e3abbeb0443ef56884f79577401fee Binary files /dev/null and b/core/__pycache__/loader.cpython-39.pyc differ diff --git a/core/__pycache__/schema.cpython-39.pyc b/core/__pycache__/schema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d379d281390f817f93027d9084d701b0040cdd5a Binary files /dev/null and b/core/__pycache__/schema.cpython-39.pyc differ diff --git a/core/__pycache__/splitter.cpython-39.pyc b/core/__pycache__/splitter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7bf2fd07fd733b59cd01dc4b4785fd01d31c20f Binary files /dev/null and b/core/__pycache__/splitter.cpython-39.pyc differ diff --git a/core/__pycache__/vector_store.cpython-39.pyc b/core/__pycache__/vector_store.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01876546ce639ba8b2e37b0bf6b5a77b1bf79bf7 Binary files /dev/null and b/core/__pycache__/vector_store.cpython-39.pyc differ diff --git a/core/embedder.py b/core/embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..8d85c5cef0adb20c2de3013eb0290f62252e02f0 --- /dev/null +++ b/core/embedder.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/24 11:50 +# @Author : hukangzhe +# @File : embedder.py +# @Description : + +from sentence_transformers import SentenceTransformer +from typing import List +import numpy as np + + +class EmbeddingModel: + def __init__(self, model_name: str): + self.embedding_model = SentenceTransformer(model_name) + + def embed(self, texts: List[str], batch_size: int = 32) -> np.ndarray: + return self.embedding_model.encode(texts, batch_size=batch_size, convert_to_numpy=True) + diff --git a/core/llm_interface.py b/core/llm_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..bff2ea38acbc6524c9a8e2844464abaea3ba23a5 --- /dev/null +++ b/core/llm_interface.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/29 19:54 +# @Author : hukangzhe +# @File : generator.py +# @Description : 负责生成答案模块 +import os +import queue +import logging +import threading + +import torch +from typing import Dict, List, Tuple, Generator +from sentence_transformers import CrossEncoder +from .schema import Document, Chunk +from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer + +class ThinkStreamer(TextStreamer): + def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool =True, **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.is_thinking = True + self.think_end_token_id = self.tokenizer.encode("
", add_special_tokens=False)[0] + self.output_queue = queue.Queue() + + def on_finalized_text(self, text: str, stream_end: bool = False): + self.output_queue.put(text) + if stream_end: + self.output_queue.put(None) # 发送结束信号 + + def __iter__(self): + return self + + def __next__(self): + value = self.output_queue.get() + if value is None: + raise StopIteration() + return value + + def generate_output(self) -> Generator[Tuple[str, str], None, None]: + """ + 分离Think和回答 + :return: + """ + full_decode_text = "" + already_yielded_len = 0 + for text_chunk in self: + if not self.is_thinking: + yield "answer", text_chunk + continue + + full_decode_text += text_chunk + tokens = self.tokenizer.encode(full_decode_text, add_special_tokens=False) + + if self.think_end_token_id in tokens: + spilt_point = tokens.index(self.think_end_token_id) + think_part_tokens = tokens[:spilt_point] + thinking_text = self.tokenizer.decode(think_part_tokens) + + answer_part_tokens = tokens[spilt_point:] + answer_text = self.tokenizer.decode(answer_part_tokens) + remaining_thinking = thinking_text[already_yielded_len:] + if remaining_thinking: + yield "thinking", remaining_thinking + + if answer_text: + yield "answer", answer_text + + self.is_thinking = False + already_yielded_len = len(thinking_text) + len(self.tokenizer.decode(self.think_end_token_id)) + else: + yield "thinking", text_chunk + already_yielded_len += len(text_chunk) + + +class QueueTextStreamer(TextStreamer): + def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool = True, **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.output_queue = queue.Queue() + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Puts text into the queue; sends None as a sentinel value to signal the end.""" + self.output_queue.put(text) + if stream_end: + self.output_queue.put(None) + + def __iter__(self): + return self + + def __next__(self): + value = self.output_queue.get() + if value is None: + raise StopIteration() + return value + + +class LLMInterface: + def __init__(self, config: dict): + self.config = config + self.reranker = CrossEncoder(config['models']['reranker']) + self.generator_new_tokens = config['generation']['max_new_tokens'] + self.device =torch.device("cuda" if torch.cuda.is_available() else "cpu") + + generator_name = config['models']['llm_generator'] + logging.info(f"Initializing generator {generator_name}") + self.generator_tokenizer = AutoTokenizer.from_pretrained(generator_name) + self.generator_model = AutoModelForCausalLM.from_pretrained( + generator_name, + torch_dtype="auto", + device_map="auto") + + def rerank(self, query: str, docs: List[Document]) -> List[Document]: + pairs = [[query, doc.text] for doc in docs] + scores = self.reranker.predict(pairs) + ranked_docs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) + return [doc for doc, score in ranked_docs] + + def _threaded_generate(self, streamer: QueueTextStreamer, generation_kwargs: dict): + """ + 一个包装函数,将 model.generate 放入 try...finally 块中。 + """ + try: + self.generator_model.generate(**generation_kwargs) + finally: + # 无论成功还是失败,都确保在最后发送结束信号 + streamer.output_queue.put(None) + + def generate_answer(self, query: str, context_docs: List[Document]) -> str: + context_str = "" + for doc in context_docs: + context_str += f"Source: {os.path.basename(doc.metadata.get('source', ''))}, Page: {doc.metadata.get('page', 'N/A')}\n" + context_str += f"Content: {doc.text}\n\n" + # content设置为英文,回答则为英文 + messages = [ + {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"}, + {"role": "user", "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题:{query}"} + ] + prompt = self.generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(self.device) + output = self.generator_model.generate(**inputs, + max_new_tokens=self.generator_new_tokens, num_return_sequences=1, + eos_token_id=self.generator_tokenizer.eos_token_id) + generated_ids = output[0][inputs["input_ids"].shape[1]:] + answer = self.generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + return answer + + def generate_answer_stream(self, query: str, context_docs: List[Document]) -> Generator[str, None, None]: + context_str = "" + for doc in context_docs: + context_str += f"Content: {doc.text}\n\n" + + messages = [ + {"role": "system", + "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"}, + {"role": "user", + "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题: {query}"} + ] + + prompt = self.generator_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device) + + streamer = QueueTextStreamer(self.generator_tokenizer, skip_prompt=True) + + generation_kwargs = dict( + **model_inputs, + max_new_tokens=self.generator_new_tokens, + streamer=streamer, + pad_token_id=self.generator_tokenizer.eos_token_id, + ) + + thread = threading.Thread(target=self._threaded_generate, args=(streamer,generation_kwargs,)) + thread.start() + for new_text in streamer: + if new_text is not None: + yield new_text + + def generate_answer_stream_split(self, query: str, context_docs: List[Document]) -> Generator[Tuple[str, str], None, None]: + """分离思考和回答的流式输出""" + context_str = "" + for doc in context_docs: + context_str += f"Content: {doc.text}\n\n" + + messages = [ + {"role": "system", + "content": "You are a helpful assistant. Please answer the question based on the provided context. First, think through the process in tags, then provide the final answer."}, + {"role": "user", + "content": f"Context:\n---\n{context_str}\n---\nBased on the context above, please answer the question: {query}"} + ] + + prompt = self.generator_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True + ) + model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device) + + streamer = ThinkStreamer(self.generator_tokenizer, skip_prompt=True) + + generation_kwargs = dict( + **model_inputs, + max_new_tokens=self.generator_new_tokens, + streamer=streamer + ) + + thread = threading.Thread(target=self.generator_model.generate, kwargs=generation_kwargs) + thread.start() + + yield from streamer.generate_output() + + + + diff --git a/core/loader.py b/core/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..169895436c9119c68856abfc9cafbc8055a6c078 --- /dev/null +++ b/core/loader.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/24 19:51 +# @Author : hukangzhe +# @File : loader.py +# @Description : 文档加载模块,不在像reg_mini返回长字符串,而是返回一个document对象列表,每个document都带有源文件信息 +import logging +from .schema import Document +from typing import List +import fitz +import os + + +class DocumentLoader: + + @staticmethod + def _load_pdf(file_path): + logging.info(f"Loading PDF file from {file_path}") + try: + with fitz.open(file_path) as f: + text = "".join(page.get_text() for page in f) + logging.info(f"Successfully loaded {len(f)} pages.") + return text + except Exception as e: + logging.error(f"Failed to load PDF {file_path}: {e}") + return None + + +class MultiDocumentLoader: + def __init__(self, paths: List[str]): + self.paths = paths + + def load(self) -> List[Document]: + docs = [] + for file_path in self.paths: + file_extension = os.path.splitext(file_path)[1].lower() + if file_extension == '.pdf': + docs.extend(self._load_pdf(file_path)) + elif file_extension == '.txt': + docs.extend(self._load_txt(file_path)) + else: + logging.warning(f"Unsupported file type:{file_extension}. Skipping {file_path}") + + return docs + + def _load_pdf(self, file_path: str) -> List[Document]: + logging.info(f"Loading PDF file from {file_path}") + try: + pdf_docs = [] + with fitz.open(file_path) as doc: + for i, page in enumerate(doc): + pdf_docs.append(Document( + text=page.get_text(), + metadata={'source': file_path, 'page': i + 1} + )) + return pdf_docs + except Exception as e: + logging.error(f"Failed to load PDF {file_path}: {e}") + return [] + + def _load_txt(self, file_path: str) -> List[Document]: + logging.info(f"Loading txt file from {file_path}") + try: + txt_docs = [] + with open(file_path, 'r', encoding="utf-8") as f: + txt_docs.append(Document( + text=f.read(), + metadata={'source': file_path} + )) + return txt_docs + except Exception as e: + logging.error(f"Failed to load txt {file_path}:{e}") + return [] diff --git a/core/schema.py b/core/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..647924b047ac0f8f8225a5d85ecb3191f499ef7d --- /dev/null +++ b/core/schema.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/24 21:11 +# @Author : hukangzhe +# @File : schema.py +# @Description : 不直接处理纯文本字符串. 引入一个标准化的数据结构来承载文本块,既有内容,也有元数据。 + +from dataclasses import dataclass, field +from typing import Dict, Any + + +@dataclass +class Document: + text: str + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class Chunk(Document): + parent_id: int = None \ No newline at end of file diff --git a/core/splitter.py b/core/splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8df02e6bd6fddbb1a8157211ba9983a10dfe63 --- /dev/null +++ b/core/splitter.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/25 19:52 +# @Author : hukangzhe +# @File : splitter.py +# @Description : 负责切分文本的模块 + +import logging +from typing import List, Dict, Tuple +from .schema import Document, Chunk + + +class SemanticRecursiveSplitter: + def __init__(self, chunk_size: int=500, chunk_overlap:int = 50, separators: List[str] = None): + """ + 一个真正实现递归的语义文本切分器。 + :param chunk_size: 每个文本块的目标大小。 + :param chunk_overlap: 文本块之间的重叠大小。 + :param separators: 用于切分的语义分隔符列表,按优先级从高到低排列。 + """ + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + if self.chunk_size <= self.chunk_overlap: + raise ValueError("Chunk overlap must be smaller than chunk size.") + + self.separators = separators if separators else ['\n\n', '\n', " ", ""] # 默认分割符 + + def text_split(self, text: str) -> List[str]: + """ + 切分入口 + :param text: + :return: + """ + logging.info("Starting semantic recursive splitting...") + final_chunks = self._split(text, self.separators) + logging.info(f"Text successfully split into {len(final_chunks)} chunks.") + return final_chunks + + def _split(self, text: str, separators: List[str]) -> List[str]: + final_chunks = [] + # 1. 如果文本足够小,直接返回 + if len(text) < self.chunk_size: + return [text] + # 2. 先尝试最高优先的分割符 + cur_separator = separators[0] + + # 3. 如果可以分割 + if cur_separator in text: + # 分割成多个小部分 + parts = text.split(cur_separator) + + buffer="" # 用来合并小部分 + for i, part in enumerate(parts): + # 如果小于chunk_size,就再加一小部分,使buffer接近chunk_size + if len(buffer) + len(part) + len(cur_separator) <= self.chunk_size: + buffer += part+cur_separator + else: + # 如果buffer 不为空 + if buffer: + final_chunks.append(buffer) + # 如果当前part就已经超过chunk_size + if len(part) > self.chunk_size: + # 递归调用下一级 + sub_chunks = self._split(part, separators = separators[1:]) + final_chunks.extend(sub_chunks) + else: # 成为新的缓冲区 + buffer = part + cur_separator + + if buffer: # 最后一部分的缓冲区 + final_chunks.append(buffer.strip()) + + else: + # 4. 使用下一级分隔符 + final_chunks = self._split(text, separators[1:]) + + # 处理重叠 + if self.chunk_overlap > 0: + return self._handle_overlap(final_chunks) + else: + return final_chunks + + def _handle_overlap(self, final_chunks: List[str]) -> List[str]: + overlap_chunks = [] + if not final_chunks: + return [] + overlap_chunks.append(final_chunks[0]) + for i in range(1, len(final_chunks)): + pre_chunk = overlap_chunks[-1] + cur_chunk = final_chunks[i] + # 从前一个chunk取出重叠部分与当前chunk合并 + overlap_part = pre_chunk[-self.chunk_overlap:] + overlap_chunks.append(overlap_part+cur_chunk) + + return overlap_chunks + + +class HierarchicalSemanticSplitter: + """ + 结合了层次化(父/子)和递归语义分割策略。确保在创建父块和子块时,遵循文本的自然语义边界。 + """ + def __init__(self, + parent_chunk_size: int = 800, + parent_chunk_overlap: int = 100, + child_chunk_size: int = 250, + separators: List[str] = None): + if parent_chunk_overlap >= parent_chunk_size: + raise ValueError("Parent chunk overlap must be smaller than parent chunk size.") + if child_chunk_size >= parent_chunk_size: + raise ValueError("Child chunk size must be smaller than parent chunk size.") + + self.parent_chunk_size = parent_chunk_size + self.parent_chunk_overlap = parent_chunk_overlap + self.child_chunk_size = child_chunk_size + self.separators = separators or ["\n\n", "\n", "。", ". ", "!", "!", "?", "?", " ", ""] + + def _recursive_semantic_split(self, text: str, chunk_size: int) -> List[str]: + """ + 优先考虑语义边界 + """ + if len(text) <= chunk_size: + return [text] + + for sep in self.separators: + split_point = text.rfind(sep, 0, chunk_size) + if split_point != -1: + break + else: + split_point = chunk_size + + chunk1 = text[:split_point] + remaining_text = text[split_point:].lstrip() # 删除剩余部分的前空格 + + # 递归拆分剩余文本 + # 分隔符将添加回第一个块以保持上下文 + if remaining_text: + return [chunk1 + (sep if sep in " \n" else "")] + self._recursive_semantic_split(remaining_text, chunk_size) + else: + return [chunk1] + + def _apply_overlap(self, chunks: List[str], overlap: int) -> List[str]: + """处理重叠部分chunk""" + if not overlap or len(chunks) <= 1: + return chunks + + overlapped_chunks = [chunks[0]] + for i in range(1, len(chunks)): + # 从前一个chunk中获取最后的“重叠”字符 + overlap_content = chunks[i - 1][-overlap:] + overlapped_chunks.append(overlap_content + chunks[i]) + + return overlapped_chunks + + def split_documents(self, documents: List[Document]) -> Tuple[Dict[int, Document], List[Chunk]]: + """ + 两次切分 + :param documents: + :return: + - parent documents: {parent_id: Document} + - child chunks: [Chunk, Chunk, ...] + """ + parent_docs_dict: Dict[int, Document] = {} + child_chunks_list: List[Chunk] = [] + parent_id_counter = 0 + + logging.info("Starting robust hierarchical semantic splitting...") + + for doc in documents: + # === PASS 1: 创建父chunks === + # 1. 将整个文档text分割成大的语义chunks + initial_parent_chunks = self._recursive_semantic_split(doc.text, self.parent_chunk_size) + + # 2. 父chunks进行重叠处理 + overlapped_parent_texts = self._apply_overlap(initial_parent_chunks, self.parent_chunk_overlap) + + for p_text in overlapped_parent_texts: + parent_doc = Document(text=p_text, metadata=doc.metadata.copy()) + parent_docs_dict[parent_id_counter] = parent_doc + + # === PASS 2: Create Child Chunks from each Parent === + child_texts = self._recursive_semantic_split(p_text, self.child_chunk_size) + + for c_text in child_texts: + child_metadata = doc.metadata.copy() + child_metadata['parent_id'] = parent_id_counter + child_chunk = Chunk(text=c_text, metadata=child_metadata, parent_id=parent_id_counter) + child_chunks_list.append(child_chunk) + + parent_id_counter += 1 + + logging.info( + f"Splitting complete. Created {len(parent_docs_dict)} parent chunks and {len(child_chunks_list)} child chunks.") + return parent_docs_dict, child_chunks_list diff --git a/core/test_qw.py b/core/test_qw.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3f83f1b3c1366dee367520cf6c1129b9be3535 --- /dev/null +++ b/core/test_qw.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/29 20:03 +# @Author : hukangzhe +# @File : test_qw.py +# @Description : 测试两个模型(qwen1.5 qwen3)的两种输出方式(full or stream)是否正确 +import os +import queue +import logging +import threading +import torch +from typing import Tuple, Generator +from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer + +class ThinkStreamer(TextStreamer): + def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool =True, **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.is_thinking = True + self.think_end_token_id = self.tokenizer.encode("", add_special_tokens=False)[0] + self.output_queue = queue.Queue() + + def on_finalized_text(self, text: str, stream_end: bool = False): + self.output_queue.put(text) + if stream_end: + self.output_queue.put(None) # 发送结束信号 + + def __iter__(self): + return self + + def __next__(self): + value = self.output_queue.get() + if value is None: + raise StopIteration() + return value + + def generate_output(self) -> Generator[Tuple[str, str], None, None]: + full_decode_text = "" + already_yielded_len = 0 + for text_chunk in self: + if not self.is_thinking: + yield "answer", text_chunk + continue + + full_decode_text += text_chunk + tokens = self.tokenizer.encode(full_decode_text, add_special_tokens=False) + + if self.think_end_token_id in tokens: + spilt_point = tokens.index(self.think_end_token_id) + think_part_tokens = tokens[:spilt_point] + thinking_text = self.tokenizer.decode(think_part_tokens) + + answer_part_tokens = tokens[spilt_point:] + answer_text = self.tokenizer.decode(answer_part_tokens) + remaining_thinking = thinking_text[already_yielded_len:] + if remaining_thinking: + yield "thinking", remaining_thinking + + if answer_text: + yield "answer", answer_text + + self.is_thinking = False + already_yielded_len = len(thinking_text) + len(self.tokenizer.decode(self.think_end_token_id)) + else: + yield "thinking", text_chunk + already_yielded_len += len(text_chunk) + + + +class LLMInterface: + def __init__(self, model_name: str= "Qwen/Qwen3-0.6B"): + logging.info(f"Initializing generator {model_name}") + self.generator_tokenizer = AutoTokenizer.from_pretrained(model_name) + self.generator_model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="auto") + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + def generate_answer(self, query: str, context_str: str) -> str: + messages = [ + {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"}, + {"role": "user", "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题:{query}"} + ] + prompt = self.generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(self.device) + output = self.generator_model.generate(**inputs, + max_new_tokens=256, num_return_sequences=1, + eos_token_id=self.generator_tokenizer.eos_token_id) + generated_ids = output[0][inputs["input_ids"].shape[1]:] + answer = self.generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + return answer + + def generate_answer_stream(self, query: str, context_str: str) -> Generator[Tuple[str, str], None, None]: + """Generates an answer as a stream of (state, content) tuples.""" + messages = [ + {"role": "system", + "content": "You are a helpful assistant. Please answer the question based on the provided context. First, think through the process in tags, then provide the final answer."}, + {"role": "user", + "content": f"Context:\n---\n{context_str}\n---\nBased on the context above, please answer the question: {query}"} + ] + + # Use the template that enables thinking for Qwen models + prompt = self.generator_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True + ) + model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device) + + streamer = ThinkStreamer(self.generator_tokenizer, skip_prompt=True) + + generation_kwargs = dict( + **model_inputs, + max_new_tokens=512, + streamer=streamer + ) + + thread = threading.Thread(target=self.generator_model.generate, kwargs=generation_kwargs) + thread.start() + + yield from streamer.generate_output() + + +# if __name__ == "__main__": +# qwen = LLMInterface("Qwen/Qwen3-0.6B") +# answer = qwen.generate_answer("儒家思想的创始人是谁?", "中国传统哲学以儒家、道家和法家为主要流派。儒家思想由孔子创立,强调“仁”、“义”、“礼”、“智”、“信”,主张修身齐家治国平天下,对中国社会产生了深远的影响。其核心价值观如“己所不欲,勿施于人”至今仍具有普世意义。"+ +# +# "道家思想以老子和庄子为代表,主张“道法自然”,追求人与自然的和谐统一,强调无为而治、清静无为。道家思想对中国人的审美情趣、艺术创作以及养生之道都有着重要的影响。"+ +# +# "法家思想以韩非子为集大成者,主张以法治国,强调君主的权威和法律的至高无上。尽管法家思想在历史上曾被用于强化中央集权,但其对建立健全的法律体系也提供了重要的理论基础。") +# +# print(answer) diff --git a/core/vector_store.py b/core/vector_store.py new file mode 100644 index 0000000000000000000000000000000000000000..6792e7151477341db43f19e148c5734d94e95bab --- /dev/null +++ b/core/vector_store.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/27 19:52 +# @Author : hukangzhe +# @File : retriever.py +# @Description : 负责向量化、存储、检索的模块 +import os +import faiss +import numpy as np +import pickle +import logging +from rank_bm25 import BM25Okapi +from typing import List, Dict, Tuple +from .schema import Document, Chunk + + +class HybridVectorStore: + def __init__(self, config: dict, embedder): + self.config = config["vector_store"] + self.embedder = embedder + self.faiss_index = None + self.bm25_index = None + self.parent_docs: Dict[int, Document] = {} + self.child_chunks: List[Chunk] = [] + + def build(self, parent_docs: Dict[int, Document], child_chunks: List[Chunk]): + self.parent_docs = parent_docs + self.child_chunks = child_chunks + + # Build Faiss index + child_text = [child.text for child in child_chunks] + embeddings = self.embedder.embed(child_text) + dim = embeddings.shape[1] + self.faiss_index = faiss.IndexFlatL2(dim) + self.faiss_index.add(embeddings) + logging.info(f"FAISS index built with {len(child_chunks)} vectors.") + + # Build BM25 index + tokenize_chunks = [doc.text.split(" ") for doc in child_chunks] + self.bm25_index = BM25Okapi(tokenize_chunks) + logging.info(f"BM25 index built for {len(child_chunks)} documents.") + + self.save() + + def search(self, query: str, top_k: int , alpha: float) -> List[Tuple[int, float]]: + # Vector Search + query_embedding = self.embedder.embed([query]) + distances, indices = self.faiss_index.search(query_embedding, k=top_k) + vector_scores = {idx : 1.0/(1.0 + dist) for idx, dist in zip(indices[0], distances[0])} + + # BM25 Search + tokenize_query = query.split(" ") + bm25_scores = self.bm25_index.get_scores(tokenize_query) + bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k] + bm25_scores = {idx: bm25_scores[idx] for idx in bm25_top_indices} + + # Hybrid Search + all_indices = set(vector_scores.keys()) | set(bm25_scores.keys()) # 求并集 + hybrid_scors = {} + + # Normalization + max_v_score = max(vector_scores.values()) if vector_scores else 1.0 + max_b_score = max(bm25_scores.values()) if bm25_scores else 1.0 + for idx in all_indices: + v_score = (vector_scores.get(idx, 0))/max_v_score + b_score = (bm25_scores.get(idx, 0))/max_b_score + hybrid_scors[idx] = alpha * v_score + (1 - alpha) * b_score + + sorted_indices = sorted(hybrid_scors.items(), key=lambda item: item[1], reverse=True)[:top_k] + return sorted_indices + + def get_chunks(self, indices: List[int]) -> List[Chunk]: + return [self.child_chunks[i] for i in indices] + + def get_parent_docs(self, chunks: List[Chunk]) -> List[Document]: + parent_ids = sorted(list(set(chunk.parent_id for chunk in chunks))) + return [self.parent_docs[pid] for pid in parent_ids] + + def save(self): + index_path = self.config['index_path'] + metadata_path = self.config['metadata_path'] + + os.makedirs(os.path.dirname(index_path), exist_ok=True) + os.makedirs(os.path.dirname(metadata_path), exist_ok=True) + logging.info(f"Saving FAISS index to: {index_path}") + try: + faiss.write_index(self.faiss_index, index_path) + except Exception as e: + logging.error(f"Failed to save FAISS index: {e}") + raise + + logging.info(f"Saving metadata data to: {metadata_path}") + try: + with open(metadata_path, 'wb') as f: + metadata = { + 'parent_docs': self.parent_docs, + 'child_chunks': self.child_chunks, + 'bm25_index': self.bm25_index + } + pickle.dump(metadata, f) + except Exception as e: + logging.error(f"Failed to save metadata: {e}") + raise + + logging.info("Vector store saved successfully.") + + def load(self) -> bool: + """ + 从磁盘加载整个向量存储状态,成功时返回 True,失败时返回 False。 + """ + index_path = self.config['index_path'] + metadata_path = self.config['metadata_path'] + + if not os.path.exists(index_path) or not os.path.exists(metadata_path): + logging.warning("Index files not found. Cannot load vector store.") + return False + + logging.info(f"Loading vector store from disk...") + try: + # Load FAISS index + logging.info(f"Loading FAISS index from: {index_path}") + self.faiss_index = faiss.read_index(index_path) + + # Load metadata + logging.info(f"Loading metadata from: {metadata_path}") + with open(metadata_path, 'rb') as f: + metadata = pickle.load(f) + self.parent_docs = metadata['parent_docs'] + self.child_chunks = metadata['child_chunks'] + self.bm25_index = metadata['bm25_index'] + + logging.info("Vector store loaded successfully.") + return True + + except Exception as e: + logging.error(f"Failed to load vector store from disk: {e}") + self.faiss_index = None + self.bm25_index = None + self.parent_docs = {} + self.child_chunks = [] + return False + + + + + + + + + + + + + + + + + diff --git a/data/chinese_document.pdf b/data/chinese_document.pdf new file mode 100644 index 0000000000000000000000000000000000000000..cf3c783f36f435c8a29f5d77e52efceb04852205 --- /dev/null +++ b/data/chinese_document.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bde4aba4245bb013438240cabf4c149d8aec76536f7a1459a165db02ec3695c +size 317035 diff --git a/data/english_document.pdf b/data/english_document.pdf new file mode 100644 index 0000000000000000000000000000000000000000..214d2ea71708ac3cd25d9dc02c8a9213fdefe17c --- /dev/null +++ b/data/english_document.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5aa48e1b7d30d8cf2ac87cae6e2fee73e2de4d87178d0878d0d34b499fc7b26d +size 166331 diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..032ece81fc60f6a0415af85f2359e0b1511ad874 --- /dev/null +++ b/main.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/5/1 14:00 +# @Author : hukangzhe +# @File : main.py +# @Description : +import yaml +from service.rag_service import RAGService +from ui.app import GradioApp +from utils.logger import setup_logger + + +def main(): + setup_logger() + + with open('configs/config.yaml', 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + rag_service = RAGService(config) + app = GradioApp(rag_service) + app.launch() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/rag_mini.py b/rag_mini.py new file mode 100644 index 0000000000000000000000000000000000000000..813c892b4ba93a3bfd33089fc2baea497b6fab59 --- /dev/null +++ b/rag_mini.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/24 14:04 +# @Author : hukangzhe +# @File : rag_core.py +# @Description :非常简单的RAG系统 + +import PyPDF2 +import fitz +from sentence_transformers import SentenceTransformer, CrossEncoder +from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM +import numpy as np +import faiss +import torch + + +class RAGSystem: + def __init__(self, pdf_path): + self.pdf_path = pdf_path + self.texts = self._load_and_spilt_pdf() + self.embedder = SentenceTransformer('moka-ai/m3e-base') + self.reranker = CrossEncoder('BAAI/bge-reranker-base') # 加载一个reranker模型 + self.vector_store = self._create_vector_store() + print("3. Initializing Generator Model...") + model_name = "Qwen/Qwen1.5-1.8B-Chat" + + # 检查是否有可用的GPU + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f" - Using device: {device}") + + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + # 注意:对于像Qwen这样的模型,我们通常使用 AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(model_name).to(device) + + self.generator = pipeline( + 'text-generation', + model=model, + tokenizer=self.tokenizer, + device=device + ) + + # 1. 文档加载 & 2.文本切分 (为了简化,合在一起) + def _load_and_spilt_pdf(self): + print("1. Loading and splitting PDF...") + full_text = "" + with fitz.open(self.pdf_path) as doc: + for page in doc: + full_text += page.get_text() + + # 非常基础的切分:根据固定大小 + chunk_size = 500 + overlap = 50 + chunks = [full_text[i: i+chunk_size] for i in range(0, len(full_text), chunk_size-overlap)] + print(f" - Splitted into {len(chunks)} chunks.") + return chunks + + # 3. 文本向量化 & 向量存储 + def _create_vector_store(self): + print("2. Creating vector store...") + # embedding + embeddings = self.embedder.encode(self.texts) + + # Storing with faiss + dim = embeddings.shape[1] + index = faiss.IndexFlatL2(dim) # 使用L2距离进行相似度计算 + index.add(np.array(embeddings)) + print(" - Created vector store") + return index + + # 4.检索 + def retrieve(self, query, k=3): + print(f"3. Retrieving top {k} relevant chunks for query: '{query}' ") + query_embeddings = self.embedder.encode([query]) + + distances, indices = self.vector_store.search(np.array(query_embeddings), k=k) + retrieved_chunks = [self.texts[i] for i in indices[0]] + print(" - Retrieval complete.") + return retrieved_chunks + + # 5.生成 + def generate(self, query, context_chunks): + print("4. Generate answer...") + context = "\n".join(context_chunks) + + messages = [ + {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"}, + {"role": "user", "content": f"上下文:\n---\n{context}\n---\n请根据以上上下文回答这个问题:{query}"} + ] + prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + # print("Final Prompt:\n", prompt) + # print("Prompt token length:", len(self.tokenizer.encode(prompt))) + + result = self.generator(prompt, max_new_tokens=200, num_return_sequences=1, + eos_token_id=self.tokenizer.eos_token_id) + + print(" - Generation complete.") + # print("Raw results:", result) + # 提取生成的文本 + # 注意:Qwen模型返回的文本包含了prompt,我们需要从中提取出答案部分 + full_response = result[0]["generated_text"] + answer = full_response[len(prompt):].strip() # 从prompt之后开始截取 + + # print("Final Answer:", repr(answer)) + return answer + + # 优化1 + def rerank(self, query, chunks): + print(" - Reranking retrieved chunks...") + pairs = [[query, chunk] for chunk in chunks] + scores = self.reranker.predict(pairs) + + # 将chunks和scores打包,并按score降序排序 + ranked_chunks = sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True) + return [chunk for chunk, score in ranked_chunks] + + def query(self, query_text): + # 1. 检索(可以检索更多结果,如top 10) + retrieved_chunks = self.retrieve(query_text, k=10) + + # 2. 重排(从10个中选出最相关的3个) + reranked_chunks = self.rerank(query_text, retrieved_chunks) + top_k_reranked = reranked_chunks[:3] + + answer = self.generate(query_text, top_k_reranked) + return answer + + +def main(): + # 确保你的data文件夹里有一个叫做sample.pdf的文件 + pdf_path = 'data/chinese_document.pdf' + + print("Initializing RAG System...") + rag_system = RAGSystem(pdf_path) + print("\nRAG System is ready. You can start asking questions.") + print("Type 'q' to quit.") + + while True: + user_query = input("\nYour Question: ") + if user_query.lower() == 'q': + break + + answer = rag_system.query(user_query) + print("\nAnswer:", answer) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c5a1afd4a08714741949343784a3e878e0e69ec8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +pyyaml +pypdf2 +PyMuPDF +sentence-transformers +faiss-cpu +rank_bm25 +transformers +torch +gradio \ No newline at end of file diff --git a/service/__init__.py b/service/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/service/__pycache__/__init__.cpython-39.pyc b/service/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e4b381495c7815b79f4444a2a3d2f5cad4027ae Binary files /dev/null and b/service/__pycache__/__init__.cpython-39.pyc differ diff --git a/service/__pycache__/rag_service.cpython-39.pyc b/service/__pycache__/rag_service.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f04ec204633814e2a126918295c729eb3e500c98 Binary files /dev/null and b/service/__pycache__/rag_service.cpython-39.pyc differ diff --git a/service/rag_service.py b/service/rag_service.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fb58099fc1372b269ae32c312e083f330adfe0 --- /dev/null +++ b/service/rag_service.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/30 11:50 +# @Author : hukangzhe +# @File : rag_service.py +# @Description : +import logging +import os +from typing import List, Generator, Tuple +from core.schema import Document +from core.embedder import EmbeddingModel +from core.loader import MultiDocumentLoader +from core.splitter import HierarchicalSemanticSplitter +from core.vector_store import HybridVectorStore +from core.llm_interface import LLMInterface + + +class RAGService: + def __init__(self, config: dict): + self.config = config + logging.info("Initializing RAG Service...") + self.embedder = EmbeddingModel(config['models']['embedding']) + self.vector_store = HybridVectorStore(config, self.embedder) + self.llm = LLMInterface(config) + self.is_ready = False # 是否准备好进行查询 + logging.info("RAG Service initialized. Knowledge base is not loaded.") + + def load_knowledge_base(self) -> Tuple[bool, str]: + """ + 尝试从磁盘加载 + Returns: + A tuple (success: bool, message: str) + """ + if self.is_ready: + return True, "Knowledge base is already loaded." + + logging.info("Attempting to load knowledge base from disk...") + success = self.vector_store.load() + if success: + self.is_ready = True + message = "Knowledge base loaded successfully from disk." + logging.info(message) + return True, message + else: + self.is_ready = False + message = "No existing knowledge base found or failed to load. Please build a new one." + logging.warning(message) + return False, message + + def build_knowledge_base(self, file_paths: List[str]) -> Generator[str, None, None]: + self.is_ready = False + yield "Step 1/3: Loading documents..." + loader = MultiDocumentLoader(file_paths) + docs = loader.load() + + yield "Step 2/3: Splitting documents into hierarchical chunks..." + splitter = HierarchicalSemanticSplitter( + parent_chunk_size=self.config['splitter']['parent_chunk_size'], + parent_chunk_overlap=self.config['splitter']['parent_chunk_overlap'], + child_chunk_size=self.config['splitter']['child_chunk_size'] + ) + parent_docs, child_chunks = splitter.split_documents(docs) + + yield "Step 3/3: Building and saving vector index..." + self.vector_store.build(parent_docs, child_chunks) + self.is_ready = True + yield "Knowledge base built and ready!" + + def _get_context_and_sources(self, query: str) -> List[Document]: + if not self.is_ready: + raise Exception("Knowledge base is not ready. Please build it first.") + + # Hybrid Search to get child chunks + retrieved_child_indices_scores = self.vector_store.search( + query, + top_k=self.config['retrieval']['retrieval_top_k'], + alpha=self.config['retrieval']['hybrid_search_alpha'] + ) + retrieved_child_indices = [idx for idx, score in retrieved_child_indices_scores] + retrieved_child_chunks = self.vector_store.get_chunks(retrieved_child_indices) + + # Get Parent Documents + retrieved_parent_docs = self.vector_store.get_parent_docs(retrieved_child_chunks) + + # Rerank Parent Documents + reranked_docs = self.llm.rerank(query, retrieved_parent_docs) + final_context_docs = reranked_docs[:self.config['retrieval']['rerank_top_k']] + + return final_context_docs + + def get_response_full(self, query: str) ->(str, List[Document]): + final_context_docs = self._get_context_and_sources(query) + answer = self.llm.generate_answer(query, final_context_docs) + return answer, final_context_docs + + def get_response_stream(self, query: str) ->(Generator[str, None, None], List[Document]): + final_context_docs = self._get_context_and_sources(query) + answer_generator = self.llm.generate_answer_stream(query, final_context_docs) + return answer_generator, final_context_docs + + def get_context_string(self, context_docs: List[Document]) -> str: + context_str = "引用上下文 (Context Sources):\n\n" + for doc in context_docs: + source_info = f"--- (来源: {os.path.basename(doc.metadata.get('source', ''))}, 页码: {doc.metadata.get('page', 'N/A')}) ---\n" + content = doc.text[:200]+"..." if len(doc.text) > 200 else doc.text + context_str += source_info + content + "\n\n" + return context_str.strip() + + + diff --git a/storage/faiss_index/chunks.pkl b/storage/faiss_index/chunks.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a084a4cd7d4834a6e8929987bce62976f5b9101b --- /dev/null +++ b/storage/faiss_index/chunks.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da7fae509a97a93f69b8048e375fca25dd45f0e26a40bff83c30759c5853e473 +size 27663 diff --git a/storage/faiss_index/index.faiss b/storage/faiss_index/index.faiss new file mode 100644 index 0000000000000000000000000000000000000000..3c2e5a8e2cd09faee8ced8ee789d661f50e1a434 Binary files /dev/null and b/storage/faiss_index/index.faiss differ diff --git a/ui/__init__.py b/ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ui/__pycache__/__init__.cpython-39.pyc b/ui/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15bb473a52d09cfd63c269b912e6a93bbf57995a Binary files /dev/null and b/ui/__pycache__/__init__.cpython-39.pyc differ diff --git a/ui/__pycache__/app.cpython-39.pyc b/ui/__pycache__/app.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ceb60199ef7e2d2327109894ee63fa5ae7a69b8f Binary files /dev/null and b/ui/__pycache__/app.cpython-39.pyc differ diff --git a/ui/app.py b/ui/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5990d22ce4ec0ec2019beb57c8fe5a20555225bd --- /dev/null +++ b/ui/app.py @@ -0,0 +1,202 @@ +import gradio as gr +import os +from typing import List, Tuple + +from service.rag_service import RAGService + + +class GradioApp: + def __init__(self, rag_service: RAGService): + self.rag_service = rag_service + self._build_ui() + + def _build_ui(self): + with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), + title="Enterprise RAG System") as self.demo: + gr.Markdown("# 企业级RAG智能问答系统 (Enterprise RAG System)") + gr.Markdown("您可以**加载现有知识库**快速开始,或**上传新文档**构建一个全新的知识库。") + + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("### 控制面板 (Control Panel)") + + self.load_kb_button = gr.Button("加载已有知识库 (Load Existing KB)") + + gr.Markdown("
") + + self.file_uploader = gr.File( + label="上传新文档以构建 (Upload New Docs to Build)", + file_count="multiple", + file_types=[".pdf", ".txt"], + interactive=True + ) + self.build_kb_button = gr.Button("构建新知识库 (Build New KB)", variant="primary") + + self.status_box = gr.Textbox( + label="系统状态 (System Status)", + value="系统已初始化,等待加载或构建知识库。", + interactive=False, + lines=4 + ) + + # --- 刚开始隐藏,构建了数据库再显示 --- + with gr.Column(scale=2, visible=False) as self.chat_area: + gr.Markdown("### 对话窗口 (Chat Window)") + self.chatbot = gr.Chatbot(label="RAG Chatbot", bubble_full_width=False, height=500) + self.mode_selector = gr.Radio( + ["流式输出(Streaming)","一次性输出(Full)"], + label="输出模式:(Output Mode)", + value="流式输出(Streaming)" + ) + self.question_box = gr.Textbox(label="您的问题", placeholder="请在此处输入您的问题...", + show_label=False) + with gr.Row(): + self.submit_btn = gr.Button("提交 (Submit)", variant="primary") + self.clear_btn = gr.Button("清空历史 (Clear History)") + + gr.Markdown("---") + self.source_display = gr.Markdown("### 引用来源 (Sources)") + + # --- Event Listeners --- + self.load_kb_button.click( + fn=self._handle_load_kb, + inputs=None, + outputs=[self.status_box, self.chat_area] + ) + + self.build_kb_button.click( + fn=self._handle_build_kb, + inputs=[self.file_uploader], + outputs=[self.status_box, self.chat_area] + ) + + self.submit_btn.click( + fn=self._handle_chat_submission, + inputs=[self.question_box, self.chatbot, self.mode_selector], + outputs=[self.chatbot, self.question_box, self.source_display] + ) + + self.question_box.submit( + fn=self._handle_chat_submission, + inputs=[self.question_box, self.chatbot, self.mode_selector], + outputs=[self.chatbot, self.question_box, self.source_display] + ) + + self.clear_btn.click( + fn=self._clear_chat, + inputs=None, + outputs=[self.chatbot, self.question_box, self.source_display] + ) + + def _handle_load_kb(self): + """处理现有知识库的加载。返回更新字典。""" + success, message = self.rag_service.load_knowledge_base() + if success: + return { + self.status_box: gr.update(value=message), + self.chat_area: gr.update(visible=True) + } + else: + return { + self.status_box: gr.update(value=message), + self.chat_area: gr.update(visible=False) + } + + def _handle_build_kb(self, files: List[str], progress=gr.Progress(track_tqdm=True)): + """构建新知识库,返回更新的字典.""" + if not files: + # --- MODIFIED LINE --- + return { + self.status_box: gr.update(value="错误:请至少上传一个文档。"), + self.chat_area: gr.update(visible=False) + } + + file_paths = [file.name for file in files] + + try: + for status in self.rag_service.build_knowledge_base(file_paths): + progress(0.5, desc=status) + + final_status = "知识库构建完成并已就绪!√" + # --- MODIFIED LINE --- + return { + self.status_box: gr.update(value=final_status), + self.chat_area: gr.update(visible=True) + } + except Exception as e: + error_message = f"构建失败: {e}" + # --- MODIFIED LINE --- + return { + self.status_box: gr.update(value=error_message), + self.chat_area: gr.update(visible=False) + } + + def _handle_chat_submission(self, question: str, history: List[Tuple[str, str]], mode: str): + if not question or not question.strip(): + yield history, "", "### 引用来源 (Sources)\n" + return + + history.append((question, "")) + + try: + # 一次全部输出 + if "Full" in mode: + yield history, "", "### 引用来源 (Sources)\n" + + answer, sources = self.rag_service.get_response_full(question) + # 获取引用内容 + context_string_for_display = self.rag_service.get_context_string(sources) + # 修改格式 + source_text_for_panel = self._format_sources(sources) + #完整内容:引用+回答 + full_response = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n{answer}" + history[-1] = (question, full_response) + + yield history, "", source_text_for_panel + + # 流式输出 + else: + answer_generator, sources = self.rag_service.get_response_stream(question) + + context_string_for_display = self.rag_service.get_context_string(sources) + + source_text_for_panel = self._format_sources(sources) + + yield history, "", source_text_for_panel + + response_prefix = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n" + history[-1] = (question, response_prefix) + yield history, "", source_text_for_panel + + answer_log = "" + for text_chunk in answer_generator: + answer_log += text_chunk + history[-1] = (question, response_prefix + answer_log) + yield history, "", source_text_for_panel + + except Exception as e: + error_response = f"处理请求时出错: {e}" + history[-1] = (question, error_response) + yield history, "", "### 引用来源 (Sources)\n" + + def _format_sources(self, sources: List) -> str: + source_text = "### 引用来源 (sources)\n)" + if not sources: + return source_text + + unique_sources = set() + for doc in sources: + source_name = os.path.basename(doc.metadata.get('source', 'Unknown')) + page_num = doc.metadata.get('page', 'N/A') + unique_sources.add(f"- **{source_name}** (Page: {page_num})") + + source_text += "\n".join(sorted(list(unique_sources))) + return source_text + + def _clear_chat(self): + """清理聊天内容""" + return None, "", "### 引用来源 (Sources)\n" + + def launch(self): + self.demo.queue().launch() + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ddccd3328327de77cb1351661e3593a951a5768 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/utils/__pycache__/logger.cpython-39.pyc b/utils/__pycache__/logger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a39a45115fb2ce92fd3a19ffa7ec666683f1e0b Binary files /dev/null and b/utils/__pycache__/logger.cpython-39.pyc differ diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..975c1081b7df6e96b5cc94194dfbeddeffe15567 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2025/4/25 19:51 +# @Author : hukangzhe +# @File : logger.py.py +# @Description : 日志配置模块 +import logging +import sys + + +def setup_logger(): + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s", + handlers=[ + logging.FileHandler("storage/logs/app.log"), + logging.StreamHandler(sys.stdout) + ] + ) +