STAR / fairseq /modules /transformer_layer_aug.py
Yixuan Li
add fairseq folder
85ba398
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional
import torch
from numpy.random import uniform
from torch import Tensor
from fairseq.modules import LayerNorm
from fairseq.modules.transformer_layer import TransformerDecoderLayerBase
class AugTransformerDecoderLayerBase(TransformerDecoderLayerBase):
"""Decoder layer block augmented with an additional cross-attention.
This decoder block is processed with the sequence of the following sub-modules.
self-attention -> cross-attention (first) -> cross-attention (second) -> FFN
Args:
cfg (argparse.Namespace): parsed command-line arguments
encoder_attn_merge_type (str, optional): the way to combine outputs from
two cross-attention modules. If "sequential" is set, two cross-attention
modules are stacked sequentially. If "parallel" is set, they are processed
in parallel and combined before feeding it to FFN (default: sequential).
dropnet_ratio (float, optional): a probability to drop each cross-attention
module during training (default: 0.0).
"""
def __init__(
self,
cfg,
add_bias_kv=False,
add_zero_attn=False,
encoder_attn_merge_type="sequential",
dropnet_ratio=0.0,
):
super().__init__(
cfg,
no_encoder_attn=False,
add_bias_kv=add_bias_kv,
add_zero_attn=False,
)
self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.encoder_attn2 = self.build_encoder_attention(self.embed_dim, cfg)
if encoder_attn_merge_type == "sequential":
self.encoder_attn_layer_norm2 = LayerNorm(self.embed_dim, export=cfg.export)
else:
self.encoder_attn_layer_norm2 = None
self.encoder_attn_merge_type = encoder_attn_merge_type
self.dropnet_ratio = dropnet_ratio
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
encoder_out_aug: Optional[torch.Tensor] = None,
encoder_padding_mask2: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
if self.c_attn is not None:
tgt_len, bsz = x.size(0), x.size(1)
x = x.view(tgt_len, bsz, self.nh, self.head_dim)
x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
x = x.reshape(tgt_len, bsz, self.embed_dim)
if self.attn_ln is not None:
x = self.attn_ln(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
assert encoder_out is not None
assert encoder_out_aug is not None
if self.encoder_attn_merge_type == "sequential":
ratios = self.get_dropnet_ratio()
# first encoder attention
if ratios[0] > 0:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x = ratios[0] * x
# second encoder attention
if ratios[1] > 0:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm2(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn2._set_input_buffer(incremental_state, saved_state)
x, attn2 = self.encoder_attn2(
query=x,
key=encoder_out_aug,
value=encoder_out_aug,
key_padding_mask=encoder_padding_mask2,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm2(x)
x = ratios[1] * x
elif self.encoder_attn_merge_type == "parallel":
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x1, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x2, attn2 = self.encoder_attn2(
query=x,
key=encoder_out_aug,
value=encoder_out_aug,
key_padding_mask=encoder_padding_mask2,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x1 = self.dropout_module(x1)
x2 = self.dropout_module(x2)
ratios = self.get_dropnet_ratio()
x = ratios[0] * x1 + ratios[1] * x2
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
else:
raise NotImplementedError(self.encoder_attn_merge_type)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)
x = self.fc2(x)
x = self.dropout_module(x)
if self.w_resid is not None:
residual = torch.mul(self.w_resid, residual)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, attn2, self_attn_state
return x, attn, attn2, None
def get_dropnet_ratio(self):
if self.encoder_attn_merge_type == "sequential":
if self.dropnet_ratio > 0:
frand = float(uniform(0, 1))
if frand < self.dropnet_ratio and self.training:
return [2, 0]
elif frand > 1 - self.dropnet_ratio and self.training:
return [0, 2]
else:
return [1, 1]
else:
return [1, 1]
elif self.encoder_attn_merge_type == "parallel":
if self.dropnet_ratio > 0:
frand = float(uniform(0, 1))
if frand < self.dropnet_ratio and self.training:
return [1, 0]
elif frand > 1 - self.dropnet_ratio and self.training:
return [0, 1]
else:
return [0.5, 0.5]
else:
return [0.5, 0.5]