Fix precision error
Browse files- modeling_chatglm.py +4 -20
modeling_chatglm.py
CHANGED
|
@@ -3,9 +3,7 @@
|
|
| 3 |
import math
|
| 4 |
import copy
|
| 5 |
import warnings
|
| 6 |
-
import re
|
| 7 |
import sys
|
| 8 |
-
import functools
|
| 9 |
import torch
|
| 10 |
import torch.utils.checkpoint
|
| 11 |
import torch.nn.functional as F
|
|
@@ -177,14 +175,13 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
|
|
| 177 |
|
| 178 |
|
| 179 |
class RMSNorm(torch.nn.Module):
|
| 180 |
-
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None,
|
| 181 |
super().__init__()
|
| 182 |
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
| 183 |
self.eps = eps
|
| 184 |
-
self.quantized = quantized
|
| 185 |
|
| 186 |
def forward(self, hidden_states: torch.Tensor):
|
| 187 |
-
if
|
| 188 |
norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
|
| 189 |
x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
|
| 190 |
return self.weight * x_normed
|
|
@@ -521,14 +518,7 @@ class GLMBlock(torch.nn.Module):
|
|
| 521 |
|
| 522 |
self.fp32_residual_connection = config.fp32_residual_connection
|
| 523 |
|
| 524 |
-
if config.rmsnorm
|
| 525 |
-
if config.quantization_bit != 0:
|
| 526 |
-
LayerNormFunc = functools.partial(RMSNorm, quantized=True)
|
| 527 |
-
else:
|
| 528 |
-
LayerNormFunc = RMSNorm
|
| 529 |
-
else:
|
| 530 |
-
LayerNormFunc = LayerNorm
|
| 531 |
-
|
| 532 |
# Layernorm on the input data.
|
| 533 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 534 |
dtype=config.torch_dtype)
|
|
@@ -606,13 +596,7 @@ class GLMTransformer(torch.nn.Module):
|
|
| 606 |
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
| 607 |
|
| 608 |
if self.post_layer_norm:
|
| 609 |
-
if config.rmsnorm
|
| 610 |
-
if config.quantization_bit != 0:
|
| 611 |
-
LayerNormFunc = functools.partial(RMSNorm, quantized=True)
|
| 612 |
-
else:
|
| 613 |
-
LayerNormFunc = RMSNorm
|
| 614 |
-
else:
|
| 615 |
-
LayerNormFunc = LayerNorm
|
| 616 |
# Final layer norm before output.
|
| 617 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 618 |
dtype=config.torch_dtype)
|
|
|
|
| 3 |
import math
|
| 4 |
import copy
|
| 5 |
import warnings
|
|
|
|
| 6 |
import sys
|
|
|
|
| 7 |
import torch
|
| 8 |
import torch.utils.checkpoint
|
| 9 |
import torch.nn.functional as F
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
class RMSNorm(torch.nn.Module):
|
| 178 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
| 179 |
super().__init__()
|
| 180 |
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
| 181 |
self.eps = eps
|
|
|
|
| 182 |
|
| 183 |
def forward(self, hidden_states: torch.Tensor):
|
| 184 |
+
if hidden_states == torch.bfloat16:
|
| 185 |
norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
|
| 186 |
x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
|
| 187 |
return self.weight * x_normed
|
|
|
|
| 518 |
|
| 519 |
self.fp32_residual_connection = config.fp32_residual_connection
|
| 520 |
|
| 521 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
# Layernorm on the input data.
|
| 523 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 524 |
dtype=config.torch_dtype)
|
|
|
|
| 596 |
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
| 597 |
|
| 598 |
if self.post_layer_norm:
|
| 599 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
# Final layer norm before output.
|
| 601 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 602 |
dtype=config.torch_dtype)
|