Upload 88 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -0
- Mini_RAG/.gitattributes +35 -0
- Mini_RAG/.idea/.gitignore +8 -0
- Mini_RAG/.idea/RAG_Min.iml +8 -0
- Mini_RAG/.idea/inspectionProfiles/Project_Default.xml +19 -0
- Mini_RAG/.idea/inspectionProfiles/profiles_settings.xml +6 -0
- Mini_RAG/.idea/misc.xml +7 -0
- Mini_RAG/.idea/modules.xml +8 -0
- Mini_RAG/.idea/vcs.xml +6 -0
- Mini_RAG/.idea/workspace.xml +236 -0
- Mini_RAG/LICENSE +21 -0
- Mini_RAG/README.md +185 -0
- Mini_RAG/app.py +19 -0
- Mini_RAG/assets/1.png +3 -0
- Mini_RAG/assets/2.png +3 -0
- Mini_RAG/configs/config.yaml +26 -0
- Mini_RAG/core/__init__.py +1 -0
- Mini_RAG/core/__pycache__/__init__.cpython-39.pyc +0 -0
- Mini_RAG/core/__pycache__/embedder.cpython-39.pyc +0 -0
- Mini_RAG/core/__pycache__/llm_interface.cpython-39.pyc +0 -0
- Mini_RAG/core/__pycache__/loader.cpython-39.pyc +0 -0
- Mini_RAG/core/__pycache__/schema.cpython-39.pyc +0 -0
- Mini_RAG/core/__pycache__/splitter.cpython-39.pyc +0 -0
- Mini_RAG/core/__pycache__/vector_store.cpython-39.pyc +0 -0
- Mini_RAG/core/embedder.py +19 -0
- Mini_RAG/core/llm_interface.py +216 -0
- Mini_RAG/core/loader.py +73 -0
- Mini_RAG/core/schema.py +20 -0
- Mini_RAG/core/splitter.py +192 -0
- Mini_RAG/core/test_qw.py +133 -0
- Mini_RAG/core/vector_store.py +158 -0
- Mini_RAG/data/chinese_document.pdf +3 -0
- Mini_RAG/data/english_document.pdf +3 -0
- Mini_RAG/main.py +25 -0
- Mini_RAG/rag_mini.py +148 -0
- Mini_RAG/requirements.txt +9 -0
- Mini_RAG/service/__init__.py +0 -0
- Mini_RAG/service/__pycache__/__init__.cpython-39.pyc +0 -0
- Mini_RAG/service/__pycache__/rag_service.cpython-39.pyc +0 -0
- Mini_RAG/service/rag_service.py +110 -0
- Mini_RAG/storage/faiss_index/chunks.pkl +3 -0
- Mini_RAG/storage/faiss_index/index.faiss +0 -0
- Mini_RAG/ui/__init__.py +0 -0
- Mini_RAG/ui/__pycache__/__init__.cpython-39.pyc +0 -0
- Mini_RAG/ui/__pycache__/app.cpython-39.pyc +0 -0
- Mini_RAG/ui/app.py +202 -0
- Mini_RAG/utils/__init__.py +0 -0
- Mini_RAG/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- Mini_RAG/utils/__pycache__/logger.cpython-39.pyc +0 -0
- Mini_RAG/utils/logger.py +20 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/1.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/2.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/chinese_document.pdf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
data/english_document.pdf filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
Mini_RAG/assets/1.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
Mini_RAG/assets/2.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
Mini_RAG/data/chinese_document.pdf filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
Mini_RAG/data/english_document.pdf filter=lfs diff=lfs merge=lfs -text
|
Mini_RAG/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Mini_RAG/.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
Mini_RAG/.idea/RAG_Min.iml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="jdk" jdkName="loracode" jdkType="Python SDK" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
</module>
|
Mini_RAG/.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
| 5 |
+
<option name="ignoredErrors">
|
| 6 |
+
<list>
|
| 7 |
+
<option value="N812" />
|
| 8 |
+
</list>
|
| 9 |
+
</option>
|
| 10 |
+
</inspection_tool>
|
| 11 |
+
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
| 12 |
+
<option name="ignoredIdentifiers">
|
| 13 |
+
<list>
|
| 14 |
+
<option value="utils.nearest_neighbors.lib.python.*" />
|
| 15 |
+
</list>
|
| 16 |
+
</option>
|
| 17 |
+
</inspection_tool>
|
| 18 |
+
</profile>
|
| 19 |
+
</component>
|
Mini_RAG/.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
Mini_RAG/.idea/misc.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="Black">
|
| 4 |
+
<option name="sdkName" value="Python 3.9 (bilitorch)" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="bilitorch" project-jdk-type="Python SDK" />
|
| 7 |
+
</project>
|
Mini_RAG/.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/RAG_Min.iml" filepath="$PROJECT_DIR$/.idea/RAG_Min.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
Mini_RAG/.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
Mini_RAG/.idea/workspace.xml
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="AutoImportSettings">
|
| 4 |
+
<option name="autoReloadType" value="SELECTIVE" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ChangeListManager">
|
| 7 |
+
<list default="true" id="9a6c89b5-3e5a-45c1-bb02-b7d6e7e251f9" name="Changes" comment="Changes">
|
| 8 |
+
<change beforePath="$PROJECT_DIR$/ui/app.py" beforeDir="false" afterPath="$PROJECT_DIR$/ui/app.py" afterDir="false" />
|
| 9 |
+
</list>
|
| 10 |
+
<option name="SHOW_DIALOG" value="false" />
|
| 11 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
| 12 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
| 13 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
| 14 |
+
</component>
|
| 15 |
+
<component name="FileTemplateManagerImpl">
|
| 16 |
+
<option name="RECENT_TEMPLATES">
|
| 17 |
+
<list>
|
| 18 |
+
<option value="Python Script" />
|
| 19 |
+
</list>
|
| 20 |
+
</option>
|
| 21 |
+
</component>
|
| 22 |
+
<component name="FlaskConsoleOptions" custom-start-script="import sys; print('Python %s on %s' % (sys.version, sys.platform)); sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS]) from flask.cli import ScriptInfo, NoAppException for module in ["main.py", "wsgi.py", "app.py"]: try: locals().update(ScriptInfo(app_import_path=module, create_app=None).load_app().make_shell_context()); print("\nFlask App: %s" % app.import_name); break except NoAppException: pass">
|
| 23 |
+
<envs>
|
| 24 |
+
<env key="FLASK_APP" value="app" />
|
| 25 |
+
</envs>
|
| 26 |
+
<option name="myCustomStartScript" value="import sys; print('Python %s on %s' % (sys.version, sys.platform)); sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS]) from flask.cli import ScriptInfo, NoAppException for module in ["main.py", "wsgi.py", "app.py"]: try: locals().update(ScriptInfo(app_import_path=module, create_app=None).load_app().make_shell_context()); print("\nFlask App: %s" % app.import_name); break except NoAppException: pass" />
|
| 27 |
+
<option name="myEnvs">
|
| 28 |
+
<map>
|
| 29 |
+
<entry key="FLASK_APP" value="app" />
|
| 30 |
+
</map>
|
| 31 |
+
</option>
|
| 32 |
+
</component>
|
| 33 |
+
<component name="Git.Settings">
|
| 34 |
+
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
| 35 |
+
</component>
|
| 36 |
+
<component name="MarkdownSettingsMigration">
|
| 37 |
+
<option name="stateVersion" value="1" />
|
| 38 |
+
</component>
|
| 39 |
+
<component name="ProjectColorInfo">{
|
| 40 |
+
"customColor": "",
|
| 41 |
+
"associatedIndex": 3
|
| 42 |
+
}</component>
|
| 43 |
+
<component name="ProjectId" id="31rEeJzXYWkXkU1zZalZWqADOi1" />
|
| 44 |
+
<component name="ProjectViewState">
|
| 45 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
| 46 |
+
<option name="showLibraryContents" value="true" />
|
| 47 |
+
</component>
|
| 48 |
+
<component name="PropertiesComponent">{
|
| 49 |
+
"keyToString": {
|
| 50 |
+
"ModuleVcsDetector.initialDetectionPerformed": "true",
|
| 51 |
+
"Python.llm_stearm.executor": "Run",
|
| 52 |
+
"Python.main.executor": "Run",
|
| 53 |
+
"Python.rag_mini (1).executor": "Run",
|
| 54 |
+
"Python.test_qw.executor": "Run",
|
| 55 |
+
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
| 56 |
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
| 57 |
+
"RunOnceActivity.TerminalTabsStorage.copyFrom.TerminalArrangementManager.252": "true",
|
| 58 |
+
"RunOnceActivity.git.unshallow": "true",
|
| 59 |
+
"WebServerToolWindowFactoryState": "false",
|
| 60 |
+
"git-widget-placeholder": "main",
|
| 61 |
+
"ignore.virus.scanning.warn.message": "true",
|
| 62 |
+
"last_opened_file_path": "D:/LLms/Tiny-RAG",
|
| 63 |
+
"node.js.detected.package.eslint": "true",
|
| 64 |
+
"node.js.detected.package.tslint": "true",
|
| 65 |
+
"node.js.selected.package.eslint": "(autodetect)",
|
| 66 |
+
"node.js.selected.package.tslint": "(autodetect)",
|
| 67 |
+
"nodejs_package_manager_path": "npm",
|
| 68 |
+
"settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable",
|
| 69 |
+
"vue.rearranger.settings.migration": "true"
|
| 70 |
+
}
|
| 71 |
+
}</component>
|
| 72 |
+
<component name="RecentsManager">
|
| 73 |
+
<key name="MoveFile.RECENT_KEYS">
|
| 74 |
+
<recent name="D:\LLms\RAG_Min\data" />
|
| 75 |
+
</key>
|
| 76 |
+
</component>
|
| 77 |
+
<component name="RunManager" selected="Python.main">
|
| 78 |
+
<configuration name="llm_stearm" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
| 79 |
+
<module name="RAG_Min" />
|
| 80 |
+
<option name="ENV_FILES" value="" />
|
| 81 |
+
<option name="INTERPRETER_OPTIONS" value="" />
|
| 82 |
+
<option name="PARENT_ENVS" value="true" />
|
| 83 |
+
<envs>
|
| 84 |
+
<env name="PYTHONUNBUFFERED" value="1" />
|
| 85 |
+
</envs>
|
| 86 |
+
<option name="SDK_HOME" value="" />
|
| 87 |
+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/core" />
|
| 88 |
+
<option name="IS_MODULE_SDK" value="true" />
|
| 89 |
+
<option name="ADD_CONTENT_ROOTS" value="true" />
|
| 90 |
+
<option name="ADD_SOURCE_ROOTS" value="true" />
|
| 91 |
+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
| 92 |
+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/core/llm_stearm.py" />
|
| 93 |
+
<option name="PARAMETERS" value="" />
|
| 94 |
+
<option name="SHOW_COMMAND_LINE" value="false" />
|
| 95 |
+
<option name="EMULATE_TERMINAL" value="false" />
|
| 96 |
+
<option name="MODULE_MODE" value="false" />
|
| 97 |
+
<option name="REDIRECT_INPUT" value="false" />
|
| 98 |
+
<option name="INPUT_FILE" value="" />
|
| 99 |
+
<method v="2" />
|
| 100 |
+
</configuration>
|
| 101 |
+
<configuration name="main" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true">
|
| 102 |
+
<module name="RAG_Min" />
|
| 103 |
+
<option name="ENV_FILES" value="" />
|
| 104 |
+
<option name="INTERPRETER_OPTIONS" value="" />
|
| 105 |
+
<option name="PARENT_ENVS" value="true" />
|
| 106 |
+
<envs>
|
| 107 |
+
<env name="PYTHONUNBUFFERED" value="1" />
|
| 108 |
+
</envs>
|
| 109 |
+
<option name="SDK_HOME" value="" />
|
| 110 |
+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
| 111 |
+
<option name="IS_MODULE_SDK" value="true" />
|
| 112 |
+
<option name="ADD_CONTENT_ROOTS" value="true" />
|
| 113 |
+
<option name="ADD_SOURCE_ROOTS" value="true" />
|
| 114 |
+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
| 115 |
+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
| 116 |
+
<option name="PARAMETERS" value="" />
|
| 117 |
+
<option name="SHOW_COMMAND_LINE" value="false" />
|
| 118 |
+
<option name="EMULATE_TERMINAL" value="false" />
|
| 119 |
+
<option name="MODULE_MODE" value="false" />
|
| 120 |
+
<option name="REDIRECT_INPUT" value="false" />
|
| 121 |
+
<option name="INPUT_FILE" value="" />
|
| 122 |
+
<method v="2" />
|
| 123 |
+
</configuration>
|
| 124 |
+
<configuration name="rag_mini" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
| 125 |
+
<module name="RAG_Min" />
|
| 126 |
+
<option name="ENV_FILES" value="" />
|
| 127 |
+
<option name="INTERPRETER_OPTIONS" value="" />
|
| 128 |
+
<option name="PARENT_ENVS" value="true" />
|
| 129 |
+
<envs>
|
| 130 |
+
<env name="PYTHONUNBUFFERED" value="1" />
|
| 131 |
+
</envs>
|
| 132 |
+
<option name="SDK_HOME" value="" />
|
| 133 |
+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
| 134 |
+
<option name="IS_MODULE_SDK" value="true" />
|
| 135 |
+
<option name="ADD_CONTENT_ROOTS" value="true" />
|
| 136 |
+
<option name="ADD_SOURCE_ROOTS" value="true" />
|
| 137 |
+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
| 138 |
+
<option name="SCRIPT_NAME" value="D:\LLms\RAG_Min\rag_mini.py" />
|
| 139 |
+
<option name="PARAMETERS" value="" />
|
| 140 |
+
<option name="SHOW_COMMAND_LINE" value="false" />
|
| 141 |
+
<option name="EMULATE_TERMINAL" value="false" />
|
| 142 |
+
<option name="MODULE_MODE" value="false" />
|
| 143 |
+
<option name="REDIRECT_INPUT" value="false" />
|
| 144 |
+
<option name="INPUT_FILE" value="" />
|
| 145 |
+
<method v="2" />
|
| 146 |
+
</configuration>
|
| 147 |
+
<configuration name="test_qw" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
| 148 |
+
<module name="RAG_Min" />
|
| 149 |
+
<option name="ENV_FILES" value="" />
|
| 150 |
+
<option name="INTERPRETER_OPTIONS" value="" />
|
| 151 |
+
<option name="PARENT_ENVS" value="true" />
|
| 152 |
+
<envs>
|
| 153 |
+
<env name="PYTHONUNBUFFERED" value="1" />
|
| 154 |
+
</envs>
|
| 155 |
+
<option name="SDK_HOME" value="" />
|
| 156 |
+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/core" />
|
| 157 |
+
<option name="IS_MODULE_SDK" value="true" />
|
| 158 |
+
<option name="ADD_CONTENT_ROOTS" value="true" />
|
| 159 |
+
<option name="ADD_SOURCE_ROOTS" value="true" />
|
| 160 |
+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
| 161 |
+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/core/test_qw.py" />
|
| 162 |
+
<option name="PARAMETERS" value="" />
|
| 163 |
+
<option name="SHOW_COMMAND_LINE" value="false" />
|
| 164 |
+
<option name="EMULATE_TERMINAL" value="false" />
|
| 165 |
+
<option name="MODULE_MODE" value="false" />
|
| 166 |
+
<option name="REDIRECT_INPUT" value="false" />
|
| 167 |
+
<option name="INPUT_FILE" value="" />
|
| 168 |
+
<method v="2" />
|
| 169 |
+
</configuration>
|
| 170 |
+
<recent_temporary>
|
| 171 |
+
<list>
|
| 172 |
+
<item itemvalue="Python.test_qw" />
|
| 173 |
+
<item itemvalue="Python.llm_stearm" />
|
| 174 |
+
<item itemvalue="Python.rag_mini" />
|
| 175 |
+
</list>
|
| 176 |
+
</recent_temporary>
|
| 177 |
+
</component>
|
| 178 |
+
<component name="SharedIndexes">
|
| 179 |
+
<attachedChunks>
|
| 180 |
+
<set>
|
| 181 |
+
<option value="bundled-js-predefined-d6986cc7102b-b598e85cdad2-JavaScript-PY-252.25557.130" />
|
| 182 |
+
<option value="bundled-python-sdk-1bde30d8e611-7b97d883f26b-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-252.25557.130" />
|
| 183 |
+
</set>
|
| 184 |
+
</attachedChunks>
|
| 185 |
+
</component>
|
| 186 |
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
| 187 |
+
<component name="TaskManager">
|
| 188 |
+
<task active="true" id="Default" summary="Default task">
|
| 189 |
+
<changelist id="9a6c89b5-3e5a-45c1-bb02-b7d6e7e251f9" name="Changes" comment="" />
|
| 190 |
+
<created>1756273884601</created>
|
| 191 |
+
<option name="number" value="Default" />
|
| 192 |
+
<option name="presentableId" value="Default" />
|
| 193 |
+
<updated>1756273884601</updated>
|
| 194 |
+
<workItem from="1756273885673" duration="25695000" />
|
| 195 |
+
<workItem from="1756341369847" duration="35394000" />
|
| 196 |
+
<workItem from="1756387476000" duration="997000" />
|
| 197 |
+
<workItem from="1756388504411" duration="9000" />
|
| 198 |
+
<workItem from="1756428572914" duration="40000" />
|
| 199 |
+
<workItem from="1756563941144" duration="10000" />
|
| 200 |
+
<workItem from="1756564148012" duration="284000" />
|
| 201 |
+
<workItem from="1756604160932" duration="210000" />
|
| 202 |
+
<workItem from="1756625407567" duration="827000" />
|
| 203 |
+
<workItem from="1756626254182" duration="19274000" />
|
| 204 |
+
<workItem from="1756649657021" duration="45000" />
|
| 205 |
+
<workItem from="1756687922446" duration="4042000" />
|
| 206 |
+
<workItem from="1756692506309" duration="15000" />
|
| 207 |
+
<workItem from="1756694800870" duration="1304000" />
|
| 208 |
+
<workItem from="1756821488685" duration="524000" />
|
| 209 |
+
<workItem from="1756862778374" duration="1566000" />
|
| 210 |
+
</task>
|
| 211 |
+
<task id="LOCAL-00001" summary="Changes">
|
| 212 |
+
<option name="closed" value="true" />
|
| 213 |
+
<created>1756863588308</created>
|
| 214 |
+
<option name="number" value="00001" />
|
| 215 |
+
<option name="presentableId" value="LOCAL-00001" />
|
| 216 |
+
<option name="project" value="LOCAL" />
|
| 217 |
+
<updated>1756863588308</updated>
|
| 218 |
+
</task>
|
| 219 |
+
<option name="localTasksCounter" value="2" />
|
| 220 |
+
<servers />
|
| 221 |
+
</component>
|
| 222 |
+
<component name="TypeScriptGeneratedFilesManager">
|
| 223 |
+
<option name="version" value="3" />
|
| 224 |
+
</component>
|
| 225 |
+
<component name="VcsManagerConfiguration">
|
| 226 |
+
<MESSAGE value="Changes" />
|
| 227 |
+
<option name="LAST_COMMIT_MESSAGE" value="Changes" />
|
| 228 |
+
</component>
|
| 229 |
+
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
| 230 |
+
<SUITE FILE_PATH="coverage/RAG_Min$rag_mini__1_.coverage" NAME="rag_mini (1) 覆盖结果" MODIFIED="1756604236123" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
| 231 |
+
<SUITE FILE_PATH="coverage/RAG_Min$llm_stearm.coverage" NAME="llm_stearm 覆盖结果" MODIFIED="1756626842854" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/core" />
|
| 232 |
+
<SUITE FILE_PATH="coverage/RAG_Min$rag_core.coverage" NAME="rag_core Coverage Results" MODIFIED="1756292009397" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
| 233 |
+
<SUITE FILE_PATH="coverage/RAG_Min$test_qw.coverage" NAME="test_qw 覆盖结果" MODIFIED="1756643881941" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/core" />
|
| 234 |
+
<SUITE FILE_PATH="coverage/RAG_Min$main.coverage" NAME="main 覆盖结果" MODIFIED="1756689369726" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
| 235 |
+
</component>
|
| 236 |
+
</project>
|
Mini_RAG/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 TuNan
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
Mini_RAG/README.md
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RAG_Mini
|
| 2 |
+
---
|
| 3 |
+
|
| 4 |
+
# Enterprise-Ready RAG System with Gradio Interface
|
| 5 |
+
|
| 6 |
+
This is a powerful, enterprise-grade Retrieval-Augmented Generation (RAG) system designed to transform your documents into an interactive and intelligent knowledge base. Users can upload their own documents (PDFs, TXT files), build a searchable vector index, and ask complex questions in natural language to receive accurate, context-aware answers sourced directly from the provided materials.
|
| 7 |
+
|
| 8 |
+
The entire application is wrapped in a clean, user-friendly web interface powered by Gradio.
|
| 9 |
+
|
| 10 |
+

