Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import unittest | |
| import tests.utils as test_utils | |
| import torch | |
| from fairseq.data import ( | |
| BacktranslationDataset, | |
| LanguagePairDataset, | |
| TransformEosDataset, | |
| ) | |
| from fairseq.sequence_generator import SequenceGenerator | |
| class TestBacktranslationDataset(unittest.TestCase): | |
| def setUp(self): | |
| ( | |
| self.tgt_dict, | |
| self.w1, | |
| self.w2, | |
| self.src_tokens, | |
| self.src_lengths, | |
| self.model, | |
| ) = test_utils.sequence_generator_setup() | |
| dummy_src_samples = self.src_tokens | |
| self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples) | |
| self.cuda = torch.cuda.is_available() | |
| def _backtranslation_dataset_helper( | |
| self, | |
| remove_eos_from_input_src, | |
| remove_eos_from_output_src, | |
| ): | |
| tgt_dataset = LanguagePairDataset( | |
| src=self.tgt_dataset, | |
| src_sizes=self.tgt_dataset.sizes, | |
| src_dict=self.tgt_dict, | |
| tgt=None, | |
| tgt_sizes=None, | |
| tgt_dict=None, | |
| ) | |
| generator = SequenceGenerator( | |
| [self.model], | |
| tgt_dict=self.tgt_dict, | |
| max_len_a=0, | |
| max_len_b=200, | |
| beam_size=2, | |
| unk_penalty=0, | |
| ) | |
| backtranslation_dataset = BacktranslationDataset( | |
| tgt_dataset=TransformEosDataset( | |
| dataset=tgt_dataset, | |
| eos=self.tgt_dict.eos(), | |
| # remove eos from the input src | |
| remove_eos_from_src=remove_eos_from_input_src, | |
| ), | |
| src_dict=self.tgt_dict, | |
| backtranslation_fn=( | |
| lambda sample: generator.generate([self.model], sample) | |
| ), | |
| output_collater=TransformEosDataset( | |
| dataset=tgt_dataset, | |
| eos=self.tgt_dict.eos(), | |
| # if we remove eos from the input src, then we need to add it | |
| # back to the output tgt | |
| append_eos_to_tgt=remove_eos_from_input_src, | |
| remove_eos_from_src=remove_eos_from_output_src, | |
| ).collater, | |
| cuda=self.cuda, | |
| ) | |
| dataloader = torch.utils.data.DataLoader( | |
| backtranslation_dataset, | |
| batch_size=2, | |
| collate_fn=backtranslation_dataset.collater, | |
| ) | |
| backtranslation_batch_result = next(iter(dataloader)) | |
| eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2 | |
| # Note that we sort by src_lengths and add left padding, so actually | |
| # ids will look like: [1, 0] | |
| expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]]) | |
| if remove_eos_from_output_src: | |
| expected_src = expected_src[:, :-1] | |
| expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) | |
| generated_src = backtranslation_batch_result["net_input"]["src_tokens"] | |
| tgt_tokens = backtranslation_batch_result["target"] | |
| self.assertTensorEqual(expected_src, generated_src) | |
| self.assertTensorEqual(expected_tgt, tgt_tokens) | |
| def test_backtranslation_dataset_no_eos_in_output_src(self): | |
| self._backtranslation_dataset_helper( | |
| remove_eos_from_input_src=False, | |
| remove_eos_from_output_src=True, | |
| ) | |
| def test_backtranslation_dataset_with_eos_in_output_src(self): | |
| self._backtranslation_dataset_helper( | |
| remove_eos_from_input_src=False, | |
| remove_eos_from_output_src=False, | |
| ) | |
| def test_backtranslation_dataset_no_eos_in_input_src(self): | |
| self._backtranslation_dataset_helper( | |
| remove_eos_from_input_src=True, | |
| remove_eos_from_output_src=False, | |
| ) | |
| def assertTensorEqual(self, t1, t2): | |
| self.assertEqual(t1.size(), t2.size(), "size mismatch") | |
| self.assertEqual(t1.ne(t2).long().sum(), 0) | |
| if __name__ == "__main__": | |
| unittest.main() | |