Spaces:
Running
Running
| import os | |
| import shutil | |
| import tempfile | |
| import unittest | |
| from fairseq import options | |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
| from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored | |
| from .utils import create_dummy_data, preprocess_lm_data, train_language_model | |
| def make_lm_config( | |
| data_dir=None, | |
| extra_flags=None, | |
| task="language_modeling", | |
| arch="transformer_lm_gpt2_tiny", | |
| ): | |
| task_args = [task] | |
| if data_dir is not None: | |
| task_args += [data_dir] | |
| train_parser = options.get_training_parser() | |
| train_args = options.parse_args_and_arch( | |
| train_parser, | |
| [ | |
| "--task", | |
| *task_args, | |
| "--arch", | |
| arch, | |
| "--optimizer", | |
| "adam", | |
| "--lr", | |
| "0.0001", | |
| "--max-tokens", | |
| "500", | |
| "--tokens-per-sample", | |
| "500", | |
| "--save-dir", | |
| data_dir, | |
| "--max-epoch", | |
| "1", | |
| ] | |
| + (extra_flags or []), | |
| ) | |
| cfg = convert_namespace_to_omegaconf(train_args) | |
| return cfg | |
| def write_empty_file(path): | |
| with open(path, "w"): | |
| pass | |
| assert os.path.exists(path) | |
| class TestValidSubsetsErrors(unittest.TestCase): | |
| """Test various filesystem, clarg combinations and ensure that error raising happens as expected""" | |
| def _test_case(self, paths, extra_flags): | |
| with tempfile.TemporaryDirectory() as data_dir: | |
| [ | |
| write_empty_file(os.path.join(data_dir, f"{p}.bin")) | |
| for p in paths + ["train"] | |
| ] | |
| cfg = make_lm_config(data_dir, extra_flags=extra_flags) | |
| raise_if_valid_subsets_unintentionally_ignored(cfg) | |
| def test_default_raises(self): | |
| with self.assertRaises(ValueError): | |
| self._test_case(["valid", "valid1"], []) | |
| with self.assertRaises(ValueError): | |
| self._test_case( | |
| ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"] | |
| ) | |
| def partially_specified_valid_subsets(self): | |
| with self.assertRaises(ValueError): | |
| self._test_case( | |
| ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"] | |
| ) | |
| # Fix with ignore unused | |
| self._test_case( | |
| ["valid", "valid1", "valid2"], | |
| ["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"], | |
| ) | |
| def test_legal_configs(self): | |
| self._test_case(["valid"], []) | |
| self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"]) | |
| self._test_case(["valid", "valid1"], ["--combine-val"]) | |
| self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"]) | |
| self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"]) | |
| self._test_case( | |
| ["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"] | |
| ) | |
| self._test_case( | |
| ["valid1"], ["--valid-subset", "valid1"] | |
| ) # valid.bin doesn't need to be ignored. | |
| def test_disable_validation(self): | |
| self._test_case([], ["--disable-validation"]) | |
| self._test_case(["valid", "valid1"], ["--disable-validation"]) | |
| def test_dummy_task(self): | |
| cfg = make_lm_config(task="dummy_lm") | |
| raise_if_valid_subsets_unintentionally_ignored(cfg) | |
| def test_masked_dummy_task(self): | |
| cfg = make_lm_config(task="dummy_masked_lm") | |
| raise_if_valid_subsets_unintentionally_ignored(cfg) | |
| class TestCombineValidSubsets(unittest.TestCase): | |
| def _train(self, extra_flags): | |
| with self.assertLogs() as logs: | |
| with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: | |
| create_dummy_data(data_dir, num_examples=20) | |
| preprocess_lm_data(data_dir) | |
| shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin") | |
| shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx") | |
| train_language_model( | |
| data_dir, | |
| "transformer_lm", | |
| ["--max-update", "0", "--log-format", "json"] + extra_flags, | |
| run_validation=False, | |
| ) | |
| return [x.message for x in logs.records] | |
| def test_combined(self): | |
| flags = ["--combine-valid-subsets"] | |
| logs = self._train(flags) | |
| assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1 | |
| assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined | |
| def test_subsets(self): | |
| flags = ["--valid-subset", "valid,valid1"] | |
| logs = self._train(flags) | |
| assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1 | |
| assert any(["valid1_ppl" in x for x in logs]) # metrics are combined | |