Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import os | |
| from collections import OrderedDict | |
| from fairseq import utils | |
| from fairseq.data import ( | |
| BacktranslationDataset, | |
| IndexedCachedDataset, | |
| IndexedDataset, | |
| IndexedRawTextDataset, | |
| LanguagePairDataset, | |
| NoisingDataset, | |
| RoundRobinZipDatasets, | |
| data_utils, | |
| indexed_dataset, | |
| ) | |
| from fairseq.models import FairseqMultiModel | |
| from fairseq.sequence_generator import SequenceGenerator | |
| from . import register_task | |
| from .multilingual_translation import MultilingualTranslationTask | |
| logger = logging.getLogger(__name__) | |
| def _get_bt_dataset_key(lang_pair): | |
| return "bt:" + lang_pair | |
| def _get_denoising_dataset_key(lang_pair): | |
| return "denoising:" + lang_pair | |
| # ported from UnsupervisedMT | |
| def parse_lambda_config(x): | |
| """ | |
| Parse the configuration of lambda coefficient (for scheduling). | |
| x = "3" # lambda will be a constant equal to x | |
| x = "0:1,1000:0" # lambda will start from 1 and linearly decrease | |
| # to 0 during the first 1000 iterations | |
| x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 | |
| # iterations, then will linearly increase to 1 until iteration 2000 | |
| """ | |
| split = x.split(",") | |
| if len(split) == 1: | |
| return float(x), None | |
| else: | |
| split = [s.split(os.pathsep) for s in split] | |
| assert all(len(s) == 2 for s in split) | |
| assert all(k.isdigit() for k, _ in split) | |
| assert all( | |
| int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1) | |
| ) | |
| return float(split[0][1]), [(int(k), float(v)) for k, v in split] | |
| class SemisupervisedTranslationTask(MultilingualTranslationTask): | |
| """A task for training multiple translation models simultaneously. | |
| We iterate round-robin over batches from multiple language pairs, ordered | |
| according to the `--lang-pairs` argument. | |
| The training loop is roughly: | |
| for i in range(len(epoch)): | |
| for lang_pair in args.lang_pairs: | |
| batch = next_batch_for_lang_pair(lang_pair) | |
| loss = criterion(model_for_lang_pair(lang_pair), batch) | |
| loss.backward() | |
| optimizer.step() | |
| In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset | |
| (e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that | |
| implements the `FairseqMultiModel` interface. | |
| During inference it is required to specify a single `--source-lang` and | |
| `--target-lang`, instead of `--lang-pairs`. | |
| """ | |
| def add_args(parser): | |
| """Add task-specific arguments to the parser.""" | |
| # fmt: off | |
| MultilingualTranslationTask.add_args(parser) | |
| parser.add_argument('--lambda-parallel-config', default="1.0", type=str, metavar='CONFIG', | |
| help='cross-entropy reconstruction coefficient (parallel data). ' | |
| 'use fixed weight during training if set to floating point number. ' | |
| 'use piecewise linear function over number of updates to schedule the ' | |
| 'weight with the format: w0:step0,w1:step1,...') | |
| parser.add_argument('--lambda-denoising-config', default="0.0", type=str, metavar='CONFIG', | |
| help='Cross-entropy reconstruction coefficient (denoising autoencoding)' | |
| 'use fixed weight during training if set to floating point number. ' | |
| 'use piecewise linear function over number of updates to schedule the ' | |
| 'weight with the format: w0:step0,w1:step1,...') | |
| parser.add_argument('--lambda-otf-bt-config', default="0.0", type=str, metavar='CONFIG', | |
| help='cross-entropy reconstruction coefficient (on-the-fly back-translation parallel data)' | |
| 'use fixed weight during training if set to floating point number. ' | |
| 'use piecewise linear function over number of updates to schedule the ' | |
| 'weight with the format: w0:step0,w1:step1,...') | |
| parser.add_argument('--bt-max-len-a', default=1.1, type=float, metavar='N', | |
| help='generate back-translated sequences of maximum length ax + b, where x is the ' | |
| 'source length') | |
| parser.add_argument('--bt-max-len-b', default=10.0, type=float, metavar='N', | |
| help='generate back-translated sequences of maximum length ax + b, where x is the ' | |
| 'source length') | |
| parser.add_argument('--bt-beam-size', default=1, type=int, metavar='N', | |
| help='beam size used in beam search of online back-translation') | |
| parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N', | |
| help='maximum word shuffle distance for denoising autoencoding data generation') | |
| parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N', | |
| help='word dropout probability for denoising autoencoding data generation') | |
| parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N', | |
| help='word blanking probability for denoising autoencoding data generation') | |
| # fmt: on | |
| def __init__(self, args, dicts, training): | |
| super().__init__(args, dicts, training) | |
| self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config( | |
| args.lambda_parallel_config | |
| ) | |
| self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config( | |
| args.lambda_otf_bt_config | |
| ) | |
| self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config( | |
| args.lambda_denoising_config | |
| ) | |
| if self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None: | |
| denoising_lang_pairs = [ | |
| "%s-%s" % (tgt, tgt) | |
| for tgt in {lang_pair.split("-")[1] for lang_pair in args.lang_pairs} | |
| ] | |
| self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs | |
| self.backtranslate_datasets = {} | |
| self.backtranslators = {} | |
| def setup_task(cls, args, **kwargs): | |
| dicts, training = MultilingualTranslationTask.prepare(args, **kwargs) | |
| return cls(args, dicts, training) | |
| def load_dataset(self, split, epoch=1, **kwargs): | |
| """Load a dataset split.""" | |
| paths = utils.split_paths(self.args.data) | |
| assert len(paths) > 0 | |
| data_path = paths[(epoch - 1) % len(paths)] | |
| def split_exists(split, src, tgt, lang): | |
| if src is not None: | |
| filename = os.path.join( | |
| data_path, "{}.{}-{}.{}".format(split, src, tgt, lang) | |
| ) | |
| else: | |
| filename = os.path.join( | |
| data_path, "{}.{}-None.{}".format(split, src, tgt) | |
| ) | |
| return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) | |
| def load_indexed_dataset(path, dictionary): | |
| return data_utils.load_indexed_dataset( | |
| path, dictionary, self.args.dataset_impl | |
| ) | |
| # load parallel datasets | |
| src_datasets, tgt_datasets = {}, {} | |
| if ( | |
| self.lambda_parallel > 0.0 | |
| or self.lambda_parallel_steps is not None | |
| or not split.startswith("train") | |
| ): | |
| for lang_pair in self.lang_pairs: | |
| src, tgt = lang_pair.split("-") | |
| if split_exists(split, src, tgt, src): | |
| prefix = os.path.join( | |
| data_path, "{}.{}-{}.".format(split, src, tgt) | |
| ) | |
| elif split_exists(split, tgt, src, src): | |
| prefix = os.path.join( | |
| data_path, "{}.{}-{}.".format(split, tgt, src) | |
| ) | |
| else: | |
| continue | |
| src_datasets[lang_pair] = load_indexed_dataset( | |
| prefix + src, self.dicts[src] | |
| ) | |
| tgt_datasets[lang_pair] = load_indexed_dataset( | |
| prefix + tgt, self.dicts[tgt] | |
| ) | |
| logger.info( | |
| "parallel-{} {} {} examples".format( | |
| data_path, split, len(src_datasets[lang_pair]) | |
| ) | |
| ) | |
| if len(src_datasets) == 0: | |
| raise FileNotFoundError( | |
| "Dataset not found: {} ({})".format(split, data_path) | |
| ) | |
| # back translation datasets | |
| backtranslate_datasets = {} | |
| if ( | |
| self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None | |
| ) and split.startswith("train"): | |
| for lang_pair in self.lang_pairs: | |
| src, tgt = lang_pair.split("-") | |
| if not split_exists(split, tgt, None, tgt): | |
| raise FileNotFoundError( | |
| "Dataset not found: backtranslation {} ({})".format( | |
| split, data_path | |
| ) | |
| ) | |
| filename = os.path.join( | |
| data_path, "{}.{}-None.{}".format(split, tgt, tgt) | |
| ) | |
| dataset = load_indexed_dataset(filename, self.dicts[tgt]) | |
| lang_pair_dataset_tgt = LanguagePairDataset( | |
| dataset, | |
| dataset.sizes, | |
| self.dicts[tgt], | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| ) | |
| lang_pair_dataset = LanguagePairDataset( | |
| dataset, | |
| dataset.sizes, | |
| src_dict=self.dicts[src], | |
| tgt=dataset, | |
| tgt_sizes=dataset.sizes, | |
| tgt_dict=self.dicts[tgt], | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| ) | |
| backtranslate_datasets[lang_pair] = BacktranslationDataset( | |
| tgt_dataset=self.alter_dataset_langtok( | |
| lang_pair_dataset_tgt, | |
| src_eos=self.dicts[tgt].eos(), | |
| src_lang=tgt, | |
| tgt_lang=src, | |
| ), | |
| backtranslation_fn=self.backtranslators[lang_pair], | |
| src_dict=self.dicts[src], | |
| tgt_dict=self.dicts[tgt], | |
| output_collater=self.alter_dataset_langtok( | |
| lang_pair_dataset=lang_pair_dataset, | |
| src_eos=self.dicts[src].eos(), | |
| src_lang=src, | |
| tgt_eos=self.dicts[tgt].eos(), | |
| tgt_lang=tgt, | |
| ).collater, | |
| ) | |
| logger.info( | |
| "backtranslate-{}: {} {} {} examples".format( | |
| tgt, | |
| data_path, | |
| split, | |
| len(backtranslate_datasets[lang_pair]), | |
| ) | |
| ) | |
| self.backtranslate_datasets[lang_pair] = backtranslate_datasets[ | |
| lang_pair | |
| ] | |
| # denoising autoencoder | |
| noising_datasets = {} | |
| if ( | |
| self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None | |
| ) and split.startswith("train"): | |
| for lang_pair in self.lang_pairs: | |
| _, tgt = lang_pair.split("-") | |
| if not split_exists(split, tgt, None, tgt): | |
| continue | |
| filename = os.path.join( | |
| data_path, "{}.{}-None.{}".format(split, tgt, tgt) | |
| ) | |
| tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt]) | |
| tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt]) | |
| noising_dataset = NoisingDataset( | |
| tgt_dataset1, | |
| self.dicts[tgt], | |
| seed=1, | |
| max_word_shuffle_distance=self.args.max_word_shuffle_distance, | |
| word_dropout_prob=self.args.word_dropout_prob, | |
| word_blanking_prob=self.args.word_blanking_prob, | |
| ) | |
| noising_datasets[lang_pair] = self.alter_dataset_langtok( | |
| LanguagePairDataset( | |
| noising_dataset, | |
| tgt_dataset1.sizes, | |
| self.dicts[tgt], | |
| tgt_dataset2, | |
| tgt_dataset2.sizes, | |
| self.dicts[tgt], | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| ), | |
| src_eos=self.dicts[tgt].eos(), | |
| src_lang=tgt, | |
| tgt_eos=self.dicts[tgt].eos(), | |
| tgt_lang=tgt, | |
| ) | |
| logger.info( | |
| "denoising-{}: {} {} {} examples".format( | |
| tgt, | |
| data_path, | |
| split, | |
| len(noising_datasets[lang_pair]), | |
| ) | |
| ) | |
| def language_pair_dataset(lang_pair): | |
| src, tgt = lang_pair.split("-") | |
| src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] | |
| return self.alter_dataset_langtok( | |
| LanguagePairDataset( | |
| src_dataset, | |
| src_dataset.sizes, | |
| self.dicts[src], | |
| tgt_dataset, | |
| tgt_dataset.sizes, | |
| self.dicts[tgt], | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| ), | |
| self.dicts[src].eos(), | |
| src, | |
| self.dicts[tgt].eos(), | |
| tgt, | |
| ) | |
| self.datasets[split] = RoundRobinZipDatasets( | |
| OrderedDict( | |
| [ | |
| (lang_pair, language_pair_dataset(lang_pair)) | |
| for lang_pair in src_datasets.keys() | |
| ] | |
| + [ | |
| (_get_bt_dataset_key(lang_pair), dataset) | |
| for lang_pair, dataset in backtranslate_datasets.items() | |
| ] | |
| + [ | |
| (_get_denoising_dataset_key(lang_pair), dataset) | |
| for lang_pair, dataset in noising_datasets.items() | |
| ] | |
| ), | |
| eval_key=None | |
| if self.training | |
| else "%s-%s" % (self.args.source_lang, self.args.target_lang), | |
| ) | |
| def build_model(self, args, from_checkpoint=False): | |
| from fairseq import models | |
| model = models.build_model(args, self, from_checkpoint) | |
| if not isinstance(model, FairseqMultiModel): | |
| raise ValueError( | |
| "SemisupervisedTranslationTask requires a FairseqMultiModel architecture" | |
| ) | |
| # create SequenceGenerator for each model that has backtranslation dependency on it | |
| self.sequence_generators = {} | |
| if ( | |
| self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None | |
| ) and self.training: | |
| for lang_pair in self.lang_pairs: | |
| src, tgt = lang_pair.split("-") | |
| key = "{}-{}".format(tgt, src) | |
| self.sequence_generators[key] = SequenceGenerator( | |
| [model.models[key]], | |
| tgt_dict=self.dicts[src], | |
| beam_size=args.bt_beam_size, | |
| max_len_a=args.bt_max_len_a, | |
| max_len_b=args.bt_max_len_b, | |
| ) | |
| decoder_lang_tok_idx = self.get_decoder_langtok(src) | |
| def backtranslate_fn( | |
| sample, | |
| model=model.models[key], | |
| bos_token=decoder_lang_tok_idx, | |
| sequence_generator=self.sequence_generators[key], | |
| ): | |
| return sequence_generator.generate( | |
| [model], | |
| sample, | |
| bos_token=bos_token, | |
| ) | |
| self.backtranslators[lang_pair] = backtranslate_fn | |
| return model | |
| def train_step( | |
| self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
| ): | |
| model.train() | |
| if update_num > 0: | |
| self.update_step(update_num) | |
| agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, {} | |
| def forward_backward(model, samples, logging_output_key, weight): | |
| nonlocal agg_loss, agg_sample_size, agg_logging_output | |
| if samples is None or len(samples) == 0: | |
| return | |
| loss, sample_size, logging_output = criterion(model, samples) | |
| if ignore_grad: | |
| loss *= 0 | |
| else: | |
| loss *= weight | |
| optimizer.backward(loss) | |
| agg_loss += loss.detach().item() | |
| # TODO make summing of the sample sizes configurable | |
| agg_sample_size += sample_size | |
| for k in logging_output: | |
| agg_logging_output[k] += logging_output[k] | |
| agg_logging_output[logging_output_key] += logging_output[k] | |
| if self.lambda_parallel > 0.0: | |
| for lang_pair in self.lang_pairs: | |
| forward_backward( | |
| model.models[lang_pair], | |
| sample[lang_pair], | |
| lang_pair, | |
| self.lambda_parallel, | |
| ) | |
| if self.lambda_otf_bt > 0.0: | |
| for lang_pair in self.lang_pairs: | |
| sample_key = _get_bt_dataset_key(lang_pair) | |
| forward_backward( | |
| model.models[lang_pair], | |
| sample[sample_key], | |
| sample_key, | |
| self.lambda_otf_bt, | |
| ) | |
| if self.lambda_denoising > 0.0: | |
| for lang_pair in self.lang_pairs: | |
| _, tgt = lang_pair.split("-") | |
| sample_key = _get_denoising_dataset_key(lang_pair) | |
| forward_backward( | |
| model.models["{0}-{0}".format(tgt)], | |
| sample[sample_key], | |
| sample_key, | |
| self.lambda_denoising, | |
| ) | |
| return agg_loss, agg_sample_size, agg_logging_output | |
| def update_step(self, num_updates): | |
| def lambda_step_func(config, n_iter): | |
| """ | |
| Update a lambda value according to its schedule configuration. | |
| """ | |
| ranges = [ | |
| i | |
| for i in range(len(config) - 1) | |
| if config[i][0] <= n_iter < config[i + 1][0] | |
| ] | |
| if len(ranges) == 0: | |
| assert n_iter >= config[-1][0] | |
| return config[-1][1] | |
| assert len(ranges) == 1 | |
| i = ranges[0] | |
| x_a, y_a = config[i] | |
| x_b, y_b = config[i + 1] | |
| return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a) | |
| if self.lambda_parallel_steps is not None: | |
| self.lambda_parallel = lambda_step_func( | |
| self.lambda_parallel_steps, num_updates | |
| ) | |
| if self.lambda_denoising_steps is not None: | |
| self.lambda_denoising = lambda_step_func( | |
| self.lambda_denoising_steps, num_updates | |
| ) | |
| if self.lambda_otf_bt_steps is not None: | |
| self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates) | |