File size: 5,572 Bytes
c69a4d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2025/4/24 14:04
# @Author  : hukangzhe
# @File    : rag_core.py
# @Description :非常简单的RAG系统

import PyPDF2
import fitz
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import numpy as np
import faiss
import torch


class RAGSystem:
    def __init__(self, pdf_path):
        self.pdf_path = pdf_path
        self.texts = self._load_and_spilt_pdf()
        self.embedder = SentenceTransformer('moka-ai/m3e-base')
        self.reranker = CrossEncoder('BAAI/bge-reranker-base')  # 加载一个reranker模型
        self.vector_store = self._create_vector_store()
        print("3. Initializing Generator Model...")
        model_name = "Qwen/Qwen1.5-1.8B-Chat"

        # 检查是否有可用的GPU
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"   - Using device: {device}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # 注意:对于像Qwen这样的模型,我们通常使用 AutoModelForCausalLM
        model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

        self.generator = pipeline(
            'text-generation',
            model=model,
            tokenizer=self.tokenizer,
            device=device
        )

    # 1. 文档加载 & 2.文本切分 (为了简化,合在一起)
    def _load_and_spilt_pdf(self):
        print("1. Loading and splitting PDF...")
        full_text = ""
        with fitz.open(self.pdf_path) as doc:
            for page in doc:
                full_text += page.get_text()

        # 非常基础的切分:根据固定大小
        chunk_size = 500
        overlap = 50
        chunks = [full_text[i: i+chunk_size] for i in range(0, len(full_text), chunk_size-overlap)]
        print(f"   - Splitted into {len(chunks)} chunks.")
        return chunks

    # 3. 文本向量化 & 向量存储
    def _create_vector_store(self):
        print("2. Creating vector store...")
        # embedding
        embeddings = self.embedder.encode(self.texts)

        # Storing with faiss
        dim = embeddings.shape[1]
        index = faiss.IndexFlatL2(dim)  # 使用L2距离进行相似度计算
        index.add(np.array(embeddings))
        print(" - Created vector store")
        return index

    # 4.检索
    def retrieve(self, query, k=3):
        print(f"3. Retrieving top {k} relevant chunks for query: '{query}' ")
        query_embeddings = self.embedder.encode([query])

        distances, indices = self.vector_store.search(np.array(query_embeddings), k=k)
        retrieved_chunks = [self.texts[i] for i in indices[0]]
        print("   - Retrieval complete.")
        return retrieved_chunks

    # 5.生成
    def generate(self, query, context_chunks):
        print("4. Generate answer...")
        context = "\n".join(context_chunks)

        messages = [
            {"role": "system", "content": "你是一个问答助手,请根据提供的上下文来回答问题,不要编造信息。"},
            {"role": "user", "content": f"上下文:\n---\n{context}\n---\n请根据以上上下文回答这个问题:{query}"}
        ]
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        # print("Final Prompt:\n", prompt)
        # print("Prompt token length:", len(self.tokenizer.encode(prompt)))

        result = self.generator(prompt, max_new_tokens=200, num_return_sequences=1,
                                eos_token_id=self.tokenizer.eos_token_id)

        print("   - Generation complete.")
        # print("Raw results:", result)
        # 提取生成的文本
        # 注意:Qwen模型返回的文本包含了prompt,我们需要从中提取出答案部分
        full_response = result[0]["generated_text"]
        answer = full_response[len(prompt):].strip()  # 从prompt之后开始截取

        # print("Final Answer:", repr(answer))
        return answer

    # 优化1
    def rerank(self, query, chunks):
        print("   - Reranking retrieved chunks...")
        pairs = [[query, chunk] for chunk in chunks]
        scores = self.reranker.predict(pairs)

        # 将chunks和scores打包,并按score降序排序
        ranked_chunks = sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
        return [chunk for chunk, score in ranked_chunks]

    def query(self, query_text):
        # 1. 检索(可以检索更多结果,如top 10)
        retrieved_chunks = self.retrieve(query_text, k=10)

        # 2. 重排(从10个中选出最相关的3个)
        reranked_chunks = self.rerank(query_text, retrieved_chunks)
        top_k_reranked = reranked_chunks[:3]

        answer = self.generate(query_text, top_k_reranked)
        return answer


def main():
    # 确保你的data文件夹里有一个叫做sample.pdf的文件
    pdf_path = 'data/chinese_document.pdf'

    print("Initializing RAG System...")
    rag_system = RAGSystem(pdf_path)
    print("\nRAG System is ready. You can start asking questions.")
    print("Type 'q' to quit.")

    while True:
        user_query = input("\nYour Question: ")
        if user_query.lower() == 'q':
            break

        answer = rag_system.query(user_query)
        print("\nAnswer:", answer)


if __name__ == "__main__":
    main()