Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| from typing import Dict, Optional | |
| from fairseq.incremental_decoding_utils import with_incremental_state | |
| from fairseq.models import FairseqDecoder | |
| from torch import Tensor | |
| logger = logging.getLogger(__name__) | |
| class FairseqIncrementalDecoder(FairseqDecoder): | |
| """Base class for incremental decoders. | |
| Incremental decoding is a special mode at inference time where the Model | |
| only receives a single timestep of input corresponding to the previous | |
| output token (for teacher forcing) and must produce the next output | |
| *incrementally*. Thus the model must cache any long-term state that is | |
| needed about the sequence, e.g., hidden states, convolutional states, etc. | |
| Compared to the standard :class:`FairseqDecoder` interface, the incremental | |
| decoder interface allows :func:`forward` functions to take an extra keyword | |
| argument (*incremental_state*) that can be used to cache state across | |
| time-steps. | |
| The :class:`FairseqIncrementalDecoder` interface also defines the | |
| :func:`reorder_incremental_state` method, which is used during beam search | |
| to select and reorder the incremental state based on the selection of beams. | |
| To learn more about how incremental decoding works, refer to `this blog | |
| <http://www.telesens.co/2019/04/21/understanding-incremental-decoding-in-fairseq/>`_. | |
| """ | |
| def __init__(self, dictionary): | |
| super().__init__(dictionary) | |
| def forward( | |
| self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs | |
| ): | |
| """ | |
| Args: | |
| prev_output_tokens (LongTensor): shifted output tokens of shape | |
| `(batch, tgt_len)`, for teacher forcing | |
| encoder_out (dict, optional): output from the encoder, used for | |
| encoder-side attention | |
| incremental_state (dict, optional): dictionary used for storing | |
| state during :ref:`Incremental decoding` | |
| Returns: | |
| tuple: | |
| - the decoder's output of shape `(batch, tgt_len, vocab)` | |
| - a dictionary with any model-specific outputs | |
| """ | |
| raise NotImplementedError | |
| def extract_features( | |
| self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs | |
| ): | |
| """ | |
| Returns: | |
| tuple: | |
| - the decoder's features of shape `(batch, tgt_len, embed_dim)` | |
| - a dictionary with any model-specific outputs | |
| """ | |
| raise NotImplementedError | |
| def reorder_incremental_state( | |
| self, | |
| incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |
| new_order: Tensor, | |
| ): | |
| """Reorder incremental state. | |
| This will be called when the order of the input has changed from the | |
| previous time step. A typical use case is beam search, where the input | |
| order changes between time steps based on the selection of beams. | |
| """ | |
| pass | |
| def reorder_incremental_state_scripting( | |
| self, | |
| incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |
| new_order: Tensor, | |
| ): | |
| """Main entry point for reordering the incremental state. | |
| Due to limitations in TorchScript, we call this function in | |
| :class:`fairseq.sequence_generator.SequenceGenerator` instead of | |
| calling :func:`reorder_incremental_state` directly. | |
| """ | |
| for module in self.modules(): | |
| if hasattr(module, "reorder_incremental_state"): | |
| result = module.reorder_incremental_state(incremental_state, new_order) | |
| if result is not None: | |
| incremental_state = result | |
| def set_beam_size(self, beam_size): | |
| """Sets the beam size in the decoder and all children.""" | |
| if getattr(self, "_beam_size", -1) != beam_size: | |
| seen = set() | |
| def apply_set_beam_size(module): | |
| if ( | |
| module != self | |
| and hasattr(module, "set_beam_size") | |
| and module not in seen | |
| ): | |
| seen.add(module) | |
| module.set_beam_size(beam_size) | |
| self.apply(apply_set_beam_size) | |
| self._beam_size = beam_size | |