Spaces:
Runtime error
Runtime error
| # Copyright (c) 2017-present, Facebook, Inc. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the LICENSE file in | |
| # the root directory of this source tree. An additional grant of patent rights | |
| # can be found in the PATENTS file in the same directory. | |
| import logging | |
| import os | |
| import torch | |
| import json | |
| from argparse import Namespace | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Any | |
| from fairseq.data import AddTargetDataset, Dictionary, encoders | |
| from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig | |
| from fairseq.dataclass import FairseqDataclass | |
| from fairseq.dataclass.configs import GenerationConfig | |
| from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel | |
| from . import register_task | |
| from .. import utils | |
| from ..logging import metrics | |
| logger = logging.getLogger(__name__) | |
| class LabelEncoder(object): | |
| def __init__(self, dictionary): | |
| self.dictionary = dictionary | |
| def __call__(self, label): | |
| return self.dictionary.encode_line( | |
| label, append_eos=False, add_if_not_exist=False | |
| ) | |
| def label_len_fn(label): | |
| return len(label.split(" ")) | |
| class NLUFinetuningConfig(AudioPretrainingConfig): | |
| # Options for reporting WER metrics during validation. Only applicable to | |
| # Seq2Seq models during fine-tuning | |
| eval_wer: bool = field( | |
| default=False, metadata={"help": "compute WER for Seq2Seq models"} | |
| ) | |
| eval_wer_parse: bool = field( | |
| default=False, metadata={"help": "compute WER for Seq2Seq models"} | |
| ) | |
| eval_wer_config: GenerationConfig = field( | |
| default_factory=lambda: GenerationConfig(), | |
| metadata={"help": "beam search config for evaluating wer during training"}, | |
| ) | |
| eval_wer_tokenizer: Any = field( | |
| default=None, | |
| metadata={"help": "tokenizer config for evaluating wer during training"}, | |
| ) | |
| eval_wer_post_process: str = field( | |
| default="letter", | |
| metadata={ | |
| "help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)" | |
| }, | |
| ) | |
| eval_bleu: bool = field( | |
| default=False, metadata={"help": "evaluation with BLEU scores"} | |
| ) | |
| eval_bleu_detok: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": "detokenize before computing BLEU (e.g., 'moses'); " | |
| "required if using --eval-bleu; use 'space' to disable " | |
| "detokenization; see fairseq.data.encoders for other options" | |
| }, | |
| ) | |
| eval_bleu_detok_args: str = field( | |
| default="{}", metadata={"help": "args for building the tokenizer, if needed"} | |
| ) | |
| eval_tokenized_bleu: bool = field( | |
| default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"} | |
| ) | |
| eval_bleu_remove_bpe: Optional[str] = field( | |
| default=None, metadata={"help": "remove BPE before computing BLEU"} | |
| ) | |
| eval_bleu_args: str = field( | |
| default="{}", | |
| metadata={ | |
| "help": "generation args for BLUE scoring, e.g., " | |
| '\'{"beam": 4, "lenpen": 0.6}\'' | |
| }, | |
| ) | |
| eval_bleu_print_samples: bool = field( | |
| default=False, metadata={"help": "print sample generations during validation"} | |
| ) | |
| autoregressive: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": "required for autoregressive decoders (like seq2seq models); " | |
| "adds 'prev_output_tokens' to input and appends eos to target" | |
| }, | |
| ) | |
| class NLUFinetuningTask(AudioPretrainingTask): | |
| """ """ | |
| cfg: NLUFinetuningConfig | |
| def __init__( | |
| self, | |
| cfg: NLUFinetuningConfig, | |
| ): | |
| super().__init__(cfg) | |
| self.blank_symbol = "<s>" | |
| self.state.add_factory("target_dictionary", self.load_target_dictionary) | |
| def load_target_dictionary(self): | |
| if self.cfg.labels: | |
| dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt") | |
| return Dictionary.load(dict_path) | |
| return None | |
| def load_dataset(self, split: str, task_cfg: NLUFinetuningConfig = None, **kwargs): | |
| super().load_dataset(split, task_cfg, **kwargs) | |
| task_cfg = task_cfg or self.cfg | |
| assert task_cfg.labels is not None | |
| text_compression_level = getattr( | |
| TextCompressionLevel, str(self.cfg.text_compression_level) | |
| ) | |
| data_path = self.cfg.data | |
| label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") | |
| skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) | |
| text_compressor = TextCompressor(level=text_compression_level) | |
| with open(label_path, "r") as f: | |
| labels = [ | |
| text_compressor.compress(l) | |
| for i, l in enumerate(f) | |
| if i not in skipped_indices | |
| ] | |
| assert len(labels) == len(self.datasets[split]), ( | |
| f"labels length ({len(labels)}) and dataset length " | |
| f"({len(self.datasets[split])}) do not match" | |
| ) | |
| process_label = LabelEncoder(self.target_dictionary) | |
| self.datasets[split] = AddTargetDataset( | |
| self.datasets[split], | |
| labels, | |
| pad=self.target_dictionary.pad(), | |
| eos=self.target_dictionary.eos(), | |
| batch_targets=True, | |
| process_label=process_label, | |
| label_len_fn=label_len_fn, | |
| add_to_input=task_cfg.get("autoregressive", False), | |
| text_compression_level=text_compression_level, | |
| ) | |
| def target_dictionary(self): | |
| """Return the :class:`~fairseq.data.Dictionary` for the language | |
| model.""" | |
| return self.state.target_dictionary | |
| def valid_step(self, sample, model, criterion): | |
| loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
| if self.cfg.eval_wer_parse and self.cfg.autoregressive: | |
| metrics = self._inference_with_wer_parse( | |
| self.sequence_generator, sample, model | |
| ) | |
| logging_output["_num_char_errors"] = metrics["num_char_errors"] | |
| logging_output["_num_chars"] = metrics["num_chars"] | |
| logging_output["_num_word_errors"] = metrics["num_word_errors"] | |
| logging_output["_num_words"] = metrics["num_words"] | |
| logging_output["_num_em_errors"] = metrics["num_em_errors"] | |
| logging_output["_num_ems"] = metrics["num_ems"] | |
| logging_output["_num_tree_errors"] = metrics["num_tree_errors"] | |
| logging_output["_num_trees"] = metrics["num_trees"] | |
| if self.cfg.eval_wer and self.cfg.autoregressive: | |
| metrics = self._inference_with_wer(self.sequence_generator, sample, model) | |
| logging_output["_num_char_errors"] = metrics["num_char_errors"] | |
| logging_output["_num_chars"] = metrics["num_chars"] | |
| logging_output["_num_word_errors"] = metrics["num_word_errors"] | |
| logging_output["_num_words"] = metrics["num_words"] | |
| if self.cfg.eval_bleu and self.cfg.autoregressive: | |
| metrics = self._inference_with_bleu(self.sequence_generator, sample, model) | |
| logging_output["_bleu_sys_len"] = metrics.sys_len | |
| logging_output["_bleu_ref_len"] = metrics.ref_len | |
| # we split counts into separate entries so that they can be | |
| # summed efficiently across workers using fast-stat-sync | |
| assert len(metrics.counts) == 4 | |
| for i in range(4): | |
| logging_output[f"_bleu_counts_{i}"] = metrics.counts[i] | |
| logging_output[f"_bleu_totals_{i}"] = metrics.totals[i] | |
| return loss, sample_size, logging_output | |
| def build_model(self, model_cfg: FairseqDataclass): | |
| model = super().build_model(model_cfg) | |
| if (self.cfg.eval_wer or self.cfg.eval_wer_parse) and self.cfg.autoregressive: | |
| self.sequence_generator = self.build_generator( | |
| [model], | |
| self.cfg.eval_wer_config, | |
| ) | |
| if self.cfg.eval_wer_tokenizer: | |
| self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) | |
| else: | |
| self.tokenizer = None | |
| if self.cfg.eval_bleu and self.cfg.autoregressive: | |
| assert self.cfg.eval_bleu_detok is not None, ( | |
| "--eval-bleu-detok is required if using --eval-bleu; " | |
| "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " | |
| "to disable detokenization, e.g., when using sentencepiece)" | |
| ) | |
| detok_args = json.loads(self.cfg.eval_bleu_detok_args) | |
| self.tokenizer = encoders.build_tokenizer( | |
| Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) | |
| ) | |
| gen_args = json.loads(self.cfg.eval_bleu_args) | |
| gen_args = Namespace(**gen_args) | |
| self.sequence_generator = self.build_generator([model], gen_args) | |
| return model | |
| def _inference_with_wer_parse(self, generator, sample, model): | |
| import editdistance | |
| def decode(toks): | |
| s = self.target_dictionary.string( | |
| toks.int().cpu(), | |
| self.cfg.eval_wer_post_process, | |
| escape_unk=True, | |
| ) | |
| if self.tokenizer: | |
| s = self.tokenizer.decode(s) | |
| return s | |
| def decode_to_list(toks): | |
| def token_string(i): | |
| if i == self.target_dictionary.unk(): | |
| return self.target_dictionary.unk_string(False) | |
| else: | |
| return self.target_dictionary[i] | |
| return [token_string(i) for i in toks] | |
| def is_ont_token(token): | |
| return "[" in token or "]" in token | |
| def post_process(l): | |
| o = [] | |
| for w in l: | |
| if w == self.target_dictionary.eos_word or w == "|": | |
| continue | |
| if w == "_": | |
| o.append(" ") | |
| else: | |
| o.append(w) | |
| if is_ont_token(w): | |
| o.append(" ") | |
| return o | |
| num_word_errors, num_char_errors = 0, 0 | |
| num_chars, num_words = 0, 0 | |
| num_em_errors, num_ems = 0, 0 | |
| num_tree_errors, num_trees = 0, 0 | |
| gen_out = self.inference_step(generator, [model], sample, None) | |
| for i in range(len(gen_out)): | |
| hyp_tokens = gen_out[i][0]["tokens"] | |
| # hyp = decode(hyp_tokens) | |
| ref_tokens = utils.strip_pad( | |
| sample["target"][i], self.target_dictionary.pad() | |
| ) | |
| # ref = decode(ref_tokens) | |
| hyp_list = decode_to_list(hyp_tokens) | |
| ref_list = decode_to_list(ref_tokens) | |
| hyp_list = post_process(hyp_list) | |
| ref_list = post_process(ref_list) | |
| hyp = "".join(hyp_list).strip() | |
| ref = "".join(ref_list).strip() | |
| num_chars += len(ref) | |
| num_char_errors += editdistance.eval(hyp, ref) | |
| hyp_words = hyp.split() | |
| ref_words = ref.split() | |
| hyp_tree = [word for word in hyp_list if ("[" in word or "]" in word)] | |
| ref_tree = [word for word in ref_list if ("[" in word or "]" in word)] | |
| # num_word_errors += editdistance.eval(hyp_words, ref_words) | |
| hyp_before = decode(hyp_tokens).split() | |
| ref_before = decode(ref_tokens).split() | |
| num_word_errors += editdistance.eval(hyp_before, ref_before) | |
| num_words += len(ref_before) | |
| if hyp != ref: | |
| num_em_errors += 1 | |
| if hyp_tree != ref_tree: | |
| num_tree_errors += 1 | |
| num_ems += 1 | |
| num_trees += 1 | |
| return { | |
| "num_char_errors": num_char_errors, | |
| "num_chars": num_chars, | |
| "num_word_errors": num_word_errors, | |
| "num_words": num_words, | |
| "num_ems": num_ems, | |
| "num_em_errors": num_em_errors, | |
| "num_trees": num_trees, | |
| "num_tree_errors": num_tree_errors, | |
| } | |
| def _inference_with_wer(self, generator, sample, model): | |
| import editdistance | |
| def decode(toks): | |
| s = self.target_dictionary.string( | |
| toks.int().cpu(), | |
| self.cfg.eval_wer_post_process, | |
| escape_unk=True, | |
| ) | |
| if self.tokenizer: | |
| s = self.tokenizer.decode(s) | |
| return s | |
| num_word_errors, num_char_errors = 0, 0 | |
| num_chars, num_words = 0, 0 | |
| gen_out = self.inference_step(generator, [model], sample, None) | |
| for i in range(len(gen_out)): | |
| hyp = decode(gen_out[i][0]["tokens"]) | |
| ref = decode( | |
| utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), | |
| ) | |
| num_char_errors += editdistance.eval(hyp, ref) | |
| num_chars += len(ref) | |
| hyp_words = hyp.split() | |
| ref_words = ref.split() | |
| num_word_errors += editdistance.eval(hyp_words, ref_words) | |
| num_words += len(ref_words) | |
| return { | |
| "num_char_errors": num_char_errors, | |
| "num_chars": num_chars, | |
| "num_word_errors": num_word_errors, | |
| "num_words": num_words, | |
| } | |
| def _inference_with_bleu(self, generator, sample, model): | |
| import sacrebleu | |
| def decode(toks, is_ref): | |
| s = self.target_dictionary.string( | |
| toks.int().cpu(), | |
| self.cfg.eval_bleu_remove_bpe, | |
| # The default unknown string in fairseq is `<unk>`, but | |
| # this is tokenized by sacrebleu as `< unk >`, inflating | |
| # BLEU scores. Instead, we use a somewhat more verbose | |
| # alternative that is unlikely to appear in the real | |
| # reference, but doesn't get split into multiple tokens. | |
| unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"), | |
| ) | |
| if self.tokenizer: | |
| s = self.tokenizer.decode(s) | |
| return s | |
| gen_out = self.inference_step(generator, [model], sample) | |
| hyps, refs = [], [] | |
| for i in range(len(gen_out)): | |
| hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False)) | |
| refs.append( | |
| decode( | |
| utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), | |
| is_ref=True, # don't count <unk> as matches to the hypo | |
| ) | |
| ) | |
| if self.cfg.eval_bleu_print_samples: | |
| logger.info("H-{} {}".format(sample["id"][0], hyps[0])) | |
| logger.info("T-{} {}".format(sample["id"][0], refs[0])) | |
| eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a" | |
| return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization) | |
| def reduce_metrics(self, logging_outputs, criterion): | |
| super().reduce_metrics(logging_outputs, criterion) | |
| if self.cfg.eval_wer or self.cfg.eval_wer_parse: | |
| zero = torch.scalar_tensor(0.0) | |
| num_char_errors = sum( | |
| log.get("_num_char_errors", zero) for log in logging_outputs | |
| ) | |
| num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) | |
| num_word_errors = sum( | |
| log.get("_num_word_errors", zero) for log in logging_outputs | |
| ) | |
| num_words = sum(log.get("_num_words", zero) for log in logging_outputs) | |
| metrics.log_scalar("_num_char_errors", num_char_errors) | |
| metrics.log_scalar("_num_chars", num_chars) | |
| metrics.log_scalar("_num_word_errors", num_word_errors) | |
| metrics.log_scalar("_num_words", num_words) | |
| if num_chars > 0: | |
| metrics.log_derived( | |
| "uer", | |
| lambda meters: meters["_num_char_errors"].sum | |
| * 100.0 | |
| / meters["_num_chars"].sum | |
| if meters["_num_chars"].sum > 0 | |
| else float("nan"), | |
| ) | |
| if num_words > 0: | |
| metrics.log_derived( | |
| "wer", | |
| lambda meters: meters["_num_word_errors"].sum | |
| * 100.0 | |
| / meters["_num_words"].sum | |
| if meters["_num_words"].sum > 0 | |
| else float("nan"), | |
| ) | |
| if self.cfg.eval_wer_parse: | |
| num_em_errors = sum( | |
| log.get("_num_em_errors", zero) for log in logging_outputs | |
| ) | |
| num_ems = sum(log.get("_num_ems", zero) for log in logging_outputs) | |
| metrics.log_scalar("_num_em_errors", num_em_errors) | |
| metrics.log_scalar("_num_ems", num_ems) | |
| num_tree_errors = sum( | |
| log.get("_num_tree_errors", zero) for log in logging_outputs | |
| ) | |
| num_trees = sum(log.get("_num_trees", zero) for log in logging_outputs) | |
| metrics.log_scalar("_num_tree_errors", num_tree_errors) | |
| metrics.log_scalar("_num_trees", num_trees) | |
| if num_ems > 0: | |
| metrics.log_derived( | |
| "em_error", | |
| lambda meters: meters["_num_em_errors"].sum | |
| * 100.0 | |
| / meters["_num_ems"].sum | |
| if meters["_num_ems"].sum > 0 | |
| else float("nan"), | |
| ) | |
| if num_trees > 0: | |
| metrics.log_derived( | |
| "tree_error", | |
| lambda meters: meters["_num_tree_errors"].sum | |
| * 100.0 | |
| / meters["_num_trees"].sum | |
| if meters["_num_trees"].sum > 0 | |
| else float("nan"), | |
| ) | |
| if self.cfg.eval_bleu: | |
| len_keys = ["_bleu_sys_len", "_bleu_ref_len"] | |
| count_keys = [f"_bleu_counts_{i}" for i in range(4)] | |
| total_keys = [f"_bleu_totals_{i}" for i in range(4)] | |
| for k in len_keys + count_keys + total_keys: | |
| metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs)) | |
| import sacrebleu | |
| metrics.log_derived( | |
| "bleu", | |
| lambda meters: sacrebleu.compute_bleu( | |
| correct=[meters[k].sum for k in count_keys], | |
| total=[meters[k].sum for k in total_keys], | |
| sys_len=meters["_bleu_sys_len"].sum, | |
| ref_len=meters["_bleu_ref_len"].sum, | |
| smooth_method="exp", | |
| ).score, | |
| ) | |