Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| class LocationAttention(nn.Module): | |
| """ | |
| Attention-Based Models for Speech Recognition | |
| https://arxiv.org/pdf/1506.07503.pdf | |
| :param int encoder_dim: # projection-units of encoder | |
| :param int decoder_dim: # units of decoder | |
| :param int attn_dim: attention dimension | |
| :param int conv_dim: # channels of attention convolution | |
| :param int conv_kernel_size: filter size of attention convolution | |
| """ | |
| def __init__( | |
| self, | |
| attn_dim, | |
| encoder_dim, | |
| decoder_dim, | |
| attn_state_kernel_size, | |
| conv_dim, | |
| conv_kernel_size, | |
| scaling=2.0, | |
| ): | |
| super(LocationAttention, self).__init__() | |
| self.attn_dim = attn_dim | |
| self.decoder_dim = decoder_dim | |
| self.scaling = scaling | |
| self.proj_enc = nn.Linear(encoder_dim, attn_dim) | |
| self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False) | |
| self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False) | |
| self.conv = nn.Conv1d( | |
| attn_state_kernel_size, | |
| conv_dim, | |
| 2 * conv_kernel_size + 1, | |
| padding=conv_kernel_size, | |
| bias=False, | |
| ) | |
| self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1)) | |
| self.proj_enc_out = None # cache | |
| def clear_cache(self): | |
| self.proj_enc_out = None | |
| def forward(self, encoder_out, encoder_padding_mask, decoder_h, attn_state): | |
| """ | |
| :param torch.Tensor encoder_out: padded encoder hidden state B x T x D | |
| :param torch.Tensor encoder_padding_mask: encoder padding mask | |
| :param torch.Tensor decoder_h: decoder hidden state B x D | |
| :param torch.Tensor attn_prev: previous attention weight B x K x T | |
| :return: attention weighted encoder state (B, D) | |
| :rtype: torch.Tensor | |
| :return: previous attention weights (B x T) | |
| :rtype: torch.Tensor | |
| """ | |
| bsz, seq_len, _ = encoder_out.size() | |
| if self.proj_enc_out is None: | |
| self.proj_enc_out = self.proj_enc(encoder_out) | |
| # B x K x T -> B x C x T | |
| attn = self.conv(attn_state) | |
| # B x C x T -> B x T x C -> B x T x D | |
| attn = self.proj_attn(attn.transpose(1, 2)) | |
| if decoder_h is None: | |
| decoder_h = encoder_out.new_zeros(bsz, self.decoder_dim) | |
| dec_h = self.proj_dec(decoder_h).view(bsz, 1, self.attn_dim) | |
| out = self.proj_out(attn + self.proj_enc_out + dec_h).squeeze(2) | |
| out.masked_fill_(encoder_padding_mask, -float("inf")) | |
| w = F.softmax(self.scaling * out, dim=1) | |
| c = torch.sum(encoder_out * w.view(bsz, seq_len, 1), dim=1) | |
| return c, w | |