from typing import List import torch from datasets import IterableDataset from .prompt_tokenizers import PromptTokenizingStrategy # We want this to be a wrapper for an existing dataset that we have loaded # lets use the concept of middlewares to wrap each dataset, for example # ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])) # let's check to ensure we don't truncate an item in the middle, we'll use # the collators later on to pad the datasets class TokenizedPromptDataset(IterableDataset): def __init__( self, prompt_tokenizer: PromptTokenizingStrategy, dataset: IterableDataset, ): self.prompt_tokenizer = prompt_tokenizer self.dataset = dataset def __iter__(self): iterator = iter(self.dataset) yield self.prompt_tokenizer.tokenize_prompt(next(iterator)) class ConstantLengthDataset(IterableDataset): """ Iterable dataset that returns constant length chunks of tokens from stream of text files. Args: tokenizer (Tokenizer): The processor used for proccessing the data. dataset (dataset.Dataset): Dataset with text files. infinite (bool): If True the iterator is reset after dataset reaches end else stops. seq_length (int): Length of token sequences to return. chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. """ def __init__( self, tokenizer, datasets, infinite=False, seq_length=2048, num_of_sequences=1024, chars_per_token=3.6, ): self.tokenizer = tokenizer self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id self.datasets: List[IterableDataset] = datasets self.seq_length = seq_length self.infinite = infinite self.current_size = 0 self.max_buffer_size = seq_length * chars_per_token * num_of_sequences def __iter__(self): iterator = iter(self.datasets) more_examples = True while more_examples: buffer, buffer_len = [], 0 while True: if buffer_len >= self.max_buffer_size: break try: buffer.append(next(iterator)) buffer_len += len(buffer[-1]) except StopIteration: if self.infinite: iterator = iter(self.datasets) else: more_examples = False break tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] all_token_ids = [] for tokenized_input in tokenized_inputs: all_token_ids.extend(tokenized_input + [self.concat_token_id]) for i in range(0, len(all_token_ids), self.seq_length): input_ids = all_token_ids[i : i + self.seq_length] if len(input_ids) == self.seq_length: self.current_size += 1 yield { "input_ids": torch.LongTensor(input_ids), "labels": torch.LongTensor(input_ids), "attention_masks": torch.LongTensor(input_ids), }