|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
from typing import List, Dict, Tuple
|
|
|
from .schema import Document, Chunk
|
|
|
|
|
|
|
|
|
class SemanticRecursiveSplitter:
|
|
|
def __init__(self, chunk_size: int=500, chunk_overlap:int = 50, separators: List[str] = None):
|
|
|
"""
|
|
|
一个真正实现递归的语义文本切分器。
|
|
|
:param chunk_size: 每个文本块的目标大小。
|
|
|
:param chunk_overlap: 文本块之间的重叠大小。
|
|
|
:param separators: 用于切分的语义分隔符列表,按优先级从高到低排列。
|
|
|
"""
|
|
|
self.chunk_size = chunk_size
|
|
|
self.chunk_overlap = chunk_overlap
|
|
|
if self.chunk_size <= self.chunk_overlap:
|
|
|
raise ValueError("Chunk overlap must be smaller than chunk size.")
|
|
|
|
|
|
self.separators = separators if separators else ['\n\n', '\n', " ", ""]
|
|
|
|
|
|
def text_split(self, text: str) -> List[str]:
|
|
|
"""
|
|
|
切分入口
|
|
|
:param text:
|
|
|
:return:
|
|
|
"""
|
|
|
logging.info("Starting semantic recursive splitting...")
|
|
|
final_chunks = self._split(text, self.separators)
|
|
|
logging.info(f"Text successfully split into {len(final_chunks)} chunks.")
|
|
|
return final_chunks
|
|
|
|
|
|
def _split(self, text: str, separators: List[str]) -> List[str]:
|
|
|
final_chunks = []
|
|
|
|
|
|
if len(text) < self.chunk_size:
|
|
|
return [text]
|
|
|
|
|
|
cur_separator = separators[0]
|
|
|
|
|
|
|
|
|
if cur_separator in text:
|
|
|
|
|
|
parts = text.split(cur_separator)
|
|
|
|
|
|
buffer=""
|
|
|
for i, part in enumerate(parts):
|
|
|
|
|
|
if len(buffer) + len(part) + len(cur_separator) <= self.chunk_size:
|
|
|
buffer += part+cur_separator
|
|
|
else:
|
|
|
|
|
|
if buffer:
|
|
|
final_chunks.append(buffer)
|
|
|
|
|
|
if len(part) > self.chunk_size:
|
|
|
|
|
|
sub_chunks = self._split(part, separators = separators[1:])
|
|
|
final_chunks.extend(sub_chunks)
|
|
|
else:
|
|
|
buffer = part + cur_separator
|
|
|
|
|
|
if buffer:
|
|
|
final_chunks.append(buffer.strip())
|
|
|
|
|
|
else:
|
|
|
|
|
|
final_chunks = self._split(text, separators[1:])
|
|
|
|
|
|
|
|
|
if self.chunk_overlap > 0:
|
|
|
return self._handle_overlap(final_chunks)
|
|
|
else:
|
|
|
return final_chunks
|
|
|
|
|
|
def _handle_overlap(self, final_chunks: List[str]) -> List[str]:
|
|
|
overlap_chunks = []
|
|
|
if not final_chunks:
|
|
|
return []
|
|
|
overlap_chunks.append(final_chunks[0])
|
|
|
for i in range(1, len(final_chunks)):
|
|
|
pre_chunk = overlap_chunks[-1]
|
|
|
cur_chunk = final_chunks[i]
|
|
|
|
|
|
overlap_part = pre_chunk[-self.chunk_overlap:]
|
|
|
overlap_chunks.append(overlap_part+cur_chunk)
|
|
|
|
|
|
return overlap_chunks
|
|
|
|
|
|
|
|
|
class HierarchicalSemanticSplitter:
|
|
|
"""
|
|
|
结合了层次化(父/子)和递归语义分割策略。确保在创建父块和子块时,遵循文本的自然语义边界。
|
|
|
"""
|
|
|
def __init__(self,
|
|
|
parent_chunk_size: int = 800,
|
|
|
parent_chunk_overlap: int = 100,
|
|
|
child_chunk_size: int = 250,
|
|
|
separators: List[str] = None):
|
|
|
if parent_chunk_overlap >= parent_chunk_size:
|
|
|
raise ValueError("Parent chunk overlap must be smaller than parent chunk size.")
|
|
|
if child_chunk_size >= parent_chunk_size:
|
|
|
raise ValueError("Child chunk size must be smaller than parent chunk size.")
|
|
|
|
|
|
self.parent_chunk_size = parent_chunk_size
|
|
|
self.parent_chunk_overlap = parent_chunk_overlap
|
|
|
self.child_chunk_size = child_chunk_size
|
|
|
self.separators = separators or ["\n\n", "\n", "。", ". ", "!", "!", "?", "?", " ", ""]
|
|
|
|
|
|
def _recursive_semantic_split(self, text: str, chunk_size: int) -> List[str]:
|
|
|
"""
|
|
|
优先考虑语义边界
|
|
|
"""
|
|
|
if len(text) <= chunk_size:
|
|
|
return [text]
|
|
|
|
|
|
for sep in self.separators:
|
|
|
split_point = text.rfind(sep, 0, chunk_size)
|
|
|
if split_point != -1:
|
|
|
break
|
|
|
else:
|
|
|
split_point = chunk_size
|
|
|
|
|
|
chunk1 = text[:split_point]
|
|
|
remaining_text = text[split_point:].lstrip()
|
|
|
|
|
|
|
|
|
|
|
|
if remaining_text:
|
|
|
return [chunk1 + (sep if sep in " \n" else "")] + self._recursive_semantic_split(remaining_text, chunk_size)
|
|
|
else:
|
|
|
return [chunk1]
|
|
|
|
|
|
def _apply_overlap(self, chunks: List[str], overlap: int) -> List[str]:
|
|
|
"""处理重叠部分chunk"""
|
|
|
if not overlap or len(chunks) <= 1:
|
|
|
return chunks
|
|
|
|
|
|
overlapped_chunks = [chunks[0]]
|
|
|
for i in range(1, len(chunks)):
|
|
|
|
|
|
overlap_content = chunks[i - 1][-overlap:]
|
|
|
overlapped_chunks.append(overlap_content + chunks[i])
|
|
|
|
|
|
return overlapped_chunks
|
|
|
|
|
|
def split_documents(self, documents: List[Document]) -> Tuple[Dict[int, Document], List[Chunk]]:
|
|
|
"""
|
|
|
两次切分
|
|
|
:param documents:
|
|
|
:return:
|
|
|
- parent documents: {parent_id: Document}
|
|
|
- child chunks: [Chunk, Chunk, ...]
|
|
|
"""
|
|
|
parent_docs_dict: Dict[int, Document] = {}
|
|
|
child_chunks_list: List[Chunk] = []
|
|
|
parent_id_counter = 0
|
|
|
|
|
|
logging.info("Starting robust hierarchical semantic splitting...")
|
|
|
|
|
|
for doc in documents:
|
|
|
|
|
|
|
|
|
initial_parent_chunks = self._recursive_semantic_split(doc.text, self.parent_chunk_size)
|
|
|
|
|
|
|
|
|
overlapped_parent_texts = self._apply_overlap(initial_parent_chunks, self.parent_chunk_overlap)
|
|
|
|
|
|
for p_text in overlapped_parent_texts:
|
|
|
parent_doc = Document(text=p_text, metadata=doc.metadata.copy())
|
|
|
parent_docs_dict[parent_id_counter] = parent_doc
|
|
|
|
|
|
|
|
|
child_texts = self._recursive_semantic_split(p_text, self.child_chunk_size)
|
|
|
|
|
|
for c_text in child_texts:
|
|
|
child_metadata = doc.metadata.copy()
|
|
|
child_metadata['parent_id'] = parent_id_counter
|
|
|
child_chunk = Chunk(text=c_text, metadata=child_metadata, parent_id=parent_id_counter)
|
|
|
child_chunks_list.append(child_chunk)
|
|
|
|
|
|
parent_id_counter += 1
|
|
|
|
|
|
logging.info(
|
|
|
f"Splitting complete. Created {len(parent_docs_dict)} parent chunks and {len(child_chunks_list)} child chunks.")
|
|
|
return parent_docs_dict, child_chunks_list
|
|
|
|