Spaces:
Running
on
Zero
Running
on
Zero
| """Custom layers for the transformer model.""" | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Tuple | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| class RMSNorm(nn.Module): | |
| """Root Mean Square Layer Normalization.""" | |
| def __init__(self, hidden_size, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.eps = eps | |
| def forward(self, hidden_states): | |
| input_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.eps) | |
| return self.weight * hidden_states.to(input_dtype) | |
| class RotaryEmbedding(nn.Module): | |
| """Rotary Position Embedding.""" | |
| def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | |
| super().__init__() | |
| self.dim = dim | |
| self.max_position_embeddings = max_position_embeddings | |
| self.base = base | |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| # Build cached cos/sin | |
| self._set_cos_sin_cache( | |
| seq_len=max_position_embeddings, | |
| device=self.inv_freq.device, | |
| dtype=torch.get_default_dtype() | |
| ) | |
| def _set_cos_sin_cache(self, seq_len, device, dtype): | |
| self.max_seq_len_cached = seq_len | |
| t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) | |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) | |
| self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) | |
| def forward(self, x, seq_len=None): | |
| if seq_len > self.max_seq_len_cached: | |
| self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) | |
| return ( | |
| self.cos_cached[:seq_len].to(dtype=x.dtype), | |
| self.sin_cached[:seq_len].to(dtype=x.dtype), | |
| ) | |
| def rotate_half(x): | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids): | |
| cos = cos[position_ids].unsqueeze(1) | |
| sin = sin[position_ids].unsqueeze(1) | |
| q_embed = (q * cos) + (self.rotate_half(q) * sin) | |
| k_embed = (k * cos) + (self.rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| class SwiGLU(nn.Module): | |
| """SwiGLU activation function.""" | |
| def __init__(self, hidden_size, intermediate_size, hidden_act="silu"): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) | |
| self.act_fn = F.silu if hidden_act == "silu" else F.gelu | |
| def forward(self, x): | |
| return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |