STAR / fairseq /modules /conformer_layer.py
Yixuan Li
add fairseq folder
85ba398
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from fairseq.modules import (
ESPNETMultiHeadedAttention,
LayerNorm,
MultiheadAttention,
RelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
)
from fairseq.utils import get_activation_fn
class ConvolutionModule(torch.nn.Module):
"""Convolution block used in the conformer block"""
def __init__(
self,
embed_dim,
channels,
depthwise_kernel_size,
dropout,
activation_fn="swish",
bias=False,
export=False,
):
"""
Args:
embed_dim: Embedding dimension
channels: Number of channels in depthwise conv layers
depthwise_kernel_size: Depthwise conv layer kernel size
dropout: dropout value
activation_fn: Activation function to use after depthwise convolution kernel
bias: If bias should be added to conv layers
export: If layernorm should be exported to jit
"""
super(ConvolutionModule, self).__init__()
assert (
depthwise_kernel_size - 1
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
self.layer_norm = LayerNorm(embed_dim, export=export)
self.pointwise_conv1 = torch.nn.Conv1d(
embed_dim,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.glu = torch.nn.GLU(dim=1)
self.depthwise_conv = torch.nn.Conv1d(
channels,
channels,
depthwise_kernel_size,
stride=1,
padding=(depthwise_kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.batch_norm = torch.nn.BatchNorm1d(channels)
self.activation = get_activation_fn(activation_fn)(channels)
self.pointwise_conv2 = torch.nn.Conv1d(
channels,
embed_dim,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: Input of shape B X T X C
Returns:
Tensor of shape B X T X C
"""
x = self.layer_norm(x)
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = self.glu(x) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.batch_norm(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = self.dropout(x)
return x.transpose(1, 2)
class FeedForwardModule(torch.nn.Module):
"""Positionwise feed forward layer used in conformer"""
def __init__(
self,
input_feat,
hidden_units,
dropout1,
dropout2,
activation_fn="swish",
bias=True,
):
"""
Args:
input_feat: Input feature dimension
hidden_units: Hidden unit dimension
dropout1: dropout value for layer1
dropout2: dropout value for layer2
activation_fn: Name of activation function
bias: If linear layers should have bias
"""
super(FeedForwardModule, self).__init__()
self.layer_norm = LayerNorm(input_feat)
self.w_1 = torch.nn.Linear(input_feat, hidden_units, bias=bias)
self.w_2 = torch.nn.Linear(hidden_units, input_feat, bias=bias)
self.dropout1 = torch.nn.Dropout(dropout1)
self.dropout2 = torch.nn.Dropout(dropout2)
self.activation = get_activation_fn(activation_fn)(hidden_units)
def forward(self, x):
"""
Args:
x: Input Tensor of shape T X B X C
Returns:
Tensor of shape T X B X C
"""
x = self.layer_norm(x)
x = self.w_1(x)
x = self.activation(x)
x = self.dropout1(x)
x = self.w_2(x)
return self.dropout2(x)
class ConformerEncoderLayer(torch.nn.Module):
"""Conformer block based on https://arxiv.org/abs/2005.08100. We currently don't support relative positional encoding in MHA"""
def __init__(
self,
embed_dim,
ffn_embed_dim,
attention_heads,
dropout,
use_fp16,
depthwise_conv_kernel_size=31,
activation_fn="swish",
attn_type=None,
pos_enc_type="abs",
):
"""
Args:
embed_dim: Input embedding dimension
ffn_embed_dim: FFN layer dimension
attention_heads: Number of attention heads in MHA
dropout: dropout value
depthwise_conv_kernel_size: Size of kernel in depthwise conv layer in convolution module
activation_fn: Activation function name to use in convulation block and feed forward block
attn_type: MHA implementation from ESPNET vs fairseq
pos_enc_type: Positional encoding type - abs, rope, rel_pos
"""
self.pos_enc_type = pos_enc_type
super(ConformerEncoderLayer, self).__init__()
self.ffn1 = FeedForwardModule(
embed_dim,
ffn_embed_dim,
dropout,
dropout,
)
self.self_attn_layer_norm = LayerNorm(embed_dim, export=False)
self.self_attn_dropout = torch.nn.Dropout(dropout)
if attn_type == "espnet":
if self.pos_enc_type == "rel_pos":
self.self_attn = RelPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
)
elif self.pos_enc_type == "rope":
self.self_attn = RotaryPositionMultiHeadedAttention(
embed_dim, attention_heads, dropout=dropout, precision=use_fp16
)
elif self.pos_enc_type == "abs":
self.self_attn = ESPNETMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
)
else:
raise Exception(f"Unsupported attention type {self.pos_enc_type}")
else:
# Default to fairseq MHA
self.self_attn = MultiheadAttention(
embed_dim,
attention_heads,
dropout=dropout,
)
self.conv_module = ConvolutionModule(
embed_dim=embed_dim,
channels=embed_dim,
depthwise_kernel_size=depthwise_conv_kernel_size,
dropout=dropout,
activation_fn=activation_fn,
)
self.ffn2 = FeedForwardModule(
embed_dim,
ffn_embed_dim,
dropout,
dropout,
activation_fn=activation_fn,
)
self.final_layer_norm = LayerNorm(embed_dim, export=False)
def forward(
self,
x,
encoder_padding_mask: Optional[torch.Tensor],
position_emb: Optional[torch.Tensor] = None,
):
"""
Args:
x: Tensor of shape T X B X C
encoder_padding_mask: Optional mask tensor
positions:
Returns:
Tensor of shape T X B X C
"""
residual = x
x = self.ffn1(x)
x = x * 0.5 + residual
residual = x
x = self.self_attn_layer_norm(x)
if self.pos_enc_type == "rel_pos":
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
pos_emb=position_emb,
need_weights=False,
)
else:
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
)
x = self.self_attn_dropout(x)
x = x + residual
residual = x
# TBC to BTC
x = x.transpose(0, 1)
x = self.conv_module(x)
# BTC to TBC
x = x.transpose(0, 1)
x = residual + x
residual = x
x = self.ffn2(x)
layer_result = x
x = x * 0.5 + residual
x = self.final_layer_norm(x)
return x, (attn, layer_result)
class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer):
"""Encoder layer for Wav2vec2 encoder"""
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
att_args=None,
position_emb=None,
):
return super().forward(x, self_attn_padding_mask, position_emb)