Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Dict, List, Optional | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from fairseq import utils | |
| from fairseq.models.transformer import TransformerConfig | |
| from fairseq.modules import LayerNorm, MultiheadAttention | |
| from fairseq.modules.fairseq_dropout import FairseqDropout | |
| from fairseq.modules.quant_noise import quant_noise | |
| class TransformerEncoderLayerBase(nn.Module): | |
| """Encoder layer block. | |
| In the original paper each operation (multi-head attention or FFN) is | |
| postprocessed with: `dropout -> add residual -> layernorm`. In the | |
| tensor2tensor code they suggest that learning is more robust when | |
| preprocessing each layer with layernorm and postprocessing with: | |
| `dropout -> add residual`. We default to the approach in the paper, but the | |
| tensor2tensor approach can be enabled by setting | |
| *cfg.encoder.normalize_before* to ``True``. | |
| Args: | |
| cfg (argparse.Namespace): parsed command-line arguments | |
| """ | |
| def __init__(self, cfg, return_fc=False): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.return_fc = return_fc | |
| self.embed_dim = cfg.encoder.embed_dim | |
| self.quant_noise = cfg.quant_noise.pq | |
| self.quant_noise_block_size = cfg.quant_noise.pq_block_size | |
| self.self_attn = self.build_self_attention(self.embed_dim, cfg) | |
| self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) | |
| self.dropout_module = FairseqDropout( | |
| cfg.dropout, module_name=self.__class__.__name__ | |
| ) | |
| self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn) | |
| activation_dropout_p = cfg.activation_dropout | |
| if activation_dropout_p == 0: | |
| # for backwards compatibility with models that use cfg.relu_dropout | |
| activation_dropout_p = cfg.relu_dropout or 0 | |
| self.activation_dropout_module = FairseqDropout( | |
| float(activation_dropout_p), module_name=self.__class__.__name__ | |
| ) | |
| self.normalize_before = cfg.encoder.normalize_before | |
| self.fc1 = self.build_fc1( | |
| self.embed_dim, | |
| cfg.encoder.ffn_embed_dim, | |
| self.quant_noise, | |
| self.quant_noise_block_size, | |
| ) | |
| self.fc2 = self.build_fc2( | |
| cfg.encoder.ffn_embed_dim, | |
| self.embed_dim, | |
| self.quant_noise, | |
| self.quant_noise_block_size, | |
| ) | |
| self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) | |
| def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): | |
| return quant_noise( | |
| nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size | |
| ) | |
| def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): | |
| return quant_noise( | |
| nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size | |
| ) | |
| def _get_fc_rank(self, remove_num: int) -> List[int]: | |
| f1_filter_param = [] | |
| for i in range(self.fc1.out_features): | |
| f1_filter_param.append( | |
| torch.sum(torch.abs(self.fc1.weight[i])) | |
| + torch.sum(torch.abs(self.fc2.weight[:, i])) | |
| + torch.abs(self.fc1.bias[i]) | |
| ) | |
| return sorted( | |
| range(len(f1_filter_param)), key=lambda k: f1_filter_param[k], reverse=False | |
| )[0:remove_num] | |
| def _prune_fc_layer(self, remove_index: List[int]): | |
| new_fc1_weight = [] | |
| new_fc1_bias = [] | |
| for i in range(self.fc1.out_features): | |
| if i not in remove_index: | |
| new_fc1_weight.append(self.fc1.weight[i]) | |
| new_fc1_bias.append(self.fc1.bias[i]) | |
| new_fc1_weight = torch.stack(new_fc1_weight).detach() | |
| new_fc1_weight.requires_grad = True | |
| new_fc1_bias = torch.stack(new_fc1_bias).detach() | |
| new_fc1_bias.requires_grad = True | |
| self.fc1 = quant_noise( | |
| nn.Linear(self.fc1.in_features, self.fc1.out_features - len(remove_index)), | |
| p=self.quant_noise, | |
| block_size=self.quant_noise_block_size, | |
| ) | |
| self.fc1.weight = torch.nn.Parameter(new_fc1_weight) | |
| self.fc1.bias = torch.nn.Parameter(new_fc1_bias) | |
| new_fc2_weight = [] | |
| new_fc2_bias = [] | |
| for i in range(self.fc2.in_features): | |
| if i not in remove_index: | |
| new_fc2_weight.append(self.fc2.weight[:, i]) | |
| new_fc2_bias = self.fc2.bias.detach() | |
| new_fc2_weight = torch.stack(new_fc2_weight, dim=-1).detach() | |
| new_fc2_weight.requires_grad = True | |
| new_fc2_bias = self.fc2.bias.detach() | |
| new_fc2_bias.requires_grad = True | |
| self.fc2 = quant_noise( | |
| nn.Linear(self.fc2.in_features - len(remove_index), self.fc2.out_features), | |
| p=self.quant_noise, | |
| block_size=self.quant_noise_block_size, | |
| ) | |
| self.fc2.weight = torch.nn.Parameter(new_fc2_weight) | |
| self.fc2.bias = torch.nn.Parameter(new_fc2_bias) | |
| def build_self_attention(self, embed_dim, cfg): | |
| return MultiheadAttention( | |
| embed_dim, | |
| cfg.encoder.attention_heads, | |
| dropout=cfg.attention_dropout, | |
| self_attention=True, | |
| q_noise=self.quant_noise, | |
| qn_block_size=self.quant_noise_block_size, | |
| xformers_att_config=cfg.encoder.xformers_att_config, | |
| ) | |
| def residual_connection(self, x, residual): | |
| return residual + x | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| """ | |
| Rename layer norm states from `...layer_norms.0.weight` to | |
| `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to | |
| `...final_layer_norm.weight` | |
| """ | |
| layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} | |
| for old, new in layer_norm_map.items(): | |
| for m in ("weight", "bias"): | |
| k = "{}.layer_norms.{}.{}".format(name, old, m) | |
| if k in state_dict: | |
| state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] | |
| del state_dict[k] | |
| def forward( | |
| self, | |
| x, | |
| encoder_padding_mask: Optional[Tensor], | |
| attn_mask: Optional[Tensor] = None, | |
| ): | |
| """ | |
| Args: | |
| x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | |
| encoder_padding_mask (ByteTensor): binary ByteTensor of shape | |
| `(batch, seq_len)` where padding elements are indicated by ``1``. | |
| attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, | |
| where `tgt_len` is the length of output and `src_len` is the | |
| length of input, though here both are equal to `seq_len`. | |
| `attn_mask[tgt_i, src_j] = 1` means that when calculating the | |
| embedding for `tgt_i`, we exclude (mask out) `src_j`. This is | |
| useful for strided self-attention. | |
| Returns: | |
| encoded output of shape `(seq_len, batch, embed_dim)` | |
| """ | |
| # anything in original attn_mask = 1, becomes -1e8 | |
| # anything in original attn_mask = 0, becomes 0 | |
| # Note that we cannot use -inf here, because at some edge cases, | |
| # the attention weight (before softmax) for some padded element in query | |
| # will become -inf, which results in NaN in model parameters | |
| if attn_mask is not None: | |
| attn_mask = attn_mask.masked_fill( | |
| attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 | |
| ) | |
| residual = x | |
| if self.normalize_before: | |
| x = self.self_attn_layer_norm(x) | |
| x, _ = self.self_attn( | |
| query=x, | |
| key=x, | |
| value=x, | |
| key_padding_mask=encoder_padding_mask, | |
| need_weights=False, | |
| attn_mask=attn_mask, | |
| ) | |
| x = self.dropout_module(x) | |
| x = self.residual_connection(x, residual) | |
| if not self.normalize_before: | |
| x = self.self_attn_layer_norm(x) | |
| residual = x | |
| if self.normalize_before: | |
| x = self.final_layer_norm(x) | |
| x = self.activation_fn(self.fc1(x)) | |
| x = self.activation_dropout_module(x) | |
| x = self.fc2(x) | |
| fc_result = x | |
| x = self.dropout_module(x) | |
| x = self.residual_connection(x, residual) | |
| if not self.normalize_before: | |
| x = self.final_layer_norm(x) | |
| if self.return_fc and not torch.jit.is_scripting(): | |
| return x, fc_result | |
| return x | |
| # backward compatible with the legacy argparse format | |
| class TransformerEncoderLayer(TransformerEncoderLayerBase): | |
| def __init__(self, args): | |
| super().__init__(TransformerConfig.from_namespace(args)) | |
| self.args = args | |
| def build_self_attention(self, embed_dim, args): | |
| return super().build_self_attention( | |
| embed_dim, TransformerConfig.from_namespace(args) | |
| ) | |
| class TransformerDecoderLayerBase(nn.Module): | |
| """Decoder layer block. | |
| In the original paper each operation (multi-head attention, encoder | |
| attention or FFN) is postprocessed with: `dropout -> add residual -> | |
| layernorm`. In the tensor2tensor code they suggest that learning is more | |
| robust when preprocessing each layer with layernorm and postprocessing with: | |
| `dropout -> add residual`. We default to the approach in the paper, but the | |
| tensor2tensor approach can be enabled by setting | |
| *cfg.decoder.normalize_before* to ``True``. | |
| Args: | |
| args (argparse.Namespace): parsed command-line arguments | |
| no_encoder_attn (bool, optional): whether to attend to encoder outputs | |
| (default: False). | |
| """ | |
| def __init__( | |
| self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False | |
| ): | |
| super().__init__() | |
| self.embed_dim = cfg.decoder.embed_dim | |
| self.dropout_module = FairseqDropout( | |
| cfg.dropout, module_name=self.__class__.__name__ | |
| ) | |
| self.quant_noise = cfg.quant_noise.pq | |
| self.quant_noise_block_size = cfg.quant_noise.pq_block_size | |
| self.cross_self_attention = cfg.cross_self_attention | |
| self.self_attn = self.build_self_attention( | |
| self.embed_dim, | |
| cfg, | |
| add_bias_kv=add_bias_kv, | |
| add_zero_attn=add_zero_attn, | |
| ) | |
| self.attn_ln = ( | |
| LayerNorm(self.embed_dim) | |
| if utils.safe_getattr(cfg, "scale_attn", False) | |
| else None | |
| ) | |
| self.nh = self.self_attn.num_heads | |
| self.head_dim = self.self_attn.head_dim | |
| scale_heads = utils.safe_getattr(cfg, "scale_heads", False) | |
| self.c_attn = ( | |
| nn.Parameter(torch.ones((self.nh,)), requires_grad=True) | |
| if scale_heads | |
| else None | |
| ) | |
| self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn) | |
| activation_dropout_p = cfg.activation_dropout | |
| if activation_dropout_p == 0: | |
| # for backwards compatibility with models that use cfg.relu_dropout | |
| activation_dropout_p = cfg.relu_dropout or 0 | |
| self.activation_dropout_module = FairseqDropout( | |
| float(activation_dropout_p), module_name=self.__class__.__name__ | |
| ) | |
| self.normalize_before = cfg.decoder.normalize_before | |
| self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) | |
| if no_encoder_attn: | |
| self.encoder_attn = None | |
| self.encoder_attn_layer_norm = None | |
| else: | |
| self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg) | |
| self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) | |
| self.ffn_layernorm = ( | |
| LayerNorm(cfg.decoder.ffn_embed_dim) | |
| if utils.safe_getattr(cfg, "scale_fc", False) | |
| else None | |
| ) | |
| self.w_resid = ( | |
| nn.Parameter( | |
| torch.ones( | |
| self.embed_dim, | |
| ), | |
| requires_grad=True, | |
| ) | |
| if utils.safe_getattr(cfg, "scale_resids", False) | |
| else None | |
| ) | |
| self.fc1 = self.build_fc1( | |
| self.embed_dim, | |
| cfg.decoder.ffn_embed_dim, | |
| self.quant_noise, | |
| self.quant_noise_block_size, | |
| ) | |
| self.fc2 = self.build_fc2( | |
| cfg.decoder.ffn_embed_dim, | |
| self.embed_dim, | |
| self.quant_noise, | |
| self.quant_noise_block_size, | |
| ) | |
| self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) | |
| self.need_attn = True | |
| self.onnx_trace = False | |
| def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): | |
| return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) | |
| def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): | |
| return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) | |
| def build_self_attention( | |
| self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False | |
| ): | |
| return MultiheadAttention( | |
| embed_dim, | |
| cfg.decoder.attention_heads, | |
| dropout=cfg.attention_dropout, | |
| add_bias_kv=add_bias_kv, | |
| add_zero_attn=add_zero_attn, | |
| self_attention=not cfg.cross_self_attention, | |
| q_noise=self.quant_noise, | |
| qn_block_size=self.quant_noise_block_size, | |
| xformers_att_config=cfg.decoder.xformers_att_config, | |
| ) | |
| def build_encoder_attention(self, embed_dim, cfg): | |
| return MultiheadAttention( | |
| embed_dim, | |
| cfg.decoder.attention_heads, | |
| kdim=cfg.encoder.embed_dim, | |
| vdim=cfg.encoder.embed_dim, | |
| dropout=cfg.attention_dropout, | |
| encoder_decoder_attention=True, | |
| q_noise=self.quant_noise, | |
| qn_block_size=self.quant_noise_block_size, | |
| xformers_att_config=cfg.encoder.xformers_att_config, | |
| ) | |
| def prepare_for_onnx_export_(self): | |
| self.onnx_trace = True | |
| def residual_connection(self, x, residual): | |
| return residual + x | |
| def forward( | |
| self, | |
| x, | |
| encoder_out: Optional[torch.Tensor] = None, | |
| encoder_padding_mask: Optional[torch.Tensor] = None, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
| prev_self_attn_state: Optional[List[torch.Tensor]] = None, | |
| prev_attn_state: Optional[List[torch.Tensor]] = None, | |
| self_attn_mask: Optional[torch.Tensor] = None, | |
| self_attn_padding_mask: Optional[torch.Tensor] = None, | |
| need_attn: bool = False, | |
| need_head_weights: bool = False, | |
| ): | |
| """ | |
| Args: | |
| x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | |
| encoder_padding_mask (ByteTensor, optional): binary | |
| ByteTensor of shape `(batch, src_len)` where padding | |
| elements are indicated by ``1``. | |
| need_attn (bool, optional): return attention weights | |
| need_head_weights (bool, optional): return attention weights | |
| for each head (default: return average over heads). | |
| Returns: | |
| encoded output of shape `(seq_len, batch, embed_dim)` | |
| """ | |
| if need_head_weights: | |
| need_attn = True | |
| residual = x | |
| if self.normalize_before: | |
| x = self.self_attn_layer_norm(x) | |
| if prev_self_attn_state is not None: | |
| prev_key, prev_value = prev_self_attn_state[:2] | |
| saved_state: Dict[str, Optional[Tensor]] = { | |
| "prev_key": prev_key, | |
| "prev_value": prev_value, | |
| } | |
| if len(prev_self_attn_state) >= 3: | |
| saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] | |
| assert incremental_state is not None | |
| self.self_attn._set_input_buffer(incremental_state, saved_state) | |
| _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) | |
| if self.cross_self_attention and not ( | |
| incremental_state is not None | |
| and _self_attn_input_buffer is not None | |
| and "prev_key" in _self_attn_input_buffer | |
| ): | |
| if self_attn_mask is not None: | |
| assert encoder_out is not None | |
| self_attn_mask = torch.cat( | |
| (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 | |
| ) | |
| if self_attn_padding_mask is not None: | |
| if encoder_padding_mask is None: | |
| assert encoder_out is not None | |
| encoder_padding_mask = self_attn_padding_mask.new_zeros( | |
| encoder_out.size(1), encoder_out.size(0) | |
| ) | |
| self_attn_padding_mask = torch.cat( | |
| (encoder_padding_mask, self_attn_padding_mask), dim=1 | |
| ) | |
| assert encoder_out is not None | |
| y = torch.cat((encoder_out, x), dim=0) | |
| else: | |
| y = x | |
| x, attn = self.self_attn( | |
| query=x, | |
| key=y, | |
| value=y, | |
| key_padding_mask=self_attn_padding_mask, | |
| incremental_state=incremental_state, | |
| need_weights=False, | |
| attn_mask=self_attn_mask, | |
| ) | |
| if self.c_attn is not None: | |
| tgt_len, bsz = x.size(0), x.size(1) | |
| x = x.view(tgt_len, bsz, self.nh, self.head_dim) | |
| x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) | |
| x = x.reshape(tgt_len, bsz, self.embed_dim) | |
| if self.attn_ln is not None: | |
| x = self.attn_ln(x) | |
| x = self.dropout_module(x) | |
| x = self.residual_connection(x, residual) | |
| if not self.normalize_before: | |
| x = self.self_attn_layer_norm(x) | |
| if self.encoder_attn is not None and encoder_out is not None: | |
| residual = x | |
| if self.normalize_before: | |
| x = self.encoder_attn_layer_norm(x) | |
| if prev_attn_state is not None: | |
| prev_key, prev_value = prev_attn_state[:2] | |
| saved_state: Dict[str, Optional[Tensor]] = { | |
| "prev_key": prev_key, | |
| "prev_value": prev_value, | |
| } | |
| if len(prev_attn_state) >= 3: | |
| saved_state["prev_key_padding_mask"] = prev_attn_state[2] | |
| assert incremental_state is not None | |
| self.encoder_attn._set_input_buffer(incremental_state, saved_state) | |
| x, attn = self.encoder_attn( | |
| query=x, | |
| key=encoder_out, | |
| value=encoder_out, | |
| key_padding_mask=encoder_padding_mask, | |
| incremental_state=incremental_state, | |
| static_kv=True, | |
| need_weights=need_attn or (not self.training and self.need_attn), | |
| need_head_weights=need_head_weights, | |
| ) | |
| x = self.dropout_module(x) | |
| x = self.residual_connection(x, residual) | |
| if not self.normalize_before: | |
| x = self.encoder_attn_layer_norm(x) | |
| residual = x | |
| if self.normalize_before: | |
| x = self.final_layer_norm(x) | |
| x = self.activation_fn(self.fc1(x)) | |
| x = self.activation_dropout_module(x) | |
| if self.ffn_layernorm is not None: | |
| x = self.ffn_layernorm(x) | |
| x = self.fc2(x) | |
| x = self.dropout_module(x) | |
| if self.w_resid is not None: | |
| residual = torch.mul(self.w_resid, residual) | |
| x = self.residual_connection(x, residual) | |
| if not self.normalize_before: | |
| x = self.final_layer_norm(x) | |
| if self.onnx_trace and incremental_state is not None: | |
| saved_state = self.self_attn._get_input_buffer(incremental_state) | |
| assert saved_state is not None | |
| if self_attn_padding_mask is not None: | |
| self_attn_state = [ | |
| saved_state["prev_key"], | |
| saved_state["prev_value"], | |
| saved_state["prev_key_padding_mask"], | |
| ] | |
| else: | |
| self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] | |
| return x, attn, self_attn_state | |
| return x, attn, None | |
| def make_generation_fast_(self, need_attn: bool = False, **kwargs): | |
| self.need_attn = need_attn | |
| # backward compatible with the legacy argparse format | |
| class TransformerDecoderLayer(TransformerDecoderLayerBase): | |
| def __init__( | |
| self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False | |
| ): | |
| super().__init__( | |
| TransformerConfig.from_namespace(args), | |
| no_encoder_attn=no_encoder_attn, | |
| add_bias_kv=add_bias_kv, | |
| add_zero_attn=add_zero_attn, | |
| ) | |
| self.args = args | |
| def build_self_attention( | |
| self, embed_dim, args, add_bias_kv=False, add_zero_attn=False | |
| ): | |
| return super().build_self_attention( | |
| embed_dim, | |
| TransformerConfig.from_namespace(args), | |
| add_bias_kv=add_bias_kv, | |
| add_zero_attn=add_zero_attn, | |
| ) | |
| def build_encoder_attention(self, embed_dim, args): | |
| return super().build_encoder_attention( | |
| embed_dim, | |
| TransformerConfig.from_namespace(args), | |
| ) | |