Upload blocks.py with huggingface_hub
Browse files
blocks.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GPT Blocks used for the GPT Model."""
|
| 2 |
+
from typing import Dict, Optional, Tuple
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from .attention import ATTN_CLASS_REGISTRY
|
| 6 |
+
from .norm import NORM_CLASS_REGISTRY
|
| 7 |
+
|
| 8 |
+
class MPTMLP(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
|
| 13 |
+
self.act = nn.GELU(approximate='none')
|
| 14 |
+
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
|
| 15 |
+
self.down_proj._is_residual = True
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.down_proj(self.act(self.up_proj(x)))
|
| 19 |
+
|
| 20 |
+
class MPTBlock(nn.Module):
|
| 21 |
+
|
| 22 |
+
def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
|
| 23 |
+
del kwargs
|
| 24 |
+
super().__init__()
|
| 25 |
+
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
| 26 |
+
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
|
| 27 |
+
self.norm_1 = norm_class(d_model, device=device)
|
| 28 |
+
self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
|
| 29 |
+
self.norm_2 = norm_class(d_model, device=device)
|
| 30 |
+
self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
|
| 31 |
+
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
| 32 |
+
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
| 35 |
+
a = self.norm_1(x)
|
| 36 |
+
(b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
|
| 37 |
+
x = x + self.resid_attn_dropout(b)
|
| 38 |
+
m = self.norm_2(x)
|
| 39 |
+
n = self.ffn(m)
|
| 40 |
+
x = x + self.resid_ffn_dropout(n)
|
| 41 |
+
return (x, past_key_value)
|