File size: 7,902 Bytes
c69a4d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2025/4/25 19:52
# @Author  : hukangzhe
# @File    : splitter.py
# @Description : 负责切分文本的模块

import logging
from typing import List, Dict, Tuple
from .schema import Document, Chunk


class SemanticRecursiveSplitter:
    def __init__(self, chunk_size: int=500, chunk_overlap:int = 50, separators: List[str] = None):
        """

        一个真正实现递归的语义文本切分器。

        :param chunk_size: 每个文本块的目标大小。

        :param chunk_overlap: 文本块之间的重叠大小。

        :param separators: 用于切分的语义分隔符列表,按优先级从高到低排列。

        """
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        if self.chunk_size <= self.chunk_overlap:
            raise ValueError("Chunk overlap must be smaller than chunk size.")

        self.separators = separators if separators else ['\n\n', '\n', " ", ""] # 默认分割符

    def text_split(self, text: str) -> List[str]:
        """

        切分入口

        :param text:

        :return:

        """
        logging.info("Starting semantic recursive splitting...")
        final_chunks = self._split(text, self.separators)
        logging.info(f"Text successfully split into {len(final_chunks)} chunks.")
        return final_chunks

    def _split(self, text: str, separators: List[str]) -> List[str]:
        final_chunks = []
        # 1. 如果文本足够小,直接返回
        if len(text) < self.chunk_size:
            return [text]
        # 2. 先尝试最高优先的分割符
        cur_separator = separators[0]

        # 3. 如果可以分割
        if cur_separator in text:
            # 分割成多个小部分
            parts = text.split(cur_separator)

            buffer=""   # 用来合并小部分
            for i, part in enumerate(parts):
                # 如果小于chunk_size,就再加一小部分,使buffer接近chunk_size
                if len(buffer) + len(part) + len(cur_separator) <= self.chunk_size:
                    buffer += part+cur_separator
                else:
                    # 如果buffer 不为空
                    if buffer:
                        final_chunks.append(buffer)
                    # 如果当前part就已经超过chunk_size
                    if len(part) > self.chunk_size:
                        # 递归调用下一级
                        sub_chunks = self._split(part, separators = separators[1:])
                        final_chunks.extend(sub_chunks)
                    else:   # 成为新的缓冲区
                        buffer = part + cur_separator

            if buffer:  # 最后一部分的缓冲区
                final_chunks.append(buffer.strip())

        else:
            # 4. 使用下一级分隔符
            final_chunks = self._split(text, separators[1:])

        # 处理重叠
        if self.chunk_overlap > 0:
            return self._handle_overlap(final_chunks)
        else:
            return final_chunks

    def _handle_overlap(self, final_chunks: List[str]) -> List[str]:
        overlap_chunks = []
        if not final_chunks:
            return []
        overlap_chunks.append(final_chunks[0])
        for i in range(1, len(final_chunks)):
            pre_chunk = overlap_chunks[-1]
            cur_chunk = final_chunks[i]
            # 从前一个chunk取出重叠部分与当前chunk合并
            overlap_part = pre_chunk[-self.chunk_overlap:]
            overlap_chunks.append(overlap_part+cur_chunk)

        return overlap_chunks


class HierarchicalSemanticSplitter:
    """

    结合了层次化(父/子)和递归语义分割策略。确保在创建父块和子块时,遵循文本的自然语义边界。

    """
    def __init__(self,

                 parent_chunk_size: int = 800,

                 parent_chunk_overlap: int = 100,

                 child_chunk_size: int = 250,

                 separators: List[str] = None):
        if parent_chunk_overlap >= parent_chunk_size:
            raise ValueError("Parent chunk overlap must be smaller than parent chunk size.")
        if child_chunk_size >= parent_chunk_size:
            raise ValueError("Child chunk size must be smaller than parent chunk size.")

        self.parent_chunk_size = parent_chunk_size
        self.parent_chunk_overlap = parent_chunk_overlap
        self.child_chunk_size = child_chunk_size
        self.separators = separators or ["\n\n", "\n", "。", ". ", "!", "!", "?", "?", " ", ""]

    def _recursive_semantic_split(self, text: str, chunk_size: int) -> List[str]:
        """

        优先考虑语义边界

        """
        if len(text) <= chunk_size:
            return [text]

        for sep in self.separators:
            split_point = text.rfind(sep, 0, chunk_size)
            if split_point != -1:
                break
        else:
            split_point = chunk_size

        chunk1 = text[:split_point]
        remaining_text = text[split_point:].lstrip()  # 删除剩余部分的前空格

        # 递归拆分剩余文本
        # 分隔符将添加回第一个块以保持上下文
        if remaining_text:
            return [chunk1 + (sep if sep in " \n" else "")] + self._recursive_semantic_split(remaining_text, chunk_size)
        else:
            return [chunk1]

    def _apply_overlap(self, chunks: List[str], overlap: int) -> List[str]:
        """处理重叠部分chunk"""
        if not overlap or len(chunks) <= 1:
            return chunks

        overlapped_chunks = [chunks[0]]
        for i in range(1, len(chunks)):
            # 从前一个chunk中获取最后的“重叠”字符
            overlap_content = chunks[i - 1][-overlap:]
            overlapped_chunks.append(overlap_content + chunks[i])

        return overlapped_chunks

    def split_documents(self, documents: List[Document]) -> Tuple[Dict[int, Document], List[Chunk]]:
        """

        两次切分

        :param documents:

        :return:

            -  parent documents: {parent_id: Document}

            -  child chunks: [Chunk, Chunk, ...]

        """
        parent_docs_dict: Dict[int, Document] = {}
        child_chunks_list: List[Chunk] = []
        parent_id_counter = 0

        logging.info("Starting robust hierarchical semantic splitting...")

        for doc in documents:
            # === PASS 1: 创建父chunks ===
            # 1. 将整个文档text分割成大的语义chunks
            initial_parent_chunks = self._recursive_semantic_split(doc.text, self.parent_chunk_size)

            # 2. 父chunks进行重叠处理
            overlapped_parent_texts = self._apply_overlap(initial_parent_chunks, self.parent_chunk_overlap)

            for p_text in overlapped_parent_texts:
                parent_doc = Document(text=p_text, metadata=doc.metadata.copy())
                parent_docs_dict[parent_id_counter] = parent_doc

                # === PASS 2: Create Child Chunks from each Parent ===
                child_texts = self._recursive_semantic_split(p_text, self.child_chunk_size)

                for c_text in child_texts:
                    child_metadata = doc.metadata.copy()
                    child_metadata['parent_id'] = parent_id_counter
                    child_chunk = Chunk(text=c_text, metadata=child_metadata, parent_id=parent_id_counter)
                    child_chunks_list.append(child_chunk)

                parent_id_counter += 1

        logging.info(
            f"Splitting complete. Created {len(parent_docs_dict)} parent chunks and {len(child_chunks_list)} child chunks.")
        return parent_docs_dict, child_chunks_list