Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from src.utils import LlamaRotaryEmbedding, repeat_kv | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| # Root Mean Square Layer Normalization | |
| rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| return x * rms * self.weight | |
| class Attention(nn.Module): | |
| """Multi-head attention module with support for GQA (Grouped Query Attention).""" | |
| def __init__(self, config): | |
| super(Attention, self).__init__() | |
| self.emb_dim = config.emb_dim | |
| self.n_q_heads = config.n_q_heads | |
| self.n_kv_heads = config.n_kv_heads | |
| self.head_dim = self.emb_dim // self.n_q_heads | |
| self.n_rep = self.n_q_heads // self.n_kv_heads | |
| # Projections for Q, K, V & O | |
| self.q_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False) | |
| self.k_proj = nn.Linear( | |
| self.emb_dim, self.head_dim * self.n_kv_heads, bias=False | |
| ) | |
| self.v_proj = nn.Linear( | |
| self.emb_dim, self.head_dim * self.n_kv_heads, bias=False | |
| ) | |
| self.o_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False) | |
| # Initialize rotary embeddings | |
| self.rotary_embedding = LlamaRotaryEmbedding( | |
| dim=self.head_dim, max_seq_len=config.max_seq_len | |
| ) | |
| # Dropout layers | |
| self.attn_dropout = nn.Dropout(config.dropout) | |
| self.resid_dropout = nn.Dropout(config.dropout) | |
| # Causal mask | |
| self.register_buffer( | |
| "mask", | |
| torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)).view( | |
| 1, 1, config.max_seq_len, config.max_seq_len | |
| ), | |
| ) | |
| def forward(self, x): | |
| B, T, C = x.size() # batch_size, seq_len, emb_dim | |
| # Project Q, K, V | |
| q = self.q_proj(x) # (B, T, emb_dim) | |
| k = self.k_proj(x) # (B, T, n_kv_heads * head_dim) | |
| v = self.v_proj(x) # (B, T, n_kv_heads * head_dim) | |
| # Reshape Q, K, V | |
| q = q.view(B, T, self.n_q_heads, self.head_dim) # (B, T, n_q_heads, head_dim) | |
| k = k.view(B, T, self.n_kv_heads, self.head_dim) # (B, T, n_kv_heads, head_dim) | |
| v = v.view(B, T, self.n_kv_heads, self.head_dim) # (B, T, n_kv_heads, head_dim) | |
| # Reshape for attention computation | |
| q = q.transpose(1, 2) # (B, n_q_heads, T, head_dim) | |
| k = k.transpose(1, 2) # (B, n_kv_heads, T, head_dim) | |
| v = v.transpose(1, 2) # (B, n_kv_heads, T, head_dim) | |
| # Apply rotary embeddings | |
| q, k = self.rotary_embedding(q, k) | |
| # Repeat K and V for GQA | |
| k = repeat_kv(k, self.n_rep) # (B, n_q_heads, T, head_dim) | |
| v = repeat_kv(v, self.n_rep) # (B, n_q_heads, T, head_dim) | |
| # Compute attention scores | |
| scale = 1.0 / math.sqrt(self.head_dim) | |
| att = (q @ k.transpose(-2, -1)) * scale # (B, n_q_heads, T, T) | |
| att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) | |
| att = F.softmax(att, dim=-1) | |
| att = self.attn_dropout(att) | |
| # Apply attention to values | |
| y = att @ v # (B, n_q_heads, T, head_dim) | |
| # Reshape and project output | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, emb_dim) | |
| y = self.o_proj(y) | |
| y = self.resid_dropout(y) | |
| return y | |
| class FeedForward(nn.Module): | |
| """Feed-forward module with SiLU activation.""" | |
| def __init__(self, config): | |
| super(FeedForward, self).__init__() | |
| # Gate and up-projections project from hidden_size to intermediate_size | |
| self.gate_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False) | |
| # Down projection brings the dimension back to hidden_size | |
| self.down_proj = nn.Linear(config.intermediate_size, config.emb_dim, bias=False) | |
| # SiLU activation function | |
| self.act_fn = F.silu | |
| # Dropout layer | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, x): | |
| # Apply gate and up projections | |
| gate_output = self.act_fn(self.gate_proj(x)) # SiLU activation | |
| up_output = self.up_proj(x) | |
| # Element-wise multiplication of gate and up projections | |
| intermediate_output = gate_output * up_output | |
| # Project back to hidden size | |
| output = self.down_proj(intermediate_output) | |
| output = self.dropout(output) | |
| return output | |
| class TransformerBlock(nn.Module): | |
| """Transformer block with attention and feed-forward modules.""" | |
| def __init__(self, config): | |
| super(TransformerBlock, self).__init__() | |
| self.attention = Attention(config) | |
| self.feed_forward = FeedForward(config) | |
| self.input_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps) | |
| self.attention_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps) | |
| def forward(self, x): | |
| x = x + self.attention(self.input_layernorm(x)) | |
| x = x + self.feed_forward(self.attention_layernorm(x)) | |
| return x | |
| class SmolLM(nn.Module): | |
| """Small language model with transformer blocks.""" | |
| def __init__(self, config): | |
| super(SmolLM, self).__init__() | |
| self.config = config | |
| self.wte = nn.Embedding(config.vocab_size, config.emb_dim) | |
| self.transformer_blocks = nn.ModuleList( | |
| [TransformerBlock(config) for _ in range(config.num_layers)] | |
| ) | |
| self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False) | |
| self.apply(self._init_weights) | |
| self.layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps) | |
| # weight sharing | |
| self.lm_head.weight = self.wte.weight | |
| def total_params(self): | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| def _init_weights(self, module): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=self.config.init_std) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def forward(self, x): | |
| x = self.wte(x) | |
| for block in self.transformer_blocks: | |
| x = block(x) | |
| x = self.layernorm(x) | |
| logits = self.lm_head(x) | |
| return logits | |
| # @dataclass | |
| # class Config: | |
| # vocab_size: int = 49152 | |
| # emb_dim: int = 576 | |
| # intermediate_size: int = 1536 | |
| # num_layers: int = 10 | |
| # n_q_heads: int = 9 | |
| # n_kv_heads: int = 3 | |
| # max_seq_len: int = 8192 | |
| # dropout: float = 0.1 | |
| # rms_norm_eps: float = 1e-05 | |
| # init_std: float = 0.041666666666666664 | |