Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from dataclasses import dataclass, field | |
| import itertools | |
| import logging | |
| import os | |
| import numpy as np | |
| import torch | |
| from fairseq import metrics | |
| from fairseq.data import ( | |
| ConcatDataset, | |
| ConcatSentencesDataset, | |
| data_utils, | |
| Dictionary, | |
| IdDataset, | |
| indexed_dataset, | |
| NestedDictionaryDataset, | |
| NumSamplesDataset, | |
| NumelDataset, | |
| PrependTokenDataset, | |
| RawLabelDataset, | |
| RightPadDataset, | |
| SortDataset, | |
| TruncateDataset, | |
| TokenBlockDataset, | |
| ) | |
| from fairseq.dataclass import ChoiceEnum, FairseqDataclass | |
| from fairseq.tasks import FairseqTask, register_task | |
| from omegaconf import II, MISSING | |
| EVAL_BLEU_ORDER = 4 | |
| TARGET_METRIC_CHOICES = ChoiceEnum(["bleu", "ter"]) | |
| logger = logging.getLogger(__name__) | |
| class DiscriminativeRerankingNMTConfig(FairseqDataclass): | |
| data: str = field(default=MISSING, metadata={"help": "path to data directory"}) | |
| num_data_splits: int = field( | |
| default=1, metadata={"help": "total number of data splits"} | |
| ) | |
| no_shuffle: bool = field( | |
| default=False, metadata={"help": "do not shuffle training data"} | |
| ) | |
| max_positions: int = field( | |
| default=512, metadata={"help": "number of positional embeddings to learn"} | |
| ) | |
| include_src: bool = field( | |
| default=False, metadata={"help": "include source sentence"} | |
| ) | |
| mt_beam: int = field(default=50, metadata={"help": "beam size of input hypotheses"}) | |
| eval_target_metric: bool = field( | |
| default=False, | |
| metadata={"help": "evaluation with the target metric during validation"}, | |
| ) | |
| target_metric: TARGET_METRIC_CHOICES = field( | |
| default="bleu", metadata={"help": "name of the target metric to optimize for"} | |
| ) | |
| train_subset: str = field( | |
| default=II("dataset.train_subset"), | |
| metadata={"help": "data subset to use for training (e.g. train, valid, test)"}, | |
| ) | |
| seed: int = field( | |
| default=II("common.seed"), | |
| metadata={"help": "pseudo random number generator seed"}, | |
| ) | |
| class RerankerScorer(object): | |
| """Scores the target for a given (source (optional), target) input.""" | |
| def __init__(self, args, mt_beam): | |
| self.mt_beam = mt_beam | |
| def generate(self, models, sample, **kwargs): | |
| """Score a batch of translations.""" | |
| net_input = sample["net_input"] | |
| assert len(models) == 1, "does not support model ensemble" | |
| model = models[0] | |
| bs = net_input["src_tokens"].shape[0] | |
| assert ( | |
| model.joint_classification == "none" or bs % self.mt_beam == 0 | |
| ), f"invalid batch size ({bs}) for joint classification with beam size ({self.mt_beam})" | |
| model.eval() | |
| logits = model(**net_input) | |
| batch_out = model.sentence_forward(logits, net_input["src_tokens"]) | |
| if model.joint_classification == "sent": | |
| batch_out = model.joint_forward( | |
| batch_out.view(self.mt_beam, bs // self.mt_beam, -1) | |
| ) | |
| scores = model.classification_forward( | |
| batch_out.view(bs, 1, -1) | |
| ) # input: B x T x C | |
| return scores | |
| class DiscriminativeRerankingNMTTask(FairseqTask): | |
| """ | |
| Translation rerank task. | |
| The input can be either (src, tgt) sentence pairs or tgt sentence only. | |
| """ | |
| cfg: DiscriminativeRerankingNMTConfig | |
| def __init__(self, cfg: DiscriminativeRerankingNMTConfig, data_dictionary=None): | |
| super().__init__(cfg) | |
| self.dictionary = data_dictionary | |
| self._max_positions = cfg.max_positions | |
| # args.tokens_per_sample = self._max_positions | |
| # self.num_classes = 1 # for model | |
| def load_dictionary(cls, cfg, filename): | |
| """Load the dictionary from the filename""" | |
| dictionary = Dictionary.load(filename) | |
| dictionary.add_symbol("<mask>") # for loading pretrained XLMR model | |
| return dictionary | |
| def setup_task(cls, cfg: DiscriminativeRerankingNMTConfig, **kwargs): | |
| # load data dictionary (assume joint dictionary) | |
| data_path = cfg.data | |
| data_dict = cls.load_dictionary( | |
| cfg, os.path.join(data_path, "input_src/dict.txt") | |
| ) | |
| logger.info("[input] src dictionary: {} types".format(len(data_dict))) | |
| return DiscriminativeRerankingNMTTask(cfg, data_dict) | |
| def load_dataset(self, split, epoch=0, combine=False, **kwargs): | |
| """Load a given dataset split (e.g., train, valid, test).""" | |
| if self.cfg.data.endswith("1"): | |
| data_shard = (epoch - 1) % self.cfg.num_data_splits + 1 | |
| data_path = self.cfg.data[:-1] + str(data_shard) | |
| else: | |
| data_path = self.cfg.data | |
| def get_path(type, data_split): | |
| return os.path.join(data_path, str(type), data_split) | |
| def make_dataset(type, dictionary, data_split, combine): | |
| split_path = get_path(type, data_split) | |
| dataset = data_utils.load_indexed_dataset( | |
| split_path, dictionary, combine=combine, | |
| ) | |
| return dataset | |
| def load_split(data_split, metric): | |
| input_src = None | |
| if self.cfg.include_src: | |
| input_src = make_dataset( | |
| "input_src", self.dictionary, data_split, combine=False | |
| ) | |
| assert input_src is not None, "could not find dataset: {}".format( | |
| get_path("input_src", data_split) | |
| ) | |
| input_tgt = make_dataset( | |
| "input_tgt", self.dictionary, data_split, combine=False | |
| ) | |
| assert input_tgt is not None, "could not find dataset: {}".format( | |
| get_path("input_tgt", data_split) | |
| ) | |
| label_path = f"{get_path(metric, data_split)}.{metric}" | |
| assert os.path.exists(label_path), f"could not find dataset: {label_path}" | |
| np_labels = np.loadtxt(label_path) | |
| if self.cfg.target_metric == "ter": | |
| np_labels = -np_labels | |
| label = RawLabelDataset(np_labels) | |
| return input_src, input_tgt, label | |
| src_datasets = [] | |
| tgt_datasets = [] | |
| label_datasets = [] | |
| if split == self.cfg.train_subset: | |
| for k in itertools.count(): | |
| split_k = "train" + (str(k) if k > 0 else "") | |
| prefix = os.path.join(data_path, "input_tgt", split_k) | |
| if not indexed_dataset.dataset_exists(prefix, impl=None): | |
| if k > 0: | |
| break | |
| else: | |
| raise FileNotFoundError(f"Dataset not found: {prefix}") | |
| input_src, input_tgt, label = load_split( | |
| split_k, self.cfg.target_metric | |
| ) | |
| src_datasets.append(input_src) | |
| tgt_datasets.append(input_tgt) | |
| label_datasets.append(label) | |
| else: | |
| input_src, input_tgt, label = load_split(split, self.cfg.target_metric) | |
| src_datasets.append(input_src) | |
| tgt_datasets.append(input_tgt) | |
| label_datasets.append(label) | |
| if len(tgt_datasets) == 1: | |
| input_tgt, label = tgt_datasets[0], label_datasets[0] | |
| if self.cfg.include_src: | |
| input_src = src_datasets[0] | |
| else: | |
| input_tgt = ConcatDataset(tgt_datasets) | |
| label = ConcatDataset(label_datasets) | |
| if self.cfg.include_src: | |
| input_src = ConcatDataset(src_datasets) | |
| input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions) | |
| if self.cfg.include_src: | |
| input_src = PrependTokenDataset(input_src, self.dictionary.bos()) | |
| input_src = TruncateDataset(input_src, self.cfg.max_positions) | |
| src_lengths = NumelDataset(input_src, reduce=False) | |
| src_tokens = ConcatSentencesDataset(input_src, input_tgt) | |
| else: | |
| src_tokens = PrependTokenDataset(input_tgt, self.dictionary.bos()) | |
| src_lengths = NumelDataset(src_tokens, reduce=False) | |
| dataset = { | |
| "id": IdDataset(), | |
| "net_input": { | |
| "src_tokens": RightPadDataset( | |
| src_tokens, pad_idx=self.source_dictionary.pad(), | |
| ), | |
| "src_lengths": src_lengths, | |
| }, | |
| "nsentences": NumSamplesDataset(), | |
| "ntokens": NumelDataset(src_tokens, reduce=True), | |
| "target": label, | |
| } | |
| dataset = NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],) | |
| assert len(dataset) % self.cfg.mt_beam == 0, ( | |
| "dataset size (%d) is not a multiple of beam size (%d)" | |
| % (len(dataset), self.cfg.mt_beam) | |
| ) | |
| # no need to shuffle valid/test sets | |
| if not self.cfg.no_shuffle and split == self.cfg.train_subset: | |
| # need to keep all hypothese together | |
| start_idx = np.arange(0, len(dataset), self.cfg.mt_beam) | |
| with data_utils.numpy_seed(self.cfg.seed + epoch): | |
| np.random.shuffle(start_idx) | |
| idx = np.arange(0, self.cfg.mt_beam) | |
| shuffle = np.tile(idx, (len(start_idx), 1)).reshape(-1) + np.tile( | |
| start_idx, (self.cfg.mt_beam, 1) | |
| ).transpose().reshape(-1) | |
| dataset = SortDataset(dataset, sort_order=[shuffle],) | |
| logger.info(f"Loaded {split} with #samples: {len(dataset)}") | |
| self.datasets[split] = dataset | |
| return self.datasets[split] | |
| def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): | |
| assert not self.cfg.include_src or len(src_tokens[0]) == 2 | |
| input_src = None | |
| if self.cfg.include_src: | |
| input_src = TokenBlockDataset( | |
| [t[0] for t in src_tokens], | |
| [l[0] for l in src_lengths], | |
| block_size=None, # ignored for "eos" break mode | |
| pad=self.source_dictionary.pad(), | |
| eos=self.source_dictionary.eos(), | |
| break_mode="eos", | |
| ) | |
| input_src = PrependTokenDataset(input_src, self.dictionary.bos()) | |
| input_src = TruncateDataset(input_src, self.cfg.max_positions) | |
| input_tgt = TokenBlockDataset( | |
| [t[-1] for t in src_tokens], | |
| [l[-1] for l in src_lengths], | |
| block_size=None, # ignored for "eos" break mode | |
| pad=self.source_dictionary.pad(), | |
| eos=self.source_dictionary.eos(), | |
| break_mode="eos", | |
| ) | |
| input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions) | |
| if self.cfg.include_src: | |
| src_tokens = ConcatSentencesDataset(input_src, input_tgt) | |
| src_lengths = NumelDataset(input_src, reduce=False) | |
| else: | |
| input_tgt = PrependTokenDataset(input_tgt, self.dictionary.bos()) | |
| src_tokens = input_tgt | |
| src_lengths = NumelDataset(src_tokens, reduce=False) | |
| dataset = { | |
| "id": IdDataset(), | |
| "net_input": { | |
| "src_tokens": RightPadDataset( | |
| src_tokens, pad_idx=self.source_dictionary.pad(), | |
| ), | |
| "src_lengths": src_lengths, | |
| }, | |
| "nsentences": NumSamplesDataset(), | |
| "ntokens": NumelDataset(src_tokens, reduce=True), | |
| } | |
| return NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],) | |
| def build_model(self, cfg: FairseqDataclass): | |
| return super().build_model(cfg) | |
| def build_generator(self, args): | |
| return RerankerScorer(args, mt_beam=self.cfg.mt_beam) | |
| def max_positions(self): | |
| return self._max_positions | |
| def source_dictionary(self): | |
| return self.dictionary | |
| def target_dictionary(self): | |
| return self.dictionary | |
| def create_dummy_batch(self, device): | |
| dummy_target = ( | |
| torch.zeros(self.cfg.mt_beam, EVAL_BLEU_ORDER * 2 + 3).long().to(device) | |
| if not self.cfg.eval_ter | |
| else torch.zeros(self.cfg.mt_beam, 3).long().to(device) | |
| ) | |
| return { | |
| "id": torch.zeros(self.cfg.mt_beam, 1).long().to(device), | |
| "net_input": { | |
| "src_tokens": torch.zeros(self.cfg.mt_beam, 4).long().to(device), | |
| "src_lengths": torch.ones(self.cfg.mt_beam, 1).long().to(device), | |
| }, | |
| "nsentences": 0, | |
| "ntokens": 0, | |
| "target": dummy_target, | |
| } | |
| def train_step( | |
| self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
| ): | |
| if ignore_grad and sample is None: | |
| sample = self.create_dummy_batch(model.device) | |
| return super().train_step( | |
| sample, model, criterion, optimizer, update_num, ignore_grad | |
| ) | |
| def valid_step(self, sample, model, criterion): | |
| if sample is None: | |
| sample = self.create_dummy_batch(model.device) | |
| loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
| if not self.cfg.eval_target_metric: | |
| return loss, sample_size, logging_output | |
| scores = logging_output["scores"] | |
| if self.cfg.target_metric == "bleu": | |
| assert sample["target"].shape[1] == EVAL_BLEU_ORDER * 2 + 3, ( | |
| "target does not contain enough information (" | |
| + str(sample["target"].shape[1]) | |
| + "for evaluating BLEU" | |
| ) | |
| max_id = torch.argmax(scores, dim=1) | |
| select_id = max_id + torch.arange( | |
| 0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam | |
| ).to(max_id.device) | |
| bleu_data = sample["target"][select_id, 1:].sum(0).data | |
| logging_output["_bleu_sys_len"] = bleu_data[0] | |
| logging_output["_bleu_ref_len"] = bleu_data[1] | |
| for i in range(EVAL_BLEU_ORDER): | |
| logging_output["_bleu_counts_" + str(i)] = bleu_data[2 + i] | |
| logging_output["_bleu_totals_" + str(i)] = bleu_data[ | |
| 2 + EVAL_BLEU_ORDER + i | |
| ] | |
| elif self.cfg.target_metric == "ter": | |
| assert sample["target"].shape[1] == 3, ( | |
| "target does not contain enough information (" | |
| + str(sample["target"].shape[1]) | |
| + "for evaluating TER" | |
| ) | |
| max_id = torch.argmax(scores, dim=1) | |
| select_id = max_id + torch.arange( | |
| 0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam | |
| ).to(max_id.device) | |
| ter_data = sample["target"][select_id, 1:].sum(0).data | |
| logging_output["_ter_num_edits"] = -ter_data[0] | |
| logging_output["_ter_ref_len"] = -ter_data[1] | |
| return loss, sample_size, logging_output | |
| def reduce_metrics(self, logging_outputs, criterion): | |
| super().reduce_metrics(logging_outputs, criterion) | |
| if not self.cfg.eval_target_metric: | |
| return | |
| def sum_logs(key): | |
| return sum(log.get(key, 0) for log in logging_outputs) | |
| if self.cfg.target_metric == "bleu": | |
| counts, totals = [], [] | |
| for i in range(EVAL_BLEU_ORDER): | |
| counts.append(sum_logs("_bleu_counts_" + str(i))) | |
| totals.append(sum_logs("_bleu_totals_" + str(i))) | |
| if max(totals) > 0: | |
| # log counts as numpy arrays -- log_scalar will sum them correctly | |
| metrics.log_scalar("_bleu_counts", np.array(counts)) | |
| metrics.log_scalar("_bleu_totals", np.array(totals)) | |
| metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) | |
| metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) | |
| def compute_bleu(meters): | |
| import inspect | |
| import sacrebleu | |
| fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] | |
| if "smooth_method" in fn_sig: | |
| smooth = {"smooth_method": "exp"} | |
| else: | |
| smooth = {"smooth": "exp"} | |
| bleu = sacrebleu.compute_bleu( | |
| correct=meters["_bleu_counts"].sum, | |
| total=meters["_bleu_totals"].sum, | |
| sys_len=meters["_bleu_sys_len"].sum, | |
| ref_len=meters["_bleu_ref_len"].sum, | |
| **smooth, | |
| ) | |
| return round(bleu.score, 2) | |
| metrics.log_derived("bleu", compute_bleu) | |
| elif self.cfg.target_metric == "ter": | |
| num_edits = sum_logs("_ter_num_edits") | |
| ref_len = sum_logs("_ter_ref_len") | |
| if ref_len > 0: | |
| metrics.log_scalar("_ter_num_edits", num_edits) | |
| metrics.log_scalar("_ter_ref_len", ref_len) | |
| def compute_ter(meters): | |
| score = meters["_ter_num_edits"].sum / meters["_ter_ref_len"].sum | |
| return round(score.item(), 2) | |
| metrics.log_derived("ter", compute_ter) | |