Spaces:
Running
Running
| import asyncio | |
| from typing import Callable, Optional, Union | |
| import huggingface_hub | |
| import semchunk | |
| import tiktoken | |
| import tokenizers | |
| from datasets import Dataset, concatenate_datasets, load_dataset | |
| from rich.progress import track | |
| from transformers import PreTrainedTokenizer | |
| TOKENIZER_OR_TOKEN_COUNTER = Union[ | |
| str, | |
| tiktoken.Encoding, | |
| PreTrainedTokenizer, | |
| tokenizers.Tokenizer, | |
| Callable[[str], int], | |
| ] | |
| class SemanticChunker: | |
| """ | |
| SemanticChunker is a class that chunks documents into smaller segments and | |
| publishes them as datasets. | |
| This class uses the `semchunk` library to break down large documents into | |
| smaller, manageable chunks based on a specified tokenizer or token counter. | |
| This is particularly useful for processing large text datasets where | |
| smaller segments are needed for analysis or other operations. | |
| !!! example "Example Usage" | |
| ```python | |
| from medrag_multi_modal.semantic_chunking import SemanticChunker | |
| chunker = SemanticChunker(chunk_size=256) | |
| chunker.chunk( | |
| document_dataset="geekyrakshit/grays-anatomy-test", | |
| chunk_dataset_repo_id="geekyrakshit/grays-anatomy-chunks-test", | |
| ) | |
| ``` | |
| Args: | |
| tokenizer_or_token_counter (TOKENIZER_OR_TOKEN_COUNTER): The tokenizer or | |
| token counter to be used for chunking. | |
| chunk_size (Optional[int]): The size of each chunk. If not specified, the | |
| default chunk size from `semchunk` will be used. | |
| max_token_chars (Optional[int]): The maximum number of characters per token. | |
| If not specified, the default value from `semchunk` will be used. | |
| memoize (bool): Whether to memoize the chunking process for efficiency. | |
| Default is True. | |
| """ | |
| def __init__( | |
| self, | |
| tokenizer_or_token_counter: TOKENIZER_OR_TOKEN_COUNTER = "o200k_base", | |
| chunk_size: Optional[int] = None, | |
| max_token_chars: Optional[int] = None, | |
| memoize: bool = True, | |
| ) -> None: | |
| self.chunker = semchunk.chunkerify( | |
| tokenizer_or_token_counter, | |
| chunk_size=chunk_size, | |
| max_token_chars=max_token_chars, | |
| memoize=memoize, | |
| ) | |
| def chunk( | |
| self, | |
| document_dataset: Union[Dataset, str], | |
| chunk_dataset_repo_id: Optional[str] = None, | |
| overwrite_dataset: bool = False, | |
| ) -> Dataset: | |
| """ | |
| Chunks a document dataset into smaller segments and publishes them as a new dataset. | |
| This function takes a document dataset, either as a HuggingFace Dataset object or a string | |
| representing the dataset repository ID, and chunks the documents into smaller segments using | |
| the specified chunker. The resulting chunks are then optionally published to a HuggingFace | |
| dataset repository. | |
| Args: | |
| document_dataset (Union[Dataset, str]): The document dataset to be chunked. It can be either | |
| a HuggingFace Dataset object or a string representing the dataset repository ID. | |
| chunk_dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish | |
| the chunks to, if provided. Defaults to None. | |
| overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False. | |
| Returns: | |
| Dataset: A HuggingFace Dataset object containing the chunks. | |
| """ | |
| document_dataset = ( | |
| load_dataset(document_dataset, split="corpus") | |
| if isinstance(document_dataset, str) | |
| else document_dataset | |
| ).to_list() | |
| chunks = [] | |
| async def process_document(idx, document): | |
| document_chunks = self.chunker.chunk(str(document["text"])) | |
| for chunk in document_chunks: | |
| chunk_dict = {"document_idx": idx, "text": chunk} | |
| for key, value in document.items(): | |
| if key not in chunk_dict: | |
| chunk_dict[key] = value | |
| chunks.append(chunk_dict) | |
| async def process_all_documents(): | |
| tasks = [] | |
| for idx, document in track( | |
| enumerate(document_dataset), | |
| total=len(document_dataset), | |
| description="Chunking documents", | |
| ): | |
| tasks.append(process_document(idx, document)) | |
| await asyncio.gather(*tasks) | |
| asyncio.run(process_all_documents()) | |
| chunks.sort(key=lambda x: x["document_idx"]) | |
| dataset = Dataset.from_list(chunks) | |
| if chunk_dataset_repo_id: | |
| if huggingface_hub.repo_exists(chunk_dataset_repo_id, repo_type="dataset"): | |
| if not overwrite_dataset: | |
| dataset = concatenate_datasets( | |
| [ | |
| dataset, | |
| load_dataset(chunk_dataset_repo_id, split="chunks"), | |
| ] | |
| ) | |
| dataset.push_to_hub(repo_id=chunk_dataset_repo_id, split="chunks") | |
| return dataset | |