Upload norm.py with huggingface_hub
Browse files
    	
        norm.py
    ADDED
    
    | 
         @@ -0,0 +1,56 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            def _cast_if_autocast_enabled(tensor):
         
     | 
| 4 | 
         
            +
                if torch.is_autocast_enabled():
         
     | 
| 5 | 
         
            +
                    if tensor.device.type == 'cuda':
         
     | 
| 6 | 
         
            +
                        dtype = torch.get_autocast_gpu_dtype()
         
     | 
| 7 | 
         
            +
                    elif tensor.device.type == 'cpu':
         
     | 
| 8 | 
         
            +
                        dtype = torch.get_autocast_cpu_dtype()
         
     | 
| 9 | 
         
            +
                    else:
         
     | 
| 10 | 
         
            +
                        raise NotImplementedError()
         
     | 
| 11 | 
         
            +
                    return tensor.to(dtype=dtype)
         
     | 
| 12 | 
         
            +
                return tensor
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class LPLayerNorm(torch.nn.LayerNorm):
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
         
     | 
| 17 | 
         
            +
                    super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                def forward(self, x):
         
     | 
| 20 | 
         
            +
                    module_device = x.device
         
     | 
| 21 | 
         
            +
                    downcast_x = _cast_if_autocast_enabled(x)
         
     | 
| 22 | 
         
            +
                    downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
         
     | 
| 23 | 
         
            +
                    downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
         
     | 
| 24 | 
         
            +
                    with torch.autocast(enabled=False, device_type=module_device.type):
         
     | 
| 25 | 
         
            +
                        return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def rms_norm(x, weight=None, eps=1e-05):
         
     | 
| 28 | 
         
            +
                output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
         
     | 
| 29 | 
         
            +
                if weight is not None:
         
     | 
| 30 | 
         
            +
                    return output * weight
         
     | 
| 31 | 
         
            +
                return output
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            class RMSNorm(torch.nn.Module):
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
         
     | 
| 36 | 
         
            +
                    super().__init__()
         
     | 
| 37 | 
         
            +
                    self.eps = eps
         
     | 
| 38 | 
         
            +
                    if weight:
         
     | 
| 39 | 
         
            +
                        self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
         
     | 
| 40 | 
         
            +
                    else:
         
     | 
| 41 | 
         
            +
                        self.register_parameter('weight', None)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def forward(self, x):
         
     | 
| 44 | 
         
            +
                    return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            class LPRMSNorm(RMSNorm):
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
         
     | 
| 49 | 
         
            +
                    super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def forward(self, x):
         
     | 
| 52 | 
         
            +
                    downcast_x = _cast_if_autocast_enabled(x)
         
     | 
| 53 | 
         
            +
                    downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
         
     | 
| 54 | 
         
            +
                    with torch.autocast(enabled=False, device_type=x.device.type):
         
     | 
| 55 | 
         
            +
                        return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
         
     | 
| 56 | 
         
            +
            NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
         
     |