Mini-RAG / Mini_RAG /core /splitter.py
TuNan52's picture
Upload 88 files
c69a4d6 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/4/25 19:52
# @Author : hukangzhe
# @File : splitter.py
# @Description : 负责切分文本的模块
import logging
from typing import List, Dict, Tuple
from .schema import Document, Chunk
class SemanticRecursiveSplitter:
def __init__(self, chunk_size: int=500, chunk_overlap:int = 50, separators: List[str] = None):
"""
一个真正实现递归的语义文本切分器。
:param chunk_size: 每个文本块的目标大小。
:param chunk_overlap: 文本块之间的重叠大小。
:param separators: 用于切分的语义分隔符列表,按优先级从高到低排列。
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
if self.chunk_size <= self.chunk_overlap:
raise ValueError("Chunk overlap must be smaller than chunk size.")
self.separators = separators if separators else ['\n\n', '\n', " ", ""] # 默认分割符
def text_split(self, text: str) -> List[str]:
"""
切分入口
:param text:
:return:
"""
logging.info("Starting semantic recursive splitting...")
final_chunks = self._split(text, self.separators)
logging.info(f"Text successfully split into {len(final_chunks)} chunks.")
return final_chunks
def _split(self, text: str, separators: List[str]) -> List[str]:
final_chunks = []
# 1. 如果文本足够小,直接返回
if len(text) < self.chunk_size:
return [text]
# 2. 先尝试最高优先的分割符
cur_separator = separators[0]
# 3. 如果可以分割
if cur_separator in text:
# 分割成多个小部分
parts = text.split(cur_separator)
buffer="" # 用来合并小部分
for i, part in enumerate(parts):
# 如果小于chunk_size,就再加一小部分,使buffer接近chunk_size
if len(buffer) + len(part) + len(cur_separator) <= self.chunk_size:
buffer += part+cur_separator
else:
# 如果buffer 不为空
if buffer:
final_chunks.append(buffer)
# 如果当前part就已经超过chunk_size
if len(part) > self.chunk_size:
# 递归调用下一级
sub_chunks = self._split(part, separators = separators[1:])
final_chunks.extend(sub_chunks)
else: # 成为新的缓冲区
buffer = part + cur_separator
if buffer: # 最后一部分的缓冲区
final_chunks.append(buffer.strip())
else:
# 4. 使用下一级分隔符
final_chunks = self._split(text, separators[1:])
# 处理重叠
if self.chunk_overlap > 0:
return self._handle_overlap(final_chunks)
else:
return final_chunks
def _handle_overlap(self, final_chunks: List[str]) -> List[str]:
overlap_chunks = []
if not final_chunks:
return []
overlap_chunks.append(final_chunks[0])
for i in range(1, len(final_chunks)):
pre_chunk = overlap_chunks[-1]
cur_chunk = final_chunks[i]
# 从前一个chunk取出重叠部分与当前chunk合并
overlap_part = pre_chunk[-self.chunk_overlap:]
overlap_chunks.append(overlap_part+cur_chunk)
return overlap_chunks
class HierarchicalSemanticSplitter:
"""
结合了层次化(父/子)和递归语义分割策略。确保在创建父块和子块时,遵循文本的自然语义边界。
"""
def __init__(self,
parent_chunk_size: int = 800,
parent_chunk_overlap: int = 100,
child_chunk_size: int = 250,
separators: List[str] = None):
if parent_chunk_overlap >= parent_chunk_size:
raise ValueError("Parent chunk overlap must be smaller than parent chunk size.")
if child_chunk_size >= parent_chunk_size:
raise ValueError("Child chunk size must be smaller than parent chunk size.")
self.parent_chunk_size = parent_chunk_size
self.parent_chunk_overlap = parent_chunk_overlap
self.child_chunk_size = child_chunk_size
self.separators = separators or ["\n\n", "\n", "。", ". ", "!", "!", "?", "?", " ", ""]
def _recursive_semantic_split(self, text: str, chunk_size: int) -> List[str]:
"""
优先考虑语义边界
"""
if len(text) <= chunk_size:
return [text]
for sep in self.separators:
split_point = text.rfind(sep, 0, chunk_size)
if split_point != -1:
break
else:
split_point = chunk_size
chunk1 = text[:split_point]
remaining_text = text[split_point:].lstrip() # 删除剩余部分的前空格
# 递归拆分剩余文本
# 分隔符将添加回第一个块以保持上下文
if remaining_text:
return [chunk1 + (sep if sep in " \n" else "")] + self._recursive_semantic_split(remaining_text, chunk_size)
else:
return [chunk1]
def _apply_overlap(self, chunks: List[str], overlap: int) -> List[str]:
"""处理重叠部分chunk"""
if not overlap or len(chunks) <= 1:
return chunks
overlapped_chunks = [chunks[0]]
for i in range(1, len(chunks)):
# 从前一个chunk中获取最后的“重叠”字符
overlap_content = chunks[i - 1][-overlap:]
overlapped_chunks.append(overlap_content + chunks[i])
return overlapped_chunks
def split_documents(self, documents: List[Document]) -> Tuple[Dict[int, Document], List[Chunk]]:
"""
两次切分
:param documents:
:return:
- parent documents: {parent_id: Document}
- child chunks: [Chunk, Chunk, ...]
"""
parent_docs_dict: Dict[int, Document] = {}
child_chunks_list: List[Chunk] = []
parent_id_counter = 0
logging.info("Starting robust hierarchical semantic splitting...")
for doc in documents:
# === PASS 1: 创建父chunks ===
# 1. 将整个文档text分割成大的语义chunks
initial_parent_chunks = self._recursive_semantic_split(doc.text, self.parent_chunk_size)
# 2. 父chunks进行重叠处理
overlapped_parent_texts = self._apply_overlap(initial_parent_chunks, self.parent_chunk_overlap)
for p_text in overlapped_parent_texts:
parent_doc = Document(text=p_text, metadata=doc.metadata.copy())
parent_docs_dict[parent_id_counter] = parent_doc
# === PASS 2: Create Child Chunks from each Parent ===
child_texts = self._recursive_semantic_split(p_text, self.child_chunk_size)
for c_text in child_texts:
child_metadata = doc.metadata.copy()
child_metadata['parent_id'] = parent_id_counter
child_chunk = Chunk(text=c_text, metadata=child_metadata, parent_id=parent_id_counter)
child_chunks_list.append(child_chunk)
parent_id_counter += 1
logging.info(
f"Splitting complete. Created {len(parent_docs_dict)} parent chunks and {len(child_chunks_list)} child chunks.")
return parent_docs_dict, child_chunks_list