Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from collections import namedtuple | |
| import numpy as np | |
| import torch | |
| from fairseq import utils | |
| DecoderOut = namedtuple( | |
| "IterativeRefinementDecoderOut", | |
| ["output_tokens", "output_scores", "attn", "step", "max_step", "history"], | |
| ) | |
| class IterativeRefinementGenerator(object): | |
| def __init__( | |
| self, | |
| tgt_dict, | |
| models=None, | |
| eos_penalty=0.0, | |
| max_iter=10, | |
| max_ratio=2, | |
| beam_size=1, | |
| decoding_format=None, | |
| retain_dropout=False, | |
| adaptive=True, | |
| retain_history=False, | |
| reranking=False, | |
| ): | |
| """ | |
| Generates translations based on iterative refinement. | |
| Args: | |
| tgt_dict: target dictionary | |
| eos_penalty: if > 0.0, it penalized early-stopping in decoding | |
| max_iter: maximum number of refinement iterations | |
| max_ratio: generate sequences of maximum length ax, where x is the source length | |
| decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'} | |
| retain_dropout: retaining dropout in the inference | |
| adaptive: decoding with early stop | |
| """ | |
| self.bos = tgt_dict.bos() | |
| self.pad = tgt_dict.pad() | |
| self.unk = tgt_dict.unk() | |
| self.eos = tgt_dict.eos() | |
| self.vocab_size = len(tgt_dict) | |
| self.eos_penalty = eos_penalty | |
| self.max_iter = max_iter | |
| self.max_ratio = max_ratio | |
| self.beam_size = beam_size | |
| self.reranking = reranking | |
| self.decoding_format = decoding_format | |
| self.retain_dropout = retain_dropout | |
| self.retain_history = retain_history | |
| self.adaptive = adaptive | |
| self.models = models | |
| def generate_batched_itr( | |
| self, | |
| data_itr, | |
| maxlen_a=None, | |
| maxlen_b=None, | |
| cuda=False, | |
| timer=None, | |
| prefix_size=0, | |
| ): | |
| """Iterate over a batched dataset and yield individual translations. | |
| Args: | |
| maxlen_a/b: generate sequences of maximum length ax + b, | |
| where x is the source sentence length. | |
| cuda: use GPU for generation | |
| timer: StopwatchMeter for timing generations. | |
| """ | |
| for sample in data_itr: | |
| if "net_input" not in sample: | |
| continue | |
| if timer is not None: | |
| timer.start() | |
| with torch.no_grad(): | |
| hypos = self.generate( | |
| self.models, | |
| sample, | |
| prefix_tokens=sample["target"][:, :prefix_size] | |
| if prefix_size > 0 | |
| else None, | |
| ) | |
| if timer is not None: | |
| timer.stop(sample["ntokens"]) | |
| for i, id in enumerate(sample["id"]): | |
| # remove padding | |
| src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad) | |
| ref = utils.strip_pad(sample["target"][i, :], self.pad) | |
| yield id, src, ref, hypos[i] | |
| def generate(self, models, sample, prefix_tokens=None, constraints=None): | |
| if constraints is not None: | |
| raise NotImplementedError( | |
| "Constrained decoding with the IterativeRefinementGenerator is not supported" | |
| ) | |
| # TODO: iterative refinement generator does not support ensemble for now. | |
| if not self.retain_dropout: | |
| for model in models: | |
| model.eval() | |
| model, reranker = models[0], None | |
| if self.reranking: | |
| assert len(models) > 1, "Assuming the last checkpoint is the reranker" | |
| assert ( | |
| self.beam_size > 1 | |
| ), "Reranking requires multiple translation for each example" | |
| reranker = models[-1] | |
| models = models[:-1] | |
| if len(models) > 1 and hasattr(model, "enable_ensemble"): | |
| assert model.allow_ensemble, "{} does not support ensembling".format( | |
| model.__class__.__name__ | |
| ) | |
| model.enable_ensemble(models) | |
| # TODO: better encoder inputs? | |
| src_tokens = sample["net_input"]["src_tokens"] | |
| src_lengths = sample["net_input"]["src_lengths"] | |
| bsz, src_len = src_tokens.size() | |
| # initialize | |
| encoder_out = model.forward_encoder([src_tokens, src_lengths]) | |
| prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) | |
| if self.beam_size > 1: | |
| assert ( | |
| model.allow_length_beam | |
| ), "{} does not support decoding with length beam.".format( | |
| model.__class__.__name__ | |
| ) | |
| # regenerate data based on length-beam | |
| length_beam_order = ( | |
| utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1) | |
| ) | |
| encoder_out = model.encoder.reorder_encoder_out( | |
| encoder_out, length_beam_order | |
| ) | |
| prev_decoder_out = model.regenerate_length_beam( | |
| prev_decoder_out, self.beam_size | |
| ) | |
| bsz = bsz * self.beam_size | |
| sent_idxs = torch.arange(bsz) | |
| prev_output_tokens = prev_decoder_out.output_tokens.clone() | |
| if self.retain_history: | |
| prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens]) | |
| finalized = [[] for _ in range(bsz)] | |
| def is_a_loop(x, y, s, a): | |
| b, l_x, l_y = x.size(0), x.size(1), y.size(1) | |
| if l_x > l_y: | |
| y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) | |
| s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) | |
| if a is not None: | |
| a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) | |
| elif l_x < l_y: | |
| x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) | |
| return (x == y).all(1), y, s, a | |
| def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): | |
| cutoff = prev_out_token.ne(self.pad) | |
| tokens = prev_out_token[cutoff] | |
| if prev_out_score is None: | |
| scores, score = None, None | |
| else: | |
| scores = prev_out_score[cutoff] | |
| score = scores.mean() | |
| if prev_out_attn is None: | |
| hypo_attn, alignment = None, None | |
| else: | |
| hypo_attn = prev_out_attn[cutoff] | |
| alignment = hypo_attn.max(dim=1)[1] | |
| return { | |
| "steps": step, | |
| "tokens": tokens, | |
| "positional_scores": scores, | |
| "score": score, | |
| "hypo_attn": hypo_attn, | |
| "alignment": alignment, | |
| } | |
| for step in range(self.max_iter + 1): | |
| decoder_options = { | |
| "eos_penalty": self.eos_penalty, | |
| "max_ratio": self.max_ratio, | |
| "decoding_format": self.decoding_format, | |
| } | |
| prev_decoder_out = prev_decoder_out._replace( | |
| step=step, | |
| max_step=self.max_iter + 1, | |
| ) | |
| decoder_out = model.forward_decoder( | |
| prev_decoder_out, encoder_out, **decoder_options | |
| ) | |
| if self.adaptive: | |
| # terminate if there is a loop | |
| terminated, out_tokens, out_scores, out_attn = is_a_loop( | |
| prev_output_tokens, | |
| decoder_out.output_tokens, | |
| decoder_out.output_scores, | |
| decoder_out.attn, | |
| ) | |
| decoder_out = decoder_out._replace( | |
| output_tokens=out_tokens, | |
| output_scores=out_scores, | |
| attn=out_attn, | |
| ) | |
| else: | |
| terminated = decoder_out.output_tokens.new_zeros( | |
| decoder_out.output_tokens.size(0) | |
| ).bool() | |
| if step == self.max_iter: # reach last iteration, terminate | |
| terminated.fill_(1) | |
| # collect finalized sentences | |
| finalized_idxs = sent_idxs[terminated.to(sent_idxs.device)] | |
| finalized_tokens = decoder_out.output_tokens[terminated] | |
| finalized_scores = decoder_out.output_scores[terminated] | |
| finalized_attn = ( | |
| None | |
| if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) | |
| else decoder_out.attn[terminated] | |
| ) | |
| if self.retain_history: | |
| finalized_history_tokens = [h[terminated] for h in decoder_out.history] | |
| for i in range(finalized_idxs.size(0)): | |
| finalized[finalized_idxs[i]] = [ | |
| finalized_hypos( | |
| step, | |
| finalized_tokens[i], | |
| finalized_scores[i], | |
| None if finalized_attn is None else finalized_attn[i], | |
| ) | |
| ] | |
| if self.retain_history: | |
| finalized[finalized_idxs[i]][0]["history"] = [] | |
| for j in range(len(finalized_history_tokens)): | |
| finalized[finalized_idxs[i]][0]["history"].append( | |
| finalized_hypos( | |
| step, finalized_history_tokens[j][i], None, None | |
| ) | |
| ) | |
| # check if all terminated | |
| if terminated.sum() == terminated.size(0): | |
| break | |
| # for next step | |
| not_terminated = ~terminated | |
| prev_decoder_out = decoder_out._replace( | |
| output_tokens=decoder_out.output_tokens[not_terminated], | |
| output_scores=decoder_out.output_scores[not_terminated], | |
| attn=decoder_out.attn[not_terminated] | |
| if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0) | |
| else None, | |
| history=[h[not_terminated] for h in decoder_out.history] | |
| if decoder_out.history is not None | |
| else None, | |
| ) | |
| encoder_out = model.encoder.reorder_encoder_out( | |
| encoder_out, not_terminated.nonzero(as_tuple=False).squeeze() | |
| ) | |
| sent_idxs = sent_idxs[not_terminated.to(sent_idxs.device)] | |
| prev_output_tokens = prev_decoder_out.output_tokens.clone() | |
| if self.beam_size > 1: | |
| if reranker is not None: | |
| finalized = self.rerank( | |
| reranker, finalized, [src_tokens, src_lengths], self.beam_size | |
| ) | |
| # aggregate information from length beam | |
| finalized = [ | |
| finalized[ | |
| np.argmax( | |
| [ | |
| finalized[self.beam_size * i + j][0]["score"] | |
| for j in range(self.beam_size) | |
| ] | |
| ) | |
| + self.beam_size * i | |
| ] | |
| for i in range(len(finalized) // self.beam_size) | |
| ] | |
| return finalized | |
| def rerank(self, reranker, finalized, encoder_input, beam_size): | |
| def rebuild_batch(finalized): | |
| finalized_tokens = [f[0]["tokens"] for f in finalized] | |
| finalized_maxlen = max(f.size(0) for f in finalized_tokens) | |
| final_output_tokens = ( | |
| finalized_tokens[0] | |
| .new_zeros(len(finalized_tokens), finalized_maxlen) | |
| .fill_(self.pad) | |
| ) | |
| for i, f in enumerate(finalized_tokens): | |
| final_output_tokens[i, : f.size(0)] = f | |
| return final_output_tokens | |
| final_output_tokens = rebuild_batch(finalized) | |
| final_output_tokens[ | |
| :, 0 | |
| ] = self.eos # autoregressive model assumes starting with EOS | |
| reranker_encoder_out = reranker.encoder(*encoder_input) | |
| length_beam_order = ( | |
| utils.new_arange( | |
| final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1) | |
| ) | |
| .t() | |
| .reshape(-1) | |
| ) | |
| reranker_encoder_out = reranker.encoder.reorder_encoder_out( | |
| reranker_encoder_out, length_beam_order | |
| ) | |
| reranking_scores = reranker.get_normalized_probs( | |
| reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), | |
| True, | |
| None, | |
| ) | |
| reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None]) | |
| reranking_masks = final_output_tokens[:, 1:].ne(self.pad) | |
| reranking_scores = ( | |
| reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1) | |
| ) | |
| reranking_scores = reranking_scores / reranking_masks.sum(1).type_as( | |
| reranking_scores | |
| ) | |
| for i in range(len(finalized)): | |
| finalized[i][0]["score"] = reranking_scores[i] | |
| return finalized | |