Spaces:
Sleeping
Sleeping
| import torch | |
| import uuid | |
| import re | |
| import os | |
| import json | |
| import chromadb | |
| from .asg_splitter import TextSplitting | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| import time | |
| import concurrent.futures | |
| from .path_utils import get_path, setup_hf_cache | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| # 设置 Hugging Face 缓存目录 | |
| cache_dir = setup_hf_cache() | |
| class Retriever: | |
| client = None | |
| cur_dir = os.getcwd() | |
| chromadb_path = os.path.join(cur_dir, "chromadb") | |
| def __init__ (self): | |
| self.client = chromadb.PersistentClient(path=self.chromadb_path) | |
| def create_collection_chroma(self, collection_name: str): | |
| """ | |
| The Collection will be created with collection_name, the name must follow the rules:\n | |
| 0. Collection name must be unique, if the name exists then try to get this collection\n | |
| 1. The length of the name must be between 3 and 63 characters.\n | |
| 2. The name must start and end with a lowercase letter or a digit, and it can contain dots, dashes, and underscores in between.\n | |
| 3. The name must not contain two consecutive dots.\n | |
| 4. The name must not be a valid IP address.\n | |
| """ | |
| try: | |
| self.client.create_collection(name=collection_name) | |
| except chromadb.db.base.UniqueConstraintError: | |
| self.get_collection_chroma(collection_name) | |
| return collection_name | |
| def get_collection_chroma (self, collection_name: str): | |
| collection = self.client.get_collection(name=collection_name) | |
| return collection | |
| def add_documents_chroma (self, collection_name: str, embeddings_list: list[list[float]], documents_list: list[dict], metadata_list: list[dict]) : | |
| """ | |
| Please make sure that embeddings_list and metadata_list are matched with documents_list\n | |
| Example of one metadata: {"doc_name": "Test2.pdf", "page": "9"}\n | |
| The id will be created automatically as uuid v4 | |
| The chunks content and metadata will be logged (appended) into ./logs/<collection_name>.json | |
| """ | |
| collection = self.get_collection_chroma(collection_name) | |
| num = len(documents_list) | |
| ids=[str(uuid.uuid4()) for i in range(num) ] | |
| collection.add( | |
| documents= documents_list, | |
| metadatas= metadata_list, | |
| embeddings= embeddings_list, | |
| ids=ids | |
| ) | |
| logpath = os.path.join(self.cur_dir, "logs", f"{collection_name}.json") | |
| os.makedirs(os.path.dirname(logpath), exist_ok=True) | |
| logs = [] | |
| try: | |
| with open (logpath, 'r', encoding="utf-8") as chunklog: | |
| logs = json.load(chunklog) | |
| except (FileNotFoundError, json.decoder.JSONDecodeError): | |
| logs = [] | |
| added_log= [{"chunk_id": ids[i], "metadata": metadata_list[i], "page_content": documents_list[i]} \ | |
| for i in range(num)] | |
| logs.extend(added_log) | |
| # write back | |
| with open (logpath, "w", encoding="utf-8") as chunklog: | |
| json.dump(logs, chunklog, indent=4) | |
| print(f"Logged document information to '{logpath}'.") | |
| def query_chroma(self, collection_name: str, query_embeddings: list[list[float]], n_results: int = 5) -> dict: | |
| # return n closest results (chunks and metadatas) in order | |
| collection = self.get_collection_chroma(collection_name) | |
| result = collection.query( | |
| query_embeddings=query_embeddings, | |
| n_results=n_results, | |
| ) | |
| return result | |
| def update_chroma (self, collection_name: str, id_list: list[str], embeddings_list: list[list[float]], documents_list: list[str], metadata_list: list[dict]): | |
| collection = self.get_collection_chroma(collection_name) | |
| num = len(documents_list) | |
| collection.update( | |
| ids=id_list, | |
| embeddings=embeddings_list, | |
| metadatas=metadata_list, | |
| documents=documents_list, | |
| ) | |
| update_list = [{"chunk_id": id_list[i], "metadata": metadata_list[i], "page_content": documents_list[i]} for i in range(num)] | |
| # update the chunk log | |
| logs = [] | |
| logpath = os.path.join(self.cur_dir, "logs", f"{collection_name}.json") | |
| try: | |
| with open (logpath, 'r', encoding="utf-8") as chunklog: | |
| logs = json.load(chunklog) | |
| except (FileNotFoundError, json.decoder.JSONDecodeError): | |
| logs = [] # old_log does not exist or empty, then no need to update | |
| else: | |
| for i in range(num): | |
| for log in logs: | |
| if (log["chunk_id"] == update_list[i]["chunk_id"]): | |
| log["metadata"] = update_list[i]["metadata"] | |
| log["page_content"] = update_list[i]["page_content"] | |
| break | |
| with open (logpath, "w", encoding="utf-8") as chunklog: | |
| json.dump(logs, chunklog, indent=4) | |
| print(f"Updated log file at '{logpath}'.") | |
| def delete_collection_entries_chroma(self, collection_name: str, id_list: list[str]): | |
| collection = self.get_collection_chroma(collection_name) | |
| collection.delete(ids=id_list) | |
| print(f"Deleted entries with ids: {id_list} from collection '{collection_name}'.") | |
| def delete_collection_chroma(self, collection_name: str): | |
| print(f"The collection {collection_name} will be deleted forever!") | |
| self.client.delete_collection(collection_name) | |
| try: | |
| logpath = os.path.join(self.cur_dir, "logs", f"{collection_name}.json") | |
| print(f"Collection {collection_name} has been removed, deleting log file of this collection") | |
| os.remove(logpath) | |
| except FileNotFoundError: | |
| print("The log of this collection did not exist!") | |
| def list_collections_chroma(self): | |
| collections = self.client.list_collections() | |
| # Generate a legal collection name from a PDF filename | |
| def legal_pdf(filename: str) -> str: | |
| pdf_index = filename.lower().rfind('.pdf') | |
| if pdf_index != -1: | |
| name_before_pdf = filename[:pdf_index] | |
| else: | |
| name_before_pdf = filename | |
| name_before_pdf = name_before_pdf.strip() | |
| name = re.sub(r'[^a-zA-Z0-9._-]', '', name_before_pdf) | |
| name = name.lower() | |
| while '..' in name: | |
| name = name.replace('..', '.') | |
| name = name[:63] | |
| if len(name) < 3: | |
| name = name.ljust(3, '0') # fill with '0' if the length is less than 3 | |
| if not re.match(r'^[a-z0-9]', name): | |
| name = 'a' + name[1:] | |
| if not re.match(r'[a-z0-9]$', name): | |
| name = name[:-1] + 'a' | |
| ip_pattern = re.compile(r'^(\d{1,3}\.){3}\d{1,3}$') | |
| if ip_pattern.match(name): | |
| name = 'ip_' + name | |
| return name | |
| def process_pdf(file_path: str, survey_id: str, embedder: HuggingFaceEmbeddings, mode: str): | |
| # Load and split the PDF | |
| split_start_time = time.time() | |
| splitters = TextSplitting().mineru_recursive_splitter(file_path, survey_id, mode) | |
| documents_list = [document.page_content for document in splitters] | |
| for i in range(len(documents_list)): | |
| documents_list[i] = documents_list[i].replace('\n', ' ') | |
| print(f"Splitting took {time.time() - split_start_time} seconds.") | |
| # Embed the documents | |
| # embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| embed_start_time = time.time() | |
| doc_results = embedder.embed_documents(documents_list) | |
| if isinstance(doc_results, torch.Tensor): | |
| embeddings_list = doc_results.tolist() | |
| else: | |
| embeddings_list = doc_results | |
| print(f"Embedding took {time.time() - embed_start_time} seconds.") | |
| # Prepare metadata | |
| metadata_list = [{"doc_name": os.path.basename(file_path)} for i in range(len(documents_list))] | |
| title = os.path.splitext(os.path.basename(file_path))[0] | |
| title_new = title.strip() | |
| invalid_chars = ['<', '>', ':', '"', '/', '\\', '|', '?', '*','_'] | |
| for char in invalid_chars: | |
| title_new = title_new.replace(char, ' ') | |
| collection_name = legal_pdf(title_new) | |
| retriever = Retriever() | |
| retriever.list_collections_chroma() | |
| retriever.create_collection_chroma(collection_name) | |
| retriever.add_documents_chroma( | |
| collection_name=collection_name, | |
| embeddings_list=embeddings_list, | |
| documents_list=documents_list, | |
| metadata_list=metadata_list | |
| ) | |
| return collection_name, embeddings_list, documents_list, metadata_list,title_new | |
| def query_embeddings(collection_name: str, query_list: list): | |
| try: | |
| embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir) | |
| except Exception as e: | |
| print(f"Error initializing embedder: {e}") | |
| embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| retriever = Retriever() | |
| final_context = "" | |
| seen_chunks = set() | |
| for query_text in query_list: | |
| query_embeddings = embedder.embed_query(query_text) | |
| query_result = retriever.query_chroma(collection_name=collection_name, query_embeddings=[query_embeddings], n_results=2) | |
| query_result_chunks = query_result["documents"][0] | |
| # query_result_ids = query_result["ids"][0] | |
| for chunk in query_result_chunks: | |
| if chunk not in seen_chunks: | |
| final_context += chunk.strip() + "//\n" | |
| seen_chunks.add(chunk) | |
| return final_context | |
| # new, may be in parallel | |
| def query_embeddings_new(collection_name: str, query_list: list): | |
| try: | |
| embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir) | |
| except Exception as e: | |
| print(f"Error initializing embedder: {e}") | |
| embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| retriever = Retriever() | |
| final_context = "" | |
| seen_chunks = set() | |
| def process_query(query_text): | |
| query_embeddings = embedder.embed_query(query_text) | |
| query_result = retriever.query_chroma( | |
| collection_name=collection_name, | |
| query_embeddings=[query_embeddings], | |
| n_results=2 | |
| ) | |
| query_result_chunks = query_result["documents"][0] | |
| return query_result_chunks | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = {executor.submit(process_query, query_text): query_text for query_text in query_list} | |
| for future in concurrent.futures.as_completed(futures): | |
| query_result_chunks = future.result() | |
| for chunk in query_result_chunks: | |
| if chunk not in seen_chunks: | |
| final_context += chunk.strip() + "//\n" | |
| seen_chunks.add(chunk) | |
| return final_context | |
| # wza | |
| def query_embeddings_new_new(collection_name: str, query_list: list): | |
| try: | |
| embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir) | |
| except Exception as e: | |
| print(f"Error initializing embedder: {e}") | |
| embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| retriever = Retriever() | |
| final_context = "" # Stores concatenated context | |
| citation_data_list = [] # Stores chunk content and collection name as source | |
| seen_chunks = set() # Ensures unique chunks are added | |
| def process_query(query_text): | |
| # Embed the query text and retrieve relevant chunks | |
| query_embeddings = embedder.embed_query(query_text) | |
| query_result = retriever.query_chroma( | |
| collection_name=collection_name, | |
| query_embeddings=[query_embeddings], | |
| n_results=5 # Fixed number of results | |
| ) | |
| return query_result | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future_to_query = {executor.submit(process_query, q): q for q in query_list} | |
| for future in concurrent.futures.as_completed(future_to_query): | |
| query_text = future_to_query[future] | |
| try: | |
| query_result = future.result() | |
| except Exception as e: | |
| print(f"Query '{query_text}' failed with exception: {e}") | |
| continue | |
| if "documents" not in query_result or "distances" not in query_result: | |
| continue | |
| if not query_result["documents"] or not query_result["distances"]: | |
| continue | |
| docs_list = query_result["documents"][0] if query_result["documents"] else [] | |
| dist_list = query_result["distances"][0] if query_result["distances"] else [] | |
| if len(docs_list) != len(dist_list): | |
| continue | |
| for chunk, distance in zip(docs_list, dist_list): | |
| processed_chunk = chunk.strip() | |
| if processed_chunk not in seen_chunks: | |
| final_context += processed_chunk + "//\n" | |
| seen_chunks.add(processed_chunk) | |
| citation_data_list.append({ | |
| "source": collection_name, | |
| "distance": distance, | |
| "content": processed_chunk, | |
| }) | |
| return final_context, citation_data_list | |
| # concurrent version for both collection names and queries | |
| def query_multiple_collections(collection_names: list[str], query_list: list[str], survey_id: str) -> dict: | |
| """ | |
| Query multiple collections in parallel and return the combined results. | |
| Args: | |
| collection_names (list[str]): List of collection names to query. | |
| query_list (list[str]): List of queries to execute on each collection. | |
| Returns: | |
| dict: Combined results from all collections, grouped by collection. | |
| """ | |
| # Define embedder inside the function | |
| try: | |
| embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir) | |
| except Exception as e: | |
| print(f"Error initializing embedder: {e}") | |
| embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| retriever = Retriever() | |
| def query_single_collection(collection_name: str): | |
| """ | |
| Query a single collection for all queries in the query_list. | |
| """ | |
| final_context = "" | |
| seen_chunks = set() | |
| def process_query(query_text): | |
| # Embed the query | |
| query_embeddings = embedder.embed_query(query_text) | |
| # Query the collection | |
| query_result = retriever.query_chroma( | |
| collection_name=collection_name, | |
| query_embeddings=[query_embeddings], | |
| n_results=5 | |
| ) | |
| query_result_chunks = query_result["documents"][0] | |
| return query_result_chunks | |
| # Process all queries in parallel for the given collection | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = {executor.submit(process_query, query_text): query_text for query_text in query_list} | |
| for future in concurrent.futures.as_completed(futures): | |
| query_result_chunks = future.result() | |
| for chunk in query_result_chunks: | |
| if chunk not in seen_chunks: | |
| final_context += chunk.strip() + "//\n" | |
| seen_chunks.add(chunk) | |
| return final_context | |
| # Outer parallelism for multiple collections | |
| results = {} | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = {executor.submit(query_single_collection, collection_name): collection_name for collection_name in collection_names} | |
| for future in concurrent.futures.as_completed(futures): | |
| collection_name = futures[future] | |
| results[collection_name] = future.result() | |
| # Automatically save the results to a JSON file | |
| file_path = get_path('info', survey_id, 'retrieved_context.json') | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| json.dump(results, f, ensure_ascii=False, indent=4) | |
| print(f"Results saved to {file_path}") | |
| return results |