|
| 11 |
+

|
| 12 |
+
|
| 13 |
+
## ✨ Features
|
| 14 |
+
|
| 15 |
+
- **Intuitive Web UI**: Simple, clean interface built with Gradio for uploading documents and chatting.
|
| 16 |
+
- **Multi-Document Support**: Natively handles PDF and TXT files.
|
| 17 |
+
- **Advanced Text Splitting**: Uses a `HierarchicalSemanticSplitter` that first splits documents into large parent chunks (for context) and then into smaller child chunks (for precise search), respecting semantic boundaries.
|
| 18 |
+
- **Hybrid Search**: Combines the strengths of dense vector search (FAISS) and sparse keyword search (BM25) for robust and accurate retrieval.
|
| 19 |
+
- **Reranking for Accuracy**: Employs a Cross-Encoder model to rerank the retrieved documents, ensuring the most relevant context is passed to the language model.
|
| 20 |
+
- **Persistent Knowledge Base**: Automatically saves the built vector index and metadata, allowing you to load an existing knowledge base instantly on startup.
|
| 21 |
+
- **Modular & Extensible Codebase**: The project is logically structured into services for loading, splitting, embedding, and generation, making it easy to maintain and extend.
|
| 22 |
+
|
| 23 |
+
## 🏛️ System Architecture
|
| 24 |
+
|
| 25 |
+
The RAG pipeline follows a logical, multi-step process to ensure high-quality answers:
|
| 26 |
+
|
| 27 |
+
1. **Load**: Documents are loaded from various formats and parsed into a standardized `Document` object, preserving metadata like source and page number.
|
| 28 |
+
2. **Split**: The raw text is processed by the `HierarchicalSemanticSplitter`, creating parent and child text chunks. This provides both broad context and fine-grained detail.
|
| 29 |
+
3. **Embed & Index**: The child chunks are converted into vector embeddings using a `SentenceTransformer` model and indexed in a FAISS vector store. A parallel BM25 index is also built for keyword search.
|
| 30 |
+
4. **Retrieve**: When a user asks a question, a hybrid search query is performed against the FAISS and BM25 indices to retrieve the most relevant child chunks.
|
| 31 |
+
5. **Fetch Context**: The parent chunks corresponding to the retrieved child chunks are fetched. This ensures the LLM receives a wider, more complete context.
|
| 32 |
+
6. **Rerank**: A powerful Cross-Encoder model re-evaluates the relevance of the parent chunks against the query, pushing the best matches to the top.
|
| 33 |
+
7. **Generate**: The top-ranked, reranked documents are combined with the user's query into a final prompt. This prompt is sent to a Large Language Model (LLM) to generate a final, coherent answer.
|
| 34 |
+
|
| 35 |
+
```
|
| 36 |
+
[User Uploads Docs] -> [Loader] -> [Splitter] -> [Embedder & Vector Store] -> [Knowledge Base Saved]
|
| 37 |
+
|
| 38 |
+
[User Asks Question] -> [Hybrid Search] -> [Get Parent Docs] -> [Reranker] -> [LLM] -> [Answer & Sources]
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## 🛠️ Tech Stack
|
| 42 |
+
|
| 43 |
+
- **Backend**: Python 3.9+
|
| 44 |
+
- **UI**: Gradio
|
| 45 |
+
- **LLM & Embedding Framework**: Hugging Face Transformers, Sentence-Transformers
|
| 46 |
+
- **Vector Search**: Faiss (from Facebook AI)
|
| 47 |
+
- **Keyword Search**: rank-bm25
|
| 48 |
+
- **PDF Parsing**: PyMuPDF (fitz)
|
| 49 |
+
- **Configuration**: PyYAML
|
| 50 |
+
|
| 51 |
+
## 🚀 Getting Started
|
| 52 |
+
|
| 53 |
+
Follow these steps to set up and run the project on your local machine.
|
| 54 |
+
|
| 55 |
+
### 1. Prerequisites
|
| 56 |
+
|
| 57 |
+
- Python 3.9 or higher
|
| 58 |
+
- `pip` for package management
|
| 59 |
+
|
| 60 |
+
### 2. Create a `requirements.txt` file
|
| 61 |
+
|
| 62 |
+
Before proceeding, it's crucial to have a `requirements.txt` file so others can easily install the necessary dependencies. In your activated terminal, run:
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
pip freeze > requirements.txt
|
| 66 |
+
```
|
| 67 |
+
This will save all the packages from your environment into the file. Make sure this file is committed to your GitHub repository. The key packages it should contain are: `gradio`, `torch`, `transformers`, `sentence-transformers`, `faiss-cpu`, `rank_bm25`, `PyMuPDF`, `pyyaml`, `numpy`.
|
| 68 |
+
|
| 69 |
+
### 3. Installation & Setup
|
| 70 |
+
|
| 71 |
+
**1. Clone the repository:**
|
| 72 |
+
```bash
|
| 73 |
+
git clone https://github.com/YOUR_USERNAME/YOUR_REPOSITORY_NAME.git
|
| 74 |
+
cd YOUR_REPOSITORY_NAME
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
**2. Create and activate a virtual environment (recommended):**
|
| 78 |
+
```bash
|
| 79 |
+
# For Windows
|
| 80 |
+
python -m venv venv
|
| 81 |
+
.\venv\Scripts\activate
|
| 82 |
+
|
| 83 |
+
# For macOS/Linux
|
| 84 |
+
python3 -m venv venv
|
| 85 |
+
source venv/bin/activate
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
**3. Install the required packages:**
|
| 89 |
+
```bash
|
| 90 |
+
pip install -r requirements.txt
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
**4. Configure the system:**
|
| 94 |
+
Review the `configs/config.yaml` file. You can change the models, chunk sizes, and other parameters here. The default settings are a good starting point.
|
| 95 |
+
|
| 96 |
+
> **Note:** The first time you run the application, the models specified in the config file will be downloaded from Hugging Face. This may take some time depending on your internet connection.
|
| 97 |
+
|
| 98 |
+
### 4. Running the Application
|
| 99 |
+
|
| 100 |
+
To start the Gradio web server, run the `main.py` script:
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
python main.py
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
The application will be available at **`http://localhost:7860`**.
|
| 107 |
+
|
| 108 |
+
## 📖 How to Use
|
| 109 |
+
|
| 110 |
+
The application has two primary workflows:
|
| 111 |
+
|
| 112 |
+
**1. Build a New Knowledge Base:**
|
| 113 |
+
- Drag and drop one or more `.pdf` or `.txt` files into the "Upload New Docs to Build" area.
|
| 114 |
+
- Click the **"Build New KB"** button.
|
| 115 |
+
- The system status will show the progress (Loading -> Splitting -> Indexing).
|
| 116 |
+
- Once complete, the status will confirm that the knowledge base is ready, and the chat window will appear.
|
| 117 |
+
|
| 118 |
+
**2. Load an Existing Knowledge Base:**
|
| 119 |
+
- If you have previously built a knowledge base, simply click the **"Load Existing KB"** button.
|
| 120 |
+
- The system will load the saved FAISS index and metadata from the `storage` directory.
|
| 121 |
+
- The chat window will appear, and you can start asking questions immediately.
|
| 122 |
+
|
| 123 |
+
**Chatting with Your Documents:**
|
| 124 |
+
- Once the knowledge base is ready, type your question into the chat box at the bottom and press Enter or click "Submit".
|
| 125 |
+
- The model will generate an answer based on the documents you provided.
|
| 126 |
+
- The sources used to generate the answer will be displayed below the chat window.
|
| 127 |
+
|
| 128 |
+
## 📂 Project Structure
|
| 129 |
+
|
| 130 |
+
```
|
| 131 |
+
.
|
| 132 |
+
├── configs/
|
| 133 |
+
│ └── config.yaml # Main configuration file for models, paths, etc.
|
| 134 |
+
├── core/
|
| 135 |
+
│ ├── embedder.py # Handles text embedding.
|
| 136 |
+
│ ├── llm_interface.py # Handles reranking and answer generation.
|
| 137 |
+
│ ├── loader.py # Loads and parses documents.
|
| 138 |
+
│ ├── schema.py # Defines data structures (Document, Chunk).
|
| 139 |
+
│ ├── splitter.py # Splits documents into chunks.
|
| 140 |
+
│ └── vector_store.py # Manages FAISS & BM25 indices.
|
| 141 |
+
├── service/
|
| 142 |
+
│ └── rag_service.py # Orchestrates the entire RAG pipeline.
|
| 143 |
+
├── storage/ # Default location for saved indices (auto-generated).
|
| 144 |
+
│ └── ...
|
| 145 |
+
├── ui/
|
| 146 |
+
│ └── app.py # Contains the Gradio UI logic.
|
| 147 |
+
├── utils/
|
| 148 |
+
│ └── logger.py # Logging configuration.
|
| 149 |
+
├── assets/
|
| 150 |
+
│ └── 1.png # Screenshot of the application.
|
| 151 |
+
├── main.py # Entry point to run the application.
|
| 152 |
+
└── requirements.txt # Python package dependencies.
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
## 🔧 Configuration Details (`config.yaml`)
|
| 156 |
+
|
| 157 |
+
You can customize the RAG pipeline by modifying `configs/config.yaml`:
|
| 158 |
+
|
| 159 |
+
- **`models`**: Specify the Hugging Face models for embedding, reranking, and generation.
|
| 160 |
+
- **`vector_store`**: Define the paths where the FAISS index and metadata will be saved.
|
| 161 |
+
- **`splitter`**: Control the `HierarchicalSemanticSplitter` behavior.
|
| 162 |
+
- `parent_chunk_size`: The target size for larger context chunks.
|
| 163 |
+
- `parent_chunk_overlap`: The overlap between parent chunks.
|
| 164 |
+
- `child_chunk_size`: The target size for smaller, searchable chunks.
|
| 165 |
+
- **`retrieval`**: Tune the retrieval and reranking process.
|
| 166 |
+
- `retrieval_top_k`: How many initial candidates to retrieve with hybrid search.
|
| 167 |
+
- `rerank_top_k`: How many final documents to pass to the LLM after reranking.
|
| 168 |
+
- `hybrid_search_alpha`: The weighting between vector search (`alpha`) and BM25 search (`1 - alpha`). `1.0` is pure vector search, `0.0` is pure keyword search.
|
| 169 |
+
- **`generation`**: Set parameters for the final answer generation, like `max_new_tokens`.
|
| 170 |
+
|
| 171 |
+
## 🛣️ Future Roadmap
|
| 172 |
+
|
| 173 |
+
- [ ] Support for more document types (e.g., `.docx`, `.pptx`, `.html`).
|
| 174 |
+
- [ ] Implement response streaming for a more interactive chat experience.
|
| 175 |
+
- [ ] Integrate with other vector databases like ChromaDB or Pinecone.
|
| 176 |
+
- [ ] Create API endpoints for programmatic access to the RAG service.
|
| 177 |
+
- [ ] Add more advanced logging and monitoring for enterprise use.
|
| 178 |
+
|
| 179 |
+
## 🤝 Contributing
|
| 180 |
+
|
| 181 |
+
Contributions are welcome! If you have ideas for improvements or find a bug, please feel free to open an issue or submit a pull request.
|
| 182 |
+
|
| 183 |
+
## 📄 License
|
| 184 |
+
|
| 185 |
+
This project is licensed under the MIT License. See the `LICENSE` file for details.
|
Mini_RAG/app.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
from service.rag_service import RAGService
|
| 3 |
+
from ui.app import GradioApp
|
| 4 |
+
from utils.logger import setup_logger
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def main():
|
| 8 |
+
setup_logger()
|
| 9 |
+
|
| 10 |
+
with open('configs/config.yaml', 'r', encoding='utf-8') as f:
|
| 11 |
+
config = yaml.safe_load(f)
|
| 12 |
+
|
| 13 |
+
rag_service = RAGService(config)
|
| 14 |
+
app = GradioApp(rag_service)
|
| 15 |
+
app.launch()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
main()
|
Mini_RAG/assets/1.png
ADDED
|
Git LFS Details
|
Mini_RAG/assets/2.png
ADDED
|
Git LFS Details
|
Mini_RAG/configs/config.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 路径配置
|
| 2 |
+
storage_path: "./storage"
|
| 3 |
+
vector_store:
|
| 4 |
+
index_path: "./storage/faiss_index/index.faiss"
|
| 5 |
+
metadata_path: "./storage/faiss_index/chunks.pkl"
|
| 6 |
+
|
| 7 |
+
# 切分器配置
|
| 8 |
+
splitter:
|
| 9 |
+
parent_chunk_size: 800
|
| 10 |
+
parent_chunk_overlap: 100
|
| 11 |
+
child_chunk_size: 250
|
| 12 |
+
|
| 13 |
+
# 模型配置
|
| 14 |
+
models:
|
| 15 |
+
embedding: "moka-ai/m3e-base"
|
| 16 |
+
reranker: "BAAI/bge-reranker-base"
|
| 17 |
+
llm_generator: "Qwen/Qwen3-0.6B" # Qwen/Qwen1.5-1.8B-Chat or Qwen/Qwen3-0.6B
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# 检索与生成参数
|
| 21 |
+
retrieval:
|
| 22 |
+
hybrid_search_alpha: 0.5 # 混合检索中向量权(0-1), 关键词权为1-alpha
|
| 23 |
+
retrieval_top_k: 20
|
| 24 |
+
rerank_top_k: 5
|
| 25 |
+
generation:
|
| 26 |
+
max_new_tokens: 512
|
Mini_RAG/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
Mini_RAG/core/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (123 Bytes). View file
|
|
|
Mini_RAG/core/__pycache__/embedder.cpython-39.pyc
ADDED
|
Binary file (850 Bytes). View file
|
|
|
Mini_RAG/core/__pycache__/llm_interface.cpython-39.pyc
ADDED
|
Binary file (7.92 kB). View file
|
|
|
Mini_RAG/core/__pycache__/loader.cpython-39.pyc
ADDED
|
Binary file (2.76 kB). View file
|
|
|
Mini_RAG/core/__pycache__/schema.cpython-39.pyc
ADDED
|
Binary file (692 Bytes). View file
|
|
|
Mini_RAG/core/__pycache__/splitter.cpython-39.pyc
ADDED
|
Binary file (5.52 kB). View file
|
|
|
Mini_RAG/core/__pycache__/vector_store.cpython-39.pyc
ADDED
|
Binary file (5.66 kB). View file
|
|
|
Mini_RAG/core/embedder.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2025/4/24 11:50
|
| 4 |
+
# @Author : hukangzhe
|
| 5 |
+
# @File : embedder.py
|
| 6 |
+
# @Description :
|
| 7 |
+
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
from typing import List
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class EmbeddingModel:
|
| 14 |
+
def __init__(self, model_name: str):
|
| 15 |
+
self.embedding_model = SentenceTransformer(model_name)
|
| 16 |
+
|
| 17 |
+
def embed(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
| 18 |
+
return self.embedding_model.encode(texts, batch_size=batch_size, convert_to_numpy=True)
|
| 19 |
+
|
Mini_RAG/core/llm_interface.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2025/4/29 19:54
|
| 4 |
+
# @Author : hukangzhe
|
| 5 |
+
# @File : generator.py
|
| 6 |
+
# @Description : 负责生成答案模块
|
| 7 |
+
import os
|
| 8 |
+
import queue
|
| 9 |
+
import logging
|
| 10 |
+
import threading
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from typing import Dict, List, Tuple, Generator
|
| 14 |
+
from sentence_transformers import CrossEncoder
|
| 15 |
+
from .schema import Document, Chunk
|
| 16 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 17 |
+
|
| 18 |
+
class ThinkStreamer(TextStreamer):
|
| 19 |
+
def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool =True, **decode_kwargs):
|
| 20 |
+
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
| 21 |
+
self.is_thinking = True
|
| 22 |
+
self.think_end_token_id = self.tokenizer.encode("</think>", add_special_tokens=False)[0]
|
| 23 |
+
self.output_queue = queue.Queue()
|
| 24 |
+
|
| 25 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 26 |
+
self.output_queue.put(text)
|
| 27 |
+
if stream_end:
|
| 28 |
+
self.output_queue.put(None) # 发送结束信号
|
| 29 |
+
|
| 30 |
+
def __iter__(self):
|
| 31 |
+
return self
|
| 32 |
+
|
| 33 |
+
def __next__(self):
|
| 34 |
+
value = self.output_queue.get()
|
| 35 |
+
if value is None:
|
| 36 |
+
raise StopIteration()
|
| 37 |
+
return value
|
| 38 |
+
|
| 39 |
+
def generate_output(self) -> Generator[Tuple[str, str], None, None]:
|
| 40 |
+
"""
|
| 41 |
+
分离Think和回答
|
| 42 |
+
:return:
|
| 43 |
+
"""
|
| 44 |
+
full_decode_text = ""
|
| 45 |
+
already_yielded_len = 0
|
| 46 |
+
for text_chunk in self:
|
| 47 |
+
if not self.is_thinking:
|
| 48 |
+
yield "answer", text_chunk
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
full_decode_text += text_chunk
|
| 52 |
+
tokens = self.tokenizer.encode(full_decode_text, add_special_tokens=False)
|
| 53 |
+
|
| 54 |
+
if self.think_end_token_id in tokens:
|
| 55 |
+
spilt_point = tokens.index(self.think_end_token_id)
|
| 56 |
+
think_part_tokens = tokens[:spilt_point]
|
| 57 |
+
thinking_text = self.tokenizer.decode(think_part_tokens)
|
| 58 |
+
|
| 59 |
+
answer_part_tokens = tokens[spilt_point:]
|
| 60 |
+
answer_text = self.tokenizer.decode(answer_part_tokens)
|
| 61 |
+
remaining_thinking = thinking_text[already_yielded_len:]
|
| 62 |
+
if remaining_thinking:
|
| 63 |
+
yield "thinking", remaining_thinking
|
| 64 |
+
|
| 65 |
+
if answer_text:
|
| 66 |
+
yield "answer", answer_text
|
| 67 |
+
|
| 68 |
+
self.is_thinking = False
|
| 69 |
+
already_yielded_len = len(thinking_text) + len(self.tokenizer.decode(self.think_end_token_id))
|
| 70 |
+
else:
|
| 71 |
+
yield "thinking", text_chunk
|
| 72 |
+
already_yielded_len += len(text_chunk)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class QueueTextStreamer(TextStreamer):
|
| 76 |
+
def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool = True, **decode_kwargs):
|
| 77 |
+
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
| 78 |
+
self.output_queue = queue.Queue()
|
| 79 |
+
|
| 80 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 81 |
+
"""Puts text into the queue; sends None as a sentinel value to signal the end."""
|
| 82 |
+
self.output_queue.put(text)
|
| 83 |
+
if stream_end:
|
| 84 |
+
self.output_queue.put(None)
|
| 85 |
+
|
| 86 |
+
def __iter__(self):
|
| 87 |
+
return self
|
| 88 |
+
|
| 89 |
+
def __next__(self):
|
| 90 |
+
value = self.output_queue.get()
|
| 91 |
+
if value is None:
|
| 92 |
+
raise StopIteration()
|
| 93 |
+
return value
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class LLMInterface:
|
| 97 |
+
def __init__(self, config: dict):
|
| 98 |
+
self.config = config
|
| 99 |
+
self.reranker = CrossEncoder(config['models']['reranker'])
|
| 100 |
+
self.generator_new_tokens = config['generation']['max_new_tokens']
|
| 101 |
+
self.device =torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 102 |
+
|
| 103 |
+
generator_name = config['models']['llm_generator']
|
| 104 |
+
logging.info(f"Initializing generator {generator_name}")
|
| 105 |
+
self.generator_tokenizer = AutoTokenizer.from_pretrained(generator_name)
|
| 106 |
+
self.generator_model = AutoModelForCausalLM.from_pretrained(
|
| 107 |
+
generator_name,
|
| 108 |
+
torch_dtype="auto",
|
| 109 |
+
device_map="auto")
|
| 110 |
+
|
| 111 |
+
def rerank(self, query: str, docs: List[Document]) -> List[Document]:
|
| 112 |
+
pairs = [[query, doc.text] for doc in docs]
|
| 113 |
+
scores = self.reranker.predict(pairs)
|
| 114 |
+
ranked_docs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
|
| 115 |
+
return [doc for doc, score in ranked_docs]
|
| 116 |
+
|
| 117 |
+
def _threaded_generate(self, streamer: QueueTextStreamer, generation_kwargs: dict):
|
| 118 |
+
"""
|
| 119 |
+
一个包装函数,将 model.generate 放入 try...finally 块中。
|
| 120 |
+
"""
|
| 121 |
+
try:
|
| 122 |
+
self.generator_model.generate(**generation_kwargs)
|
| 123 |
+
finally:
|
| 124 |
+
# 无论成功还是失败,都确保在最后发送结束信号
|
| 125 |
+
streamer.output_queue.put(None)
|
| 126 |
+
|
| 127 |
+
def generate_answer(self, query: str, context_docs: List[Document]) -> str:
|
| 128 |
+
context_str = ""
|
| 129 |
+
for doc in context_docs:
|
| 130 |
+
context_str += f"Source: {os.path.basename(doc.metadata.get('source', ''))}, Page: {doc.metadata.get('page', 'N/A')}\n"
|
| 131 |
+
context_str += f"Content: {doc.text}\n\n"
|
| 132 |
+
# content设置为英文,回答则为英文
|
| 133 |
+
messages = [
|
| 134 |
+
{"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
|
| 135 |
+
{"role": "user", "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题:{query}"}
|
| 136 |
+
]
|
| 137 |
+
prompt = self.generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 138 |
+
inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 139 |
+
output = self.generator_model.generate(**inputs,
|
| 140 |
+
max_new_tokens=self.generator_new_tokens, num_return_sequences=1,
|
| 141 |
+
eos_token_id=self.generator_tokenizer.eos_token_id)
|
| 142 |
+
generated_ids = output[0][inputs["input_ids"].shape[1]:]
|
| 143 |
+
answer = self.generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 144 |
+
return answer
|
| 145 |
+
|
| 146 |
+
def generate_answer_stream(self, query: str, context_docs: List[Document]) -> Generator[str, None, None]:
|
| 147 |
+
context_str = ""
|
| 148 |
+
for doc in context_docs:
|
| 149 |
+
context_str += f"Content: {doc.text}\n\n"
|
| 150 |
+
|
| 151 |
+
messages = [
|
| 152 |
+
{"role": "system",
|
| 153 |
+
"content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
|
| 154 |
+
{"role": "user",
|
| 155 |
+
"content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题: {query}"}
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
prompt = self.generator_tokenizer.apply_chat_template(
|
| 159 |
+
messages,
|
| 160 |
+
tokenize=False,
|
| 161 |
+
add_generation_prompt=True,
|
| 162 |
+
)
|
| 163 |
+
model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device)
|
| 164 |
+
|
| 165 |
+
streamer = QueueTextStreamer(self.generator_tokenizer, skip_prompt=True)
|
| 166 |
+
|
| 167 |
+
generation_kwargs = dict(
|
| 168 |
+
**model_inputs,
|
| 169 |
+
max_new_tokens=self.generator_new_tokens,
|
| 170 |
+
streamer=streamer,
|
| 171 |
+
pad_token_id=self.generator_tokenizer.eos_token_id,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
thread = threading.Thread(target=self._threaded_generate, args=(streamer,generation_kwargs,))
|
| 175 |
+
thread.start()
|
| 176 |
+
for new_text in streamer:
|
| 177 |
+
if new_text is not None:
|
| 178 |
+
yield new_text
|
| 179 |
+
|
| 180 |
+
def generate_answer_stream_split(self, query: str, context_docs: List[Document]) -> Generator[Tuple[str, str], None, None]:
|
| 181 |
+
"""分离思考和回答的流式输出"""
|
| 182 |
+
context_str = ""
|
| 183 |
+
for doc in context_docs:
|
| 184 |
+
context_str += f"Content: {doc.text}\n\n"
|
| 185 |
+
|
| 186 |
+
messages = [
|
| 187 |
+
{"role": "system",
|
| 188 |
+
"content": "You are a helpful assistant. Please answer the question based on the provided context. First, think through the process in <think> tags, then provide the final answer."},
|
| 189 |
+
{"role": "user",
|
| 190 |
+
"content": f"Context:\n---\n{context_str}\n---\nBased on the context above, please answer the question: {query}"}
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
prompt = self.generator_tokenizer.apply_chat_template(
|
| 194 |
+
messages,
|
| 195 |
+
tokenize=False,
|
| 196 |
+
add_generation_prompt=True,
|
| 197 |
+
enable_thinking=True
|
| 198 |
+
)
|
| 199 |
+
model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device)
|
| 200 |
+
|
| 201 |
+
streamer = ThinkStreamer(self.generator_tokenizer, skip_prompt=True)
|
| 202 |
+
|
| 203 |
+
generation_kwargs = dict(
|
| 204 |
+
**model_inputs,
|
| 205 |
+
max_new_tokens=self.generator_new_tokens,
|
| 206 |
+
streamer=streamer
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
thread = threading.Thread(target=self.generator_model.generate, kwargs=generation_kwargs)
|
| 210 |
+
thread.start()
|
| 211 |
+
|
| 212 |
+
yield from streamer.generate_output()
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
Mini_RAG/core/loader.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2025/4/24 19:51
|
| 4 |
+
# @Author : hukangzhe
|
| 5 |
+
# @File : loader.py
|
| 6 |
+
# @Description : 文档加载模块,不在像reg_mini返回长字符串,而是返回一个document对象列表,每个document都带有源文件信息
|
| 7 |
+
import logging
|
| 8 |
+
from .schema import Document
|
| 9 |
+
from typing import List
|
| 10 |
+
import fitz
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DocumentLoader:
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def _load_pdf(file_path):
|
| 18 |
+
logging.info(f"Loading PDF file from {file_path}")
|
| 19 |
+
try:
|
| 20 |
+
with fitz.open(file_path) as f:
|
| 21 |
+
text = "".join(page.get_text() for page in f)
|
| 22 |
+
logging.info(f"Successfully loaded {len(f)} pages.")
|
| 23 |
+
return text
|
| 24 |
+
except Exception as e:
|
| 25 |
+
logging.error(f"Failed to load PDF {file_path}: {e}")
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class MultiDocumentLoader:
|
| 30 |
+
def __init__(self, paths: List[str]):
|
| 31 |
+
self.paths = paths
|
| 32 |
+
|
| 33 |
+
def load(self) -> List[Document]:
|
| 34 |
+
docs = []
|
| 35 |
+
for file_path in self.paths:
|
| 36 |
+
file_extension = os.path.splitext(file_path)[1].lower()
|
| 37 |
+
if file_extension == '.pdf':
|
| 38 |
+
docs.extend(self._load_pdf(file_path))
|
| 39 |
+
elif file_extension == '.txt':
|
| 40 |
+
docs.extend(self._load_txt(file_path))
|
| 41 |
+
else:
|
| 42 |
+
logging.warning(f"Unsupported file type:{file_extension}. Skipping {file_path}")
|
| 43 |
+
|
| 44 |
+
return docs
|
| 45 |
+
|
| 46 |
+
def _load_pdf(self, file_path: str) -> List[Document]:
|
| 47 |
+
logging.info(f"Loading PDF file from {file_path}")
|
| 48 |
+
try:
|
| 49 |
+
pdf_docs = []
|
| 50 |
+
with fitz.open(file_path) as doc:
|
| 51 |
+
for i, page in enumerate(doc):
|
| 52 |
+
pdf_docs.append(Document(
|
| 53 |
+
text=page.get_text(),
|
| 54 |
+
metadata={'source': file_path, 'page': i + 1}
|
| 55 |
+
))
|
| 56 |
+
return pdf_docs
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logging.error(f"Failed to load PDF {file_path}: {e}")
|
| 59 |
+
return []
|
| 60 |
+
|
| 61 |
+
def _load_txt(self, file_path: str) -> List[Document]:
|
| 62 |
+
logging.info(f"Loading txt file from {file_path}")
|
| 63 |
+
try:
|
| 64 |
+
txt_docs = []
|
| 65 |
+
with open(file_path, 'r', encoding="utf-8") as f:
|
| 66 |
+
txt_docs.append(Document(
|
| 67 |
+
text=f.read(),
|
| 68 |
+
metadata={'source': file_path}
|
| 69 |
+
))
|
| 70 |
+
return txt_docs
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logging.error(f"Failed to load txt {file_path}:{e}")
|
| 73 |
+
return []
|
Mini_RAG/core/schema.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2025/4/24 21:11
|
| 4 |
+
# @Author : hukangzhe
|
| 5 |
+
# @File : schema.py
|
| 6 |
+
# @Description : 不直接处理纯文本字符串. 引入一个标准化的数据结构来承载文本块,既有内容,也有元数据。
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import Dict, Any
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class Document:
|
| 14 |
+
text: str
|
| 15 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class Chunk(Document):
|
| 20 |
+
parent_id: int = None
|
Mini_RAG/core/splitter.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2025/4/25 19:52
|
| 4 |
+
# @Author : hukangzhe
|
| 5 |
+
# @File : splitter.py
|
| 6 |
+
# @Description : 负责切分文本的模块
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Tuple
|
| 10 |
+
from .schema import Document, Chunk
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SemanticRecursiveSplitter:
|
| 14 |
+
def __init__(self, chunk_size: int=500, chunk_overlap:int = 50, separators: List[str] = None):
|
| 15 |
+
"""
|
| 16 |
+
一个真正实现递归的语义文本切分器。
|
| 17 |
+
:param chunk_size: 每个文本块的目标大小。
|
| 18 |
+
:param chunk_overlap: 文本块之间的重叠大小。
|
| 19 |
+
:param separators: 用于切分的语义分隔符列表,按优先级从高到低排列。
|
| 20 |
+
"""
|
| 21 |
+
self.chunk_size = chunk_size
|
| 22 |
+
self.chunk_overlap = chunk_overlap
|
| 23 |
+
if self.chunk_size <= self.chunk_overlap:
|
| 24 |
+
raise ValueError("Chunk overlap must be smaller than chunk size.")
|
| 25 |
+
|
| 26 |
+
self.separators = separators if separators else ['\n\n', '\n', " ", ""] # 默认分割符
|
| 27 |
+
|
| 28 |
+
def text_split(self, text: str) -> List[str]:
|
| 29 |
+
"""
|
| 30 |
+
切分入口
|
| 31 |
+
:param text:
|
| 32 |
+
:return:
|
| 33 |
+
"""
|
| 34 |
+
logging.info("Starting semantic recursive splitting...")
|
| 35 |
+
final_chunks = self._split(text, self.separators)
|
| 36 |
+
logging.info(f"Text successfully split into {len(final_chunks)} chunks.")
|
| 37 |
+
return final_chunks
|
| 38 |
+
|
| 39 |
+
def _split(self, text: str, separators: List[str]) -> List[str]:
|
| 40 |
+
final_chunks = []
|
| 41 |
+
# 1. 如果文本足够小,直接返回
|
| 42 |
+
if len(text) < self.chunk_size:
|
| 43 |
+
return [text]
|
| 44 |
+
# 2. 先尝试最高优先的分割符
|
| 45 |
+
cur_separator = separators[0]
|
| 46 |
+
|
| 47 |
+
# 3. 如果可以分割
|
| 48 |
+
if cur_separator in text:
|
| 49 |
+
# 分割成多个小部分
|
| 50 |
+
parts = text.split(cur_separator)
|
| 51 |
+
|
| 52 |
+
buffer="" # 用来合并小部分
|
| 53 |
+
for i, part in enumerate(parts):
|
| 54 |
+
# 如果小于chunk_size,就再加一小部分,使buffer接近chunk_size
|
| 55 |
+
if len(buffer) + len(part) + len(cur_separator) <= self.chunk_size:
|
| 56 |
+
buffer += part+cur_separator
|
| 57 |
+
else:
|
| 58 |
+
# 如果buffer 不为空
|
| 59 |
+
if buffer:
|
| 60 |
+
final_chunks.append(buffer)
|
| 61 |
+
# 如果当前part就已经超过chunk_size
|
| 62 |
+
if len(part) > self.chunk_size:
|
| 63 |
+
# 递归调用下一级
|
| 64 |
+
sub_chunks = self._split(part, separators = separators[1:])
|
| 65 |
+
final_chunks.extend(sub_chunks)
|
| 66 |
+
else: # 成为新的缓冲区
|
| 67 |
+
buffer = part + cur_separator
|
| 68 |
+
|
| 69 |
+
if buffer: # 最后一部分的缓冲区
|
| 70 |
+
final_chunks.append(buffer.strip())
|
| 71 |
+
|
| 72 |
+
else:
|
| 73 |
+
# 4. 使用下一级分隔符
|
| 74 |
+
final_chunks = self._split(text, separators[1:])
|
| 75 |
+
|
| 76 |
+
# 处理重叠
|
| 77 |
+
if self.chunk_overlap > 0:
|
| 78 |
+
return self._handle_overlap(final_chunks)
|
| 79 |
+
else:
|
| 80 |
+
return final_chunks
|
| 81 |
+
|
| 82 |
+
def _handle_overlap(self, final_chunks: List[str]) -> List[str]:
|
| 83 |
+
overlap_chunks = []
|
| 84 |
+
if not final_chunks:
|
| 85 |
+
return []
|
| 86 |
+
overlap_chunks.append(final_chunks[0])
|
| 87 |
+
for i in range(1, len(final_chunks)):
|
| 88 |
+
pre_chunk = overlap_chunks[-1]
|
| 89 |
+
cur_chunk = final_chunks[i]
|
| 90 |
+
# 从前一个chunk取出重叠部分与当前chunk合并
|
| 91 |
+
overlap_part = pre_chunk[-self.chunk_overlap:]
|
| 92 |
+
overlap_chunks.append(overlap_part+cur_chunk)
|
| 93 |
+
|
| 94 |
+
return overlap_chunks
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class HierarchicalSemanticSplitter:
|
| 98 |
+
"""
|
| 99 |
+
结合了层次化(父/子)和递归语义分割策略。确保在创建父块和子块时,遵循文本的自然语义边界。
|
| 100 |
+
"""
|
| 101 |
+
def __init__(self,
|
| 102 |
+
parent_chunk_size: int = 800,
|
| 103 |
+
parent_chunk_overlap: int = 100,
|
| 104 |
+
child_chunk_size: int = 250,
|
| 105 |
+
separators: List[str] = None):
|
| 106 |
+
if parent_chunk_overlap >= parent_chunk_size:
|
| 107 |
+
raise ValueError("Parent chunk overlap must be smaller than parent chunk size.")
|
| 108 |
+
if child_chunk_size >= parent_chunk_size:
|
| 109 |
+
raise ValueError("Child chunk size must be smaller than parent chunk size.")
|
| 110 |
+
|
| 111 |
+
self.parent_chunk_size = parent_chunk_size
|
| 112 |
+
self.parent_chunk_overlap = parent_chunk_overlap
|
| 113 |
+
self.child_chunk_size = child_chunk_size
|
| 114 |
+
self.separators = separators or ["\n\n", "\n", "。", ". ", "!", "!", "?", "?", " ", ""]
|
| 115 |
+
|
| 116 |
+
def _recursive_semantic_split(self, text: str, chunk_size: int) -> List[str]:
|
| 117 |
+
"""
|
| 118 |
+
优先考虑语义边界
|
| 119 |
+
"""
|
| 120 |
+
if len(text) <= chunk_size:
|
| 121 |
+
return [text]
|
| 122 |
+
|
| 123 |
+
for sep in self.separators:
|
| 124 |
+
split_point = text.rfind(sep, 0, chunk_size)
|
| 125 |
+
if split_point != -1:
|
| 126 |
+
break
|
| 127 |
+
else:
|
| 128 |
+
split_point = chunk_size
|
| 129 |
+
|
| 130 |
+
chunk1 = text[:split_point]
|
| 131 |
+
remaining_text = text[split_point:].lstrip() # 删除剩余部分的前空格
|
| 132 |
+
|
| 133 |
+
# 递归拆分剩余文本
|
| 134 |
+
# 分隔符将添加回第一个块以保持上下文
|
| 135 |
+
if remaining_text:
|
| 136 |
+
return [chunk1 + (sep if sep in " \n" else "")] + self._recursive_semantic_split(remaining_text, chunk_size)
|
| 137 |
+
else:
|
| 138 |
+
return [chunk1]
|
| 139 |
+
|
| 140 |
+
def _apply_overlap(self, chunks: List[str], overlap: int) -> List[str]:
|
| 141 |
+
"""处理重叠部分chunk"""
|
| 142 |
+
if not overlap or len(chunks) <= 1:
|
| 143 |
+
return chunks
|
| 144 |
+
|
| 145 |
+
overlapped_chunks = [chunks[0]]
|
| 146 |
+
for i in range(1, len(chunks)):
|
| 147 |
+
# 从前一个chunk中获取最后的“重叠”字符
|
| 148 |
+
overlap_content = chunks[i - 1][-overlap:]
|
| 149 |
+
overlapped_chunks.append(overlap_content + chunks[i])
|
| 150 |
+
|
| 151 |
+
return overlapped_chunks
|
| 152 |
+
|
| 153 |
+
def split_documents(self, documents: List[Document]) -> Tuple[Dict[int, Document], List[Chunk]]:
|
| 154 |
+
"""
|
| 155 |
+
两次切分
|
| 156 |
+
:param documents:
|
| 157 |
+
:return:
|
| 158 |
+
- parent documents: {parent_id: Document}
|
| 159 |
+
- child chunks: [Chunk, Chunk, ...]
|
| 160 |
+
"""
|
| 161 |
+
parent_docs_dict: Dict[int, Document] = {}
|
| 162 |
+
child_chunks_list: List[Chunk] = []
|
| 163 |
+
parent_id_counter = 0
|
| 164 |
+
|
| 165 |
+
logging.info("Starting robust hierarchical semantic splitting...")
|
| 166 |
+
|
| 167 |
+
for doc in documents:
|
| 168 |
+
# === PASS 1: 创建父chunks ===
|
| 169 |
+
# 1. 将整个文档text分割成大的语义chunks
|
| 170 |
+
initial_parent_chunks = self._recursive_semantic_split(doc.text, self.parent_chunk_size)
|
| 171 |
+
|
| 172 |
+
# 2. 父chunks进行重叠处理
|
| 173 |
+
overlapped_parent_texts = self._apply_overlap(initial_parent_chunks, self.parent_chunk_overlap)
|
| 174 |
+
|
| 175 |
+
for p_text in overlapped_parent_texts:
|
| 176 |
+
parent_doc = Document(text=p_text, metadata=doc.metadata.copy())
|
| 177 |
+
parent_docs_dict[parent_id_counter] = parent_doc
|
| 178 |
+
|
| 179 |
+
# === PASS 2: Create Child Chunks from each Parent ===
|
| 180 |
+
child_texts = self._recursive_semantic_split(p_text, self.child_chunk_size)
|
| 181 |
+
|
| 182 |
+
for c_text in child_texts:
|
| 183 |
+
child_metadata = doc.metadata.copy()
|
| 184 |
+
child_metadata['parent_id'] = parent_id_counter
|
| 185 |
+
child_chunk = Chunk(text=c_text, metadata=child_metadata, parent_id=parent_id_counter)
|
| 186 |
+
child_chunks_list.append(child_chunk)
|
| 187 |
+
|
| 188 |
+
parent_id_counter += 1
|
| 189 |
+
|
| 190 |
+
logging.info(
|
| 191 |
+
f"Splitting complete. Created {len(parent_docs_dict)} parent chunks and {len(child_chunks_list)} child chunks.")
|
| 192 |
+
return parent_docs_dict, child_chunks_list
|
Mini_RAG/core/test_qw.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2025/4/29 20:03
|
| 4 |
+
# @Author : hukangzhe
|
| 5 |
+
# @File : test_qw.py
|
| 6 |
+
# @Description : 测试两个模型(qwen1.5 qwen3)的两种输出方式(full or stream)是否正确
|
| 7 |
+
import os
|
| 8 |
+
import queue
|
| 9 |
+
import logging
|
| 10 |
+
import threading
|
| 11 |
+
import torch
|
| 12 |
+
from typing import Tuple, Generator
|
| 13 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 14 |
+
|
| 15 |
+
class ThinkStreamer(TextStreamer):
|
| 16 |
+
def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool =True, **decode_kwargs):
|
| 17 |
+
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
| 18 |
+
self.is_thinking = True
|
| 19 |
+
self.think_end_token_id = self.tokenizer.encode("</think>", add_special_tokens=False)[0]
|
| 20 |
+
self.output_queue = queue.Queue()
|
| 21 |
+
|
| 22 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 23 |
+
self.output_queue.put(text)
|
| 24 |
+
if stream_end:
|
| 25 |
+
self.output_queue.put(None) # 发送结束信号
|
| 26 |
+
|
| 27 |
+
def __iter__(self):
|
| 28 |
+
return self
|
| 29 |
+
|
| 30 |
+
def __next__(self):
|
| 31 |
+
value = self.output_queue.get()
|
| 32 |
+
if value is None:
|
| 33 |
+
raise StopIteration()
|
| 34 |
+
return value
|
| 35 |
+
|
| 36 |
+
def generate_output(self) -> Generator[Tuple[str, str], None, None]:
|
| 37 |
+
full_decode_text = ""
|
| 38 |
+
already_yielded_len = 0
|
| 39 |
+
for text_chunk in self:
|
| 40 |
+
if not self.is_thinking:
|
| 41 |
+
yield "answer", text_chunk
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
full_decode_text += text_chunk
|
| 45 |
+
tokens = self.tokenizer.encode(full_decode_text, add_special_tokens=False)
|
| 46 |
+
|
| 47 |
+
if self.think_end_token_id in tokens:
|
| 48 |
+
spilt_point = tokens.index(self.think_end_token_id)
|
| 49 |
+
think_part_tokens = tokens[:spilt_point]
|
| 50 |
+
thinking_text = self.tokenizer.decode(think_part_tokens)
|
| 51 |
+
|
| 52 |
+
answer_part_tokens = tokens[spilt_point:]
|
| 53 |
+
answer_text = self.tokenizer.decode(answer_part_tokens)
|
| 54 |
+
remaining_thinking = thinking_text[already_yielded_len:]
|
| 55 |
+
if remaining_thinking:
|
| 56 |
+
yield "thinking", remaining_thinking
|
| 57 |
+
|
| 58 |
+
if answer_text:
|
| 59 |
+
yield "answer", answer_text
|
| 60 |
+
|
| 61 |
+
self.is_thinking = False
|
| 62 |
+
already_yielded_len = len(thinking_text) + len(self.tokenizer.decode(self.think_end_token_id))
|
| 63 |
+
else:
|
| 64 |
+
yield "thinking", text_chunk
|
| 65 |
+
already_yielded_len += len(text_chunk)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class LLMInterface:
|
| 70 |
+
def __init__(self, model_name: str= "Qwen/Qwen3-0.6B"):
|
| 71 |
+
logging.info(f"Initializing generator {model_name}")
|
| 72 |
+
self.generator_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 73 |
+
self.generator_model = AutoModelForCausalLM.from_pretrained(
|
| 74 |
+
model_name,
|
| 75 |
+
torch_dtype="auto",
|
| 76 |
+
device_map="auto")
|
| 77 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 78 |
+
|
| 79 |
+
def generate_answer(self, query: str, context_str: str) -> str:
|
| 80 |
+
messages = [
|
| 81 |
+
{"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
|
| 82 |
+
{"role": "user", "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题:{query}"}
|
| 83 |
+
]
|
| 84 |
+
prompt = self.generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 85 |
+
inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 86 |
+
output = self.generator_model.generate(**inputs,
|
| 87 |
+
max_new_tokens=256, num_return_sequences=1,
|
| 88 |
+
eos_token_id=self.generator_tokenizer.eos_token_id)
|
| 89 |
+
generated_ids = output[0][inputs["input_ids"].shape[1]:]
|
| 90 |
+
answer = self.generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 91 |
+
return answer
|
| 92 |
+
|
| 93 |
+
def generate_answer_stream(self, query: str, context_str: str) -> Generator[Tuple[str, str], None, None]:
|
| 94 |
+
"""Generates an answer as a stream of (state, content) tuples."""
|
| 95 |
+
messages = [
|
| 96 |
+
{"role": "system",
|
| 97 |
+
"content": "You are a helpful assistant. Please answer the question based on the provided context. First, think through the process in <think> tags, then provide the final answer."},
|
| 98 |
+
{"role": "user",
|
| 99 |
+
"content": f"Context:\n---\n{context_str}\n---\nBased on the context above, please answer the question: {query}"}
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
# Use the template that enables thinking for Qwen models
|
| 103 |
+
prompt = self.generator_tokenizer.apply_chat_template(
|
| 104 |
+
messages,
|
| 105 |
+
tokenize=False,
|
| 106 |
+
add_generation_prompt=True,
|
| 107 |
+
enable_thinking=True
|
| 108 |
+
)
|
| 109 |
+
model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device)
|
| 110 |
+
|
| 111 |
+
streamer = ThinkStreamer(self.generator_tokenizer, skip_prompt=True)
|
| 112 |
+
|
| 113 |
+
generation_kwargs = dict(
|
| 114 |
+
**model_inputs,
|
| 115 |
+
max_new_tokens=512,
|
| 116 |
+
streamer=streamer
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
thread = threading.Thread(target=self.generator_model.generate, kwargs=generation_kwargs)
|
| 120 |
+
thread.start()
|
| 121 |
+
|
| 122 |
+
yield from streamer.generate_output()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# if __name__ == "__main__":
|
| 126 |
+
# qwen = LLMInterface("Qwen/Qwen3-0.6B")
|
| 127 |
+
# answer = qwen.generate_answer("儒家思想的创始人是谁?", "中国传统哲学以儒家、道家和法家为主要流派。儒家思想由孔子创立,强调“仁”、“义”、“礼”、“智”、“信”,主张修身齐家治国平天下,对中国社会产生了深远的影响。其核心价值观如“己所不欲,勿施于人”至今仍具有普世意义。"+
|
| 128 |
+
#
|
| 129 |
+
# "道家思想以老子和庄子为代表,主张“道法自然”,追求人与自然的和谐统一,强调无为而治、清静无为。道家思想对中国人的审美情趣、艺术创作以及养生之道都有着重要的影响。"+
|
| 130 |
+
#
|
| 131 |
+
# "法家思想以韩非子为集大成者,主张以法治国,强调君主的权威和法律的至高无上。尽管法家思想在历史上曾被用于强化中央集权,但其对建立健全的法律体系也提供了重要的理论基础。")
|
| 132 |
+
#
|
| 133 |
+
# print(answer)
|
Mini_RAG/core/vector_store.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2025/4/27 19:52
|
| 4 |
+
# @Author : hukangzhe
|
| 5 |
+
# @File : retriever.py
|
| 6 |
+
# @Description : 负责向量化、存储、检索的模块
|
| 7 |
+
import os
|
| 8 |
+
import faiss
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pickle
|
| 11 |
+
import logging
|
| 12 |
+
from rank_bm25 import BM25Okapi
|
| 13 |
+
from typing import List, Dict, Tuple
|
| 14 |
+
from .schema import Document, Chunk
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class HybridVectorStore:
|
| 18 |
+
def __init__(self, config: dict, embedder):
|
| 19 |
+
self.config = config["vector_store"]
|
| 20 |
+
self.embedder = embedder
|
| 21 |
+
self.faiss_index = None
|
| 22 |
+
self.bm25_index = None
|
| 23 |
+
self.parent_docs: Dict[int, Document] = {}
|
| 24 |
+
self.child_chunks: List[Chunk] = []
|
| 25 |
+
|
| 26 |
+
def build(self, parent_docs: Dict[int, Document], child_chunks: List[Chunk]):
|
| 27 |
+
self.parent_docs = parent_docs
|
| 28 |
+
self.child_chunks = child_chunks
|
| 29 |
+
|
| 30 |
+
# Build Faiss index
|
| 31 |
+
child_text = [child.text for child in child_chunks]
|
| 32 |
+
embeddings = self.embedder.embed(child_text)
|
| 33 |
+
dim = embeddings.shape[1]
|
| 34 |
+
self.faiss_index = faiss.IndexFlatL2(dim)
|
| 35 |
+
self.faiss_index.add(embeddings)
|
| 36 |
+
logging.info(f"FAISS index built with {len(child_chunks)} vectors.")
|
| 37 |
+
|
| 38 |
+
# Build BM25 index
|
| 39 |
+
tokenize_chunks = [doc.text.split(" ") for doc in child_chunks]
|
| 40 |
+
self.bm25_index = BM25Okapi(tokenize_chunks)
|
| 41 |
+
logging.info(f"BM25 index built for {len(child_chunks)} documents.")
|
| 42 |
+
|
| 43 |
+
self.save()
|
| 44 |
+
|
| 45 |
+
def search(self, query: str, top_k: int , alpha: float) -> List[Tuple[int, float]]:
|
| 46 |
+
# Vector Search
|
| 47 |
+
query_embedding = self.embedder.embed([query])
|
| 48 |
+
distances, indices = self.faiss_index.search(query_embedding, k=top_k)
|
| 49 |
+
vector_scores = {idx : 1.0/(1.0 + dist) for idx, dist in zip(indices[0], distances[0])}
|
| 50 |
+
|
| 51 |
+
# BM25 Search
|
| 52 |
+
tokenize_query = query.split(" ")
|
| 53 |
+
bm25_scores = self.bm25_index.get_scores(tokenize_query)
|
| 54 |
+
bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
|
| 55 |
+
bm25_scores = {idx: bm25_scores[idx] for idx in bm25_top_indices}
|
| 56 |
+
|
| 57 |
+
# Hybrid Search
|
| 58 |
+
all_indices = set(vector_scores.keys()) | set(bm25_scores.keys()) # 求并集
|
| 59 |
+
hybrid_scors = {}
|
| 60 |
+
|
| 61 |
+
# Normalization
|
| 62 |
+
max_v_score = max(vector_scores.values()) if vector_scores else 1.0
|
| 63 |
+
max_b_score = max(bm25_scores.values()) if bm25_scores else 1.0
|
| 64 |
+
for idx in all_indices:
|
| 65 |
+
v_score = (vector_scores.get(idx, 0))/max_v_score
|
| 66 |
+
b_score = (bm25_scores.get(idx, 0))/max_b_score
|
| 67 |
+
hybrid_scors[idx] = alpha * v_score + (1 - alpha) * b_score
|
| 68 |
+
|
| 69 |
+
sorted_indices = sorted(hybrid_scors.items(), key=lambda item: item[1], reverse=True)[:top_k]
|
| 70 |
+
return sorted_indices
|
| 71 |
+
|
| 72 |
+
def get_chunks(self, indices: List[int]) -> List[Chunk]:
|
| 73 |
+
return [self.child_chunks[i] for i in indices]
|
| 74 |
+
|
| 75 |
+
def get_parent_docs(self, chunks: List[Chunk]) -> List[Document]:
|
| 76 |
+
parent_ids = sorted(list(set(chunk.parent_id for chunk in chunks)))
|
| 77 |
+
return [self.parent_docs[pid] for pid in parent_ids]
|
| 78 |
+
|
| 79 |
+
def save(self):
|
| 80 |
+
index_path = self.config['index_path']
|
| 81 |
+
metadata_path = self.config['metadata_path']
|
| 82 |
+
|
| 83 |
+
os.makedirs(os.path.dirname(index_path), exist_ok=True)
|
| 84 |
+
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
|
| 85 |
+
logging.info(f"Saving FAISS index to: {index_path}")
|
| 86 |
+
try:
|
| 87 |
+
faiss.write_index(self.faiss_index, index_path)
|
| 88 |
+
except Exception as e:
|
| 89 |
+
logging.error(f"Failed to save FAISS index: {e}")
|
| 90 |
+
raise
|
| 91 |
+
|
| 92 |
+
logging.info(f"Saving metadata data to: {metadata_path}")
|
| 93 |
+
try:
|
| 94 |
+
with open(metadata_path, 'wb') as f:
|
| 95 |
+
metadata = {
|
| 96 |
+
'parent_docs': self.parent_docs,
|
| 97 |
+
'child_chunks': self.child_chunks,
|
| 98 |
+
'bm25_index': self.bm25_index
|
| 99 |
+
}
|
| 100 |
+
pickle.dump(metadata, f)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logging.error(f"Failed to save metadata: {e}")
|
| 103 |
+
raise
|
| 104 |
+
|
| 105 |
+
logging.info("Vector store saved successfully.")
|
| 106 |
+
|
| 107 |
+
def load(self) -> bool:
|
| 108 |
+
"""
|
| 109 |
+
从磁盘加载整个向量存储状态,成功时返回 True,失败时返回 False。
|
| 110 |
+
"""
|
| 111 |
+
index_path = self.config['index_path']
|
| 112 |
+
metadata_path = self.config['metadata_path']
|
| 113 |
+
|
| 114 |
+
if not os.path.exists(index_path) or not os.path.exists(metadata_path):
|
| 115 |
+
logging.warning("Index files not found. Cannot load vector store.")
|
| 116 |
+
return False
|
| 117 |
+
|
| 118 |
+
logging.info(f"Loading vector store from disk...")
|
| 119 |
+
try:
|
| 120 |
+
# Load FAISS index
|
| 121 |
+
logging.info(f"Loading FAISS index from: {index_path}")
|
| 122 |
+
self.faiss_index = faiss.read_index(index_path)
|
| 123 |
+
|
| 124 |
+
# Load metadata
|
| 125 |
+
logging.info(f"Loading metadata from: {metadata_path}")
|
| 126 |
+
with open(metadata_path, 'rb') as f:
|
| 127 |
+
metadata = pickle.load(f)
|
| 128 |
+
self.parent_docs = metadata['parent_docs']
|
| 129 |
+
self.child_chunks = metadata['child_chunks']
|
| 130 |
+
self.bm25_index = metadata['bm25_index']
|
| 131 |
+
|
| 132 |
+
logging.info("Vector store loaded successfully.")
|
| 133 |
+
return True
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logging.error(f"Failed to load vector store from disk: {e}")
|
| 137 |
+
self.faiss_index = None
|
| 138 |
+
self.bm25_index = None
|
| 139 |
+
self.parent_docs = {}
|
| 140 |
+
self.child_chunks = []
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
Mini_RAG/data/chinese_document.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0bde4aba4245bb013438240cabf4c149d8aec76536f7a1459a165db02ec3695c
|
| 3 |
+
size 317035
|
Mini_RAG/data/english_document.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5aa48e1b7d30d8cf2ac87cae6e2fee73e2de4d87178d0878d0d34b499fc7b26d
|
| 3 |
+
size 166331
|
Mini_RAG/main.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2025/5/1 14:00
|
| 4 |
+
# @Author : hukangzhe
|
| 5 |
+
# @File : main.py
|
| 6 |
+
# @Description :
|
| 7 |
+
import yaml
|
| 8 |
+
from service.rag_service import RAGService
|
| 9 |
+
from ui.app import GradioApp
|
| 10 |
+
from utils.logger import setup_logger
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def main():
|
| 14 |
+
setup_logger()
|
| 15 |
+
|
| 16 |
+
with open('configs/config.yaml', 'r', encoding='utf-8') as f:
|
| 17 |
+
config = yaml.safe_load(f)
|
| 18 |
+
|
| 19 |
+
rag_service = RAGService(config)
|
| 20 |
+
app = GradioApp(rag_service)
|
| 21 |
+
app.launch()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
main()
|
Mini_RAG/rag_mini.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2025/4/24 14:04
|
| 4 |
+
# @Author : hukangzhe
|
| 5 |
+
# @File : rag_core.py
|
| 6 |
+
# @Description :非常简单的RAG系统
|
| 7 |
+
|
| 8 |
+
import PyPDF2
|
| 9 |
+
import fitz
|
| 10 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 11 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
| 12 |
+
import numpy as np
|
| 13 |
+
import faiss
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RAGSystem:
|
| 18 |
+
def __init__(self, pdf_path):
|
| 19 |
+
self.pdf_path = pdf_path
|
| 20 |
+
self.texts = self._load_and_spilt_pdf()
|
| 21 |
+
self.embedder = SentenceTransformer('moka-ai/m3e-base')
|
| 22 |
+
self.reranker = CrossEncoder('BAAI/bge-reranker-base') # 加载一个reranker模型
|
| 23 |
+
self.vector_store = self._create_vector_store()
|
| 24 |
+
print("3. Initializing Generator Model...")
|
| 25 |
+
model_name = "Qwen/Qwen1.5-1.8B-Chat"
|
| 26 |
+
|
| 27 |
+
# 检查是否有可用的GPU
|
| 28 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
print(f" - Using device: {device}")
|
| 30 |
+
|
| 31 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 32 |
+
# 注意:对于像Qwen这样的模型,我们通常使用 AutoModelForCausalLM
|
| 33 |
+
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
| 34 |
+
|
| 35 |
+
self.generator = pipeline(
|
| 36 |
+
'text-generation',
|
| 37 |
+
model=model,
|
| 38 |
+
tokenizer=self.tokenizer,
|
| 39 |
+
device=device
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 1. 文档加载 & 2.文本切分 (为了简化,合在一起)
|
| 43 |
+
def _load_and_spilt_pdf(self):
|
| 44 |
+
print("1. Loading and splitting PDF...")
|
| 45 |
+
full_text = ""
|
| 46 |
+
with fitz.open(self.pdf_path) as doc:
|
| 47 |
+
for page in doc:
|
| 48 |
+
full_text += page.get_text()
|
| 49 |
+
|
| 50 |
+
# 非常基础的切分:根据固定大小
|
| 51 |
+
chunk_size = 500
|
| 52 |
+
overlap = 50
|
| 53 |
+
chunks = [full_text[i: i+chunk_size] for i in range(0, len(full_text), chunk_size-overlap)]
|
| 54 |
+
print(f" - Splitted into {len(chunks)} chunks.")
|
| 55 |
+
return chunks
|
| 56 |
+
|
| 57 |
+
# 3. 文本向量化 & 向量存储
|
| 58 |
+
def _create_vector_store(self):
|
| 59 |
+
print("2. Creating vector store...")
|
| 60 |
+
# embedding
|
| 61 |
+
embeddings = self.embedder.encode(self.texts)
|
| 62 |
+
|
| 63 |
+
# Storing with faiss
|
| 64 |
+
dim = embeddings.shape[1]
|
| 65 |
+
index = faiss.IndexFlatL2(dim) # 使用L2距离进行相似度计算
|
| 66 |
+
index.add(np.array(embeddings))
|
| 67 |
+
print(" - Created vector store")
|
| 68 |
+
return index
|
| 69 |
+
|
| 70 |
+
# 4.检索
|
| 71 |
+
def retrieve(self, query, k=3):
|
| 72 |
+
print(f"3. Retrieving top {k} relevant chunks for query: '{query}' ")
|
| 73 |
+
query_embeddings = self.embedder.encode([query])
|
| 74 |
+
|
| 75 |
+
distances, indices = self.vector_store.search(np.array(query_embeddings), k=k)
|
| 76 |
+
retrieved_chunks = [self.texts[i] for i in indices[0]]
|
| 77 |
+
print(" - Retrieval complete.")
|
| 78 |
+
return retrieved_chunks
|
| 79 |
+
|
| 80 |
+
# 5.生成
|
| 81 |
+
def generate(self, query, context_chunks):
|
| 82 |
+
print("4. Generate answer...")
|
| 83 |
+
context = "\n".join(context_chunks)
|
| 84 |
+
|
| 85 |
+
messages = [
|
| 86 |
+
{"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
|
| 87 |
+
{"role": "user", "content": f"上下文:\n---\n{context}\n---\n请根据以上上下文回答这个问题:{query}"}
|
| 88 |
+
]
|
| 89 |
+
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 90 |
+
|
| 91 |
+
# print("Final Prompt:\n", prompt)
|
| 92 |
+
# print("Prompt token length:", len(self.tokenizer.encode(prompt)))
|
| 93 |
+
|
| 94 |
+
result = self.generator(prompt, max_new_tokens=200, num_return_sequences=1,
|
| 95 |
+
eos_token_id=self.tokenizer.eos_token_id)
|
| 96 |
+
|
| 97 |
+
print(" - Generation complete.")
|
| 98 |
+
# print("Raw results:", result)
|
| 99 |
+
# 提取生成的文本
|
| 100 |
+
# 注意:Qwen模型返回的文本包含了prompt,我们需要从中提取出答案部分
|
| 101 |
+
full_response = result[0]["generated_text"]
|
| 102 |
+
answer = full_response[len(prompt):].strip() # 从prompt之后开始截取
|
| 103 |
+
|
| 104 |
+
# print("Final Answer:", repr(answer))
|
| 105 |
+
return answer
|
| 106 |
+
|
| 107 |
+
# 优化1
|
| 108 |
+
def rerank(self, query, chunks):
|
| 109 |
+
print(" - Reranking retrieved chunks...")
|
| 110 |
+
pairs = [[query, chunk] for chunk in chunks]
|
| 111 |
+
scores = self.reranker.predict(pairs)
|
| 112 |
+
|
| 113 |
+
# 将chunks和scores打包,并按score降序排序
|
| 114 |
+
ranked_chunks = sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
| 115 |
+
return [chunk for chunk, score in ranked_chunks]
|
| 116 |
+
|
| 117 |
+
def query(self, query_text):
|
| 118 |
+
# 1. 检索(可以检索更多结果,如top 10)
|
| 119 |
+
retrieved_chunks = self.retrieve(query_text, k=10)
|
| 120 |
+
|
| 121 |
+
# 2. 重排(从10个中选出最相关的3个)
|
| 122 |
+
reranked_chunks = self.rerank(query_text, retrieved_chunks)
|
| 123 |
+
top_k_reranked = reranked_chunks[:3]
|
| 124 |
+
|
| 125 |
+
answer = self.generate(query_text, top_k_reranked)
|
| 126 |
+
return answer
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def main():
|
| 130 |
+
# 确保你的data文件夹里有一个叫做sample.pdf的文件
|
| 131 |
+
pdf_path = 'data/chinese_document.pdf'
|
| 132 |
+
|
| 133 |
+
print("Initializing RAG System...")
|
| 134 |
+
rag_system = RAGSystem(pdf_path)
|
| 135 |
+
print("\nRAG System is ready. You can start asking questions.")
|
| 136 |
+
print("Type 'q' to quit.")
|
| 137 |
+
|
| 138 |
+
while True:
|
| 139 |
+
user_query = input("\nYour Question: ")
|
| 140 |
+
if user_query.lower() == 'q':
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
answer = rag_system.query(user_query)
|
| 144 |
+
print("\nAnswer:", answer)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
main()
|
Mini_RAG/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pyyaml
|
| 2 |
+
pypdf2
|
| 3 |
+
PyMuPDF
|
| 4 |
+
sentence-transformers
|
| 5 |
+
faiss-cpu
|
| 6 |
+
rank_bm25
|
| 7 |
+
transformers
|
| 8 |
+
torch
|
| 9 |
+
gradio
|
Mini_RAG/service/__init__.py
ADDED
|
File without changes
|
Mini_RAG/service/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (126 Bytes). View file
|
|
|
Mini_RAG/service/__pycache__/rag_service.cpython-39.pyc
ADDED
|
Binary file (4.49 kB). View file
|
|
|
Mini_RAG/service/rag_service.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2025/4/30 11:50
|
| 4 |
+
# @Author : hukangzhe
|
| 5 |
+
# @File : rag_service.py
|
| 6 |
+
# @Description :
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
from typing import List, Generator, Tuple
|
| 10 |
+
from core.schema import Document
|
| 11 |
+
from core.embedder import EmbeddingModel
|
| 12 |
+
from core.loader import MultiDocumentLoader
|
| 13 |
+
from core.splitter import HierarchicalSemanticSplitter
|
| 14 |
+
from core.vector_store import HybridVectorStore
|
| 15 |
+
from core.llm_interface import LLMInterface
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RAGService:
|
| 19 |
+
def __init__(self, config: dict):
|
| 20 |
+
self.config = config
|
| 21 |
+
logging.info("Initializing RAG Service...")
|
| 22 |
+
self.embedder = EmbeddingModel(config['models']['embedding'])
|
| 23 |
+
self.vector_store = HybridVectorStore(config, self.embedder)
|
| 24 |
+
self.llm = LLMInterface(config)
|
| 25 |
+
self.is_ready = False # 是否准备好进行查询
|
| 26 |
+
logging.info("RAG Service initialized. Knowledge base is not loaded.")
|
| 27 |
+
|
| 28 |
+
def load_knowledge_base(self) -> Tuple[bool, str]:
|
| 29 |
+
"""
|
| 30 |
+
尝试从磁盘加载
|
| 31 |
+
Returns:
|
| 32 |
+
A tuple (success: bool, message: str)
|
| 33 |
+
"""
|
| 34 |
+
if self.is_ready:
|
| 35 |
+
return True, "Knowledge base is already loaded."
|
| 36 |
+
|
| 37 |
+
logging.info("Attempting to load knowledge base from disk...")
|
| 38 |
+
success = self.vector_store.load()
|
| 39 |
+
if success:
|
| 40 |
+
self.is_ready = True
|
| 41 |
+
message = "Knowledge base loaded successfully from disk."
|
| 42 |
+
logging.info(message)
|
| 43 |
+
return True, message
|
| 44 |
+
else:
|
| 45 |
+
self.is_ready = False
|
| 46 |
+
message = "No existing knowledge base found or failed to load. Please build a new one."
|
| 47 |
+
logging.warning(message)
|
| 48 |
+
return False, message
|
| 49 |
+
|
| 50 |
+
def build_knowledge_base(self, file_paths: List[str]) -> Generator[str, None, None]:
|
| 51 |
+
self.is_ready = False
|
| 52 |
+
yield "Step 1/3: Loading documents..."
|
| 53 |
+
loader = MultiDocumentLoader(file_paths)
|
| 54 |
+
docs = loader.load()
|
| 55 |
+
|
| 56 |
+
yield "Step 2/3: Splitting documents into hierarchical chunks..."
|
| 57 |
+
splitter = HierarchicalSemanticSplitter(
|
| 58 |
+
parent_chunk_size=self.config['splitter']['parent_chunk_size'],
|
| 59 |
+
parent_chunk_overlap=self.config['splitter']['parent_chunk_overlap'],
|
| 60 |
+
child_chunk_size=self.config['splitter']['child_chunk_size']
|
| 61 |
+
)
|
| 62 |
+
parent_docs, child_chunks = splitter.split_documents(docs)
|
| 63 |
+
|
| 64 |
+
yield "Step 3/3: Building and saving vector index..."
|
| 65 |
+
self.vector_store.build(parent_docs, child_chunks)
|
| 66 |
+
self.is_ready = True
|
| 67 |
+
yield "Knowledge base built and ready!"
|
| 68 |
+
|
| 69 |
+
def _get_context_and_sources(self, query: str) -> List[Document]:
|
| 70 |
+
if not self.is_ready:
|
| 71 |
+
raise Exception("Knowledge base is not ready. Please build it first.")
|
| 72 |
+
|
| 73 |
+
# Hybrid Search to get child chunks
|
| 74 |
+
retrieved_child_indices_scores = self.vector_store.search(
|
| 75 |
+
query,
|
| 76 |
+
top_k=self.config['retrieval']['retrieval_top_k'],
|
| 77 |
+
alpha=self.config['retrieval']['hybrid_search_alpha']
|
| 78 |
+
)
|
| 79 |
+
retrieved_child_indices = [idx for idx, score in retrieved_child_indices_scores]
|
| 80 |
+
retrieved_child_chunks = self.vector_store.get_chunks(retrieved_child_indices)
|
| 81 |
+
|
| 82 |
+
# Get Parent Documents
|
| 83 |
+
retrieved_parent_docs = self.vector_store.get_parent_docs(retrieved_child_chunks)
|
| 84 |
+
|
| 85 |
+
# Rerank Parent Documents
|
| 86 |
+
reranked_docs = self.llm.rerank(query, retrieved_parent_docs)
|
| 87 |
+
final_context_docs = reranked_docs[:self.config['retrieval']['rerank_top_k']]
|
| 88 |
+
|
| 89 |
+
return final_context_docs
|
| 90 |
+
|
| 91 |
+
def get_response_full(self, query: str) ->(str, List[Document]):
|
| 92 |
+
final_context_docs = self._get_context_and_sources(query)
|
| 93 |
+
answer = self.llm.generate_answer(query, final_context_docs)
|
| 94 |
+
return answer, final_context_docs
|
| 95 |
+
|
| 96 |
+
def get_response_stream(self, query: str) ->(Generator[str, None, None], List[Document]):
|
| 97 |
+
final_context_docs = self._get_context_and_sources(query)
|
| 98 |
+
answer_generator = self.llm.generate_answer_stream(query, final_context_docs)
|
| 99 |
+
return answer_generator, final_context_docs
|
| 100 |
+
|
| 101 |
+
def get_context_string(self, context_docs: List[Document]) -> str:
|
| 102 |
+
context_str = "引用上下文 (Context Sources):\n\n"
|
| 103 |
+
for doc in context_docs:
|
| 104 |
+
source_info = f"--- (来源: {os.path.basename(doc.metadata.get('source', ''))}, 页码: {doc.metadata.get('page', 'N/A')}) ---\n"
|
| 105 |
+
content = doc.text[:200]+"..." if len(doc.text) > 200 else doc.text
|
| 106 |
+
context_str += source_info + content + "\n\n"
|
| 107 |
+
return context_str.strip()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
Mini_RAG/storage/faiss_index/chunks.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da7fae509a97a93f69b8048e375fca25dd45f0e26a40bff83c30759c5853e473
|
| 3 |
+
size 27663
|
Mini_RAG/storage/faiss_index/index.faiss
ADDED
|
Binary file (92.2 kB). View file
|
|
|
Mini_RAG/ui/__init__.py
ADDED
|
File without changes
|
Mini_RAG/ui/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (121 Bytes). View file
|
|
|
Mini_RAG/ui/__pycache__/app.cpython-39.pyc
ADDED
|
Binary file (6.13 kB). View file
|
|
|
Mini_RAG/ui/app.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
from service.rag_service import RAGService
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GradioApp:
|
| 9 |
+
def __init__(self, rag_service: RAGService):
|
| 10 |
+
self.rag_service = rag_service
|
| 11 |
+
self._build_ui()
|
| 12 |
+
|
| 13 |
+
def _build_ui(self):
|
| 14 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"),
|
| 15 |
+
title="Enterprise RAG System") as self.demo:
|
| 16 |
+
gr.Markdown("# 企业级RAG智能问答系统 (Enterprise RAG System)")
|
| 17 |
+
gr.Markdown("您可以**加载现有知识库**快速开始,或**上传新文档**构建一个全新的知识库。")
|
| 18 |
+
|
| 19 |
+
with gr.Row():
|
| 20 |
+
with gr.Column(scale=1):
|
| 21 |
+
gr.Markdown("### 控制面板 (Control Panel)")
|
| 22 |
+
|
| 23 |
+
self.load_kb_button = gr.Button("加载已有知识库 (Load Existing KB)")
|
| 24 |
+
|
| 25 |
+
gr.Markdown("<hr style='border: 1px solid #ddd; margin: 1rem 0;'>")
|
| 26 |
+
|
| 27 |
+
self.file_uploader = gr.File(
|
| 28 |
+
label="上传新文档以构建 (Upload New Docs to Build)",
|
| 29 |
+
file_count="multiple",
|
| 30 |
+
file_types=[".pdf", ".txt"],
|
| 31 |
+
interactive=True
|
| 32 |
+
)
|
| 33 |
+
self.build_kb_button = gr.Button("构建新知识库 (Build New KB)", variant="primary")
|
| 34 |
+
|
| 35 |
+
self.status_box = gr.Textbox(
|
| 36 |
+
label="系统状态 (System Status)",
|
| 37 |
+
value="系统已初始化,等待加载或构建知识库。",
|
| 38 |
+
interactive=False,
|
| 39 |
+
lines=4
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# --- 刚开始隐藏,构建了数据库再显示 ---
|
| 43 |
+
with gr.Column(scale=2, visible=False) as self.chat_area:
|
| 44 |
+
gr.Markdown("### 对话窗口 (Chat Window)")
|
| 45 |
+
self.chatbot = gr.Chatbot(label="RAG Chatbot", bubble_full_width=False, height=500)
|
| 46 |
+
self.mode_selector = gr.Radio(
|
| 47 |
+
["流式输出(Streaming)","一次性输出(Full)"],
|
| 48 |
+
label="输出模式:(Output Mode)",
|
| 49 |
+
value="流式输出(Streaming)"
|
| 50 |
+
)
|
| 51 |
+
self.question_box = gr.Textbox(label="您的问题", placeholder="请在此处输入您的问题...",
|
| 52 |
+
show_label=False)
|
| 53 |
+
with gr.Row():
|
| 54 |
+
self.submit_btn = gr.Button("提交 (Submit)", variant="primary")
|
| 55 |
+
self.clear_btn = gr.Button("清空历史 (Clear History)")
|
| 56 |
+
|
| 57 |
+
gr.Markdown("---")
|
| 58 |
+
self.source_display = gr.Markdown("### 引用来源 (Sources)")
|
| 59 |
+
|
| 60 |
+
# --- Event Listeners ---
|
| 61 |
+
self.load_kb_button.click(
|
| 62 |
+
fn=self._handle_load_kb,
|
| 63 |
+
inputs=None,
|
| 64 |
+
outputs=[self.status_box, self.chat_area]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.build_kb_button.click(
|
| 68 |
+
fn=self._handle_build_kb,
|
| 69 |
+
inputs=[self.file_uploader],
|
| 70 |
+
outputs=[self.status_box, self.chat_area]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
self.submit_btn.click(
|
| 74 |
+
fn=self._handle_chat_submission,
|
| 75 |
+
inputs=[self.question_box, self.chatbot, self.mode_selector],
|
| 76 |
+
outputs=[self.chatbot, self.question_box, self.source_display]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.question_box.submit(
|
| 80 |
+
fn=self._handle_chat_submission,
|
| 81 |
+
inputs=[self.question_box, self.chatbot, self.mode_selector],
|
| 82 |
+
outputs=[self.chatbot, self.question_box, self.source_display]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.clear_btn.click(
|
| 86 |
+
fn=self._clear_chat,
|
| 87 |
+
inputs=None,
|
| 88 |
+
outputs=[self.chatbot, self.question_box, self.source_display]
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def _handle_load_kb(self):
|
| 92 |
+
"""处理现有知识库的加载。返回更新字典。"""
|
| 93 |
+
success, message = self.rag_service.load_knowledge_base()
|
| 94 |
+
if success:
|
| 95 |
+
return {
|
| 96 |
+
self.status_box: gr.update(value=message),
|
| 97 |
+
self.chat_area: gr.update(visible=True)
|
| 98 |
+
}
|
| 99 |
+
else:
|
| 100 |
+
return {
|
| 101 |
+
self.status_box: gr.update(value=message),
|
| 102 |
+
self.chat_area: gr.update(visible=False)
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
def _handle_build_kb(self, files: List[str], progress=gr.Progress(track_tqdm=True)):
|
| 106 |
+
"""构建新知识库,返回更新的字典."""
|
| 107 |
+
if not files:
|
| 108 |
+
# --- MODIFIED LINE ---
|
| 109 |
+
return {
|
| 110 |
+
self.status_box: gr.update(value="错误:请至少上传一个文档。"),
|
| 111 |
+
self.chat_area: gr.update(visible=False)
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
file_paths = [file.name for file in files]
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
for status in self.rag_service.build_knowledge_base(file_paths):
|
| 118 |
+
progress(0.5, desc=status)
|
| 119 |
+
|
| 120 |
+
final_status = "知识库构建完成并已就绪!√"
|
| 121 |
+
# --- MODIFIED LINE ---
|
| 122 |
+
return {
|
| 123 |
+
self.status_box: gr.update(value=final_status),
|
| 124 |
+
self.chat_area: gr.update(visible=True)
|
| 125 |
+
}
|
| 126 |
+
except Exception as e:
|
| 127 |
+
error_message = f"构建失败: {e}"
|
| 128 |
+
# --- MODIFIED LINE ---
|
| 129 |
+
return {
|
| 130 |
+
self.status_box: gr.update(value=error_message),
|
| 131 |
+
self.chat_area: gr.update(visible=False)
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
def _handle_chat_submission(self, question: str, history: List[Tuple[str, str]], mode: str):
|
| 135 |
+
if not question or not question.strip():
|
| 136 |
+
yield history, "", "### 引用来源 (Sources)\n"
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
history.append((question, ""))
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
# 一次全部输出
|
| 143 |
+
if "Full" in mode:
|
| 144 |
+
yield history, "", "### 引用来源 (Sources)\n"
|
| 145 |
+
|
| 146 |
+
answer, sources = self.rag_service.get_response_full(question)
|
| 147 |
+
# 获取引用内容
|
| 148 |
+
context_string_for_display = self.rag_service.get_context_string(sources)
|
| 149 |
+
# 修改格式
|
| 150 |
+
source_text_for_panel = self._format_sources(sources)
|
| 151 |
+
#完整内容:引用+回答
|
| 152 |
+
full_response = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n{answer}"
|
| 153 |
+
history[-1] = (question, full_response)
|
| 154 |
+
|
| 155 |
+
yield history, "", source_text_for_panel
|
| 156 |
+
|
| 157 |
+
# 流式输出
|
| 158 |
+
else:
|
| 159 |
+
answer_generator, sources = self.rag_service.get_response_stream(question)
|
| 160 |
+
|
| 161 |
+
context_string_for_display = self.rag_service.get_context_string(sources)
|
| 162 |
+
|
| 163 |
+
source_text_for_panel = self._format_sources(sources)
|
| 164 |
+
|
| 165 |
+
yield history, "", source_text_for_panel
|
| 166 |
+
|
| 167 |
+
response_prefix = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n"
|
| 168 |
+
history[-1] = (question, response_prefix)
|
| 169 |
+
yield history, "", source_text_for_panel
|
| 170 |
+
|
| 171 |
+
answer_log = ""
|
| 172 |
+
for text_chunk in answer_generator:
|
| 173 |
+
answer_log += text_chunk
|
| 174 |
+
history[-1] = (question, response_prefix + answer_log)
|
| 175 |
+
yield history, "", source_text_for_panel
|
| 176 |
+
|
| 177 |
+
except Exception as e:
|
| 178 |
+
error_response = f"处理请求时出错: {e}"
|
| 179 |
+
history[-1] = (question, error_response)
|
| 180 |
+
yield history, "", "### 引用来源 (Sources)\n"
|
| 181 |
+
|
| 182 |
+
def _format_sources(self, sources: List) -> str:
|
| 183 |
+
source_text = "### 引用来源 (sources)\n)"
|
| 184 |
+
if not sources:
|
| 185 |
+
return source_text
|
| 186 |
+
|
| 187 |
+
unique_sources = set()
|
| 188 |
+
for doc in sources:
|
| 189 |
+
source_name = os.path.basename(doc.metadata.get('source', 'Unknown'))
|
| 190 |
+
page_num = doc.metadata.get('page', 'N/A')
|
| 191 |
+
unique_sources.add(f"- **{source_name}** (Page: {page_num})")
|
| 192 |
+
|
| 193 |
+
source_text += "\n".join(sorted(list(unique_sources)))
|
| 194 |
+
return source_text
|
| 195 |
+
|
| 196 |
+
def _clear_chat(self):
|
| 197 |
+
"""清理聊天内容"""
|
| 198 |
+
return None, "", "### 引用来源 (Sources)\n"
|
| 199 |
+
|
| 200 |
+
def launch(self):
|
| 201 |
+
self.demo.queue().launch()
|
| 202 |
+
|
Mini_RAG/utils/__init__.py
ADDED
|
File without changes
|
Mini_RAG/utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (124 Bytes). View file
|
|
|
Mini_RAG/utils/__pycache__/logger.cpython-39.pyc
ADDED
|
Binary file (489 Bytes). View file
|
|
|
Mini_RAG/utils/logger.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2025/4/25 19:51
|
| 4 |
+
# @Author : hukangzhe
|
| 5 |
+
# @File : logger.py.py
|
| 6 |
+
# @Description : 日志配置模块
|
| 7 |
+
import logging
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def setup_logger():
|
| 12 |
+
logging.basicConfig(
|
| 13 |
+
level=logging.INFO,
|
| 14 |
+
format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s",
|
| 15 |
+
handlers=[
|
| 16 |
+
logging.FileHandler("storage/logs/app.log"),
|
| 17 |
+
logging.StreamHandler(sys.stdout)
|
| 18 |
+
]
|
| 19 |
+
)
|
| 20 |
+
|