Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import itertools | |
| import logging | |
| import os | |
| from collections import OrderedDict | |
| import numpy as np | |
| from fairseq import tokenizer, utils | |
| from fairseq.data import ConcatDataset, Dictionary, TokenBlockDataset, data_utils | |
| from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset | |
| from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary | |
| from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset | |
| from fairseq.tasks import LegacyFairseqTask, register_task | |
| logger = logging.getLogger(__name__) | |
| class CrossLingualLMTask(LegacyFairseqTask): | |
| """ | |
| Task for training cross-lingual language models. | |
| For more details look at: https://arxiv.org/pdf/1901.07291.pdf | |
| Args: | |
| dictionary (Dictionary): the dictionary for the input of the task | |
| """ | |
| def add_args(parser): | |
| """Add task-specific arguments to the parser.""" | |
| parser.add_argument( | |
| "data", | |
| help="colon separated path to data directories list, \ | |
| will be iterated upon during epochs in round-robin manner", | |
| ) | |
| parser.add_argument( | |
| "--tokens-per-sample", | |
| default=512, | |
| type=int, | |
| help="max number of total tokens over all segments" " per sample", | |
| ) | |
| parser.add_argument( | |
| "--monolingual-langs", | |
| default="en", | |
| type=str, | |
| help="comma separated list of languages for which we" | |
| " want to train XLM on", | |
| ) | |
| parser.add_argument( | |
| "--shuffle", | |
| action="store_true", | |
| help="shuffle each monolingual dataset while" " training", | |
| ) | |
| def __init__(self, args, dictionary): | |
| super().__init__(args) | |
| self.dictionary = dictionary | |
| self.seed = args.seed | |
| self.distributed_world_size = args.distributed_world_size | |
| self.langs2id = self._lang_to_id(args.monolingual_langs) | |
| def _lang_to_id(self, languages: str): | |
| """ | |
| Build a map from languages to ids. These ids are used as segment labels | |
| for cross-lingual LM training. | |
| """ | |
| lang2id = {} | |
| langs = [l.strip() for l in languages.split(",")] | |
| for id, lang in enumerate(langs): | |
| lang2id[lang] = id | |
| return lang2id | |
| def load_dictionary(cls, filename): | |
| return MaskedLMDictionary.load(filename) | |
| def build_dictionary( | |
| cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 | |
| ): | |
| d = MaskedLMDictionary() | |
| for filename in filenames: | |
| Dictionary.add_file_to_dictionary( | |
| filename, d, tokenizer.tokenize_line, workers | |
| ) | |
| d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) | |
| return d | |
| def target_dictionary(self): | |
| return self.dictionary | |
| def setup_task(cls, args, **kwargs): | |
| """Setup the task.""" | |
| dictionary = MaskedLMDictionary.load(os.path.join(args.data, "dict.txt")) | |
| logger.info("dictionary: {} types".format(len(dictionary))) | |
| return cls(args, dictionary) | |
| def _load_single_lang_dataset(self, split, epoch): | |
| loaded_datasets = [] | |
| paths = utils.split_paths(self.args.data) | |
| assert len(paths) > 0 | |
| data_path = paths[(epoch - 1) % len(paths)] | |
| for k in itertools.count(): | |
| split_k = split + (str(k) if k > 0 else "") | |
| path = os.path.join(data_path, split_k) | |
| ds = data_utils.load_indexed_dataset( | |
| path, self.dictionary, self.args.dataset_impl | |
| ) | |
| if ds is None: | |
| if k > 0: | |
| break | |
| else: | |
| raise FileNotFoundError( | |
| "Dataset not found: {} ({})".format(split, data_path) | |
| ) | |
| # Since we append each block with the classification_token, | |
| # we need to effectively create blocks of length | |
| # tokens_per_sample-1 | |
| loaded_datasets.append( | |
| TokenBlockDataset( | |
| ds, | |
| ds.sizes, | |
| self.args.tokens_per_sample - 1, | |
| pad=self.dictionary.pad(), | |
| eos=self.dictionary.eos(), | |
| ) | |
| ) | |
| logger.info( | |
| "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1])) | |
| ) | |
| if len(loaded_datasets) == 1: | |
| dataset = loaded_datasets[0] | |
| sizes = dataset.sizes | |
| else: | |
| dataset = ConcatDataset(loaded_datasets) | |
| sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) | |
| return dataset, sizes | |
| def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
| """Load a given dataset split. | |
| Args: | |
| split (str): name of the split (e.g., train, valid, test) | |
| """ | |
| dataset_map = OrderedDict() | |
| for lang in self.langs2id.keys(): | |
| # Datasets are expected to be in "split.lang" format (Eg: train.en) | |
| language_split = "{}.{}".format(split, lang) | |
| block_dataset, sizes = self._load_single_lang_dataset( | |
| split=language_split, epoch=epoch | |
| ) | |
| dataset_map[lang] = MaskedLMDataset( | |
| dataset=block_dataset, | |
| sizes=sizes, | |
| vocab=self.dictionary, | |
| pad_idx=self.dictionary.pad(), | |
| mask_idx=self.dictionary.mask(), | |
| classif_token_idx=self.dictionary.eos(), | |
| sep_token_idx=self.dictionary.eos(), | |
| shuffle=getattr(self.args, "shuffle", False), | |
| has_pairs=False, | |
| segment_id=self.langs2id[lang], | |
| seed=self.seed, | |
| ) | |
| self.datasets[split] = MultiCorpusSampledDataset(dataset_map) | |
| logger.info( | |
| "{} {} {} examples".format( | |
| utils.split_paths(self.args.data)[epoch - 1], | |
| split, | |
| len(self.datasets[split]), | |
| ) | |
| ) | |