File size: 8,956 Bytes
c69a4d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2025/4/29 19:54
# @Author  : hukangzhe
# @File    : generator.py
# @Description : 负责生成答案模块
import os
import queue
import logging
import threading

import torch
from typing import Dict, List, Tuple, Generator
from sentence_transformers import CrossEncoder
from .schema import Document, Chunk
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer

class ThinkStreamer(TextStreamer):
    def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool =True, **decode_kwargs):
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        self.is_thinking = True
        self.think_end_token_id = self.tokenizer.encode("</think>", add_special_tokens=False)[0]
        self.output_queue = queue.Queue()

    def on_finalized_text(self, text: str, stream_end: bool = False):
        self.output_queue.put(text)
        if stream_end:
            self.output_queue.put(None) # 发送结束信号

    def __iter__(self):
        return self

    def __next__(self):
        value = self.output_queue.get()
        if value is None:
            raise StopIteration()
        return value

    def generate_output(self) -> Generator[Tuple[str, str], None, None]:
        """

        分离Think和回答

        :return:

        """
        full_decode_text = ""
        already_yielded_len = 0
        for text_chunk in self:
            if not self.is_thinking:
                yield "answer", text_chunk
                continue

            full_decode_text += text_chunk
            tokens = self.tokenizer.encode(full_decode_text, add_special_tokens=False)

            if self.think_end_token_id in tokens:
                spilt_point = tokens.index(self.think_end_token_id)
                think_part_tokens = tokens[:spilt_point]
                thinking_text = self.tokenizer.decode(think_part_tokens)

                answer_part_tokens = tokens[spilt_point:]
                answer_text = self.tokenizer.decode(answer_part_tokens)
                remaining_thinking = thinking_text[already_yielded_len:]
                if remaining_thinking:
                    yield "thinking", remaining_thinking

                if answer_text:
                    yield "answer", answer_text

                self.is_thinking = False
                already_yielded_len = len(thinking_text) + len(self.tokenizer.decode(self.think_end_token_id))
            else:
                yield "thinking", text_chunk
                already_yielded_len += len(text_chunk)


class QueueTextStreamer(TextStreamer):
    def __init__(self, tokenizer: AutoTokenizer, skip_prompt: bool = True, **decode_kwargs):
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        self.output_queue = queue.Queue()

    def on_finalized_text(self, text: str, stream_end: bool = False):
        """Puts text into the queue; sends None as a sentinel value to signal the end."""
        self.output_queue.put(text)
        if stream_end:
            self.output_queue.put(None)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.output_queue.get()
        if value is None:
            raise StopIteration()
        return value


class LLMInterface:
    def __init__(self, config: dict):
        self.config = config
        self.reranker = CrossEncoder(config['models']['reranker'])
        self.generator_new_tokens = config['generation']['max_new_tokens']
        self.device =torch.device("cuda" if torch.cuda.is_available() else "cpu")

        generator_name = config['models']['llm_generator']
        logging.info(f"Initializing generator {generator_name}")
        self.generator_tokenizer = AutoTokenizer.from_pretrained(generator_name)
        self.generator_model = AutoModelForCausalLM.from_pretrained(
            generator_name,
            torch_dtype="auto",
            device_map="auto")

    def rerank(self, query: str, docs: List[Document]) -> List[Document]:
        pairs = [[query, doc.text] for doc in docs]
        scores = self.reranker.predict(pairs)
        ranked_docs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
        return [doc for doc, score in ranked_docs]

    def _threaded_generate(self, streamer: QueueTextStreamer, generation_kwargs: dict):
        """

        一个包装函数,将 model.generate 放入 try...finally 块中。

        """
        try:
            self.generator_model.generate(**generation_kwargs)
        finally:
            # 无论成功还是失败,都确保在最后发送结束信号
            streamer.output_queue.put(None)

    def generate_answer(self, query: str, context_docs: List[Document]) -> str:
        context_str = ""
        for doc in context_docs:
            context_str += f"Source: {os.path.basename(doc.metadata.get('source', ''))}, Page: {doc.metadata.get('page', 'N/A')}\n"
            context_str += f"Content: {doc.text}\n\n"
        # content设置为英文,回答则为英文
        messages = [
            {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
            {"role": "user", "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题:{query}"}
        ]
        prompt = self.generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(self.device)
        output = self.generator_model.generate(**inputs,
                                               max_new_tokens=self.generator_new_tokens, num_return_sequences=1,
                                eos_token_id=self.generator_tokenizer.eos_token_id)
        generated_ids = output[0][inputs["input_ids"].shape[1]:]
        answer = self.generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
        return answer

    def generate_answer_stream(self, query: str, context_docs: List[Document]) -> Generator[str, None, None]:
        context_str = ""
        for doc in context_docs:
            context_str += f"Content: {doc.text}\n\n"

        messages = [
            {"role": "system",
             "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
            {"role": "user",
             "content": f"上下文:\n---\n{context_str}\n---\n请根据以上上下文回答这个问题: {query}"}
        ]

        prompt = self.generator_tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device)

        streamer = QueueTextStreamer(self.generator_tokenizer, skip_prompt=True)

        generation_kwargs = dict(
            **model_inputs,
            max_new_tokens=self.generator_new_tokens,
            streamer=streamer,
            pad_token_id=self.generator_tokenizer.eos_token_id,
        )

        thread = threading.Thread(target=self._threaded_generate, args=(streamer,generation_kwargs,))
        thread.start()
        for new_text in streamer:
            if new_text is not None:
                yield new_text

    def generate_answer_stream_split(self, query: str, context_docs: List[Document]) -> Generator[Tuple[str, str], None, None]:
        """分离思考和回答的流式输出"""
        context_str = ""
        for doc in context_docs:
            context_str += f"Content: {doc.text}\n\n"

        messages = [
            {"role": "system",
             "content": "You are a helpful assistant. Please answer the question based on the provided context. First, think through the process in <think> tags, then provide the final answer."},
            {"role": "user",
             "content": f"Context:\n---\n{context_str}\n---\nBased on the context above, please answer the question: {query}"}
        ]

        prompt = self.generator_tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True
        )
        model_inputs = self.generator_tokenizer([prompt], return_tensors="pt").to(self.device)

        streamer = ThinkStreamer(self.generator_tokenizer, skip_prompt=True)

        generation_kwargs = dict(
            **model_inputs,
            max_new_tokens=self.generator_new_tokens,
            streamer=streamer
        )

        thread = threading.Thread(target=self.generator_model.generate, kwargs=generation_kwargs)
        thread.start()

        yield from streamer.generate_output()