Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import os | |
| import typing as tp | |
| from abc import ABC, abstractmethod | |
| from collections import Counter | |
| from dataclasses import dataclass | |
| from multiprocessing import Pool | |
| import torch | |
| from fairseq.data import Dictionary, indexed_dataset | |
| from fairseq.file_chunker_utils import Chunker, find_offsets | |
| from fairseq.file_io import PathManager | |
| from fairseq.tokenizer import tokenize_line | |
| logger = logging.getLogger("binarizer") | |
| class BinarizeSummary: | |
| """ | |
| Keep track of what's going on in the binarizer | |
| """ | |
| num_seq: int = 0 | |
| replaced: tp.Optional[Counter] = None | |
| num_tok: int = 0 | |
| def num_replaced(self) -> int: | |
| if self.replaced is None: | |
| return 0 | |
| return sum(self.replaced.values()) | |
| def replaced_percent(self) -> float: | |
| return 100 * self.num_replaced / self.num_tok | |
| def __str__(self) -> str: | |
| base = f"{self.num_seq} sents, {self.num_tok} tokens" | |
| if self.replaced is None: | |
| return base | |
| return f"{base}, {self.replaced_percent:.3}% replaced" | |
| def merge(self, other: "BinarizeSummary"): | |
| replaced = None | |
| if self.replaced is not None: | |
| replaced = self.replaced | |
| if other.replaced is not None: | |
| if replaced is None: | |
| replaced = other.replaced | |
| else: | |
| replaced += other.replaced | |
| self.replaced = replaced | |
| self.num_seq += other.num_seq | |
| self.num_tok += other.num_tok | |
| class Binarizer(ABC): | |
| """ | |
| a binarizer describes how to take a string and build a tensor out of it | |
| """ | |
| def binarize_line( | |
| self, | |
| line: str, | |
| summary: BinarizeSummary, | |
| ) -> torch.IntTensor: | |
| ... | |
| def _worker_prefix(output_prefix: str, worker_id: int): | |
| return f"{output_prefix}.pt{worker_id}" | |
| class FileBinarizer: | |
| """ | |
| An file binarizer can take a file, tokenize it, and binarize each line to a tensor | |
| """ | |
| def multiprocess_dataset( | |
| cls, | |
| input_file: str, | |
| dataset_impl: str, | |
| binarizer: Binarizer, | |
| output_prefix: str, | |
| vocab_size=None, | |
| num_workers=1, | |
| ) -> BinarizeSummary: | |
| final_summary = BinarizeSummary() | |
| offsets = find_offsets(input_file, num_workers) | |
| # find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs: | |
| # [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info | |
| # we zip the list with itself shifted by one to get all the pairs. | |
| (first_chunk, *more_chunks) = zip(offsets, offsets[1:]) | |
| pool = None | |
| if num_workers > 1: | |
| pool = Pool(processes=num_workers - 1) | |
| worker_results = [ | |
| pool.apply_async( | |
| cls._binarize_chunk_and_finalize, | |
| args=( | |
| binarizer, | |
| input_file, | |
| start_offset, | |
| end_offset, | |
| _worker_prefix( | |
| output_prefix, | |
| worker_id, | |
| ), | |
| dataset_impl, | |
| ), | |
| kwds={ | |
| "vocab_size": vocab_size, | |
| } | |
| if vocab_size is not None | |
| else {}, | |
| ) | |
| for worker_id, (start_offset, end_offset) in enumerate( | |
| more_chunks, start=1 | |
| ) | |
| ] | |
| pool.close() | |
| pool.join() | |
| for r in worker_results: | |
| summ = r.get() | |
| final_summary.merge(summ) | |
| # do not close the bin file as we need to merge the worker results in | |
| final_ds, summ = cls._binarize_file_chunk( | |
| binarizer, | |
| input_file, | |
| offset_start=first_chunk[0], | |
| offset_end=first_chunk[1], | |
| output_prefix=output_prefix, | |
| dataset_impl=dataset_impl, | |
| vocab_size=vocab_size if vocab_size is not None else None, | |
| ) | |
| final_summary.merge(summ) | |
| if num_workers > 1: | |
| for worker_id in range(1, num_workers): | |
| # merge the worker outputs | |
| worker_output_prefix = _worker_prefix( | |
| output_prefix, | |
| worker_id, | |
| ) | |
| final_ds.merge_file_(worker_output_prefix) | |
| try: | |
| os.remove(indexed_dataset.data_file_path(worker_output_prefix)) | |
| os.remove(indexed_dataset.index_file_path(worker_output_prefix)) | |
| except Exception as e: | |
| logger.error( | |
| f"couldn't remove {worker_output_prefix}.*", exc_info=e | |
| ) | |
| # now we can close the file | |
| idx_file = indexed_dataset.index_file_path(output_prefix) | |
| final_ds.finalize(idx_file) | |
| return final_summary | |
| def _binarize_file_chunk( | |
| binarizer: Binarizer, | |
| filename: str, | |
| offset_start: int, | |
| offset_end: int, | |
| output_prefix: str, | |
| dataset_impl: str, | |
| vocab_size=None, | |
| ) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary) | |
| """ | |
| creates a dataset builder and append binarized items to it. This function does not | |
| finalize the builder, this is useful if you want to do other things with your bin file | |
| like appending/merging other files | |
| """ | |
| bin_file = indexed_dataset.data_file_path(output_prefix) | |
| ds = indexed_dataset.make_builder( | |
| bin_file, | |
| impl=dataset_impl, | |
| vocab_size=vocab_size, | |
| ) | |
| summary = BinarizeSummary() | |
| with Chunker( | |
| PathManager.get_local_path(filename), offset_start, offset_end | |
| ) as line_iterator: | |
| for line in line_iterator: | |
| ds.add_item(binarizer.binarize_line(line, summary)) | |
| return ds, summary | |
| def _binarize_chunk_and_finalize( | |
| cls, | |
| binarizer: Binarizer, | |
| filename: str, | |
| offset_start: int, | |
| offset_end: int, | |
| output_prefix: str, | |
| dataset_impl: str, | |
| vocab_size=None, | |
| ): | |
| """ | |
| same as above, but also finalizes the builder | |
| """ | |
| ds, summ = cls._binarize_file_chunk( | |
| binarizer, | |
| filename, | |
| offset_start, | |
| offset_end, | |
| output_prefix, | |
| dataset_impl, | |
| vocab_size=vocab_size, | |
| ) | |
| idx_file = indexed_dataset.index_file_path(output_prefix) | |
| ds.finalize(idx_file) | |
| return summ | |
| class VocabularyDatasetBinarizer(Binarizer): | |
| """ | |
| Takes a Dictionary/Vocabulary, assign ids to each | |
| token using the dictionary encode_line function. | |
| """ | |
| def __init__( | |
| self, | |
| dict: Dictionary, | |
| tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line, | |
| append_eos: bool = True, | |
| reverse_order: bool = False, | |
| already_numberized: bool = False, | |
| ) -> None: | |
| self.dict = dict | |
| self.tokenize = tokenize | |
| self.append_eos = append_eos | |
| self.reverse_order = reverse_order | |
| self.already_numberized = already_numberized | |
| super().__init__() | |
| def binarize_line( | |
| self, | |
| line: str, | |
| summary: BinarizeSummary, | |
| ): | |
| if summary.replaced is None: | |
| summary.replaced = Counter() | |
| def replaced_consumer(word, idx): | |
| if idx == self.dict.unk_index and word != self.dict.unk_word: | |
| summary.replaced.update([word]) | |
| if self.already_numberized: | |
| id_strings = line.strip().split() | |
| id_list = [int(id_string) for id_string in id_strings] | |
| if self.reverse_order: | |
| id_list.reverse() | |
| if self.append_eos: | |
| id_list.append(self.dict.eos()) | |
| ids = torch.IntTensor(id_list) | |
| else: | |
| ids = self.dict.encode_line( | |
| line=line, | |
| line_tokenizer=self.tokenize, | |
| add_if_not_exist=False, | |
| consumer=replaced_consumer, | |
| append_eos=self.append_eos, | |
| reverse_order=self.reverse_order, | |
| ) | |
| summary.num_seq += 1 | |
| summary.num_tok += len(ids) | |
| return ids | |
| class AlignmentDatasetBinarizer(Binarizer): | |
| """ | |
| binarize by parsing a set of alignments and packing | |
| them in a tensor (see utils.parse_alignment) | |
| """ | |
| def __init__( | |
| self, | |
| alignment_parser: tp.Callable[[str], torch.IntTensor], | |
| ) -> None: | |
| super().__init__() | |
| self.alignment_parser = alignment_parser | |
| def binarize_line( | |
| self, | |
| line: str, | |
| summary: BinarizeSummary, | |
| ): | |
| ids = self.alignment_parser(line) | |
| summary.num_seq += 1 | |
| summary.num_tok += len(ids) | |
| return ids | |
| class LegacyBinarizer: | |
| def binarize( | |
| cls, | |
| filename: str, | |
| dico: Dictionary, | |
| consumer: tp.Callable[[torch.IntTensor], None], | |
| tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line, | |
| append_eos: bool = True, | |
| reverse_order: bool = False, | |
| offset: int = 0, | |
| end: int = -1, | |
| already_numberized: bool = False, | |
| ) -> tp.Dict[str, int]: | |
| binarizer = VocabularyDatasetBinarizer( | |
| dict=dico, | |
| tokenize=tokenize, | |
| append_eos=append_eos, | |
| reverse_order=reverse_order, | |
| already_numberized=already_numberized, | |
| ) | |
| return cls._consume_file( | |
| filename, | |
| binarizer, | |
| consumer, | |
| offset_start=offset, | |
| offset_end=end, | |
| ) | |
| def binarize_alignments( | |
| cls, | |
| filename: str, | |
| alignment_parser: tp.Callable[[str], torch.IntTensor], | |
| consumer: tp.Callable[[torch.IntTensor], None], | |
| offset: int = 0, | |
| end: int = -1, | |
| ) -> tp.Dict[str, int]: | |
| binarizer = AlignmentDatasetBinarizer(alignment_parser) | |
| return cls._consume_file( | |
| filename, | |
| binarizer, | |
| consumer, | |
| offset_start=offset, | |
| offset_end=end, | |
| ) | |
| def _consume_file( | |
| filename: str, | |
| binarizer: Binarizer, | |
| consumer: tp.Callable[[torch.IntTensor], None], | |
| offset_start: int, | |
| offset_end: int, | |
| ) -> tp.Dict[str, int]: | |
| summary = BinarizeSummary() | |
| with Chunker( | |
| PathManager.get_local_path(filename), offset_start, offset_end | |
| ) as line_iterator: | |
| for line in line_iterator: | |
| consumer(binarizer.binarize_line(line, summary)) | |
| return { | |
| "nseq": summary.num_seq, | |
| "nunk": summary.num_replaced, | |
| "ntok": summary.num_tok, | |
| "replaced": summary.replaced, | |
| } | |