| import torch | |
| from torch import nn | |
| from .layer_norm import rms_norm_fn | |
| class LlamaRMSNorm(nn.Module): | |
| """ | |
| RMS Layer Norm for Llama models. | |
| Triton-optimized RMS layer norm. The interface is compatible with `LLamaRMSNorm` in | |
| `transformers`. | |
| Attributes: | |
| weight (`torch.Tensor`): The learnable scaling parameter. | |
| variance_epsilon (`float`): The epsilon value for numerical stability. | |
| """ | |
| weight: torch.Tensor | |
| variance_epsilon: float | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Apply RMS normalization to the input hidden states. | |
| Args: | |
| hidden_states (`torch.Tensor`): | |
| Input tensor of shape `(batch_size, sequence_length, hidden_size)` or any shape | |
| where the last dimension is the feature dimension to be normalized. | |
| Returns: | |
| `torch.Tensor`: | |
| The normalized tensor with the same shape as the input `hidden_states`. | |
| """ | |
| return rms_norm_fn( | |
| hidden_states, | |
| self.weight, | |
| bias=None, | |
| residual=None, | |
| eps=self.variance_epsilon, | |
| dropout_p=0.0, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| ) | |
| __all__ = ["LlamaRMSNorm"] | |