Spaces:
Sleeping
Sleeping
| # Copyright 2022-2025 Xiaomi Corp. (authors: Daniel Povey | |
| # Wei Kang) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| import math | |
| import random | |
| import sys | |
| from typing import Optional, Tuple, Union | |
| try: | |
| import k2 | |
| except Exception as e: | |
| logging.warning( | |
| f"Failed import k2 with error {e}. Swoosh functions will fallback to PyTorch" | |
| f" implementation, leading to slower speed and higher memory consumption." | |
| ) | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| custom_bwd = lambda func: torch.amp.custom_bwd(func, device_type="cuda") | |
| custom_fwd = lambda func: torch.amp.custom_fwd(func, device_type="cuda") | |
| def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: | |
| max_value = torch.max(x, y) | |
| diff = torch.abs(x - y) | |
| return max_value + torch.log1p(torch.exp(-diff)) | |
| # RuntimeError: Exporting the operator logaddexp to ONNX opset version | |
| # 14 is not supported. Please feel free to request support or submit | |
| # a pull request on PyTorch GitHub. | |
| # | |
| # The following function is to solve the above error when exporting | |
| # models to ONNX via torch.jit.trace() | |
| def logaddexp(x: Tensor, y: Tensor) -> Tensor: | |
| # Caution(fangjun): Put torch.jit.is_scripting() before | |
| # torch.onnx.is_in_onnx_export(); | |
| # otherwise, it will cause errors for torch.jit.script(). | |
| # | |
| # torch.logaddexp() works for both torch.jit.script() and | |
| # torch.jit.trace() but it causes errors for ONNX export. | |
| # | |
| if torch.jit.is_scripting(): | |
| # Note: We cannot use torch.jit.is_tracing() here as it also | |
| # matches torch.onnx.export(). | |
| return torch.logaddexp(x, y) | |
| elif torch.onnx.is_in_onnx_export(): | |
| return logaddexp_onnx(x, y) | |
| else: | |
| # for torch.jit.trace() | |
| return torch.logaddexp(x, y) | |
| class PiecewiseLinear(object): | |
| """ | |
| Piecewise linear function, from float to float, specified as nonempty list of (x,y) | |
| pairs with the x values in order. x values <[initial x] or >[final x] are map to | |
| [initial y], [final y] respectively. | |
| """ | |
| def __init__(self, *args): | |
| assert len(args) >= 1, len(args) | |
| if len(args) == 1 and isinstance(args[0], PiecewiseLinear): | |
| self.pairs = list(args[0].pairs) | |
| else: | |
| self.pairs = [(float(x), float(y)) for x, y in args] | |
| for x, y in self.pairs: | |
| assert isinstance(x, (float, int)), type(x) | |
| assert isinstance(y, (float, int)), type(y) | |
| for i in range(len(self.pairs) - 1): | |
| assert self.pairs[i + 1][0] > self.pairs[i][0], ( | |
| i, | |
| self.pairs[i], | |
| self.pairs[i + 1], | |
| ) | |
| def __str__(self): | |
| # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' | |
| return f"PiecewiseLinear({str(self.pairs)[1:-1]})" | |
| def __call__(self, x): | |
| if x <= self.pairs[0][0]: | |
| return self.pairs[0][1] | |
| elif x >= self.pairs[-1][0]: | |
| return self.pairs[-1][1] | |
| else: | |
| cur_x, cur_y = self.pairs[0] | |
| for i in range(1, len(self.pairs)): | |
| next_x, next_y = self.pairs[i] | |
| if x >= cur_x and x <= next_x: | |
| return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) | |
| cur_x, cur_y = next_x, next_y | |
| assert False | |
| def __mul__(self, alpha): | |
| return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) | |
| def __add__(self, x): | |
| if isinstance(x, (float, int)): | |
| return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) | |
| s, x = self.get_common_basis(x) | |
| return PiecewiseLinear( | |
| *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] | |
| ) | |
| def max(self, x): | |
| if isinstance(x, (float, int)): | |
| x = PiecewiseLinear((0, x)) | |
| s, x = self.get_common_basis(x, include_crossings=True) | |
| return PiecewiseLinear( | |
| *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] | |
| ) | |
| def min(self, x): | |
| if isinstance(x, float) or isinstance(x, int): | |
| x = PiecewiseLinear((0, x)) | |
| s, x = self.get_common_basis(x, include_crossings=True) | |
| return PiecewiseLinear( | |
| *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] | |
| ) | |
| def __eq__(self, other): | |
| return self.pairs == other.pairs | |
| def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): | |
| """ | |
| Returns (self_mod, p_mod) which are equivalent piecewise linear | |
| functions to self and p, but with the same x values. | |
| p: the other piecewise linear function | |
| include_crossings: if true, include in the x values positions | |
| where the functions indicate by this and p crosss. | |
| """ | |
| assert isinstance(p, PiecewiseLinear), type(p) | |
| # get sorted x-values without repetition. | |
| x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) | |
| y_vals1 = [self(x) for x in x_vals] | |
| y_vals2 = [p(x) for x in x_vals] | |
| if include_crossings: | |
| extra_x_vals = [] | |
| for i in range(len(x_vals) - 1): | |
| if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): | |
| # if the two lines in this subsegment potentially cross each other.. | |
| diff_cur = abs(y_vals1[i] - y_vals2[i]) | |
| diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) | |
| # `pos`, between 0 and 1, gives the relative x position, | |
| # with 0 being x_vals[i] and 1 being x_vals[i+1]. | |
| pos = diff_cur / (diff_cur + diff_next) | |
| extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) | |
| extra_x_vals.append(extra_x_val) | |
| if len(extra_x_vals) > 0: | |
| x_vals = sorted(set(x_vals + extra_x_vals)) | |
| y_vals1 = [self(x) for x in x_vals] | |
| y_vals2 = [p(x) for x in x_vals] | |
| return ( | |
| PiecewiseLinear(*zip(x_vals, y_vals1)), | |
| PiecewiseLinear(*zip(x_vals, y_vals2)), | |
| ) | |
| class ScheduledFloat(torch.nn.Module): | |
| """ | |
| This object is a torch.nn.Module only because we want it to show up in | |
| [top_level module].modules(); it does not have a working forward() function. | |
| You are supposed to cast it to float, as in, float(parent_module.whatever), and use | |
| it as something like a dropout prob. | |
| It is a floating point value whose value changes depending on the batch count of the | |
| training loop. It is a piecewise linear function where you specify the (x,y) pairs | |
| in sorted order on x; x corresponds to the batch index. For batch-index values | |
| before the first x or after the last x, we just use the first or last y value. | |
| Example: | |
| self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) | |
| `default` is used when self.batch_count is not set or not in training mode or in | |
| torch.jit scripting mode. | |
| """ | |
| def __init__(self, *args, default: float = 0.0): | |
| super().__init__() | |
| # self.batch_count and self.name will be written to in the training loop. | |
| self.batch_count = None | |
| self.name = None | |
| self.default = default | |
| self.schedule = PiecewiseLinear(*args) | |
| def extra_repr(self) -> str: | |
| return ( | |
| f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" | |
| ) | |
| def __float__(self): | |
| batch_count = self.batch_count | |
| if ( | |
| batch_count is None | |
| or not self.training | |
| or torch.jit.is_scripting() | |
| or torch.jit.is_tracing() | |
| ): | |
| return float(self.default) | |
| else: | |
| ans = self.schedule(self.batch_count) | |
| if random.random() < 0.0002: | |
| logging.debug( | |
| f"ScheduledFloat: name={self.name}, " | |
| f"batch_count={self.batch_count}, ans={ans}" | |
| ) | |
| return ans | |
| def __add__(self, x): | |
| if isinstance(x, float) or isinstance(x, int): | |
| return ScheduledFloat(self.schedule + x, default=self.default) | |
| else: | |
| return ScheduledFloat( | |
| self.schedule + x.schedule, default=self.default + x.default | |
| ) | |
| def max(self, x): | |
| if isinstance(x, float) or isinstance(x, int): | |
| return ScheduledFloat(self.schedule.max(x), default=self.default) | |
| else: | |
| return ScheduledFloat( | |
| self.schedule.max(x.schedule), | |
| default=max(self.default, x.default), | |
| ) | |
| FloatLike = Union[float, ScheduledFloat] | |
| class CutoffEstimator: | |
| """ | |
| Estimates cutoffs of an arbitrary numerical quantity such that a specified | |
| proportion of items will be above the cutoff on average. | |
| p is the proportion of items that should be above the cutoff. | |
| """ | |
| def __init__(self, p: float): | |
| self.p = p | |
| # total count of items | |
| self.count = 0 | |
| # total count of items that were above the cutoff | |
| self.count_above = 0 | |
| # initial cutoff value | |
| self.cutoff = 0 | |
| def __call__(self, x: float) -> bool: | |
| """ | |
| Returns true if x is above the cutoff. | |
| """ | |
| ans = x > self.cutoff | |
| self.count += 1 | |
| if ans: | |
| self.count_above += 1 | |
| cur_p = self.count_above / self.count | |
| delta_p = cur_p - self.p | |
| if (delta_p > 0) == ans: | |
| q = abs(delta_p) | |
| self.cutoff = x * q + self.cutoff * (1 - q) | |
| return ans | |
| class SoftmaxFunction(torch.autograd.Function): | |
| """ | |
| Tries to handle half-precision derivatives in a randomized way that should | |
| be more accurate for training than the default behavior. | |
| """ | |
| def forward(ctx, x: Tensor, dim: int): | |
| ans = x.softmax(dim=dim) | |
| # if x dtype is float16, x.softmax() returns a float32 because | |
| # (presumably) that op does not support float16, and autocast | |
| # is enabled. | |
| if torch.is_autocast_enabled(): | |
| ans = ans.to(torch.float16) | |
| ctx.save_for_backward(ans) | |
| ctx.x_dtype = x.dtype | |
| ctx.dim = dim | |
| return ans | |
| def backward(ctx, ans_grad: Tensor): | |
| (ans,) = ctx.saved_tensors | |
| with torch.amp.autocast("cuda", enabled=False): | |
| ans_grad = ans_grad.to(torch.float32) | |
| ans = ans.to(torch.float32) | |
| x_grad = ans_grad * ans | |
| x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) | |
| return x_grad, None | |
| def softmax(x: Tensor, dim: int): | |
| if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): | |
| return x.softmax(dim=dim) | |
| return SoftmaxFunction.apply(x, dim) | |
| class BiasNormFunction(torch.autograd.Function): | |
| # This computes: | |
| # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() | |
| # return x * scales | |
| # (after unsqueezing the bias), but it does it in a memory-efficient way so that | |
| # it can just store the returned value (chances are, this will also be needed for | |
| # some other reason, related to the next operation, so we can save memory). | |
| def forward( | |
| ctx, | |
| x: Tensor, | |
| bias: Tensor, | |
| log_scale: Tensor, | |
| channel_dim: int, | |
| store_output_for_backprop: bool, | |
| ) -> Tensor: | |
| assert bias.ndim == 1 | |
| if channel_dim < 0: | |
| channel_dim = channel_dim + x.ndim | |
| ctx.store_output_for_backprop = store_output_for_backprop | |
| ctx.channel_dim = channel_dim | |
| for _ in range(channel_dim + 1, x.ndim): | |
| bias = bias.unsqueeze(-1) | |
| scales = ( | |
| torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 | |
| ) * log_scale.exp() | |
| ans = x * scales | |
| ctx.save_for_backward( | |
| ans.detach() if store_output_for_backprop else x, | |
| scales.detach(), | |
| bias.detach(), | |
| log_scale.detach(), | |
| ) | |
| return ans | |
| def backward(ctx, ans_grad: Tensor) -> Tensor: | |
| ans_or_x, scales, bias, log_scale = ctx.saved_tensors | |
| if ctx.store_output_for_backprop: | |
| x = ans_or_x / scales | |
| else: | |
| x = ans_or_x | |
| x = x.detach() | |
| x.requires_grad = True | |
| bias.requires_grad = True | |
| log_scale.requires_grad = True | |
| with torch.enable_grad(): | |
| # recompute scales from x, bias and log_scale. | |
| scales = ( | |
| torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 | |
| ) * log_scale.exp() | |
| ans = x * scales | |
| ans.backward(gradient=ans_grad) | |
| return x.grad, bias.grad.flatten(), log_scale.grad, None, None | |
| class BiasNorm(torch.nn.Module): | |
| """ | |
| This is intended to be a simpler, and hopefully cheaper, replacement for | |
| LayerNorm. The observation this is based on, is that Transformer-type | |
| networks, especially with pre-norm, sometimes seem to set one of the | |
| feature dimensions to a large constant value (e.g. 50), which "defeats" | |
| the LayerNorm because the output magnitude is then not strongly dependent | |
| on the other (useful) features. Presumably the weight and bias of the | |
| LayerNorm are required to allow it to do this. | |
| Instead, we give the BiasNorm a trainable bias that it can use when | |
| computing the scale for normalization. We also give it a (scalar) | |
| trainable scale on the output. | |
| Args: | |
| num_channels: the number of channels, e.g. 512. | |
| channel_dim: the axis/dimension corresponding to the channel, | |
| interpreted as an offset from the input's ndim if negative. | |
| This is NOT the num_channels; it should typically be one of | |
| {-2, -1, 0, 1, 2, 3}. | |
| log_scale: the initial log-scale that we multiply the output by; this | |
| is learnable. | |
| log_scale_min: FloatLike, minimum allowed value of log_scale | |
| log_scale_max: FloatLike, maximum allowed value of log_scale | |
| store_output_for_backprop: only possibly affects memory use; recommend | |
| to set to True if you think the output of this module is more likely | |
| than the input of this module to be required to be stored for the | |
| backprop. | |
| """ | |
| def __init__( | |
| self, | |
| num_channels: int, | |
| channel_dim: int = -1, # CAUTION: see documentation. | |
| log_scale: float = 1.0, | |
| log_scale_min: float = -1.5, | |
| log_scale_max: float = 1.5, | |
| store_output_for_backprop: bool = False, | |
| ) -> None: | |
| super(BiasNorm, self).__init__() | |
| self.num_channels = num_channels | |
| self.channel_dim = channel_dim | |
| self.log_scale = nn.Parameter(torch.tensor(log_scale)) | |
| self.bias = nn.Parameter(torch.zeros(num_channels)) | |
| self.log_scale_min = log_scale_min | |
| self.log_scale_max = log_scale_max | |
| self.store_output_for_backprop = store_output_for_backprop | |
| def forward(self, x: Tensor) -> Tensor: | |
| assert x.shape[self.channel_dim] == self.num_channels | |
| if torch.jit.is_scripting() or torch.jit.is_tracing(): | |
| channel_dim = self.channel_dim | |
| if channel_dim < 0: | |
| channel_dim += x.ndim | |
| bias = self.bias | |
| for _ in range(channel_dim + 1, x.ndim): | |
| bias = bias.unsqueeze(-1) | |
| scales = ( | |
| torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 | |
| ) * self.log_scale.exp() | |
| return x * scales | |
| log_scale = limit_param_value( | |
| self.log_scale, | |
| min=float(self.log_scale_min), | |
| max=float(self.log_scale_max), | |
| training=self.training, | |
| ) | |
| return BiasNormFunction.apply( | |
| x, | |
| self.bias, | |
| log_scale, | |
| self.channel_dim, | |
| self.store_output_for_backprop, | |
| ) | |
| def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: | |
| """ | |
| Behaves like a constructor of a modified version of nn.Linear | |
| that gives an easy way to set the default initial parameter scale. | |
| Args: | |
| Accepts the standard args and kwargs that nn.Linear accepts | |
| e.g. in_features, out_features, bias=False. | |
| initial_scale: you can override this if you want to increase | |
| or decrease the initial magnitude of the module's output | |
| (affects the initialization of weight_scale and bias_scale). | |
| Another option, if you want to do something like this, is | |
| to re-initialize the parameters. | |
| """ | |
| ans = nn.Linear(*args, **kwargs) | |
| with torch.no_grad(): | |
| ans.weight[:] *= initial_scale | |
| if ans.bias is not None: | |
| torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) | |
| return ans | |
| class BalancerFunction(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| x: Tensor, | |
| min_mean: float, | |
| max_mean: float, | |
| min_rms: float, | |
| max_rms: float, | |
| grad_scale: float, | |
| channel_dim: int, | |
| ) -> Tensor: | |
| if channel_dim < 0: | |
| channel_dim += x.ndim | |
| ctx.channel_dim = channel_dim | |
| ctx.save_for_backward(x) | |
| ctx.config = ( | |
| min_mean, | |
| max_mean, | |
| min_rms, | |
| max_rms, | |
| grad_scale, | |
| channel_dim, | |
| ) | |
| return x | |
| def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: | |
| (x,) = ctx.saved_tensors | |
| ( | |
| min_mean, | |
| max_mean, | |
| min_rms, | |
| max_rms, | |
| grad_scale, | |
| channel_dim, | |
| ) = ctx.config | |
| try: | |
| with torch.enable_grad(): | |
| with torch.amp.autocast("cuda", enabled=False): | |
| x = x.to(torch.float32) | |
| x = x.detach() | |
| x.requires_grad = True | |
| mean_dims = [i for i in range(x.ndim) if i != channel_dim] | |
| uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) | |
| mean = x.mean(dim=mean_dims, keepdim=True) | |
| stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() | |
| rms = uncentered_var.clamp(min=1.0e-20).sqrt() | |
| m = mean / stddev | |
| # part of loss that relates to mean / stddev | |
| m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() | |
| # put a much larger scale on the RMS-max-limit loss, so that if both | |
| # it and the m_loss are violated we fix the RMS loss first. | |
| rms_clamped = rms.clamp(min=min_rms, max=max_rms) | |
| r_loss = (rms_clamped / rms).log().abs() | |
| loss = m_loss + r_loss | |
| loss.backward(gradient=torch.ones_like(loss)) | |
| loss_grad = x.grad | |
| loss_grad_rms = ( | |
| (loss_grad**2) | |
| .mean(dim=mean_dims, keepdim=True) | |
| .sqrt() | |
| .clamp(min=1.0e-20) | |
| ) | |
| loss_grad = loss_grad * (grad_scale / loss_grad_rms) | |
| x_grad_float = x_grad.to(torch.float32) | |
| # scale each element of loss_grad by the absolute value of the | |
| # corresponding element of x_grad, which we view as a noisy estimate | |
| # of its magnitude for that (frame and dimension). later we can | |
| # consider factored versions. | |
| x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) | |
| x_grad = x_grad_mod.to(x_grad.dtype) | |
| except Exception as e: | |
| logging.info( | |
| f"Caught exception in Balancer backward: {e}, " | |
| f"size={list(x_grad.shape)}, will continue." | |
| ) | |
| return x_grad, None, None, None, None, None, None | |
| class Balancer(torch.nn.Module): | |
| """ | |
| Modifies the backpropped derivatives of a function to try to encourage, for | |
| each channel, that it is positive at least a proportion `threshold` of the | |
| time. It does this by multiplying negative derivative values by up to | |
| (1+max_factor), and positive derivative values by up to (1-max_factor), | |
| interpolated from 1 at the threshold to those extremal values when none | |
| of the inputs are positive. | |
| Args: | |
| num_channels: the number of channels | |
| channel_dim: the dimension/axis corresponding to the channel, e.g. | |
| -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. | |
| min_positive: the minimum, per channel, of the proportion of the time | |
| that (x > 0), below which we start to modify the derivatives. | |
| max_positive: the maximum, per channel, of the proportion of the time | |
| that (x > 0), above which we start to modify the derivatives. | |
| scale_gain_factor: determines the 'gain' with which we increase the | |
| change in gradient once the constraints on min_abs and max_abs | |
| are violated. | |
| min_abs: the minimum average-absolute-value difference from the mean | |
| value per channel, which we allow, before we start to modify | |
| the derivatives to prevent this. | |
| max_abs: the maximum average-absolute-value difference from the mean | |
| value per channel, which we allow, before we start to modify | |
| the derivatives to prevent this. | |
| prob: determines the minimum probability with which we modify the | |
| gradients for the {min,max}_positive and {min,max}_abs constraints, | |
| on each forward(). This is done randomly to prevent all layers | |
| from doing it at the same time. | |
| """ | |
| def __init__( | |
| self, | |
| num_channels: int, | |
| channel_dim: int, | |
| min_positive: FloatLike = 0.05, | |
| max_positive: FloatLike = 0.95, | |
| min_abs: FloatLike = 0.2, | |
| max_abs: FloatLike = 100.0, | |
| grad_scale: FloatLike = 0.04, | |
| prob: Optional[FloatLike] = None, | |
| ): | |
| super().__init__() | |
| if prob is None: | |
| prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) | |
| self.prob = prob | |
| # 5% of the time we will return and do nothing because memory usage is | |
| # too high. | |
| self.mem_cutoff = CutoffEstimator(0.05) | |
| # actually self.num_channels is no longer needed except for an assertion. | |
| self.num_channels = num_channels | |
| self.channel_dim = channel_dim | |
| self.min_positive = min_positive | |
| self.max_positive = max_positive | |
| self.min_abs = min_abs | |
| self.max_abs = max_abs | |
| self.grad_scale = grad_scale | |
| def forward(self, x: Tensor) -> Tensor: | |
| if ( | |
| torch.jit.is_scripting() | |
| or not x.requires_grad | |
| or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) | |
| ): | |
| return _no_op(x) | |
| prob = float(self.prob) | |
| if random.random() < prob: | |
| # The following inner-functions convert from the way we historically | |
| # specified these limitations, as limits on the absolute value and the | |
| # proportion of positive values, to limits on the RMS value and | |
| # the (mean / stddev). | |
| def _abs_to_rms(x): | |
| # for normally distributed data, if the expected absolute value is x, | |
| # the expected rms value will be sqrt(pi/2) * x. | |
| return 1.25331413732 * x | |
| def _proportion_positive_to_mean(x): | |
| def _atanh(x): | |
| eps = 1.0e-10 | |
| # eps is to prevent crashes if x is exactly 0 or 1. | |
| # we'll just end up returning a fairly large value. | |
| return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 | |
| def _approx_inverse_erf(x): | |
| # 1 / (sqrt(pi) * ln(2)), | |
| # see https://math.stackexchange.com/questions/321569/ | |
| # approximating-the-error-function-erf-by-analytical-functions | |
| # this approximation is extremely crude and gets progressively worse | |
| # for x very close to -1 or +1, but we mostly care about the | |
| # "middle" region | |
| # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, | |
| # and math.erf(0.0407316414078772) = 0.045935330944660666, | |
| # which is pretty close to 0.05. | |
| return 0.8139535143 * _atanh(x) | |
| # first convert x from the range 0..1 to the range -1..1 which the error | |
| # function returns | |
| x = -1 + (2 * x) | |
| return _approx_inverse_erf(x) | |
| min_mean = _proportion_positive_to_mean(float(self.min_positive)) | |
| max_mean = _proportion_positive_to_mean(float(self.max_positive)) | |
| min_rms = _abs_to_rms(float(self.min_abs)) | |
| max_rms = _abs_to_rms(float(self.max_abs)) | |
| grad_scale = float(self.grad_scale) | |
| assert x.shape[self.channel_dim] == self.num_channels | |
| return BalancerFunction.apply( | |
| x, | |
| min_mean, | |
| max_mean, | |
| min_rms, | |
| max_rms, | |
| grad_scale, | |
| self.channel_dim, | |
| ) | |
| else: | |
| return _no_op(x) | |
| def penalize_abs_values_gt( | |
| x: Tensor, limit: float, penalty: float, name: str = None | |
| ) -> Tensor: | |
| """ | |
| Returns x unmodified, but in backprop will put a penalty for the excess of | |
| the absolute values of elements of x over the limit "limit". E.g. if | |
| limit == 10.0, then if x has any values over 10 it will get a penalty. | |
| Caution: the value of this penalty will be affected by grad scaling used | |
| in automatic mixed precision training. For this reasons we use this, | |
| it shouldn't really matter, or may even be helpful; we just use this | |
| to disallow really implausible values of scores to be given to softmax. | |
| The name is for randomly printed debug info. | |
| """ | |
| x_sign = x.sign() | |
| over_limit = (x.abs() - limit) > 0 | |
| # The following is a memory efficient way to penalize the absolute values of | |
| # x that's over the limit. (The memory efficiency comes when you think | |
| # about which items torch needs to cache for the autograd, and which ones it | |
| # can throw away). The numerical value of aux_loss as computed here will | |
| # actually be larger than it should be, by limit * over_limit.sum(), but it | |
| # has the same derivative as the real aux_loss which is penalty * (x.abs() - | |
| # limit).relu(). | |
| aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) | |
| # note: we don't do sum() here on aux)_loss, but it's as if we had done | |
| # sum() due to how with_loss() works. | |
| x = with_loss(x, aux_loss, name) | |
| # you must use x for something, or this will be ineffective. | |
| return x | |
| def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. | |
| if x.ndim == 2: | |
| return x.diag() | |
| else: | |
| (batch, dim, dim) = x.shape | |
| x = x.reshape(batch, dim * dim) | |
| x = x[:, :: dim + 1] | |
| assert x.shape == (batch, dim) | |
| return x | |
| def _whitening_metric(x: Tensor, num_groups: int): | |
| """ | |
| Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of | |
| of the centered feature covariance are the same within each group's covariance | |
| matrix and also between groups. | |
| Args: | |
| x: a Tensor of shape (*, num_channels) | |
| num_groups: the number of groups of channels, a number >=1 that divides | |
| num_channels | |
| Returns: | |
| Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and | |
| greater than 1.0 otherwise. | |
| """ | |
| assert x.dtype != torch.float16 | |
| x = x.reshape(-1, x.shape[-1]) | |
| (num_frames, num_channels) = x.shape | |
| assert num_channels % num_groups == 0 | |
| channels_per_group = num_channels // num_groups | |
| x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) | |
| # x now has shape (num_groups, num_frames, channels_per_group) | |
| # subtract the mean so we use the centered, not uncentered, covariance. | |
| # My experience has been that when we "mess with the gradients" like this, | |
| # it's better not do anything that tries to move the mean around, because | |
| # that can easily cause instability. | |
| x = x - x.mean(dim=1, keepdim=True) | |
| # x_covar: (num_groups, channels_per_group, channels_per_group) | |
| x_covar = torch.matmul(x.transpose(1, 2), x) | |
| x_covar_mean_diag = _diag(x_covar).mean() | |
| # the following expression is what we'd get if we took the matrix product | |
| # of each covariance and measured the mean of its trace, i.e. | |
| # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). | |
| x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) | |
| # this metric will be >= 1.0; the larger it is, the less 'white' the data was. | |
| metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) | |
| return metric | |
| class WhiteningPenaltyFunction(torch.autograd.Function): | |
| def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: | |
| ctx.save_for_backward(x) | |
| ctx.module = module | |
| return x | |
| def backward(ctx, x_grad: Tensor): | |
| (x_orig,) = ctx.saved_tensors | |
| w = ctx.module | |
| try: | |
| with torch.enable_grad(): | |
| with torch.amp.autocast("cuda", enabled=False): | |
| x_detached = x_orig.to(torch.float32).detach() | |
| x_detached.requires_grad = True | |
| metric = _whitening_metric(x_detached, w.num_groups) | |
| if random.random() < 0.005 or __name__ == "__main__": | |
| logging.debug( | |
| f"Whitening: name={w.name}, num_groups={w.num_groups}," | |
| f"num_channels={x_orig.shape[-1]}, " | |
| f"metric={metric.item():.2f}" | |
| f" vs. limit={float(w.whitening_limit)}" | |
| ) | |
| if metric < float(w.whitening_limit): | |
| w.prob = w.min_prob | |
| return x_grad, None | |
| else: | |
| w.prob = w.max_prob | |
| metric.backward() | |
| penalty_grad = x_detached.grad | |
| scale = w.grad_scale * ( | |
| x_grad.to(torch.float32).norm() | |
| / (penalty_grad.norm() + 1.0e-20) | |
| ) | |
| penalty_grad = penalty_grad * scale | |
| return x_grad + penalty_grad.to(x_grad.dtype), None | |
| except Exception as e: | |
| logging.info( | |
| f"Caught exception in Whiten backward: {e}, " | |
| f"size={list(x_grad.shape)}, will continue." | |
| ) | |
| return x_grad, None | |
| class Whiten(nn.Module): | |
| def __init__( | |
| self, | |
| num_groups: int, | |
| whitening_limit: FloatLike, | |
| prob: Union[float, Tuple[float, float]], | |
| grad_scale: FloatLike, | |
| ): | |
| """ | |
| Args: | |
| num_groups: the number of groups to divide the channel dim into before | |
| whitening. We will attempt to make the feature covariance | |
| within each group, after mean subtraction, as "white" as possible, | |
| while having the same trace across all groups. | |
| whitening_limit: a value greater than 1.0, that dictates how much | |
| freedom we have to violate the constraints. 1.0 would mean perfectly | |
| white, with exactly the same trace across groups; larger values | |
| give more freedom. E.g. 2.0. | |
| prob: the probability with which we apply the gradient modification | |
| (also affects the grad scale). May be supplied as a float, | |
| or as a pair (min_prob, max_prob) | |
| grad_scale: determines the scale on the gradient term from this object, | |
| relative to the rest of the gradient on the attention weights. | |
| E.g. 0.02 (you may want to use smaller values than this if prob is large) | |
| """ | |
| super(Whiten, self).__init__() | |
| assert num_groups >= 1 | |
| assert float(whitening_limit) >= 1 | |
| assert grad_scale >= 0 | |
| self.num_groups = num_groups | |
| self.whitening_limit = whitening_limit | |
| self.grad_scale = grad_scale | |
| if isinstance(prob, float): | |
| prob = (prob, prob) | |
| (self.min_prob, self.max_prob) = prob | |
| assert 0 < self.min_prob <= self.max_prob <= 1 | |
| self.prob = self.max_prob | |
| self.name = None # will be set in training loop | |
| def forward(self, x: Tensor) -> Tensor: | |
| """ | |
| In the forward pass, this function just returns the input unmodified. | |
| In the backward pass, it will modify the gradients to ensure that the | |
| distribution in each group has close to (lambda times I) as the covariance | |
| after mean subtraction, with the same lambda across groups. | |
| For whitening_limit > 1, there will be more freedom to violate this | |
| constraint. | |
| Args: | |
| x: the input of shape (*, num_channels) | |
| Returns: | |
| x, unmodified. You should make sure | |
| you use the returned value, or the graph will be freed | |
| and nothing will happen in backprop. | |
| """ | |
| grad_scale = float(self.grad_scale) | |
| if not x.requires_grad or random.random() > self.prob or grad_scale == 0: | |
| return _no_op(x) | |
| else: | |
| return WhiteningPenaltyFunction.apply(x, self) | |
| class WithLoss(torch.autograd.Function): | |
| def forward(ctx, x: Tensor, y: Tensor, name: str): | |
| ctx.y_shape = y.shape | |
| if random.random() < 0.002 and name is not None: | |
| loss_sum = y.sum().item() | |
| logging.debug(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") | |
| return x | |
| def backward(ctx, ans_grad: Tensor): | |
| return ( | |
| ans_grad, | |
| torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), | |
| None, | |
| ) | |
| def with_loss(x, y, name): | |
| # returns x but adds y.sum() to the loss function. | |
| return WithLoss.apply(x, y, name) | |
| class LimitParamValue(torch.autograd.Function): | |
| def forward(ctx, x: Tensor, min: float, max: float): | |
| ctx.save_for_backward(x) | |
| assert max >= min | |
| ctx.min = min | |
| ctx.max = max | |
| return x | |
| def backward(ctx, x_grad: Tensor): | |
| (x,) = ctx.saved_tensors | |
| # where x < ctx.min, ensure all grads are negative (this will tend to make | |
| # x more positive). | |
| x_grad = x_grad * torch.where( | |
| torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 | |
| ) | |
| # where x > ctx.max, ensure all grads are positive (this will tend to make | |
| # x more negative). | |
| x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) | |
| return x_grad, None, None | |
| def limit_param_value( | |
| x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True | |
| ): | |
| # You apply this to (typically) an nn.Parameter during training to ensure that its | |
| # (elements mostly) stays within a supplied range. This is done by modifying the | |
| # gradients in backprop. | |
| # It's not necessary to do this on every batch: do it only some of the time, | |
| # to save a little time. | |
| if training and random.random() < prob: | |
| return LimitParamValue.apply(x, min, max) | |
| else: | |
| return x | |
| def _no_op(x: Tensor) -> Tensor: | |
| if torch.jit.is_scripting() or torch.jit.is_tracing(): | |
| return x | |
| else: | |
| # a no-op function that will have a node in the autograd graph, | |
| # to avoid certain bugs relating to backward hooks | |
| return x.chunk(1, dim=-1)[0] | |
| # Identity more friendly to backward hooks than nn.Identity(), for diagnostic reasons. | |
| class Identity(torch.nn.Module): | |
| def __init__(self): | |
| super(Identity, self).__init__() | |
| def forward(self, x): | |
| return _no_op(x) | |
| # Dropout2 is just like normal dropout, except it supports schedules | |
| # on the dropout rates. | |
| class Dropout2(nn.Module): | |
| def __init__(self, p: FloatLike): | |
| super().__init__() | |
| self.p = p | |
| def forward(self, x: Tensor) -> Tensor: | |
| return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) | |
| class MulForDropout3(torch.autograd.Function): | |
| # returns (x * y * alpha) where alpha is a float and y doesn't require | |
| # grad and is zero-or-one. | |
| def forward(ctx, x, y, alpha): | |
| assert not y.requires_grad | |
| ans = x * y * alpha | |
| ctx.save_for_backward(ans) | |
| ctx.alpha = alpha | |
| return ans | |
| def backward(ctx, ans_grad): | |
| (ans,) = ctx.saved_tensors | |
| x_grad = ctx.alpha * ans_grad * (ans != 0) | |
| return x_grad, None, None | |
| # Dropout3 is just like normal dropout, except it supports schedules on the dropout | |
| # rates, and it lets you choose one dimension to share the dropout mask over | |
| class Dropout3(nn.Module): | |
| def __init__(self, p: FloatLike, shared_dim: int): | |
| super().__init__() | |
| self.p = p | |
| self.shared_dim = shared_dim | |
| def forward(self, x: Tensor) -> Tensor: | |
| p = float(self.p) | |
| if not self.training or p == 0: | |
| return _no_op(x) | |
| scale = 1.0 / (1 - p) | |
| rand_shape = list(x.shape) | |
| rand_shape[self.shared_dim] = 1 | |
| mask = torch.rand(*rand_shape, device=x.device) > p | |
| ans = MulForDropout3.apply(x, mask, scale) | |
| return ans | |
| class SwooshLFunction(torch.autograd.Function): | |
| """ | |
| swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 | |
| """ | |
| def forward(ctx, x: Tensor) -> Tensor: | |
| requires_grad = x.requires_grad | |
| if x.dtype == torch.float16: | |
| x = x.to(torch.float32) | |
| zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) | |
| coeff = -0.08 | |
| with torch.amp.autocast("cuda", enabled=False): | |
| with torch.enable_grad(): | |
| x = x.detach() | |
| x.requires_grad = True | |
| y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 | |
| if not requires_grad: | |
| return y | |
| y.backward(gradient=torch.ones_like(y)) | |
| grad = x.grad | |
| floor = coeff | |
| ceil = 1.0 + coeff + 0.005 | |
| d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( | |
| grad | |
| ) | |
| if __name__ == "__main__": | |
| # for self-testing only. | |
| assert d_scaled.min() >= 0.0 | |
| assert d_scaled.max() < 256.0 | |
| d_int = d_scaled.to(torch.uint8) | |
| ctx.save_for_backward(d_int) | |
| if x.dtype == torch.float16 or torch.is_autocast_enabled(): | |
| y = y.to(torch.float16) | |
| return y | |
| def backward(ctx, y_grad: Tensor) -> Tensor: | |
| (d,) = ctx.saved_tensors | |
| # the same constants as used in forward pass. | |
| coeff = -0.08 | |
| floor = coeff | |
| ceil = 1.0 + coeff + 0.005 | |
| d = d * ((ceil - floor) / 255.0) + floor | |
| return y_grad * d | |
| class SwooshL(torch.nn.Module): | |
| def forward(self, x: Tensor) -> Tensor: | |
| """Return Swoosh-L activation.""" | |
| if torch.jit.is_scripting() or torch.jit.is_tracing(): | |
| zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) | |
| return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 | |
| return SwooshLFunction.apply(x) | |
| class SwooshLOnnx(torch.nn.Module): | |
| def forward(self, x: Tensor) -> Tensor: | |
| """Return Swoosh-L activation.""" | |
| zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) | |
| return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 | |
| class SwooshRFunction(torch.autograd.Function): | |
| """ | |
| swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 | |
| derivatives are between -0.08 and 0.92. | |
| """ | |
| def forward(ctx, x: Tensor) -> Tensor: | |
| requires_grad = x.requires_grad | |
| if x.dtype == torch.float16: | |
| x = x.to(torch.float32) | |
| zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) | |
| with torch.amp.autocast("cuda", enabled=False): | |
| with torch.enable_grad(): | |
| x = x.detach() | |
| x.requires_grad = True | |
| y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 | |
| if not requires_grad: | |
| return y | |
| y.backward(gradient=torch.ones_like(y)) | |
| grad = x.grad | |
| floor = -0.08 | |
| ceil = 0.925 | |
| d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( | |
| grad | |
| ) | |
| if __name__ == "__main__": | |
| # for self-testing only. | |
| assert d_scaled.min() >= 0.0 | |
| assert d_scaled.max() < 256.0 | |
| d_int = d_scaled.to(torch.uint8) | |
| ctx.save_for_backward(d_int) | |
| if x.dtype == torch.float16 or torch.is_autocast_enabled(): | |
| y = y.to(torch.float16) | |
| return y | |
| def backward(ctx, y_grad: Tensor) -> Tensor: | |
| (d,) = ctx.saved_tensors | |
| # the same constants as used in forward pass. | |
| floor = -0.08 | |
| ceil = 0.925 | |
| d = d * ((ceil - floor) / 255.0) + floor | |
| return y_grad * d | |
| class SwooshR(torch.nn.Module): | |
| def forward(self, x: Tensor) -> Tensor: | |
| """Return Swoosh-R activation.""" | |
| if torch.jit.is_scripting() or torch.jit.is_tracing(): | |
| zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) | |
| return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 | |
| return SwooshRFunction.apply(x) | |
| class SwooshROnnx(torch.nn.Module): | |
| def forward(self, x: Tensor) -> Tensor: | |
| """Return Swoosh-R activation.""" | |
| zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) | |
| return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 | |
| # simple version of SwooshL that does not redefine the backprop, used in | |
| # ActivationDropoutAndLinearFunction. | |
| def SwooshLForward(x: Tensor): | |
| with torch.amp.autocast("cuda", enabled=False): | |
| x = x.to(torch.float32) | |
| x_offset = x - 4.0 | |
| log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) | |
| log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) | |
| return log_sum - 0.08 * x - 0.035 | |
| # simple version of SwooshR that does not redefine the backprop, used in | |
| # ActivationDropoutAndLinearFunction. | |
| def SwooshRForward(x: Tensor): | |
| with torch.amp.autocast("cuda", enabled=False): | |
| x = x.to(torch.float32) | |
| x_offset = x - 1.0 | |
| log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) | |
| log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) | |
| return log_sum - 0.08 * x - 0.313261687 | |
| class ActivationDropoutAndLinearFunction(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| x: Tensor, | |
| weight: Tensor, | |
| bias: Optional[Tensor], | |
| activation: str, | |
| dropout_p: float, | |
| dropout_shared_dim: Optional[int], | |
| ): | |
| if dropout_p != 0.0: | |
| dropout_shape = list(x.shape) | |
| if dropout_shared_dim is not None: | |
| dropout_shape[dropout_shared_dim] = 1 | |
| # else it won't be very memory efficient. | |
| dropout_mask = (1.0 / (1.0 - dropout_p)) * ( | |
| torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p | |
| ) | |
| else: | |
| dropout_mask = None | |
| ctx.save_for_backward(x, weight, bias, dropout_mask) | |
| ctx.activation = activation | |
| forward_activation_dict = { | |
| "SwooshL": k2.swoosh_l_forward, | |
| "SwooshR": k2.swoosh_r_forward, | |
| } | |
| # it will raise a KeyError if this fails. This will be an error. We let it | |
| # propagate to the user. | |
| activation_func = forward_activation_dict[activation] | |
| x = activation_func(x) | |
| if dropout_mask is not None: | |
| x = x * dropout_mask | |
| x = torch.nn.functional.linear(x, weight, bias) | |
| return x | |
| def backward(ctx, ans_grad: Tensor): | |
| saved = ctx.saved_tensors | |
| (x, weight, bias, dropout_mask) = saved | |
| forward_and_deriv_activation_dict = { | |
| "SwooshL": k2.swoosh_l_forward_and_deriv, | |
| "SwooshR": k2.swoosh_r_forward_and_deriv, | |
| } | |
| # the following lines a KeyError if the activation is unrecognized. | |
| # This will be an error. We let it propagate to the user. | |
| func = forward_and_deriv_activation_dict[ctx.activation] | |
| y, func_deriv = func(x) | |
| if dropout_mask is not None: | |
| y = y * dropout_mask | |
| # now compute derivative of y w.r.t. weight and bias.. | |
| # y: (..., in_channels), ans_grad: (..., out_channels), | |
| (out_channels, in_channels) = weight.shape | |
| in_channels = y.shape[-1] | |
| g = ans_grad.reshape(-1, out_channels) | |
| weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) | |
| y_deriv = torch.matmul(ans_grad, weight) | |
| bias_deriv = None if bias is None else g.sum(dim=0) | |
| x_deriv = y_deriv * func_deriv | |
| if dropout_mask is not None: | |
| # order versus func_deriv does not matter | |
| x_deriv = x_deriv * dropout_mask | |
| return x_deriv, weight_deriv, bias_deriv, None, None, None | |
| class ActivationDropoutAndLinear(torch.nn.Module): | |
| """ | |
| This merges an activation function followed by dropout and then a nn.Linear module; | |
| it does so in a memory efficient way so that it only stores the input to the whole | |
| module. If activation == SwooshL and dropout_shared_dim != None, this will be | |
| equivalent to: | |
| nn.Sequential(SwooshL(), | |
| Dropout3(dropout_p, shared_dim=dropout_shared_dim), | |
| ScaledLinear(in_channels, out_channels, bias=bias, | |
| initial_scale=initial_scale)) | |
| If dropout_shared_dim is None, the dropout would be equivalent to | |
| Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout | |
| mask is smaller. | |
| Args: | |
| in_channels: number of input channels, e.g. 256 | |
| out_channels: number of output channels, e.g. 256 | |
| bias: if true, have a bias | |
| activation: the activation function, for now just support SwooshL. | |
| dropout_p: the dropout probability or schedule (happens after nonlinearity). | |
| dropout_shared_dim: the dimension, if any, across which the dropout mask is | |
| shared (e.g. the time dimension). If None, this may be less memory | |
| efficient if there are modules before this one that cache the input | |
| for their backprop (e.g. Balancer or Whiten). | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| bias: bool = True, | |
| activation: str = "SwooshL", | |
| dropout_p: FloatLike = 0.0, | |
| dropout_shared_dim: Optional[int] = -1, | |
| initial_scale: float = 1.0, | |
| ): | |
| super().__init__() | |
| # create a temporary module of nn.Linear that we'll steal the | |
| # weights and bias from | |
| l = ScaledLinear( | |
| in_channels, out_channels, bias=bias, initial_scale=initial_scale | |
| ) | |
| self.weight = l.weight | |
| # register_parameter properly handles making it a parameter when l.bias | |
| # is None. I think there is some reason for doing it this way rather | |
| # than just setting it to None but I don't know what it is, maybe | |
| # something to do with exporting the module.. | |
| self.register_parameter("bias", l.bias) | |
| self.activation = activation | |
| self.dropout_p = dropout_p | |
| self.dropout_shared_dim = dropout_shared_dim | |
| def forward(self, x: Tensor): | |
| if ( | |
| torch.jit.is_scripting() | |
| or torch.jit.is_tracing() | |
| or "k2" not in sys.modules | |
| ): | |
| if self.activation == "SwooshL": | |
| x = SwooshLForward(x) | |
| elif self.activation == "SwooshR": | |
| x = SwooshRForward(x) | |
| else: | |
| assert False, self.activation | |
| return torch.nn.functional.linear(x, self.weight, self.bias) | |
| return ActivationDropoutAndLinearFunction.apply( | |
| x, | |
| self.weight, | |
| self.bias, | |
| self.activation, | |
| float(self.dropout_p), | |
| self.dropout_shared_dim, | |
| ) | |
| def _test_whiten(): | |
| for proportion in [0.1, 0.5, 10.0]: | |
| logging.info(f"_test_whiten(): proportion = {proportion}") | |
| x = torch.randn(100, 128) | |
| direction = torch.randn(128) | |
| coeffs = torch.randn(100, 1) | |
| x += proportion * direction * coeffs | |
| x.requires_grad = True | |
| m = Whiten( | |
| 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, | |
| ) # grad_scale | |
| for _ in range(4): | |
| y = m(x) | |
| y_grad = torch.randn_like(x) | |
| y.backward(gradient=y_grad) | |
| if proportion < 0.2: | |
| assert torch.allclose(x.grad, y_grad) | |
| elif proportion > 1.0: | |
| assert not torch.allclose(x.grad, y_grad) | |
| def _test_balancer_sign(): | |
| probs = torch.arange(0, 1, 0.01) | |
| N = 1000 | |
| x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) | |
| x = x.detach() | |
| x.requires_grad = True | |
| m = Balancer( | |
| probs.numel(), | |
| channel_dim=0, | |
| min_positive=0.05, | |
| max_positive=0.95, | |
| min_abs=0.0, | |
| prob=1.0, | |
| ) | |
| y_grad = torch.sign(torch.randn(probs.numel(), N)) | |
| y = m(x) | |
| y.backward(gradient=y_grad) | |
| print("_test_balancer_sign: x = ", x) | |
| print("_test_balancer_sign: y grad = ", y_grad) | |
| print("_test_balancer_sign: x grad = ", x.grad) | |
| def _test_balancer_magnitude(): | |
| magnitudes = torch.arange(0, 1, 0.01) | |
| N = 1000 | |
| x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) | |
| x = x.detach() | |
| x.requires_grad = True | |
| m = Balancer( | |
| magnitudes.numel(), | |
| channel_dim=0, | |
| min_positive=0.0, | |
| max_positive=1.0, | |
| min_abs=0.2, | |
| max_abs=0.7, | |
| prob=1.0, | |
| ) | |
| y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) | |
| y = m(x) | |
| y.backward(gradient=y_grad) | |
| print("_test_balancer_magnitude: x = ", x) | |
| print("_test_balancer_magnitude: y grad = ", y_grad) | |
| print("_test_balancer_magnitude: x grad = ", x.grad) | |
| def _test_swooshl_deriv(): | |
| x = torch.randn(10, 12, dtype=torch.double) * 3.0 | |
| x.requires_grad = True | |
| m = SwooshL() | |
| tol = 1.0 / 255.0 | |
| torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) | |
| # for self-test. | |
| x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 | |
| x.requires_grad = True | |
| y = m(x) | |
| return y | |
| def _test_swooshr_deriv(): | |
| x = torch.randn(10, 12, dtype=torch.double) * 3.0 | |
| x.requires_grad = True | |
| m = SwooshR() | |
| tol = 1.0 / 255.0 | |
| torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) | |
| # for self-test. | |
| x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 | |
| x.requires_grad = True | |
| y = m(x) | |
| return y | |
| def _test_softmax(): | |
| a = torch.randn(2, 10, dtype=torch.float64) | |
| b = a.clone() | |
| a.requires_grad = True | |
| b.requires_grad = True | |
| a.softmax(dim=1)[:, 0].sum().backward() | |
| print("a grad = ", a.grad) | |
| softmax(b, dim=1)[:, 0].sum().backward() | |
| print("b grad = ", b.grad) | |
| assert torch.allclose(a.grad, b.grad) | |
| def _test_piecewise_linear(): | |
| p = PiecewiseLinear((0, 10.0)) | |
| for x in [-100, 0, 100]: | |
| assert p(x) == 10.0 | |
| p = PiecewiseLinear((0, 10.0), (1, 0.0)) | |
| for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: | |
| print("x, y = ", x, y) | |
| assert p(x) == y, (x, p(x), y) | |
| q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) | |
| x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] | |
| pq = p.max(q) | |
| for x in x_vals: | |
| y1 = max(p(x), q(x)) | |
| y2 = pq(x) | |
| assert abs(y1 - y2) < 0.001 | |
| pq = p.min(q) | |
| for x in x_vals: | |
| y1 = min(p(x), q(x)) | |
| y2 = pq(x) | |
| assert abs(y1 - y2) < 0.001 | |
| pq = p + q | |
| for x in x_vals: | |
| y1 = p(x) + q(x) | |
| y2 = pq(x) | |
| assert abs(y1 - y2) < 0.001 | |
| def _test_activation_dropout_and_linear(): | |
| in_channels = 20 | |
| out_channels = 30 | |
| for bias in [True, False]: | |
| # actually we don't test for dropout_p != 0.0 because forward functions will | |
| # different answers. This is because we are using the k2 implementation of | |
| # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() | |
| # internally, messing up the random state. | |
| for dropout_p in [0.0]: | |
| for activation in ["SwooshL", "SwooshR"]: | |
| m1 = nn.Sequential( | |
| SwooshL() if activation == "SwooshL" else SwooshR(), | |
| Dropout3(p=dropout_p, shared_dim=-1), | |
| ScaledLinear( | |
| in_channels, out_channels, bias=bias, initial_scale=0.5 | |
| ), | |
| ) | |
| m2 = ActivationDropoutAndLinear( | |
| in_channels, | |
| out_channels, | |
| bias=bias, | |
| initial_scale=0.5, | |
| activation=activation, | |
| dropout_p=dropout_p, | |
| ) | |
| with torch.no_grad(): | |
| m2.weight[:] = m1[2].weight | |
| if bias: | |
| m2.bias[:] = m1[2].bias | |
| # make sure forward gives same result. | |
| x1 = torch.randn(10, in_channels) | |
| x1.requires_grad = True | |
| # TEMP. | |
| assert torch.allclose( | |
| SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 | |
| ) | |
| x2 = x1.clone().detach() | |
| x2.requires_grad = True | |
| seed = 10 | |
| torch.manual_seed(seed) | |
| y1 = m1(x1) | |
| y_grad = torch.randn_like(y1) | |
| y1.backward(gradient=y_grad) | |
| torch.manual_seed(seed) | |
| y2 = m2(x2) | |
| y2.backward(gradient=y_grad) | |
| print( | |
| f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" | |
| ) | |
| print("y1 = ", y1) | |
| print("y2 = ", y2) | |
| assert torch.allclose(y1, y2, atol=0.02) | |
| assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) | |
| if bias: | |
| assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) | |
| print("x1.grad = ", x1.grad) | |
| print("x2.grad = ", x2.grad) | |
| def isclose(a, b): | |
| # return true if cosine similarity is > 0.9. | |
| return (a * b).sum() > 0.9 * ( | |
| (a**2).sum() * (b**2).sum() | |
| ).sqrt() | |
| # the SwooshL() implementation has a noisy gradient due to 1-byte | |
| # storage of it. | |
| assert isclose(x1.grad, x2.grad) | |
| if __name__ == "__main__": | |
| logging.getLogger().setLevel(logging.DEBUG) | |
| torch.set_num_threads(1) | |
| torch.set_num_interop_threads(1) | |
| _test_piecewise_linear() | |
| _test_softmax() | |
| _test_whiten() | |
| _test_balancer_sign() | |
| _test_balancer_magnitude() | |
| _test_swooshr_deriv() | |
| _test_swooshl_deriv() | |
| _test_activation_dropout_and_linear() | |