Spaces:
Running
Running
| # Copyright (c) 2017-present, Facebook, Inc. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the LICENSE file in | |
| # the root directory of this source tree. An additional grant of patent rights | |
| # can be found in the PATENTS file in the same directory. | |
| from fairseq import checkpoint_utils | |
| from fairseq.models import ( | |
| register_model, | |
| register_model_architecture, | |
| ) | |
| from fairseq.models.speech_to_text import ( | |
| ConvTransformerModel, | |
| convtransformer_espnet, | |
| ConvTransformerEncoder, | |
| ) | |
| from fairseq.models.speech_to_text.modules.augmented_memory_attention import ( | |
| augmented_memory, | |
| SequenceEncoder, | |
| AugmentedMemoryConvTransformerEncoder, | |
| ) | |
| from torch import nn, Tensor | |
| from typing import Dict, List | |
| from fairseq.models.speech_to_text.modules.emformer import NoSegAugmentedMemoryTransformerEncoderLayer | |
| class SimulConvTransformerModel(ConvTransformerModel): | |
| """ | |
| Implementation of the paper: | |
| SimulMT to SimulST: Adapting Simultaneous Text Translation to | |
| End-to-End Simultaneous Speech Translation | |
| https://www.aclweb.org/anthology/2020.aacl-main.58.pdf | |
| """ | |
| def add_args(parser): | |
| super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser) | |
| parser.add_argument( | |
| "--train-monotonic-only", | |
| action="store_true", | |
| default=False, | |
| help="Only train monotonic attention", | |
| ) | |
| def build_decoder(cls, args, task, embed_tokens): | |
| tgt_dict = task.tgt_dict | |
| from examples.simultaneous_translation.models.transformer_monotonic_attention import ( | |
| TransformerMonotonicDecoder, | |
| ) | |
| decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) | |
| if getattr(args, "load_pretrained_decoder_from", None): | |
| decoder = checkpoint_utils.load_pretrained_component_from_model( | |
| component=decoder, checkpoint=args.load_pretrained_decoder_from | |
| ) | |
| return decoder | |
| def convtransformer_simul_trans_espnet(args): | |
| convtransformer_espnet(args) | |
| class AugmentedMemoryConvTransformerModel(SimulConvTransformerModel): | |
| def build_encoder(cls, args): | |
| encoder = SequenceEncoder(args, AugmentedMemoryConvTransformerEncoder(args)) | |
| if getattr(args, "load_pretrained_encoder_from", None) is not None: | |
| encoder = checkpoint_utils.load_pretrained_component_from_model( | |
| component=encoder, checkpoint=args.load_pretrained_encoder_from | |
| ) | |
| return encoder | |
| def augmented_memory_convtransformer_espnet(args): | |
| convtransformer_espnet(args) | |
| # ============================================================================ # | |
| # Convtransformer | |
| # with monotonic attention decoder | |
| # with emformer encoder | |
| # ============================================================================ # | |
| class ConvTransformerEmformerEncoder(ConvTransformerEncoder): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| stride = self.conv_layer_stride(args) | |
| trf_left_context = args.segment_left_context // stride | |
| trf_right_context = args.segment_right_context // stride | |
| context_config = [trf_left_context, trf_right_context] | |
| self.transformer_layers = nn.ModuleList( | |
| [ | |
| NoSegAugmentedMemoryTransformerEncoderLayer( | |
| input_dim=args.encoder_embed_dim, | |
| num_heads=args.encoder_attention_heads, | |
| ffn_dim=args.encoder_ffn_embed_dim, | |
| num_layers=args.encoder_layers, | |
| dropout_in_attn=args.dropout, | |
| dropout_on_attn=args.dropout, | |
| dropout_on_fc1=args.dropout, | |
| dropout_on_fc2=args.dropout, | |
| activation_fn=args.activation_fn, | |
| context_config=context_config, | |
| segment_size=args.segment_length, | |
| max_memory_size=args.max_memory_size, | |
| scaled_init=True, # TODO: use constant for now. | |
| tanh_on_mem=args.amtrf_tanh_on_mem, | |
| ) | |
| ] | |
| ) | |
| self.conv_transformer_encoder = ConvTransformerEncoder(args) | |
| def forward(self, src_tokens, src_lengths): | |
| encoder_out: Dict[str, List[Tensor]] = self.conv_transformer_encoder(src_tokens, src_lengths.to(src_tokens.device)) | |
| output = encoder_out["encoder_out"][0] | |
| encoder_padding_masks = encoder_out["encoder_padding_mask"] | |
| return { | |
| "encoder_out": [output], | |
| # This is because that in the original implementation | |
| # the output didn't consider the last segment as right context. | |
| "encoder_padding_mask": [encoder_padding_masks[0][:, : output.size(0)]] if len(encoder_padding_masks) > 0 | |
| else [], | |
| "encoder_embedding": [], | |
| "encoder_states": [], | |
| "src_tokens": [], | |
| "src_lengths": [], | |
| } | |
| def conv_layer_stride(args): | |
| # TODO: make it configurable from the args | |
| return 4 | |
| class ConvtransformerEmformer(SimulConvTransformerModel): | |
| def add_args(parser): | |
| super(ConvtransformerEmformer, ConvtransformerEmformer).add_args(parser) | |
| parser.add_argument( | |
| "--segment-length", | |
| type=int, | |
| metavar="N", | |
| help="length of each segment (not including left context / right context)", | |
| ) | |
| parser.add_argument( | |
| "--segment-left-context", | |
| type=int, | |
| help="length of left context in a segment", | |
| ) | |
| parser.add_argument( | |
| "--segment-right-context", | |
| type=int, | |
| help="length of right context in a segment", | |
| ) | |
| parser.add_argument( | |
| "--max-memory-size", | |
| type=int, | |
| default=-1, | |
| help="Right context for the segment.", | |
| ) | |
| parser.add_argument( | |
| "--amtrf-tanh-on-mem", | |
| default=False, | |
| action="store_true", | |
| help="whether to use tanh on memory vector", | |
| ) | |
| def build_encoder(cls, args): | |
| encoder = ConvTransformerEmformerEncoder(args) | |
| if getattr(args, "load_pretrained_encoder_from", None): | |
| encoder = checkpoint_utils.load_pretrained_component_from_model( | |
| component=encoder, checkpoint=args.load_pretrained_encoder_from | |
| ) | |
| return encoder | |
| def convtransformer_emformer_base(args): | |
| convtransformer_espnet(args) | |