Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from fairseq.modules import ( | |
| FairseqDropout, | |
| LayerDropModuleList, | |
| LayerNorm, | |
| MultiheadAttention, | |
| PositionalEmbedding, | |
| TransformerSentenceEncoderLayer, | |
| ) | |
| from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ | |
| def init_bert_params(module): | |
| """ | |
| Initialize the weights specific to the BERT Model. | |
| This overrides the default initializations depending on the specified arguments. | |
| 1. If normal_init_linear_weights is set then weights of linear | |
| layer will be initialized using the normal distribution and | |
| bais will be set to the specified value. | |
| 2. If normal_init_embed_weights is set then weights of embedding | |
| layer will be initialized using the normal distribution. | |
| 3. If normal_init_proj_weights is set then weights of | |
| in_project_weight for MultiHeadAttention initialized using | |
| the normal distribution (to be validated). | |
| """ | |
| def normal_(data): | |
| # with FSDP, module params will be on CUDA, so we cast them back to CPU | |
| # so that the RNG is consistent with and without FSDP | |
| data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) | |
| if isinstance(module, nn.Linear): | |
| normal_(module.weight.data) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| if isinstance(module, nn.Embedding): | |
| normal_(module.weight.data) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| if isinstance(module, MultiheadAttention): | |
| normal_(module.q_proj.weight.data) | |
| normal_(module.k_proj.weight.data) | |
| normal_(module.v_proj.weight.data) | |
| class TransformerSentenceEncoder(nn.Module): | |
| """ | |
| Implementation for a Bi-directional Transformer based Sentence Encoder used | |
| in BERT/XLM style pre-trained models. | |
| This first computes the token embedding using the token embedding matrix, | |
| position embeddings (if specified) and segment embeddings | |
| (if specified). After applying the specified number of | |
| TransformerEncoderLayers, it outputs all the internal states of the | |
| encoder as well as the final representation associated with the first | |
| token (usually CLS token). | |
| Input: | |
| - tokens: B x T matrix representing sentences | |
| - segment_labels: B x T matrix representing segment label for tokens | |
| Output: | |
| - a tuple of the following: | |
| - a list of internal model states used to compute the | |
| predictions where each tensor has shape T x B x C | |
| - sentence representation associated with first input token | |
| in format B x C. | |
| """ | |
| def __init__( | |
| self, | |
| padding_idx: int, | |
| vocab_size: int, | |
| num_encoder_layers: int = 6, | |
| embedding_dim: int = 768, | |
| ffn_embedding_dim: int = 3072, | |
| num_attention_heads: int = 8, | |
| dropout: float = 0.1, | |
| attention_dropout: float = 0.1, | |
| activation_dropout: float = 0.1, | |
| layerdrop: float = 0.0, | |
| max_seq_len: int = 256, | |
| num_segments: int = 2, | |
| use_position_embeddings: bool = True, | |
| offset_positions_by_padding: bool = True, | |
| encoder_normalize_before: bool = False, | |
| apply_bert_init: bool = False, | |
| activation_fn: str = "relu", | |
| learned_pos_embedding: bool = True, | |
| embed_scale: float = None, | |
| freeze_embeddings: bool = False, | |
| n_trans_layers_to_freeze: int = 0, | |
| export: bool = False, | |
| traceable: bool = False, | |
| q_noise: float = 0.0, | |
| qn_block_size: int = 8, | |
| ) -> None: | |
| super().__init__() | |
| self.padding_idx = padding_idx | |
| self.vocab_size = vocab_size | |
| self.dropout_module = FairseqDropout( | |
| dropout, module_name=self.__class__.__name__ | |
| ) | |
| self.layerdrop = layerdrop | |
| self.max_seq_len = max_seq_len | |
| self.embedding_dim = embedding_dim | |
| self.num_segments = num_segments | |
| self.use_position_embeddings = use_position_embeddings | |
| self.apply_bert_init = apply_bert_init | |
| self.learned_pos_embedding = learned_pos_embedding | |
| self.traceable = traceable | |
| self.embed_tokens = self.build_embedding( | |
| self.vocab_size, self.embedding_dim, self.padding_idx | |
| ) | |
| self.embed_scale = embed_scale | |
| if q_noise > 0: | |
| self.quant_noise = apply_quant_noise_( | |
| nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), | |
| q_noise, | |
| qn_block_size, | |
| ) | |
| else: | |
| self.quant_noise = None | |
| self.segment_embeddings = ( | |
| nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None) | |
| if self.num_segments > 0 | |
| else None | |
| ) | |
| self.embed_positions = ( | |
| PositionalEmbedding( | |
| self.max_seq_len, | |
| self.embedding_dim, | |
| padding_idx=(self.padding_idx if offset_positions_by_padding else None), | |
| learned=self.learned_pos_embedding, | |
| ) | |
| if self.use_position_embeddings | |
| else None | |
| ) | |
| if encoder_normalize_before: | |
| self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) | |
| else: | |
| self.emb_layer_norm = None | |
| if self.layerdrop > 0.0: | |
| self.layers = LayerDropModuleList(p=self.layerdrop) | |
| else: | |
| self.layers = nn.ModuleList([]) | |
| self.layers.extend( | |
| [ | |
| self.build_transformer_sentence_encoder_layer( | |
| embedding_dim=self.embedding_dim, | |
| ffn_embedding_dim=ffn_embedding_dim, | |
| num_attention_heads=num_attention_heads, | |
| dropout=self.dropout_module.p, | |
| attention_dropout=attention_dropout, | |
| activation_dropout=activation_dropout, | |
| activation_fn=activation_fn, | |
| export=export, | |
| q_noise=q_noise, | |
| qn_block_size=qn_block_size, | |
| ) | |
| for _ in range(num_encoder_layers) | |
| ] | |
| ) | |
| # Apply initialization of model params after building the model | |
| if self.apply_bert_init: | |
| self.apply(init_bert_params) | |
| def freeze_module_params(m): | |
| if m is not None: | |
| for p in m.parameters(): | |
| p.requires_grad = False | |
| if freeze_embeddings: | |
| freeze_module_params(self.embed_tokens) | |
| freeze_module_params(self.segment_embeddings) | |
| freeze_module_params(self.embed_positions) | |
| freeze_module_params(self.emb_layer_norm) | |
| for layer in range(n_trans_layers_to_freeze): | |
| freeze_module_params(self.layers[layer]) | |
| def build_embedding(self, vocab_size, embedding_dim, padding_idx): | |
| return nn.Embedding(vocab_size, embedding_dim, padding_idx) | |
| def build_transformer_sentence_encoder_layer( | |
| self, | |
| embedding_dim, | |
| ffn_embedding_dim, | |
| num_attention_heads, | |
| dropout, | |
| attention_dropout, | |
| activation_dropout, | |
| activation_fn, | |
| export, | |
| q_noise, | |
| qn_block_size, | |
| ): | |
| return TransformerSentenceEncoderLayer( | |
| embedding_dim=embedding_dim, | |
| ffn_embedding_dim=ffn_embedding_dim, | |
| num_attention_heads=num_attention_heads, | |
| dropout=dropout, | |
| attention_dropout=attention_dropout, | |
| activation_dropout=activation_dropout, | |
| activation_fn=activation_fn, | |
| export=export, | |
| q_noise=q_noise, | |
| qn_block_size=qn_block_size, | |
| ) | |
| def forward( | |
| self, | |
| tokens: torch.Tensor, | |
| segment_labels: torch.Tensor = None, | |
| last_state_only: bool = False, | |
| positions: Optional[torch.Tensor] = None, | |
| token_embeddings: Optional[torch.Tensor] = None, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| is_tpu = tokens.device.type == "xla" | |
| # compute padding mask. This is needed for multi-head attention | |
| padding_mask = tokens.eq(self.padding_idx) | |
| if not self.traceable and not is_tpu and not padding_mask.any(): | |
| padding_mask = None | |
| if token_embeddings is not None: | |
| x = token_embeddings | |
| else: | |
| x = self.embed_tokens(tokens) | |
| if self.embed_scale is not None: | |
| x = x * self.embed_scale | |
| if self.embed_positions is not None: | |
| x = x + self.embed_positions(tokens, positions=positions) | |
| if self.segment_embeddings is not None and segment_labels is not None: | |
| x = x + self.segment_embeddings(segment_labels) | |
| if self.quant_noise is not None: | |
| x = self.quant_noise(x) | |
| if self.emb_layer_norm is not None: | |
| x = self.emb_layer_norm(x) | |
| x = self.dropout_module(x) | |
| # account for padding while computing the representation | |
| if padding_mask is not None: | |
| x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| inner_states = [] | |
| if not last_state_only: | |
| inner_states.append(x) | |
| for layer in self.layers: | |
| x, _ = layer( | |
| x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask | |
| ) | |
| if not last_state_only: | |
| inner_states.append(x) | |
| sentence_rep = x[0, :, :] | |
| if last_state_only: | |
| inner_states = [x] | |
| if self.traceable: | |
| return torch.stack(inner_states), sentence_rep | |
| else: | |
| return inner_states, sentence_rep | |