Spaces:
Runtime error
Runtime error
| import torch | |
| class RMSNorm(torch.nn.Module): | |
| """Root Mean Square Layer Normalization. | |
| Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: | |
| https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. | |
| """ | |
| def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: | |
| super().__init__() | |
| self.weight = torch.nn.Parameter(torch.ones(size)) | |
| self.eps = eps | |
| self.dim = dim | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| dtype = x.dtype | |
| x = x.float() | |
| # NOTE: the original RMSNorm paper implementation is not equivalent | |
| norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) | |
| x_normed = x * torch.rsqrt(norm_x + self.eps) | |
| return (self.weight * x_normed).to(dtype=dtype) | |
| def reset_parameters(self) -> None: | |
| torch.nn.init.ones_(self.weight) |