Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Dict, List, NamedTuple, Optional | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| EncoderOut = NamedTuple( | |
| "EncoderOut", | |
| [ | |
| ("encoder_out", Tensor), # T x B x C | |
| ("encoder_padding_mask", Optional[Tensor]), # B x T | |
| ("encoder_embedding", Optional[Tensor]), # B x T x C | |
| ("encoder_states", Optional[List[Tensor]]), # List[T x B x C] | |
| ("src_tokens", Optional[Tensor]), # B x T | |
| ("src_lengths", Optional[Tensor]), # B x 1 | |
| ], | |
| ) | |
| class FairseqEncoder(nn.Module): | |
| """Base class for encoders.""" | |
| def __init__(self, dictionary): | |
| super().__init__() | |
| self.dictionary = dictionary | |
| def forward(self, src_tokens, src_lengths=None, **kwargs): | |
| """ | |
| Args: | |
| src_tokens (LongTensor): tokens in the source language of shape | |
| `(batch, src_len)` | |
| src_lengths (LongTensor): lengths of each source sentence of shape | |
| `(batch)` | |
| """ | |
| raise NotImplementedError | |
| def forward_torchscript(self, net_input: Dict[str, Tensor]): | |
| """A TorchScript-compatible version of forward. | |
| Encoders which use additional arguments may want to override | |
| this method for TorchScript compatibility. | |
| """ | |
| if torch.jit.is_scripting(): | |
| return self.forward( | |
| src_tokens=net_input["src_tokens"], | |
| src_lengths=net_input["src_lengths"], | |
| ) | |
| else: | |
| return self.forward_non_torchscript(net_input) | |
| def forward_non_torchscript(self, net_input: Dict[str, Tensor]): | |
| encoder_input = { | |
| k: v for k, v in net_input.items() if k != "prev_output_tokens" | |
| } | |
| return self.forward(**encoder_input) | |
| def reorder_encoder_out(self, encoder_out, new_order): | |
| """ | |
| Reorder encoder output according to `new_order`. | |
| Args: | |
| encoder_out: output from the ``forward()`` method | |
| new_order (LongTensor): desired order | |
| Returns: | |
| `encoder_out` rearranged according to `new_order` | |
| """ | |
| raise NotImplementedError | |
| def max_positions(self): | |
| """Maximum input length supported by the encoder.""" | |
| return 1e6 # an arbitrary large number | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| """Upgrade old state dicts to work with newer code.""" | |
| return state_dict | |
| def set_num_updates(self, num_updates): | |
| """State from trainer to pass along to model at every update.""" | |
| def _apply(m): | |
| if hasattr(m, "set_num_updates") and m != self: | |
| m.set_num_updates(num_updates) | |
| self.apply(_apply) | |