Spaces:
Sleeping
Sleeping
Commit
·
d8c6d94
1
Parent(s):
92d8c87
update cache
Browse files- src/demo/asg_retriever.py +25 -5
- src/demo/category_and_tsne.py +17 -4
- src/demo/main.py +9 -1
- src/demo/path_utils.py +15 -0
- src/demo/survey_generation_pipeline/asg_retriever.py +72 -42
- src/demo/survey_generation_pipeline/category_and_tsne.py +25 -7
- src/demo/survey_generation_pipeline/main.py +17 -9
- src/demo/survey_generator_api.py +75 -124
- src/demo/views.py +13 -3
src/demo/asg_retriever.py
CHANGED
|
@@ -8,7 +8,11 @@ from .asg_splitter import TextSplitting
|
|
| 8 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 9 |
import time
|
| 10 |
import concurrent.futures
|
| 11 |
-
from .path_utils import get_path
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class Retriever:
|
| 14 |
client = None
|
|
@@ -201,7 +205,11 @@ def process_pdf(file_path: str, survey_id: str, embedder: HuggingFaceEmbeddings,
|
|
| 201 |
return collection_name, embeddings_list, documents_list, metadata_list,title_new
|
| 202 |
|
| 203 |
def query_embeddings(collection_name: str, query_list: list):
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
retriever = Retriever()
|
| 206 |
|
| 207 |
final_context = ""
|
|
@@ -222,7 +230,11 @@ def query_embeddings(collection_name: str, query_list: list):
|
|
| 222 |
|
| 223 |
# new, may be in parallel
|
| 224 |
def query_embeddings_new(collection_name: str, query_list: list):
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
retriever = Retriever()
|
| 227 |
|
| 228 |
final_context = ""
|
|
@@ -250,7 +262,11 @@ def query_embeddings_new(collection_name: str, query_list: list):
|
|
| 250 |
|
| 251 |
# wza
|
| 252 |
def query_embeddings_new_new(collection_name: str, query_list: list):
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
retriever = Retriever()
|
| 255 |
|
| 256 |
final_context = "" # Stores concatenated context
|
|
@@ -313,7 +329,11 @@ def query_multiple_collections(collection_names: list[str], query_list: list[str
|
|
| 313 |
dict: Combined results from all collections, grouped by collection.
|
| 314 |
"""
|
| 315 |
# Define embedder inside the function
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
retriever = Retriever()
|
| 318 |
|
| 319 |
def query_single_collection(collection_name: str):
|
|
|
|
| 8 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 9 |
import time
|
| 10 |
import concurrent.futures
|
| 11 |
+
from .path_utils import get_path, setup_hf_cache
|
| 12 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 13 |
+
|
| 14 |
+
# 设置 Hugging Face 缓存目录
|
| 15 |
+
cache_dir = setup_hf_cache()
|
| 16 |
|
| 17 |
class Retriever:
|
| 18 |
client = None
|
|
|
|
| 205 |
return collection_name, embeddings_list, documents_list, metadata_list,title_new
|
| 206 |
|
| 207 |
def query_embeddings(collection_name: str, query_list: list):
|
| 208 |
+
try:
|
| 209 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"Error initializing embedder: {e}")
|
| 212 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 213 |
retriever = Retriever()
|
| 214 |
|
| 215 |
final_context = ""
|
|
|
|
| 230 |
|
| 231 |
# new, may be in parallel
|
| 232 |
def query_embeddings_new(collection_name: str, query_list: list):
|
| 233 |
+
try:
|
| 234 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f"Error initializing embedder: {e}")
|
| 237 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 238 |
retriever = Retriever()
|
| 239 |
|
| 240 |
final_context = ""
|
|
|
|
| 262 |
|
| 263 |
# wza
|
| 264 |
def query_embeddings_new_new(collection_name: str, query_list: list):
|
| 265 |
+
try:
|
| 266 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 267 |
+
except Exception as e:
|
| 268 |
+
print(f"Error initializing embedder: {e}")
|
| 269 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 270 |
retriever = Retriever()
|
| 271 |
|
| 272 |
final_context = "" # Stores concatenated context
|
|
|
|
| 329 |
dict: Combined results from all collections, grouped by collection.
|
| 330 |
"""
|
| 331 |
# Define embedder inside the function
|
| 332 |
+
try:
|
| 333 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 334 |
+
except Exception as e:
|
| 335 |
+
print(f"Error initializing embedder: {e}")
|
| 336 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 337 |
retriever = Retriever()
|
| 338 |
|
| 339 |
def query_single_collection(collection_name: str):
|
src/demo/category_and_tsne.py
CHANGED
|
@@ -7,6 +7,8 @@ import seaborn as sns
|
|
| 7 |
import json
|
| 8 |
from sklearn.manifold import TSNE
|
| 9 |
from sklearn.cluster import AgglomerativeClustering
|
|
|
|
|
|
|
| 10 |
|
| 11 |
from sentence_transformers import SentenceTransformer
|
| 12 |
from bertopic import BERTopic
|
|
@@ -14,7 +16,7 @@ from bertopic.representation import KeyBERTInspired
|
|
| 14 |
from sklearn.feature_extraction.text import CountVectorizer
|
| 15 |
from bertopic.vectorizers import ClassTfidfTransformer
|
| 16 |
from umap import UMAP
|
| 17 |
-
from .path_utils import get_path
|
| 18 |
|
| 19 |
plt.switch_backend('agg')
|
| 20 |
device = 0
|
|
@@ -35,6 +37,9 @@ import matplotlib.pyplot as plt
|
|
| 35 |
from sklearn.manifold import TSNE
|
| 36 |
import seaborn as sns
|
| 37 |
|
|
|
|
|
|
|
|
|
|
| 38 |
class DimensionalityReduction:
|
| 39 |
def fit(self, X):
|
| 40 |
return self
|
|
@@ -44,7 +49,11 @@ class DimensionalityReduction:
|
|
| 44 |
|
| 45 |
class ClusteringWithTopic:
|
| 46 |
def __init__(self, df, n_topics=3):
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# umap_model = DimensionalityReduction()
|
| 49 |
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', init = 'pca')
|
| 50 |
hdbscan_model = AgglomerativeClustering(n_clusters=n_topics)
|
|
@@ -81,7 +90,11 @@ class ClusteringWithTopic:
|
|
| 81 |
初始化 ClusteringWithTopic,接受一个 n_topics_list,其中包含多个聚类数目,
|
| 82 |
选取 silhouette_score 最高的结果。
|
| 83 |
"""
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
self.embeddings = embedding_model.encode(df, show_progress_bar=True)
|
| 86 |
|
| 87 |
self.df = df
|
|
@@ -97,7 +110,7 @@ class ClusteringWithTopic:
|
|
| 97 |
# 用于存储不同聚类数目的结果
|
| 98 |
self.best_n_topics = None
|
| 99 |
self.best_labels = None
|
| 100 |
-
self.best_score = -1
|
| 101 |
# def fit_and_get_labels(self, X):
|
| 102 |
# topics, probs = self.topic_model.fit_transform(self.df, self.embeddings)
|
| 103 |
# return topics
|
|
|
|
| 7 |
import json
|
| 8 |
from sklearn.manifold import TSNE
|
| 9 |
from sklearn.cluster import AgglomerativeClustering
|
| 10 |
+
import os
|
| 11 |
+
import tempfile
|
| 12 |
|
| 13 |
from sentence_transformers import SentenceTransformer
|
| 14 |
from bertopic import BERTopic
|
|
|
|
| 16 |
from sklearn.feature_extraction.text import CountVectorizer
|
| 17 |
from bertopic.vectorizers import ClassTfidfTransformer
|
| 18 |
from umap import UMAP
|
| 19 |
+
from .path_utils import get_path, setup_hf_cache
|
| 20 |
|
| 21 |
plt.switch_backend('agg')
|
| 22 |
device = 0
|
|
|
|
| 37 |
from sklearn.manifold import TSNE
|
| 38 |
import seaborn as sns
|
| 39 |
|
| 40 |
+
# 设置 Hugging Face 缓存目录
|
| 41 |
+
cache_dir = setup_hf_cache()
|
| 42 |
+
|
| 43 |
class DimensionalityReduction:
|
| 44 |
def fit(self, X):
|
| 45 |
return self
|
|
|
|
| 49 |
|
| 50 |
class ClusteringWithTopic:
|
| 51 |
def __init__(self, df, n_topics=3):
|
| 52 |
+
try:
|
| 53 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True, cache_folder=cache_dir)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f"Error initializing SentenceTransformer: {e}")
|
| 56 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True)
|
| 57 |
# umap_model = DimensionalityReduction()
|
| 58 |
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', init = 'pca')
|
| 59 |
hdbscan_model = AgglomerativeClustering(n_clusters=n_topics)
|
|
|
|
| 90 |
初始化 ClusteringWithTopic,接受一个 n_topics_list,其中包含多个聚类数目,
|
| 91 |
选取 silhouette_score 最高的结果。
|
| 92 |
"""
|
| 93 |
+
try:
|
| 94 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True, cache_folder=cache_dir)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"Error initializing SentenceTransformer: {e}")
|
| 97 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True)
|
| 98 |
self.embeddings = embedding_model.encode(df, show_progress_bar=True)
|
| 99 |
|
| 100 |
self.df = df
|
|
|
|
| 110 |
# 用于存储不同聚类数目的结果
|
| 111 |
self.best_n_topics = None
|
| 112 |
self.best_labels = None
|
| 113 |
+
self.best_score = -1
|
| 114 |
# def fit_and_get_labels(self, X):
|
| 115 |
# topics, probs = self.topic_model.fit_transform(self.df, self.embeddings)
|
| 116 |
# return topics
|
src/demo/main.py
CHANGED
|
@@ -20,6 +20,10 @@ from asg_outline import OutlineGenerator, generateSurvey_qwen_new
|
|
| 20 |
import os
|
| 21 |
from markdown_pdf import MarkdownPdf, Section # Assuming you are using markdown_pdf
|
| 22 |
from typing import Any
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def clean_str(input_str):
|
| 25 |
input_str = str(input_str).strip().lower()
|
|
@@ -135,7 +139,11 @@ class ASG_system:
|
|
| 135 |
|
| 136 |
|
| 137 |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
self.pipeline = transformers.pipeline(
|
| 140 |
"text-generation",
|
| 141 |
model=model_id,
|
|
|
|
| 20 |
import os
|
| 21 |
from markdown_pdf import MarkdownPdf, Section # Assuming you are using markdown_pdf
|
| 22 |
from typing import Any
|
| 23 |
+
from .path_utils import get_path, setup_hf_cache
|
| 24 |
+
|
| 25 |
+
# 设置 Hugging Face 缓存目录
|
| 26 |
+
cache_dir = setup_hf_cache()
|
| 27 |
|
| 28 |
def clean_str(input_str):
|
| 29 |
input_str = str(input_str).strip().lower()
|
|
|
|
| 139 |
|
| 140 |
|
| 141 |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
| 142 |
+
try:
|
| 143 |
+
self.embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"Error initializing embedder: {e}")
|
| 146 |
+
self.embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 147 |
self.pipeline = transformers.pipeline(
|
| 148 |
"text-generation",
|
| 149 |
model=model_id,
|
src/demo/path_utils.py
CHANGED
|
@@ -1,6 +1,21 @@
|
|
| 1 |
import os
|
| 2 |
import tempfile
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
# 检查是否在 Hugging Face Spaces 环境中
|
| 5 |
def get_data_paths():
|
| 6 |
# 如果在 Hugging Face Spaces 中,使用临时目录
|
|
|
|
| 1 |
import os
|
| 2 |
import tempfile
|
| 3 |
|
| 4 |
+
# 设置 Hugging Face 缓存目录
|
| 5 |
+
def setup_hf_cache():
|
| 6 |
+
"""设置 Hugging Face 缓存目录,在 Hugging Face Spaces 中使用临时目录"""
|
| 7 |
+
if os.environ.get('SPACE_ID') or os.environ.get('HF_SPACE_ID'):
|
| 8 |
+
# 在 Hugging Face Spaces 中使用临时目录作为缓存
|
| 9 |
+
cache_dir = tempfile.mkdtemp()
|
| 10 |
+
os.environ['HF_HOME'] = cache_dir
|
| 11 |
+
os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_dir, 'transformers')
|
| 12 |
+
os.environ['HF_HUB_CACHE'] = os.path.join(cache_dir, 'hub')
|
| 13 |
+
print(f"Using Hugging Face cache directory: {cache_dir}")
|
| 14 |
+
return cache_dir
|
| 15 |
+
else:
|
| 16 |
+
# 本地环境使用默认缓存目录
|
| 17 |
+
return None
|
| 18 |
+
|
| 19 |
# 检查是否在 Hugging Face Spaces 环境中
|
| 20 |
def get_data_paths():
|
| 21 |
# 如果在 Hugging Face Spaces 中,使用临时目录
|
src/demo/survey_generation_pipeline/asg_retriever.py
CHANGED
|
@@ -8,7 +8,11 @@ from .asg_splitter import TextSplitting
|
|
| 8 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 9 |
import time
|
| 10 |
import concurrent.futures
|
| 11 |
-
from ..path_utils import get_path
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class Retriever:
|
| 14 |
client = None
|
|
@@ -223,7 +227,11 @@ def process_pdf(file_path: str, survey_id: str, embedder: HuggingFaceEmbeddings,
|
|
| 223 |
return collection_name, embeddings_list, documents_list, metadata_list,title_new
|
| 224 |
|
| 225 |
def query_embeddings(collection_name: str, query_list: list):
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
retriever = Retriever()
|
| 228 |
|
| 229 |
final_context = ""
|
|
@@ -244,7 +252,11 @@ def query_embeddings(collection_name: str, query_list: list):
|
|
| 244 |
|
| 245 |
# new, may be in parallel
|
| 246 |
def query_embeddings_new(collection_name: str, query_list: list):
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
retriever = Retriever()
|
| 249 |
|
| 250 |
final_context = ""
|
|
@@ -270,45 +282,59 @@ def query_embeddings_new(collection_name: str, query_list: list):
|
|
| 270 |
seen_chunks.add(chunk)
|
| 271 |
return final_context
|
| 272 |
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
return final_context, citation_data_list
|
| 314 |
|
|
@@ -325,7 +351,11 @@ def query_multiple_collections(collection_names: list[str], query_list: list[str
|
|
| 325 |
dict: Combined results from all collections, grouped by collection.
|
| 326 |
"""
|
| 327 |
# Define embedder inside the function
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
retriever = Retriever()
|
| 330 |
|
| 331 |
def query_single_collection(collection_name: str):
|
|
|
|
| 8 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 9 |
import time
|
| 10 |
import concurrent.futures
|
| 11 |
+
from ..path_utils import get_path, setup_hf_cache
|
| 12 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 13 |
+
|
| 14 |
+
# 设置 Hugging Face 缓存目录
|
| 15 |
+
cache_dir = setup_hf_cache()
|
| 16 |
|
| 17 |
class Retriever:
|
| 18 |
client = None
|
|
|
|
| 227 |
return collection_name, embeddings_list, documents_list, metadata_list,title_new
|
| 228 |
|
| 229 |
def query_embeddings(collection_name: str, query_list: list):
|
| 230 |
+
try:
|
| 231 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 232 |
+
except Exception as e:
|
| 233 |
+
print(f"Error initializing embedder: {e}")
|
| 234 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 235 |
retriever = Retriever()
|
| 236 |
|
| 237 |
final_context = ""
|
|
|
|
| 252 |
|
| 253 |
# new, may be in parallel
|
| 254 |
def query_embeddings_new(collection_name: str, query_list: list):
|
| 255 |
+
try:
|
| 256 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 257 |
+
except Exception as e:
|
| 258 |
+
print(f"Error initializing embedder: {e}")
|
| 259 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 260 |
retriever = Retriever()
|
| 261 |
|
| 262 |
final_context = ""
|
|
|
|
| 282 |
seen_chunks.add(chunk)
|
| 283 |
return final_context
|
| 284 |
|
| 285 |
+
# wza
|
| 286 |
+
def query_embeddings_new_new(collection_name: str, query_list: list):
|
| 287 |
+
try:
|
| 288 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 289 |
+
except Exception as e:
|
| 290 |
+
print(f"Error initializing embedder: {e}")
|
| 291 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 292 |
+
retriever = Retriever()
|
| 293 |
|
| 294 |
+
final_context = "" # Stores concatenated context
|
| 295 |
+
citation_data_list = [] # Stores chunk content and collection name as source
|
| 296 |
+
seen_chunks = set() # Ensures unique chunks are added
|
| 297 |
+
|
| 298 |
+
def process_query(query_text):
|
| 299 |
+
# Embed the query text and retrieve relevant chunks
|
| 300 |
+
query_embeddings = embedder.embed_query(query_text)
|
| 301 |
+
query_result = retriever.query_chroma(
|
| 302 |
+
collection_name=collection_name,
|
| 303 |
+
query_embeddings=[query_embeddings],
|
| 304 |
+
n_results=5 # Fixed number of results
|
| 305 |
+
)
|
| 306 |
+
return query_result
|
| 307 |
+
|
| 308 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 309 |
+
future_to_query = {executor.submit(process_query, q): q for q in query_list}
|
| 310 |
+
for future in concurrent.futures.as_completed(future_to_query):
|
| 311 |
+
query_text = future_to_query[future]
|
| 312 |
+
try:
|
| 313 |
+
query_result = future.result()
|
| 314 |
+
except Exception as e:
|
| 315 |
+
print(f"Query '{query_text}' failed with exception: {e}")
|
| 316 |
+
continue
|
| 317 |
+
|
| 318 |
+
if "documents" not in query_result or "distances" not in query_result:
|
| 319 |
+
continue
|
| 320 |
+
if not query_result["documents"] or not query_result["distances"]:
|
| 321 |
+
continue
|
| 322 |
+
docs_list = query_result["documents"][0] if query_result["documents"] else []
|
| 323 |
+
dist_list = query_result["distances"][0] if query_result["distances"] else []
|
| 324 |
+
|
| 325 |
+
if len(docs_list) != len(dist_list):
|
| 326 |
+
continue
|
| 327 |
+
|
| 328 |
+
for chunk, distance in zip(docs_list, dist_list):
|
| 329 |
+
processed_chunk = chunk.strip()
|
| 330 |
+
if processed_chunk not in seen_chunks:
|
| 331 |
+
final_context += processed_chunk + "//\n"
|
| 332 |
+
seen_chunks.add(processed_chunk)
|
| 333 |
+
citation_data_list.append({
|
| 334 |
+
"source": collection_name,
|
| 335 |
+
"distance": distance,
|
| 336 |
+
"content": processed_chunk,
|
| 337 |
+
})
|
| 338 |
|
| 339 |
return final_context, citation_data_list
|
| 340 |
|
|
|
|
| 351 |
dict: Combined results from all collections, grouped by collection.
|
| 352 |
"""
|
| 353 |
# Define embedder inside the function
|
| 354 |
+
try:
|
| 355 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 356 |
+
except Exception as e:
|
| 357 |
+
print(f"Error initializing embedder: {e}")
|
| 358 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 359 |
retriever = Retriever()
|
| 360 |
|
| 361 |
def query_single_collection(collection_name: str):
|
src/demo/survey_generation_pipeline/category_and_tsne.py
CHANGED
|
@@ -1,15 +1,22 @@
|
|
| 1 |
from sklearn.metrics import silhouette_score
|
| 2 |
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import seaborn as sns
|
| 6 |
-
import
|
| 7 |
from sklearn.manifold import TSNE
|
| 8 |
from sklearn.cluster import AgglomerativeClustering
|
| 9 |
-
import
|
| 10 |
-
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
plt.switch_backend('agg')
|
| 15 |
device = 0
|
|
@@ -30,6 +37,9 @@ import matplotlib.pyplot as plt
|
|
| 30 |
from sklearn.manifold import TSNE
|
| 31 |
import seaborn as sns
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
class DimensionalityReduction:
|
| 34 |
def fit(self, X):
|
| 35 |
return self
|
|
@@ -39,7 +49,11 @@ class DimensionalityReduction:
|
|
| 39 |
|
| 40 |
class ClusteringWithTopic:
|
| 41 |
def __init__(self, df, n_topics=3):
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
# umap_model = DimensionalityReduction()
|
| 44 |
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', init = 'pca')
|
| 45 |
hdbscan_model = AgglomerativeClustering(n_clusters=n_topics)
|
|
@@ -76,7 +90,11 @@ class ClusteringWithTopic:
|
|
| 76 |
初始化 ClusteringWithTopic,接受一个 n_topics_list,其中包含多个聚类数目,
|
| 77 |
选取 silhouette_score 最高的结果。
|
| 78 |
"""
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
self.embeddings = embedding_model.encode(df, show_progress_bar=True)
|
| 81 |
|
| 82 |
self.df = df
|
|
@@ -92,7 +110,7 @@ class ClusteringWithTopic:
|
|
| 92 |
# 用于存储不同聚类数目的结果
|
| 93 |
self.best_n_topics = None
|
| 94 |
self.best_labels = None
|
| 95 |
-
self.best_score = -1
|
| 96 |
# def fit_and_get_labels(self, X):
|
| 97 |
# topics, probs = self.topic_model.fit_transform(self.df, self.embeddings)
|
| 98 |
# return topics
|
|
|
|
| 1 |
from sklearn.metrics import silhouette_score
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import seaborn as sns
|
| 7 |
+
import json
|
| 8 |
from sklearn.manifold import TSNE
|
| 9 |
from sklearn.cluster import AgglomerativeClustering
|
| 10 |
+
import os
|
| 11 |
+
import tempfile
|
| 12 |
|
| 13 |
+
from sentence_transformers import SentenceTransformer
|
| 14 |
+
from bertopic import BERTopic
|
| 15 |
+
from bertopic.representation import KeyBERTInspired
|
| 16 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
| 17 |
+
from bertopic.vectorizers import ClassTfidfTransformer
|
| 18 |
+
from umap import UMAP
|
| 19 |
+
from ..path_utils import get_path, setup_hf_cache
|
| 20 |
|
| 21 |
plt.switch_backend('agg')
|
| 22 |
device = 0
|
|
|
|
| 37 |
from sklearn.manifold import TSNE
|
| 38 |
import seaborn as sns
|
| 39 |
|
| 40 |
+
# 设置 Hugging Face 缓存目录
|
| 41 |
+
cache_dir = setup_hf_cache()
|
| 42 |
+
|
| 43 |
class DimensionalityReduction:
|
| 44 |
def fit(self, X):
|
| 45 |
return self
|
|
|
|
| 49 |
|
| 50 |
class ClusteringWithTopic:
|
| 51 |
def __init__(self, df, n_topics=3):
|
| 52 |
+
try:
|
| 53 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True, cache_folder=cache_dir)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f"Error initializing SentenceTransformer: {e}")
|
| 56 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True)
|
| 57 |
# umap_model = DimensionalityReduction()
|
| 58 |
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', init = 'pca')
|
| 59 |
hdbscan_model = AgglomerativeClustering(n_clusters=n_topics)
|
|
|
|
| 90 |
初始化 ClusteringWithTopic,接受一个 n_topics_list,其中包含多个聚类数目,
|
| 91 |
选取 silhouette_score 最高的结果。
|
| 92 |
"""
|
| 93 |
+
try:
|
| 94 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True, cache_folder=cache_dir)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"Error initializing SentenceTransformer: {e}")
|
| 97 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True)
|
| 98 |
self.embeddings = embedding_model.encode(df, show_progress_bar=True)
|
| 99 |
|
| 100 |
self.df = df
|
|
|
|
| 110 |
# 用于存储不同聚类数目的结果
|
| 111 |
self.best_n_topics = None
|
| 112 |
self.best_labels = None
|
| 113 |
+
self.best_score = -1
|
| 114 |
# def fit_and_get_labels(self, X):
|
| 115 |
# topics, probs = self.topic_model.fit_transform(self.df, self.embeddings)
|
| 116 |
# return topics
|
src/demo/survey_generation_pipeline/main.py
CHANGED
|
@@ -27,6 +27,10 @@ import os
|
|
| 27 |
from markdown_pdf import MarkdownPdf, Section # Assuming you are using markdown_pdf
|
| 28 |
from typing import Any
|
| 29 |
import xml.etree.ElementTree as ET
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def clean_str(input_str):
|
| 32 |
input_str = str(input_str).strip().lower()
|
|
@@ -286,15 +290,19 @@ class ASG_system:
|
|
| 286 |
self.pipeline = None
|
| 287 |
|
| 288 |
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
# self.pipeline.model.load_adapter(peft_model_id = "technicolor/llama3.1_8b_outline_generation", adapter_name="outline")
|
| 299 |
# self.pipeline.model.load_adapter(peft_model_id ="technicolor/llama3.1_8b_abstract_generation", adapter_name="abstract")
|
| 300 |
# self.pipeline.model.load_adapter(peft_model_id ="technicolor/llama3.1_8b_conclusion_generation", adapter_name="conclusion")
|
|
|
|
| 27 |
from markdown_pdf import MarkdownPdf, Section # Assuming you are using markdown_pdf
|
| 28 |
from typing import Any
|
| 29 |
import xml.etree.ElementTree as ET
|
| 30 |
+
from .path_utils import get_path, setup_hf_cache
|
| 31 |
+
|
| 32 |
+
# 设置 Hugging Face 缓存目录
|
| 33 |
+
cache_dir = setup_hf_cache()
|
| 34 |
|
| 35 |
def clean_str(input_str):
|
| 36 |
input_str = str(input_str).strip().lower()
|
|
|
|
| 290 |
self.pipeline = None
|
| 291 |
|
| 292 |
|
| 293 |
+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
| 294 |
+
try:
|
| 295 |
+
self.embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 296 |
+
except Exception as e:
|
| 297 |
+
print(f"Error initializing embedder: {e}")
|
| 298 |
+
self.embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 299 |
+
self.pipeline = transformers.pipeline(
|
| 300 |
+
"text-generation",
|
| 301 |
+
model=model_id,
|
| 302 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
| 303 |
+
token = os.getenv('HF_API_KEY'),
|
| 304 |
+
device_map="auto",
|
| 305 |
+
)
|
| 306 |
# self.pipeline.model.load_adapter(peft_model_id = "technicolor/llama3.1_8b_outline_generation", adapter_name="outline")
|
| 307 |
# self.pipeline.model.load_adapter(peft_model_id ="technicolor/llama3.1_8b_abstract_generation", adapter_name="abstract")
|
| 308 |
# self.pipeline.model.load_adapter(peft_model_id ="technicolor/llama3.1_8b_conclusion_generation", adapter_name="conclusion")
|
src/demo/survey_generator_api.py
CHANGED
|
@@ -9,6 +9,10 @@ import numpy as np
|
|
| 9 |
from numpy.linalg import norm
|
| 10 |
import openai
|
| 11 |
from .asg_retriever import Retriever
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def getQwenClient():
|
| 14 |
# openai_api_key = os.environ.get("OPENAI_API_KEY")
|
|
@@ -506,7 +510,7 @@ Survey Paper Content for "{section_title}":
|
|
| 506 |
response = generateResponse(client, formatted_prompt).strip()
|
| 507 |
sentences = re.split(r'(?<=[.!?])\s+', response.strip())
|
| 508 |
|
| 509 |
-
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 510 |
sentence_embeddings = embedder.embed_documents(sentences)
|
| 511 |
chunk_texts = [c["content"] for c in citation_data_list]
|
| 512 |
chunk_sources = [c["source"] for c in citation_data_list]
|
|
@@ -627,7 +631,7 @@ Survey Paper Content for "{section_title}":
|
|
| 627 |
para_index_map.append(p_idx)
|
| 628 |
|
| 629 |
# -- 3. 对所有句子进行向量化嵌入(保持逻辑:一次性处理全文) ---
|
| 630 |
-
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 631 |
sentence_embeddings = embedder.embed_documents(all_sentences)
|
| 632 |
|
| 633 |
# -- 4. 对 citation_data_list 做向量化嵌入 ---
|
|
@@ -763,25 +767,17 @@ def query_embedding_for_title(
|
|
| 763 |
n_results: int = 1,
|
| 764 |
embedder: HuggingFaceEmbeddings = None
|
| 765 |
):
|
| 766 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
retriever = Retriever()
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
query_result
|
| 771 |
-
collection_name=collection_name,
|
| 772 |
-
query_embeddings=[title_embedding],
|
| 773 |
-
n_results=n_results
|
| 774 |
-
)
|
| 775 |
-
# old
|
| 776 |
-
# query_result_chunks = query_result["documents"][0]
|
| 777 |
-
# for chunk in query_result_chunks:
|
| 778 |
-
# final_context += chunk.strip() + "//\n"
|
| 779 |
-
|
| 780 |
-
# 2025
|
| 781 |
-
if "documents" in query_result and len(query_result["documents"]) > 0:
|
| 782 |
-
for chunk in query_result["documents"][0]:
|
| 783 |
-
final_context += chunk.strip() + "//\n"
|
| 784 |
-
return final_context
|
| 785 |
|
| 786 |
# old
|
| 787 |
def generate_context_list(outline, collection_list):
|
|
@@ -812,32 +808,32 @@ def generate_context_list(outline, collection_list):
|
|
| 812 |
|
| 813 |
# 2025
|
| 814 |
def generate_context_list(outline, collection_list):
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
context_list_final = []
|
| 821 |
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
|
|
|
| 841 |
|
| 842 |
# 1.8 输入introduction 输出带引用 (collection name) 的introduction
|
| 843 |
def introduction_with_citations(
|
|
@@ -847,110 +843,65 @@ def introduction_with_citations(
|
|
| 847 |
dynamic_threshold: bool = True,
|
| 848 |
diversity_limit: int = 3
|
| 849 |
) -> str:
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
:
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
#
|
| 861 |
-
|
| 862 |
-
if not paragraphs:
|
| 863 |
-
return intro_text
|
| 864 |
-
|
| 865 |
-
# 2. 逐段落拆分句子,记录每句所属段落编号
|
| 866 |
-
all_sentences = []
|
| 867 |
-
para_index_map = []
|
| 868 |
-
for p_idx, para in enumerate(paragraphs):
|
| 869 |
-
if not para.strip():
|
| 870 |
-
# 空段落,直接跳过切句,保持段落分隔
|
| 871 |
-
continue
|
| 872 |
-
# 用正则在段落内部按 .!? 分句
|
| 873 |
-
sentences_in_para = re.split(r'(?<=[.!?])\s+', para)
|
| 874 |
-
for sent in sentences_in_para:
|
| 875 |
-
if sent:
|
| 876 |
-
all_sentences.append(sent)
|
| 877 |
-
para_index_map.append(p_idx)
|
| 878 |
-
|
| 879 |
-
# 如果拆不出任何句子,直接返回
|
| 880 |
-
if not all_sentences:
|
| 881 |
-
return intro_text
|
| 882 |
-
|
| 883 |
-
# 3. 对所有句子进行 Embedding
|
| 884 |
-
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 885 |
-
sentence_embeddings = embedder.embed_documents(all_sentences)
|
| 886 |
-
|
| 887 |
-
# 4. 对 citation_data_list 里每段文献块进行向量化
|
| 888 |
chunk_texts = [c["content"] for c in citation_data_list]
|
| 889 |
chunk_sources = [c["source"] for c in citation_data_list]
|
| 890 |
chunk_embeddings = embedder.embed_documents(chunk_texts)
|
| 891 |
-
|
|
|
|
| 892 |
def cosine_sim(a, b):
|
| 893 |
return np.dot(a, b) / (norm(a) * norm(b) + 1e-9)
|
| 894 |
-
|
| 895 |
-
#
|
| 896 |
sim_matrix = []
|
| 897 |
for s_emb in sentence_embeddings:
|
| 898 |
row = [cosine_sim(s_emb, c_emb) for c_emb in chunk_embeddings]
|
| 899 |
sim_matrix.append(row)
|
| 900 |
sim_matrix = np.array(sim_matrix)
|
| 901 |
-
|
| 902 |
-
#
|
| 903 |
all_sims = sim_matrix.flatten()
|
| 904 |
mean_sim = np.mean(all_sims)
|
| 905 |
-
std_sim
|
| 906 |
k = 0.5
|
| 907 |
threshold = max(base_threshold, mean_sim + k * std_sim) if dynamic_threshold else base_threshold
|
| 908 |
-
|
| 909 |
-
#
|
| 910 |
candidates = []
|
| 911 |
-
for i in
|
| 912 |
-
for j in
|
| 913 |
-
if
|
| 914 |
-
candidates.append((i, j,
|
| 915 |
-
|
| 916 |
-
#
|
|
|
|
| 917 |
candidates.sort(key=lambda x: x[2], reverse=True)
|
| 918 |
-
|
| 919 |
-
# 记录:句子 -> 已分配的 source;并限制每个 source 最多引用次数
|
| 920 |
-
source_count = {src: 0 for src in chunk_sources}
|
| 921 |
assigned = {}
|
| 922 |
-
|
| 923 |
-
for (sent_id, chk_id,
|
| 924 |
if sent_id not in assigned:
|
| 925 |
src = chunk_sources[chk_id]
|
| 926 |
if source_count[src] < diversity_limit:
|
| 927 |
assigned[sent_id] = src
|
| 928 |
source_count[src] += 1
|
| 929 |
-
|
| 930 |
-
#
|
| 931 |
updated_sentences = []
|
| 932 |
-
for i, sentence in enumerate(
|
| 933 |
if i in assigned:
|
| 934 |
updated_sentences.append(sentence + f" [{assigned[i]}]")
|
| 935 |
else:
|
| 936 |
updated_sentences.append(sentence)
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
updated_paras = [""] * len(paragraphs)
|
| 940 |
-
para_sentences_map = [[] for _ in range(len(paragraphs))]
|
| 941 |
-
|
| 942 |
-
for s_idx, sent in enumerate(updated_sentences):
|
| 943 |
-
p_idx = para_index_map[s_idx]
|
| 944 |
-
para_sentences_map[p_idx].append(sent)
|
| 945 |
-
|
| 946 |
-
for i in range(len(paragraphs)):
|
| 947 |
-
if not paragraphs[i].strip():
|
| 948 |
-
# 保持空段落不动
|
| 949 |
-
updated_paras[i] = paragraphs[i]
|
| 950 |
-
else:
|
| 951 |
-
# 同段落内的句子用空格拼起来
|
| 952 |
-
updated_paras[i] = " ".join(para_sentences_map[i])
|
| 953 |
-
|
| 954 |
-
# 11. 用原先换行分隔符拼回
|
| 955 |
-
updated_intro = "\n\n".join(updated_paras)
|
| 956 |
-
return updated_intro
|
|
|
|
| 9 |
from numpy.linalg import norm
|
| 10 |
import openai
|
| 11 |
from .asg_retriever import Retriever
|
| 12 |
+
from .path_utils import get_path, setup_hf_cache
|
| 13 |
+
|
| 14 |
+
# 设置 Hugging Face 缓存目录
|
| 15 |
+
cache_dir = setup_hf_cache()
|
| 16 |
|
| 17 |
def getQwenClient():
|
| 18 |
# openai_api_key = os.environ.get("OPENAI_API_KEY")
|
|
|
|
| 510 |
response = generateResponse(client, formatted_prompt).strip()
|
| 511 |
sentences = re.split(r'(?<=[.!?])\s+', response.strip())
|
| 512 |
|
| 513 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 514 |
sentence_embeddings = embedder.embed_documents(sentences)
|
| 515 |
chunk_texts = [c["content"] for c in citation_data_list]
|
| 516 |
chunk_sources = [c["source"] for c in citation_data_list]
|
|
|
|
| 631 |
para_index_map.append(p_idx)
|
| 632 |
|
| 633 |
# -- 3. 对所有句子进行向量化嵌入(保持逻辑:一次性处理全文) ---
|
| 634 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 635 |
sentence_embeddings = embedder.embed_documents(all_sentences)
|
| 636 |
|
| 637 |
# -- 4. 对 citation_data_list 做向量化嵌入 ---
|
|
|
|
| 767 |
n_results: int = 1,
|
| 768 |
embedder: HuggingFaceEmbeddings = None
|
| 769 |
):
|
| 770 |
+
if embedder is None:
|
| 771 |
+
try:
|
| 772 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 773 |
+
except Exception as e:
|
| 774 |
+
print(f"Error initializing embedder: {e}")
|
| 775 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 776 |
+
|
| 777 |
retriever = Retriever()
|
| 778 |
+
query_embeddings = embedder.embed_query(title)
|
| 779 |
+
query_result = retriever.query_chroma(collection_name=collection_name, query_embeddings=[query_embeddings], n_results=n_results)
|
| 780 |
+
return query_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
|
| 782 |
# old
|
| 783 |
def generate_context_list(outline, collection_list):
|
|
|
|
| 808 |
|
| 809 |
# 2025
|
| 810 |
def generate_context_list(outline, collection_list):
|
| 811 |
+
try:
|
| 812 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 813 |
+
except Exception as e:
|
| 814 |
+
print(f"Error initializing embedder: {e}")
|
| 815 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
| 816 |
|
| 817 |
+
retriever = Retriever()
|
| 818 |
+
context_list = []
|
| 819 |
+
|
| 820 |
+
for section_title in outline:
|
| 821 |
+
query_embeddings = embedder.embed_query(section_title)
|
| 822 |
+
final_context = ""
|
| 823 |
+
seen_chunks = set()
|
| 824 |
|
| 825 |
+
for collection_name in collection_list:
|
| 826 |
+
query_result = retriever.query_chroma(collection_name=collection_name, query_embeddings=[query_embeddings], n_results=2)
|
| 827 |
+
query_result_chunks = query_result["documents"][0]
|
| 828 |
+
|
| 829 |
+
for chunk in query_result_chunks:
|
| 830 |
+
if chunk not in seen_chunks:
|
| 831 |
+
final_context += chunk.strip() + "//\n"
|
| 832 |
+
seen_chunks.add(chunk)
|
| 833 |
+
|
| 834 |
+
context_list.append(final_context)
|
| 835 |
+
|
| 836 |
+
return context_list
|
| 837 |
|
| 838 |
# 1.8 输入introduction 输出带引用 (collection name) 的introduction
|
| 839 |
def introduction_with_citations(
|
|
|
|
| 843 |
dynamic_threshold: bool = True,
|
| 844 |
diversity_limit: int = 3
|
| 845 |
) -> str:
|
| 846 |
+
# 将介绍文本按句子分割
|
| 847 |
+
sentences = re.split(r'(?<=[.!?])\s+', intro_text.strip())
|
| 848 |
+
|
| 849 |
+
# 初始化 embedder
|
| 850 |
+
try:
|
| 851 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 852 |
+
except Exception as e:
|
| 853 |
+
print(f"Error initializing embedder: {e}")
|
| 854 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 855 |
+
|
| 856 |
+
# 对句子和引用数据进行向量化
|
| 857 |
+
sentence_embeddings = embedder.embed_documents(sentences)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
chunk_texts = [c["content"] for c in citation_data_list]
|
| 859 |
chunk_sources = [c["source"] for c in citation_data_list]
|
| 860 |
chunk_embeddings = embedder.embed_documents(chunk_texts)
|
| 861 |
+
|
| 862 |
+
# 计算余弦相似度
|
| 863 |
def cosine_sim(a, b):
|
| 864 |
return np.dot(a, b) / (norm(a) * norm(b) + 1e-9)
|
| 865 |
+
|
| 866 |
+
# 构建相似度矩阵
|
| 867 |
sim_matrix = []
|
| 868 |
for s_emb in sentence_embeddings:
|
| 869 |
row = [cosine_sim(s_emb, c_emb) for c_emb in chunk_embeddings]
|
| 870 |
sim_matrix.append(row)
|
| 871 |
sim_matrix = np.array(sim_matrix)
|
| 872 |
+
|
| 873 |
+
# 计算动态阈值
|
| 874 |
all_sims = sim_matrix.flatten()
|
| 875 |
mean_sim = np.mean(all_sims)
|
| 876 |
+
std_sim = np.std(all_sims)
|
| 877 |
k = 0.5
|
| 878 |
threshold = max(base_threshold, mean_sim + k * std_sim) if dynamic_threshold else base_threshold
|
| 879 |
+
|
| 880 |
+
# 找出候选引用
|
| 881 |
candidates = []
|
| 882 |
+
for i, sent in enumerate(sentences):
|
| 883 |
+
for j, sim in enumerate(sim_matrix[i]):
|
| 884 |
+
if sim >= threshold:
|
| 885 |
+
candidates.append((i, j, sim))
|
| 886 |
+
|
| 887 |
+
# 按相似度排序并分配引用
|
| 888 |
+
source_count = {s: 0 for s in chunk_sources}
|
| 889 |
candidates.sort(key=lambda x: x[2], reverse=True)
|
|
|
|
|
|
|
|
|
|
| 890 |
assigned = {}
|
| 891 |
+
|
| 892 |
+
for (sent_id, chk_id, sim) in candidates:
|
| 893 |
if sent_id not in assigned:
|
| 894 |
src = chunk_sources[chk_id]
|
| 895 |
if source_count[src] < diversity_limit:
|
| 896 |
assigned[sent_id] = src
|
| 897 |
source_count[src] += 1
|
| 898 |
+
|
| 899 |
+
# 更新句子
|
| 900 |
updated_sentences = []
|
| 901 |
+
for i, sentence in enumerate(sentences):
|
| 902 |
if i in assigned:
|
| 903 |
updated_sentences.append(sentence + f" [{assigned[i]}]")
|
| 904 |
else:
|
| 905 |
updated_sentences.append(sentence)
|
| 906 |
+
|
| 907 |
+
return " ".join(updated_sentences)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/demo/views.py
CHANGED
|
@@ -44,7 +44,7 @@ from dotenv import load_dotenv
|
|
| 44 |
from pathlib import Path
|
| 45 |
from markdown_pdf import MarkdownPdf, Section
|
| 46 |
import tempfile
|
| 47 |
-
from .path_utils import get_path
|
| 48 |
|
| 49 |
dotenv_path = os.path.join(os.path.dirname(__file__), ".env")
|
| 50 |
load_dotenv()
|
|
@@ -59,6 +59,9 @@ load_dotenv()
|
|
| 59 |
# print(f"OPENAI_API_KEY: {openai_api_key}")
|
| 60 |
# print(f"OPENAI_API_BASE: {openai_api_base}")
|
| 61 |
|
|
|
|
|
|
|
|
|
|
| 62 |
# 获取路径配置
|
| 63 |
paths_config = get_path('pdf') # 使用 get_path 函数获取路径配置
|
| 64 |
DATA_PATH = get_path('pdf')
|
|
@@ -144,8 +147,15 @@ Global_cluster_names = []
|
|
| 144 |
Global_citation_data = []
|
| 145 |
Global_cluster_num = 4
|
| 146 |
|
| 147 |
-
|
| 148 |
-
embedder = HuggingFaceEmbeddings(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
from demo.category_and_tsne import clustering
|
| 151 |
|
|
|
|
| 44 |
from pathlib import Path
|
| 45 |
from markdown_pdf import MarkdownPdf, Section
|
| 46 |
import tempfile
|
| 47 |
+
from .path_utils import get_path, setup_hf_cache
|
| 48 |
|
| 49 |
dotenv_path = os.path.join(os.path.dirname(__file__), ".env")
|
| 50 |
load_dotenv()
|
|
|
|
| 59 |
# print(f"OPENAI_API_KEY: {openai_api_key}")
|
| 60 |
# print(f"OPENAI_API_BASE: {openai_api_base}")
|
| 61 |
|
| 62 |
+
# 设置 Hugging Face 缓存目录
|
| 63 |
+
cache_dir = setup_hf_cache()
|
| 64 |
+
|
| 65 |
# 获取路径配置
|
| 66 |
paths_config = get_path('pdf') # 使用 get_path 函数获取路径配置
|
| 67 |
DATA_PATH = get_path('pdf')
|
|
|
|
| 147 |
Global_citation_data = []
|
| 148 |
Global_cluster_num = 4
|
| 149 |
|
| 150 |
+
try:
|
| 151 |
+
embedder = HuggingFaceEmbeddings(
|
| 152 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 153 |
+
cache_folder=cache_dir
|
| 154 |
+
)
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(f"Error initializing embedder: {e}")
|
| 157 |
+
# 如果初始化失败,尝试使用默认设置
|
| 158 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 159 |
|
| 160 |
from demo.category_and_tsne import clustering
|
| 161 |
|