Spaces:
Runtime error
Runtime error
| """ Activations (memory-efficient w/ custom autograd) | |
| A collection of activations fn and modules with a common interface so that they can | |
| easily be swapped. All have an `inplace` arg even if not used. | |
| These activations are not compatible with jit scripting or ONNX export of the model, please use either | |
| the JIT or basic versions of the activations. | |
| Hacked together by / Copyright 2020 Ross Wightman | |
| """ | |
| import torch | |
| from torch import nn as nn | |
| from torch.nn import functional as F | |
| def swish_jit_fwd(x): | |
| return x.mul(torch.sigmoid(x)) | |
| def swish_jit_bwd(x, grad_output): | |
| x_sigmoid = torch.sigmoid(x) | |
| return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) | |
| class SwishJitAutoFn(torch.autograd.Function): | |
| """ torch.jit.script optimised Swish w/ memory-efficient checkpoint | |
| Inspired by conversation btw Jeremy Howard & Adam Pazske | |
| https://twitter.com/jeremyphoward/status/1188251041835315200 | |
| """ | |
| def symbolic(g, x): | |
| return g.op("Mul", x, g.op("Sigmoid", x)) | |
| def forward(ctx, x): | |
| ctx.save_for_backward(x) | |
| return swish_jit_fwd(x) | |
| def backward(ctx, grad_output): | |
| x = ctx.saved_tensors[0] | |
| return swish_jit_bwd(x, grad_output) | |
| def swish_me(x, inplace=False): | |
| return SwishJitAutoFn.apply(x) | |
| class SwishMe(nn.Module): | |
| def __init__(self, inplace: bool = False): | |
| super(SwishMe, self).__init__() | |
| def forward(self, x): | |
| return SwishJitAutoFn.apply(x) | |
| def mish_jit_fwd(x): | |
| return x.mul(torch.tanh(F.softplus(x))) | |
| def mish_jit_bwd(x, grad_output): | |
| x_sigmoid = torch.sigmoid(x) | |
| x_tanh_sp = F.softplus(x).tanh() | |
| return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) | |
| class MishJitAutoFn(torch.autograd.Function): | |
| """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 | |
| A memory efficient, jit scripted variant of Mish | |
| """ | |
| def forward(ctx, x): | |
| ctx.save_for_backward(x) | |
| return mish_jit_fwd(x) | |
| def backward(ctx, grad_output): | |
| x = ctx.saved_tensors[0] | |
| return mish_jit_bwd(x, grad_output) | |
| def mish_me(x, inplace=False): | |
| return MishJitAutoFn.apply(x) | |
| class MishMe(nn.Module): | |
| def __init__(self, inplace: bool = False): | |
| super(MishMe, self).__init__() | |
| def forward(self, x): | |
| return MishJitAutoFn.apply(x) | |
| def hard_sigmoid_jit_fwd(x, inplace: bool = False): | |
| return (x + 3).clamp(min=0, max=6).div(6.) | |
| def hard_sigmoid_jit_bwd(x, grad_output): | |
| m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. | |
| return grad_output * m | |
| class HardSigmoidJitAutoFn(torch.autograd.Function): | |
| def forward(ctx, x): | |
| ctx.save_for_backward(x) | |
| return hard_sigmoid_jit_fwd(x) | |
| def backward(ctx, grad_output): | |
| x = ctx.saved_tensors[0] | |
| return hard_sigmoid_jit_bwd(x, grad_output) | |
| def hard_sigmoid_me(x, inplace: bool = False): | |
| return HardSigmoidJitAutoFn.apply(x) | |
| class HardSigmoidMe(nn.Module): | |
| def __init__(self, inplace: bool = False): | |
| super(HardSigmoidMe, self).__init__() | |
| def forward(self, x): | |
| return HardSigmoidJitAutoFn.apply(x) | |
| def hard_swish_jit_fwd(x): | |
| return x * (x + 3).clamp(min=0, max=6).div(6.) | |
| def hard_swish_jit_bwd(x, grad_output): | |
| m = torch.ones_like(x) * (x >= 3.) | |
| m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) | |
| return grad_output * m | |
| class HardSwishJitAutoFn(torch.autograd.Function): | |
| """A memory efficient, jit-scripted HardSwish activation""" | |
| def forward(ctx, x): | |
| ctx.save_for_backward(x) | |
| return hard_swish_jit_fwd(x) | |
| def backward(ctx, grad_output): | |
| x = ctx.saved_tensors[0] | |
| return hard_swish_jit_bwd(x, grad_output) | |
| def symbolic(g, self): | |
| input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) | |
| hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) | |
| hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) | |
| return g.op("Mul", self, hardtanh_) | |
| def hard_swish_me(x, inplace=False): | |
| return HardSwishJitAutoFn.apply(x) | |
| class HardSwishMe(nn.Module): | |
| def __init__(self, inplace: bool = False): | |
| super(HardSwishMe, self).__init__() | |
| def forward(self, x): | |
| return HardSwishJitAutoFn.apply(x) | |
| def hard_mish_jit_fwd(x): | |
| return 0.5 * x * (x + 2).clamp(min=0, max=2) | |
| def hard_mish_jit_bwd(x, grad_output): | |
| m = torch.ones_like(x) * (x >= -2.) | |
| m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) | |
| return grad_output * m | |
| class HardMishJitAutoFn(torch.autograd.Function): | |
| """ A memory efficient, jit scripted variant of Hard Mish | |
| Experimental, based on notes by Mish author Diganta Misra at | |
| https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md | |
| """ | |
| def forward(ctx, x): | |
| ctx.save_for_backward(x) | |
| return hard_mish_jit_fwd(x) | |
| def backward(ctx, grad_output): | |
| x = ctx.saved_tensors[0] | |
| return hard_mish_jit_bwd(x, grad_output) | |
| def hard_mish_me(x, inplace: bool = False): | |
| return HardMishJitAutoFn.apply(x) | |
| class HardMishMe(nn.Module): | |
| def __init__(self, inplace: bool = False): | |
| super(HardMishMe, self).__init__() | |
| def forward(self, x): | |
| return HardMishJitAutoFn.apply(x) | |