Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.autograd import Function | |
| from torch.cuda.amp import custom_bwd, custom_fwd | |
| class _trunc_exp(Function): | |
| def forward(ctx, x): | |
| ctx.save_for_backward(x) | |
| return torch.exp(x) | |
| def backward(ctx, g): | |
| x = ctx.saved_tensors[0] | |
| return g * torch.exp(x.clamp(-15, 15)) | |
| trunc_exp = _trunc_exp.apply |