Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| import torch | |
| class ScalarBias(torch.autograd.Function): | |
| """ | |
| Adds a vector of scalars, used in self-attention mechanism to allow | |
| the model to optionally attend to this vector instead of the past | |
| """ | |
| def forward(ctx, input, dim, bias_init): | |
| size = list(input.size()) | |
| size[dim] += 1 | |
| output = input.new(*size).fill_(bias_init) | |
| output.narrow(dim, 1, size[dim] - 1).copy_(input) | |
| ctx.dim = dim | |
| return output | |
| def backward(ctx, grad): | |
| return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None | |
| def scalar_bias(input, dim, bias_init=0): | |
| return ScalarBias.apply(input, dim, bias_init) | |