Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class Normalize(nn.Module): | |
| def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): | |
| """ | |
| :param num_features: the number of features or channels | |
| :param eps: a value added for numerical stability | |
| :param affine: if True, RevIN has learnable affine parameters | |
| """ | |
| super(Normalize, self).__init__() | |
| self.num_features = num_features | |
| self.eps = eps | |
| self.affine = affine | |
| self.subtract_last = subtract_last | |
| self.non_norm = non_norm | |
| if self.affine: | |
| self._init_params() | |
| def forward(self, x, mode: str): | |
| if mode == 'norm': | |
| self._get_statistics(x) | |
| x = self._normalize(x) | |
| elif mode == 'denorm': | |
| x = self._denormalize(x) | |
| else: | |
| raise NotImplementedError | |
| return x | |
| def _init_params(self): | |
| # initialize RevIN params: (C,) | |
| self.affine_weight = nn.Parameter(torch.ones(self.num_features)) | |
| self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) | |
| def _get_statistics(self, x): | |
| dim2reduce = tuple(range(1, x.ndim - 1)) | |
| if self.subtract_last: | |
| self.last = x[:, -1, :].unsqueeze(1) | |
| else: | |
| self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() | |
| self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() | |
| def _normalize(self, x): | |
| if self.non_norm: | |
| return x | |
| if self.subtract_last: | |
| x = x - self.last | |
| else: | |
| x = x - self.mean | |
| x = x / self.stdev | |
| if self.affine: | |
| x = x * self.affine_weight | |
| x = x + self.affine_bias | |
| return x | |
| def _denormalize(self, x): | |
| if self.non_norm: | |
| return x | |
| if self.affine: | |
| x = x - self.affine_bias | |
| x = x / (self.affine_weight + self.eps * self.eps) | |
| x = x * self.stdev | |
| if self.subtract_last: | |
| x = x + self.last | |
| else: | |
| x = x + self.mean | |
| return x | |