|  | """PyTorch OpenAI GPT-2 model modified with MultiQuery attention""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  | import os | 
					
						
						|  | from dataclasses import dataclass | 
					
						
						|  | from typing import Optional, Tuple, Union | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.utils.checkpoint | 
					
						
						|  | from torch import nn | 
					
						
						|  | from torch.cuda.amp import autocast | 
					
						
						|  | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | 
					
						
						|  |  | 
					
						
						|  | from transformers.activations import ACT2FN | 
					
						
						|  | from transformers.modeling_outputs import ( | 
					
						
						|  | BaseModelOutputWithPastAndCrossAttentions, | 
					
						
						|  | CausalLMOutputWithCrossAttentions, | 
					
						
						|  | SequenceClassifierOutputWithPast, | 
					
						
						|  | TokenClassifierOutput, | 
					
						
						|  | ) | 
					
						
						|  | from transformers.modeling_utils import PreTrainedModel, SequenceSummary | 
					
						
						|  | from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer | 
					
						
						|  |  | 
					
						
						|  | from transformers.utils import ( | 
					
						
						|  | ModelOutput, | 
					
						
						|  | add_code_sample_docstrings, | 
					
						
						|  | add_start_docstrings, | 
					
						
						|  | add_start_docstrings_to_model_forward, | 
					
						
						|  | logging, | 
					
						
						|  | replace_return_docstrings, | 
					
						
						|  | ) | 
					
						
						|  | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map | 
					
						
						|  | from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel | 
					
						
						|  | from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY, MULTI_HEAD | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GPT2MQAttention(nn.Module): | 
					
						
						|  | def __init__(self, config, is_cross_attention=False, layer_idx=None): | 
					
						
						|  | super().__init__() | 
					
						
						|  | assert config.attention_head_type == MULTI_QUERY | 
					
						
						|  |  | 
					
						
						|  | max_positions = config.max_position_embeddings | 
					
						
						|  | self.register_buffer( | 
					
						
						|  | "bias", | 
					
						
						|  | torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( | 
					
						
						|  | 1, 1, max_positions, max_positions | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | self.register_buffer("masked_bias", torch.tensor(-1e4)) | 
					
						
						|  |  | 
					
						
						|  | self.embed_dim = config.hidden_size | 
					
						
						|  | self.num_heads = config.num_attention_heads | 
					
						
						|  | self.head_dim = self.embed_dim // self.num_heads | 
					
						
						|  | self.split_size = self.embed_dim | 
					
						
						|  | if self.head_dim * self.num_heads != self.embed_dim: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" | 
					
						
						|  | f" {self.num_heads})." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.scale_attn_weights = config.scale_attn_weights | 
					
						
						|  | if is_cross_attention: | 
					
						
						|  | raise NotImplementedError("Cross-attention not implemented for MQA") | 
					
						
						|  | self.is_cross_attention = is_cross_attention | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx | 
					
						
						|  | self.layer_idx = layer_idx | 
					
						
						|  | self.reorder_and_upcast_attn = config.reorder_and_upcast_attn | 
					
						
						|  |  | 
					
						
						|  | if self.is_cross_attention: | 
					
						
						|  | self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) | 
					
						
						|  | self.q_attn = Conv1D(self.embed_dim, self.embed_dim) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | self.q_attn = Conv1D(self.embed_dim, self.embed_dim) | 
					
						
						|  |  | 
					
						
						|  | self.kv_attn = Conv1D(2 * self.head_dim, self.embed_dim) | 
					
						
						|  | self.c_proj = Conv1D(self.embed_dim, self.embed_dim) | 
					
						
						|  |  | 
					
						
						|  | self.attn_dropout = nn.Dropout(config.attn_pdrop) | 
					
						
						|  | self.resid_dropout = nn.Dropout(config.resid_pdrop) | 
					
						
						|  |  | 
					
						
						|  | self.pruned_heads = set() | 
					
						
						|  |  | 
					
						
						|  | def prune_heads(self, heads): | 
					
						
						|  | if len(heads) == 0: | 
					
						
						|  | return | 
					
						
						|  | heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) | 
					
						
						|  | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) | 
					
						
						|  | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) | 
					
						
						|  | self.num_heads = self.num_heads - len(heads) | 
					
						
						|  | self.pruned_heads = self.pruned_heads.union(heads) | 
					
						
						|  |  | 
					
						
						|  | def _attn(self, query, key, value, attention_mask=None, head_mask=None): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | batch_size = query.size(0) | 
					
						
						|  | query_length = query.size(1) // self.num_heads | 
					
						
						|  | key_length = key.size(2) | 
					
						
						|  |  | 
					
						
						|  | attn_weights = torch.bmm(query, key) | 
					
						
						|  |  | 
					
						
						|  | attn_weights = attn_weights.view(batch_size, self.num_heads, query_length, key_length) | 
					
						
						|  |  | 
					
						
						|  | if self.scale_attn_weights: | 
					
						
						|  | attn_weights = attn_weights / torch.tensor( | 
					
						
						|  | value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.scale_attn_by_inverse_layer_idx: | 
					
						
						|  | attn_weights = attn_weights / float(self.layer_idx + 1) | 
					
						
						|  |  | 
					
						
						|  | if not self.is_cross_attention: | 
					
						
						|  |  | 
					
						
						|  | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) | 
					
						
						|  | mask_value = torch.finfo(attn_weights.dtype).min | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) | 
					
						
						|  | attn_weights = torch.where(causal_mask, attn_weights, mask_value) | 
					
						
						|  |  | 
					
						
						|  | if attention_mask is not None: | 
					
						
						|  |  | 
					
						
						|  | attn_weights = attn_weights + attention_mask | 
					
						
						|  |  | 
					
						
						|  | attn_weights = nn.functional.softmax(attn_weights, dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | attn_weights = attn_weights.type(value.dtype) | 
					
						
						|  | attn_weights = self.attn_dropout(attn_weights) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if head_mask is not None: | 
					
						
						|  | attn_weights = attn_weights * head_mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | _attn_weights = attn_weights.view(batch_size, self.num_heads * query_length, key_length) | 
					
						
						|  |  | 
					
						
						|  | attn_output = torch.bmm(_attn_weights, value) | 
					
						
						|  | attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) | 
					
						
						|  |  | 
					
						
						|  | return attn_output, attn_weights | 
					
						
						|  |  | 
					
						
						|  | def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): | 
					
						
						|  |  | 
					
						
						|  | bsz, num_heads, q_seq_len, dk = query.size() | 
					
						
						|  | _, _, k_seq_len, _ = key.size() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | scale_factor = 1.0 | 
					
						
						|  | if self.scale_attn_weights: | 
					
						
						|  | scale_factor /= float(value.size(-1)) ** 0.5 | 
					
						
						|  |  | 
					
						
						|  | if self.scale_attn_by_inverse_layer_idx: | 
					
						
						|  | scale_factor /= float(self.layer_idx + 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with autocast(enabled=False): | 
					
						
						|  | q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) | 
					
						
						|  | attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) | 
					
						
						|  | attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) | 
					
						
						|  |  | 
					
						
						|  | if not self.is_cross_attention: | 
					
						
						|  |  | 
					
						
						|  | query_length, key_length = query.size(-2), key.size(-2) | 
					
						
						|  | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() | 
					
						
						|  | mask_value = torch.finfo(attn_weights.dtype).min | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) | 
					
						
						|  | attn_weights = torch.where(causal_mask, attn_weights, mask_value) | 
					
						
						|  |  | 
					
						
						|  | if attention_mask is not None: | 
					
						
						|  |  | 
					
						
						|  | attn_weights = attn_weights + attention_mask | 
					
						
						|  |  | 
					
						
						|  | attn_weights = nn.functional.softmax(attn_weights, dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if attn_weights.dtype != torch.float32: | 
					
						
						|  | raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") | 
					
						
						|  | attn_weights = attn_weights.type(value.dtype) | 
					
						
						|  | attn_weights = self.attn_dropout(attn_weights) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if head_mask is not None: | 
					
						
						|  | attn_weights = attn_weights * head_mask | 
					
						
						|  |  | 
					
						
						|  | attn_output = torch.matmul(attn_weights, value) | 
					
						
						|  |  | 
					
						
						|  | return attn_output, attn_weights | 
					
						
						|  |  | 
					
						
						|  | def _split_heads(self, tensor, num_heads, attn_head_size): | 
					
						
						|  | """ | 
					
						
						|  | Splits hidden_size dim into attn_head_size and num_heads | 
					
						
						|  | """ | 
					
						
						|  | new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) | 
					
						
						|  | tensor = tensor.view(new_shape) | 
					
						
						|  | return tensor.permute(0, 2, 1, 3) | 
					
						
						|  |  | 
					
						
						|  | def _merge_heads(self, tensor, num_heads, attn_head_size): | 
					
						
						|  | """ | 
					
						
						|  | Merges attn_head_size dim and num_attn_heads dim into hidden_size | 
					
						
						|  | """ | 
					
						
						|  | tensor = tensor.permute(0, 2, 1, 3).contiguous() | 
					
						
						|  | new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) | 
					
						
						|  | return tensor.view(new_shape) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | hidden_states: Optional[Tuple[torch.FloatTensor]], | 
					
						
						|  | layer_past: Optional[Tuple[torch.Tensor]] = None, | 
					
						
						|  | attention_mask: Optional[torch.FloatTensor] = None, | 
					
						
						|  | head_mask: Optional[torch.FloatTensor] = None, | 
					
						
						|  | encoder_hidden_states: Optional[torch.Tensor] = None, | 
					
						
						|  | encoder_attention_mask: Optional[torch.FloatTensor] = None, | 
					
						
						|  | use_cache: Optional[bool] = False, | 
					
						
						|  | output_attentions: Optional[bool] = False, | 
					
						
						|  | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: | 
					
						
						|  | if encoder_hidden_states is not None: | 
					
						
						|  | raise NotImplementedError("Cross-attention not implemented for MQA") | 
					
						
						|  | if not hasattr(self, "q_attn"): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "If class is used as cross attention, the weights `q_attn` have to be defined. " | 
					
						
						|  | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | query = self.q_attn(hidden_states) | 
					
						
						|  | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) | 
					
						
						|  | attention_mask = encoder_attention_mask | 
					
						
						|  | else: | 
					
						
						|  | query = self.q_attn(hidden_states) | 
					
						
						|  | key, value = self.kv_attn(hidden_states).split(self.head_dim, dim=2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | batch_size, seq_length = query.shape[:2] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | query = query.view(batch_size, seq_length, self.num_heads, self.head_dim).permute([0, 2, 1, 3]) | 
					
						
						|  |  | 
					
						
						|  | query = query.reshape(batch_size, self.num_heads * seq_length, self.head_dim) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | key = key.permute(0, 2, 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if layer_past is not None: | 
					
						
						|  | past_key, past_value = layer_past | 
					
						
						|  |  | 
					
						
						|  | key = torch.cat((past_key, key), dim=-1) | 
					
						
						|  | value = torch.cat((past_value, value), dim=-2) | 
					
						
						|  |  | 
					
						
						|  | if use_cache is True: | 
					
						
						|  | present = (key, value) | 
					
						
						|  | else: | 
					
						
						|  | present = None | 
					
						
						|  |  | 
					
						
						|  | if self.reorder_and_upcast_attn: | 
					
						
						|  | raise NotImplementedError("Reorder and upcast attention not implemented for MQA") | 
					
						
						|  | attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) | 
					
						
						|  | else: | 
					
						
						|  | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) | 
					
						
						|  |  | 
					
						
						|  | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) | 
					
						
						|  | attn_output = self.c_proj(attn_output) | 
					
						
						|  | attn_output = self.resid_dropout(attn_output) | 
					
						
						|  |  | 
					
						
						|  | outputs = (attn_output, present) | 
					
						
						|  | if output_attentions: | 
					
						
						|  | outputs += (attn_weights,) | 
					
						
						|  |  | 
					
						
						|  | return outputs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GPT2CustomBlock(GPT2Block): | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: GPT2CustomConfig, layer_idx=None): | 
					
						
						|  | super().__init__(config, layer_idx) | 
					
						
						|  |  | 
					
						
						|  | if config.attention_head_type == MULTI_QUERY: | 
					
						
						|  | self.attn = GPT2MQAttention(config, layer_idx=layer_idx) | 
					
						
						|  | if config.add_cross_attention: | 
					
						
						|  | raise NotImplementedError("Cross-attention not implemented for MQA") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GPT2CustomModel(GPT2Model): | 
					
						
						|  | config_class = GPT2CustomConfig | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | GPT2PreTrainedModel.__init__(self, config) | 
					
						
						|  |  | 
					
						
						|  | self.embed_dim = config.hidden_size | 
					
						
						|  |  | 
					
						
						|  | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) | 
					
						
						|  | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) | 
					
						
						|  |  | 
					
						
						|  | self.drop = nn.Dropout(config.embd_pdrop) | 
					
						
						|  | self.h = nn.ModuleList([GPT2CustomBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) | 
					
						
						|  | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.model_parallel = False | 
					
						
						|  | self.device_map = None | 
					
						
						|  | self.gradient_checkpointing = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.post_init() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GPT2LMHeadCustomModel(GPT2LMHeadModel): | 
					
						
						|  | config_class = GPT2CustomConfig | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | GPT2PreTrainedModel.__init__(self, config) | 
					
						
						|  | self.transformer = GPT2CustomModel(config) | 
					
						
						|  | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.model_parallel = False | 
					
						
						|  | self.device_map = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.post_init() |