Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import contextlib | |
| import json | |
| import logging | |
| import math | |
| import os | |
| from argparse import Namespace | |
| from collections import OrderedDict, defaultdict | |
| from pathlib import Path | |
| from typing import Dict, Sequence, Tuple | |
| from argparse import ArgumentError | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import fairseq | |
| from fairseq import options, utils | |
| from fairseq.logging import metrics | |
| from fairseq.data import ( | |
| FairseqDataset, | |
| LanguagePairDataset, | |
| NoisingDataset, | |
| PrependTokenDataset, | |
| RoundRobinZipDatasets, | |
| TransformEosLangPairDataset, | |
| data_utils, | |
| encoders, | |
| ) | |
| from fairseq.sequence_generator import SequenceGenerator | |
| from fairseq.tasks import register_task | |
| from fairseq.tasks.translation import TranslationTask, load_langpair_dataset | |
| logger = logging.getLogger(__name__) | |
| class PiecewiseLinearFn: | |
| """Piecewise linear function. Can be configured with a string.""" | |
| def __init__(self, pieces: Sequence[Tuple[int, float]]): | |
| assert pieces == sorted( | |
| pieces | |
| ), f"PiecewiseLinearFn configuration should be sorted, received: {pieces}" | |
| self.pieces = pieces | |
| def __call__(self, x: int) -> float: | |
| for i, (x_a, y_a) in enumerate(self.pieces[:-1]): | |
| x_b, y_b = self.pieces[i + 1] | |
| if x_a <= x <= x_b: | |
| return y_a + (x - x_a) * (y_b - y_a) / (x_b - x_a) | |
| return self.pieces[-1][1] | |
| def from_string(configuration: str) -> "PiecewiseLinearFn": | |
| """ | |
| Parse the configuration of lambda coefficient (for scheduling). | |
| x = "3" # lambda will be a constant equal to x | |
| x = "0:1,1000:0" # lambda will start from 1 and linearly decrease | |
| # to 0 during the first 1000 iterations | |
| x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 | |
| # iterations, then will linearly increase to 1 until iteration 2000 | |
| """ | |
| if isinstance(configuration, float): | |
| return PiecewiseLinearFn([(0, configuration)]) | |
| try: | |
| parts = configuration.split(",") | |
| if len(parts) == 1: | |
| v = float(configuration) | |
| return PiecewiseLinearFn([(0, v)]) | |
| split = [s.split(":") for s in parts] | |
| pieces = [(int(t), float(v)) for t, v in split] | |
| return PiecewiseLinearFn(pieces) | |
| except Exception: | |
| raise ValueError( | |
| f"Invalid PiecewiseLinearFn configuration: {configuration!r}" | |
| ) | |
| def one() -> "PiecewiseLinearFn": | |
| return PiecewiseLinearFn([(0, 1.0)]) | |
| class OnlineBackTranslationTask(TranslationTask): | |
| def add_args(parser): | |
| """Add task-specific arguments to the parser.""" | |
| # fmt: off | |
| # Generic translation args | |
| parser.add_argument('data', help='colon separated path to data directories list, \ | |
| will be iterated upon during epochs in round-robin manner; \ | |
| however, valid and test data are always in the first directory to \ | |
| avoid the need for repeating them in all directories') | |
| parser.add_argument('--mono-langs', metavar='MONO_LANGS', | |
| help='monolingual languages for training') | |
| parser.add_argument('--valid-lang-pairs', default=None, metavar='VALID_LANG_PAIRS', | |
| help='language pairs for validation') | |
| parser.add_argument('--load-alignments', action='store_true', | |
| help='load the binarized alignments') | |
| parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL', | |
| help='pad the source on the left') | |
| parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', | |
| help='pad the target on the left') | |
| parser.add_argument('--upsample-primary', default=1, type=int, | |
| help='amount to upsample primary dataset') | |
| try: | |
| parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', | |
| help='max number of tokens in the source sequence') | |
| parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', | |
| help='max number of tokens in the target sequence') | |
| except ArgumentError: | |
| # this might have already been defined. Once we transition this to hydra it should be fine to add it here. | |
| pass | |
| parser.add_argument('--truncate-source', action='store_true', default=False, | |
| help='truncate source to max-source-positions') | |
| parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N', | |
| help='if >0, then bucket source and target lengths into N ' | |
| 'buckets and pad accordingly; this is useful on TPUs ' | |
| 'to minimize the number of compilations') | |
| # Denoising args | |
| parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N', | |
| help='maximum word shuffle distance for denoising autoencoding data generation') | |
| parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N', | |
| help='word dropout probability for denoising autoencoding data generation') | |
| parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N', | |
| help='word blanking probability for denoising autoencoding data generation') | |
| # Backtranslation args | |
| parser.add_argument('--lambda-bt', default="1.0", type=str, metavar='N', | |
| help='back-translation weight') | |
| parser.add_argument('--lambda-dae', default="1.0", type=str, metavar='N', | |
| help='denoising auto-encoder weight') | |
| # Evaluation args | |
| parser.add_argument('--generate-one-by-one', action='store_true', | |
| help='generate one sentence at a time for backtranslation') | |
| parser.add_argument('--eval-bleu', action='store_true', | |
| help='evaluation with BLEU scores') | |
| parser.add_argument('--eval-bleu-detok', type=str, default="space", | |
| help='detokenize before computing BLEU (e.g., "moses"); ' | |
| 'required if using --eval-bleu; use "space" to ' | |
| 'disable detokenization; see fairseq.data.encoders ' | |
| 'for other options') | |
| parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON', | |
| help='args for building the tokenizer, if needed') | |
| parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False, | |
| help='compute tokenized BLEU instead of sacrebleu') | |
| parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None, | |
| help='remove BPE before computing BLEU') | |
| parser.add_argument('--eval-bleu-args', type=str, metavar='JSON', | |
| help='generation args for BLUE scoring, ' | |
| 'e.g., \'{"beam": 4, "lenpen": 0.6}\'') | |
| parser.add_argument('--eval-bleu-print-samples', action='store_true', | |
| help='print sample generations during validation') | |
| # fmt: on | |
| def __init__(self, args, common_dict, mono_langs, valid_lang_pairs): | |
| super().__init__(args, common_dict, common_dict) | |
| self.common_dict = common_dict | |
| self.mono_langs = mono_langs | |
| self.valid_lang_pairs = valid_lang_pairs | |
| self.SHOW_SAMPLES_INTERVAL = 1000 | |
| # Start by showing samples | |
| self._show_samples_ctr = self.SHOW_SAMPLES_INTERVAL | |
| self.SHOW_SAMPLES_NUMBER = 5 | |
| self.lambda_bt = PiecewiseLinearFn.from_string(args.lambda_bt) | |
| self.lambda_dae = PiecewiseLinearFn.from_string(args.lambda_dae) | |
| self.args = args | |
| self.data = utils.split_paths(self.args.data) | |
| if len(self.data) == 1: | |
| shards = list(Path(self.data[0]).glob("shard*")) | |
| if len(shards) > 0: | |
| # keep this as strings, since it can also be a manifold path | |
| old_data = self.data | |
| self.data = [str(shard) for shard in shards] | |
| logging.warning(f"Expanded data directory {old_data} to {self.data}") | |
| def setup_task(cls, args, **kwargs): | |
| """Setup the task (e.g., load dictionaries). | |
| Args: | |
| args (argparse.Namespace): parsed command-line arguments | |
| """ | |
| args.left_pad_source = options.eval_bool(args.left_pad_source) | |
| args.left_pad_target = options.eval_bool(args.left_pad_target) | |
| paths = utils.split_paths(args.data) | |
| assert len(paths) > 0 | |
| assert args.mono_langs is not None | |
| mono_langs = args.mono_langs.split(",") | |
| valid_lang_pairs = args.valid_lang_pairs.split(",") | |
| # load dictionary | |
| dict_path = os.path.join(paths[0], "dict.txt") | |
| common_dict = cls.load_dictionary(dict_path) | |
| return cls(args, common_dict, mono_langs, valid_lang_pairs) | |
| def load_dataset(self, split, epoch=1, combine=False, **kwargs) -> FairseqDataset: | |
| """Load a given dataset split. | |
| Args: | |
| split (str): name of the split (e.g., train, valid, test) | |
| """ | |
| if split == "train": | |
| data_path = self.data[(epoch - 1) % len(self.data)] | |
| dataset = self.load_train_dataset(data_path) | |
| else: | |
| # valid/test should always be the same. | |
| dataset = self.load_translation_dataset(split, self.data[0]) | |
| self.datasets[split] = dataset | |
| return dataset | |
| def load_train_dataset(self, data_path: str) -> FairseqDataset: | |
| """The training dataset is made of backtranslation dataset and denoising dataset.""" | |
| data = [] | |
| for lang in self.mono_langs: | |
| train_path = os.path.join(data_path, lang, "train") | |
| # TODO: could we do the BT using denoise sample ? | |
| # this would half the data loading work | |
| data.append((f"{lang}-BT", self.load_bt_dataset(train_path, lang))) | |
| data.append( | |
| (f"{lang}-DENOISE", self.load_denoise_dataset(train_path, lang)) | |
| ) | |
| return RoundRobinZipDatasets(OrderedDict(data)) | |
| def _langpair_dataset( | |
| self, src: FairseqDataset, tgt: FairseqDataset | |
| ) -> LanguagePairDataset: | |
| return LanguagePairDataset( | |
| src, | |
| src.sizes, | |
| self.dictionary, | |
| tgt=tgt, | |
| tgt_sizes=tgt.sizes, | |
| tgt_dict=self.dictionary, | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| # TODO: should we shuffle ? we are already sorting batch by sizes so ? | |
| # shuffle=True, | |
| ) | |
| def _prepend_lang_bos_to_target( | |
| self, dataset: LanguagePairDataset, lang: str | |
| ) -> LanguagePairDataset: | |
| bos = _lang_token_index(self.dictionary, lang) | |
| return TransformEosLangPairDataset( | |
| dataset, | |
| src_eos=self.dictionary.eos(), | |
| new_src_eos=self.dictionary.eos(), | |
| tgt_bos=self.dictionary.eos(), | |
| new_tgt_bos=bos, | |
| ) | |
| def load_bt_dataset(self, data_path: str, lang: str) -> FairseqDataset: | |
| """The BT dataset is generated with (tgt, tgt) pairs. | |
| The actual translation to a (generated_src, tgt) pair | |
| is done on the fly during training. | |
| """ | |
| mono_dataset = data_utils.load_indexed_dataset( | |
| data_path, self.common_dict, self.args.dataset_impl | |
| ) | |
| assert mono_dataset is not None, f"No dataset found for {lang}" | |
| mono_dataset_src = PrependTokenDataset( | |
| mono_dataset, _lang_token_index(self.dictionary, lang) | |
| ) | |
| mono_dataset_bt = self._langpair_dataset(mono_dataset_src, mono_dataset) | |
| logger.info( | |
| f"mono_lang = {lang} " | |
| f"lang token index = {_lang_token_index(self.dictionary, lang)} " | |
| f"lang token = {_lang_token(lang)}" | |
| ) | |
| mono_dataset_bt = self._prepend_lang_bos_to_target(mono_dataset_bt, lang) | |
| return mono_dataset_bt | |
| def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset: | |
| """Classic denoising dataset""" | |
| dataset = data_utils.load_indexed_dataset( | |
| data_path, self.common_dict, self.args.dataset_impl | |
| ) | |
| noisy_dataset = NoisingDataset( | |
| dataset, | |
| self.dictionary, | |
| seed=1, | |
| max_word_shuffle_distance=self.args.max_word_shuffle_distance, | |
| word_dropout_prob=self.args.word_dropout_prob, | |
| word_blanking_prob=self.args.word_blanking_prob, | |
| ) | |
| noisy_dataset = PrependTokenDataset( | |
| noisy_dataset, _lang_token_index(self.dictionary, lang) | |
| ) | |
| clean_dataset = data_utils.load_indexed_dataset( | |
| data_path, self.common_dict, self.args.dataset_impl | |
| ) | |
| denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset) | |
| denoising_dataset = self._prepend_lang_bos_to_target(denoising_dataset, lang) | |
| return denoising_dataset | |
| def load_translation_dataset( | |
| self, split: str, data_path: str, combine: bool = False | |
| ): | |
| # only judging with one language pair for the moment, | |
| # since ConcatDataset doesn't work as expected | |
| assert len(self.valid_lang_pairs) == 1, "For now..." | |
| valid_lang_pair = self.valid_lang_pairs[0] | |
| src, tgt = valid_lang_pair.split("-") | |
| # use the same function than TranslationTask | |
| src_tgt_dt = load_langpair_dataset( | |
| data_path, | |
| split, | |
| src, | |
| self.common_dict, | |
| tgt, | |
| self.common_dict, | |
| combine=combine, | |
| dataset_impl=self.args.dataset_impl, | |
| upsample_primary=self.args.upsample_primary, | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| max_source_positions=self.args.max_source_positions, | |
| max_target_positions=self.args.max_target_positions, | |
| load_alignments=self.args.load_alignments, | |
| truncate_source=self.args.truncate_source, | |
| num_buckets=self.args.num_batch_buckets, | |
| shuffle=(split != "test"), | |
| prepend_bos_src=_lang_token_index(self.dictionary, src), | |
| ) | |
| src_tgt_eos_dt = self._prepend_lang_bos_to_target(src_tgt_dt, tgt) | |
| src_tgt_eos_dt.args = self.args | |
| return src_tgt_eos_dt | |
| def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): | |
| raise NotImplementedError | |
| def build_model(self, args, from_checkpoint=False): | |
| # torch.autograd.set_detect_anomaly(True) | |
| model = super().build_model(args, from_checkpoint) | |
| add_secial_tokens_to_dict_and_model(self.common_dict, model, self.mono_langs) | |
| self.sequence_generators = {} | |
| for mono_lang in self.mono_langs: | |
| self.sequence_generators[mono_lang] = SequenceGenerator( | |
| [model], | |
| tgt_dict=self.dictionary, | |
| beam_size=1, | |
| max_len_a=1.3, | |
| max_len_b=5, | |
| min_len=5, | |
| # keep 1 to be able to prepend bos | |
| max_len=model.max_decoder_positions() - 1, | |
| ) | |
| if getattr(args, "eval_bleu", False): | |
| assert getattr(args, "eval_bleu_detok", None) is not None, ( | |
| "--eval-bleu-detok is required if using --eval-bleu; " | |
| "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " | |
| "to disable detokenization, e.g., when using sentencepiece)" | |
| ) | |
| detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}") | |
| self.tokenizer = encoders.build_tokenizer( | |
| Namespace( | |
| tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args | |
| ) | |
| ) | |
| gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}") | |
| self.bleu_sequence_generator = self.build_generator( | |
| [model], Namespace(**gen_args) | |
| ) | |
| return model | |
| def max_positions(self): | |
| """Return the max sentence length allowed by the task.""" | |
| return (self.args.max_source_positions, self.args.max_target_positions) | |
| def dictionary(self): | |
| """Return the source :class:`~fairseq.data.Dictionary`.""" | |
| return self.common_dict | |
| def display_samples_once_in_a_while(self, smp, mono_lang, other_lang): | |
| self._show_samples_ctr += 1 | |
| if self._show_samples_ctr < self.SHOW_SAMPLES_INTERVAL: | |
| return | |
| self._show_samples_ctr = 0 | |
| ln = smp["net_input"]["src_tokens"].shape[0] | |
| logger.info( | |
| f"(r:{self.args.distributed_rank}) : " | |
| f"{other_lang} ---> {mono_lang} " | |
| f"({other_lang} was generated by back-translation.) {ln} samples" | |
| ) | |
| for i in range(min(ln, self.SHOW_SAMPLES_NUMBER)): | |
| src_tokens = smp["net_input"]["src_tokens"][i] | |
| tgt_tokens = smp["target"][i] | |
| src_str = self.dictionary.string(src_tokens, "sentencepiece") | |
| tgt_str = self.dictionary.string(tgt_tokens, "sentencepiece") | |
| logger.info( | |
| f"\n{i}\t\t[{other_lang} generated] {src_str}\n" | |
| f"\t\t[{mono_lang} original ] {tgt_str}\n" | |
| f"\t\t[ src tokens] {src_tokens}\n" | |
| ) | |
| def backtranslate_sample(self, smp, orig_lang, other_lang) -> None: | |
| """ | |
| * WARNING: smp is modified in place. | |
| * At the start of this function, `smp` has the same input and target: | |
| |--------------------------------------------------------| | |
| | smp['net_input']['src_tokens'] | smp['target'] | | |
| | (from data) __en__ hello world | __en__ hello world | | |
| |--------------------------------------------------------| | |
| * We call generator.generate(smp, bos_token = token("ro")), | |
| and copy the result as input | |
| * At the end, `smp` has the translation to other language. | |
| |--------------------------------------------------------| | |
| | smp['net_input']['src_tokens'] | smp['target'] | | |
| | (generated) __ro__ salut lume | __en__ hello world | | |
| |--------------------------------------------------------| | |
| """ | |
| bos_token = _lang_token_index(self.dictionary, other_lang) | |
| generated = self.sequence_generators[orig_lang].generate( | |
| models=[], sample=smp, bos_token=bos_token | |
| ) | |
| max_lngth = max([gn[0]["tokens"].size(0) for gn in generated]) | |
| net_input = smp["net_input"] | |
| n_src_tokens = torch.empty( | |
| size=(len(generated), max_lngth + 1), dtype=net_input["src_tokens"].dtype | |
| ) | |
| n_src_lengths = torch.empty( | |
| len(generated), dtype=net_input["src_lengths"].dtype | |
| ) | |
| for i, gn in enumerate(generated): | |
| tokens = gn[0]["tokens"] | |
| tokens_size = tokens.size(0) | |
| padding_needed = max_lngth - tokens_size | |
| tokens = torch.cat([tokens.new([bos_token]), tokens]) | |
| tokens = F.pad(tokens, (0, padding_needed), value=self.dictionary.pad()) | |
| n_src_tokens[i] = tokens | |
| n_src_lengths[i] = tokens_size + 1 | |
| device = net_input["src_tokens"].device | |
| # This seems to be important | |
| del net_input["src_tokens"] | |
| del net_input["src_lengths"] | |
| net_input["src_tokens"] = n_src_tokens.to(device) | |
| net_input["src_lengths"] = n_src_lengths.to(device) | |
| def generate(self, smp, model): | |
| model.eval() | |
| orig_lang = ( | |
| self.dictionary[smp["net_input"]["src_tokens"][0][0]] | |
| .replace(" ", "") | |
| .replace("_", "") | |
| ) | |
| bos_token = smp["net_input"]["prev_output_tokens"][0][0] | |
| with torch.no_grad(): | |
| generated = self.sequence_generators[orig_lang].generate( | |
| models=[model], sample=smp, bos_token=bos_token | |
| ) | |
| return generated | |
| def get_other_lang(self, lang): | |
| # TODO: allow more complex mapping | |
| if lang != self.mono_langs[0]: | |
| return self.mono_langs[0] | |
| if len(self.mono_langs) == 2: | |
| return self.mono_langs[1] | |
| return self.mono_langs[np.random.randint(1, len(self.mono_langs))] | |
| def train_step( | |
| self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
| ): | |
| model.train() | |
| model.set_num_updates(update_num) | |
| agg_loss, agg_sample_size = 0.0, 0.0 | |
| agg_logging_output: Dict[str, float] = defaultdict(float) | |
| dataset_keys = self.datasets["train"].datasets.keys() | |
| weights = { | |
| "BT": self.lambda_bt(update_num), | |
| "DENOISE": self.lambda_dae(update_num), | |
| } | |
| log_keys = {"BT": "bt_", "DENOISE": "dae_"} | |
| for dataset_key in dataset_keys: | |
| smp = sample[dataset_key] | |
| mono_lang, task_subtype = dataset_key.split("-") | |
| if weights[task_subtype] == 0: | |
| continue | |
| if task_subtype == "BT": | |
| with torch.autograd.profiler.record_function("backtranslation"): | |
| model.eval() | |
| # TODO: Could we translate to several language at once ? | |
| # this would allow to share encoder_out and maximize GPU usage. | |
| other_lang = self.get_other_lang(mono_lang) | |
| self.backtranslate_sample(smp, mono_lang, other_lang) | |
| self.display_samples_once_in_a_while(smp, mono_lang, other_lang) | |
| model.train() | |
| # Like in FairseqTask.train_step | |
| with torch.autograd.profiler.record_function("forward"): | |
| loss, sample_size, logging_output = criterion(model, smp) | |
| loss *= weights[task_subtype] | |
| if ignore_grad: | |
| loss *= 0 | |
| with torch.autograd.profiler.record_function("backward"): | |
| optimizer.backward(loss) | |
| agg_loss += loss.item() | |
| agg_sample_size += sample_size | |
| for k in logging_output: | |
| agg_logging_output[log_keys[task_subtype] + k] += logging_output[k] | |
| agg_logging_output[k] += logging_output[k] | |
| return agg_loss, agg_sample_size, agg_logging_output | |
| def get_bos_token_from_sample(self, sample): | |
| net_input = sample["net_input"] | |
| source_lang_token_id = torch.unique(net_input["src_tokens"][:, 0]).item() | |
| source_lang_token = self.dictionary[source_lang_token_id].replace("_", "") | |
| target_lang_token_id = _lang_token_index( | |
| self.dictionary, self.get_other_lang(source_lang_token) | |
| ) | |
| return target_lang_token_id | |
| def reduce_metrics(self, logging_outputs, criterion): | |
| super().reduce_metrics(logging_outputs, criterion) | |
| bt_sample_size = sum(x.get("bt_sample_size", 0) for x in logging_outputs) | |
| if bt_sample_size: | |
| bt_loss_sum = sum(x.get("bt_loss", 0) for x in logging_outputs) | |
| bt_loss_sum *= 1 / bt_sample_size / math.log(2) | |
| metrics.log_scalar("bt_loss", bt_loss_sum, bt_sample_size, round=3) | |
| bt_nll_loss_sum = sum(x.get("bt_nll_loss", 0) for x in logging_outputs) | |
| bt_ntokens = sum(x.get("bt_ntokens", 0) for x in logging_outputs) | |
| bt_nll_loss_sum *= 1 / bt_ntokens / math.log(2) | |
| metrics.log_scalar("bt_nll_loss", bt_nll_loss_sum, bt_ntokens, round=3) | |
| metrics.log_derived( | |
| "bt_ppl", lambda meters: utils.get_perplexity(meters["bt_nll_loss"].avg) | |
| ) | |
| dae_sample_size = sum(x.get("dae_sample_size", 0) for x in logging_outputs) | |
| if dae_sample_size: | |
| dae_loss_sum = sum(x.get("dae_loss", 0) for x in logging_outputs) | |
| dae_loss_sum *= 1 / dae_sample_size / math.log(2) | |
| metrics.log_scalar("dae_loss", dae_loss_sum, dae_sample_size, round=3) | |
| dae_nll_loss_sum = sum(x.get("dae_nll_loss", 0) for x in logging_outputs) | |
| dae_ntokens = sum(x.get("dae_ntokens", 0) for x in logging_outputs) | |
| dae_nll_loss_sum *= 1 / dae_ntokens / math.log(2) | |
| metrics.log_scalar("dae_nll_loss", dae_nll_loss_sum, dae_ntokens, round=3) | |
| metrics.log_derived( | |
| "dae_ppl", | |
| lambda meters: utils.get_perplexity(meters["dae_nll_loss"].avg), | |
| ) | |
| def extend_embedding( | |
| emb: nn.Module, new_vocab_size: int, copy_from_token_id: int | |
| ) -> None: | |
| old_emb_data = emb.weight.data | |
| (old_vocab_size, dim) = old_emb_data.shape | |
| assert new_vocab_size >= old_vocab_size | |
| if new_vocab_size > old_vocab_size: | |
| emb.weight.data = torch.zeros((new_vocab_size, dim)) | |
| emb.weight.data[:old_vocab_size, :] = old_emb_data | |
| # initialize new embeddings | |
| emb.weight.data[old_vocab_size:, :] = old_emb_data[copy_from_token_id] | |
| if hasattr(emb, "num_embeddings"): | |
| emb.num_embeddings = new_vocab_size | |
| if hasattr(emb, "out_features"): | |
| emb.out_features = new_vocab_size | |
| if getattr(emb, "bias", None) is None: | |
| return | |
| # Fix the bias. | |
| # Bias shape can be different from the previous vocab size | |
| # if the weight matrix was shared and alread extended but not the bias. | |
| (old_vocab_size,) = emb.bias.shape | |
| assert new_vocab_size >= old_vocab_size | |
| if new_vocab_size > old_vocab_size: | |
| old_bias = emb.bias.data | |
| new_bias = torch.zeros( | |
| (new_vocab_size,), dtype=old_bias.dtype, device=old_bias.device | |
| ) | |
| new_bias[:old_vocab_size] = old_bias | |
| emb.bias.data = new_bias | |
| def add_secial_tokens_to_dict_and_model( | |
| dictionary: "fairseq.data.Dictionary", | |
| model: nn.Module, | |
| mono_langs: Sequence[str], | |
| ) -> None: | |
| embs = model.encoder.embed_tokens | |
| vocab_size, embedding_dim = embs.weight.shape | |
| # The model may or may not have a '<mask>' embedding yet | |
| assert ( | |
| len(dictionary) <= vocab_size <= len(dictionary) + 1 | |
| ), f"Dictionary len ({len(dictionary)}) doesn't match embs shape ({embs.weight.shape})" | |
| # TODO: we should reuse the pretrained model dict which already has <mask> | |
| dictionary.add_symbol("<mask>") | |
| for lang in mono_langs: | |
| lang_token = _lang_token(lang) | |
| dictionary.add_symbol(lang_token) | |
| logger.info( | |
| f"dictionary: {len(dictionary)} -> {vocab_size} tokens " | |
| f"after adding {len(mono_langs)} lang tokens." | |
| ) | |
| if len(dictionary) <= vocab_size: | |
| return | |
| extend_embedding(embs, len(dictionary), dictionary.bos()) | |
| dec_embs = model.decoder.embed_tokens | |
| extend_embedding(dec_embs, len(dictionary), dictionary.bos()) | |
| lm_head = model.decoder.output_projection | |
| extend_embedding(lm_head, len(dictionary), dictionary.bos()) | |
| assert lm_head.weight.shape == (len(dictionary), embedding_dim) | |
| def _lang_token(lang: str) -> str: | |
| return f"__{lang}__" | |
| def _lang_token_index(dictionary, lang: str) -> int: | |
| return dictionary.index(_lang_token(lang)) | |
| def assert_weights_have_changed(model: nn.Module): | |
| def checksum(model: nn.Module) -> float: | |
| return sum(p.sum().item() for p in model.parameters()) | |
| initial_checksum = checksum(model) | |
| yield model | |
| final_checksum = checksum(model) | |
| logger.info( | |
| f"initial_checksum={initial_checksum} -> final_checksum={final_checksum}" | |
| ) | |
| assert initial_checksum != final_checksum, "Model hasn't changed !" | |