Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from .fairseq_encoder import FairseqEncoder | |
| class CompositeEncoder(FairseqEncoder): | |
| """ | |
| A wrapper around a dictionary of :class:`FairseqEncoder` objects. | |
| We run forward on each encoder and return a dictionary of outputs. The first | |
| encoder's dictionary is used for initialization. | |
| Args: | |
| encoders (dict): a dictionary of :class:`FairseqEncoder` objects. | |
| """ | |
| def __init__(self, encoders): | |
| super().__init__(next(iter(encoders.values())).dictionary) | |
| self.encoders = encoders | |
| for key in self.encoders: | |
| self.add_module(key, self.encoders[key]) | |
| def forward(self, src_tokens, src_lengths): | |
| """ | |
| Args: | |
| src_tokens (LongTensor): tokens in the source language of shape | |
| `(batch, src_len)` | |
| src_lengths (LongTensor): lengths of each source sentence of shape | |
| `(batch)` | |
| Returns: | |
| dict: | |
| the outputs from each Encoder | |
| """ | |
| encoder_out = {} | |
| for key in self.encoders: | |
| encoder_out[key] = self.encoders[key](src_tokens, src_lengths) | |
| return encoder_out | |
| def reorder_encoder_out(self, encoder_out, new_order): | |
| """Reorder encoder output according to new_order.""" | |
| for key in self.encoders: | |
| encoder_out[key] = self.encoders[key].reorder_encoder_out( | |
| encoder_out[key], new_order | |
| ) | |
| return encoder_out | |
| def max_positions(self): | |
| return min(self.encoders[key].max_positions() for key in self.encoders) | |
| def upgrade_state_dict(self, state_dict): | |
| for key in self.encoders: | |
| self.encoders[key].upgrade_state_dict(state_dict) | |
| return state_dict | |