Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import lightconv_cuda | |
| import torch | |
| import torch.nn.functional as F | |
| from fairseq import utils | |
| from fairseq.incremental_decoding_utils import with_incremental_state | |
| from fairseq.modules.fairseq_dropout import FairseqDropout | |
| from torch import nn | |
| from torch.autograd import Function | |
| class lightconvFunction(Function): | |
| def forward(ctx, x, weights, padding_l): | |
| ctx.padding_l = padding_l | |
| outputs = lightconv_cuda.forward(x, weights, padding_l) | |
| variables = [x, weights] | |
| ctx.save_for_backward(*variables) | |
| return outputs[0] | |
| def backward(ctx, grad_output): | |
| outputs = lightconv_cuda.backward( | |
| grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors | |
| ) | |
| grad_input, grad_weights = outputs | |
| return grad_input, grad_weights, None | |
| class LightconvLayer(nn.Module): | |
| def __init__( | |
| self, | |
| input_size, | |
| kernel_size=1, | |
| padding_l=None, | |
| weight_softmax=False, | |
| num_heads=1, | |
| weight_dropout=0.0, | |
| bias=False, | |
| ): | |
| super(LightconvLayer, self).__init__() | |
| self.input_size = input_size | |
| self.kernel_size = kernel_size | |
| self.padding_l = padding_l | |
| self.num_heads = num_heads | |
| self.weight_softmax = weight_softmax | |
| self.weight_dropout_module = FairseqDropout( | |
| weight_dropout, module_name=self.__class__.__name__ | |
| ) | |
| self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size)) | |
| if bias: | |
| self.bias = nn.Parameter(torch.Tensor(input_size)) | |
| else: | |
| self.bias = None | |
| self.reset_parameters() | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| prefix = name + "." if name != "" else "" | |
| for k, v in state_dict.items(): | |
| if k.endswith(prefix + "weight"): | |
| if v.dim() == 3 and v.size(1) == 1: | |
| state_dict[k] = v.squeeze(1) | |
| def reset_parameters(self): | |
| nn.init.xavier_uniform_(self.weight) | |
| if self.bias is not None: | |
| nn.init.constant_(self.bias, 0.0) | |
| def forward(self, x, incremental_state=None): | |
| # during inference time, incremental BMM is faster | |
| if incremental_state is not None: | |
| T, B, C = x.size() | |
| K, H = self.kernel_size, self.num_heads | |
| R = C // H | |
| input_buffer = self._get_input_buffer(incremental_state) | |
| if input_buffer is None: | |
| input_buffer = x.new() | |
| x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) | |
| if self.kernel_size > 1: | |
| self._set_input_buffer( | |
| incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] | |
| ) | |
| x_unfold = x_unfold.view(T * B * H, R, -1) | |
| weight = self.weight | |
| if self.weight_softmax: | |
| weight = F.softmax(weight.float(), dim=1).type_as(weight) | |
| weight = weight[:, -x_unfold.size(2) :] | |
| K = weight.size(1) | |
| weight = ( | |
| weight.view(1, H, K) | |
| .expand(T * B, H, K) | |
| .contiguous() | |
| .view(T * B * H, K, 1) | |
| ) | |
| weight = self.weight_dropout_module(weight) | |
| output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 | |
| output = output.view(T, B, C) | |
| return output | |
| # during training time, use CUDA kernel | |
| else: | |
| x = x.permute(1, 2, 0).contiguous() | |
| weight = self.weight | |
| if self.weight_softmax: | |
| weight = F.softmax(self.weight, -1) | |
| if self.weight_dropout_module.p: | |
| weight = self.weight_dropout_module(weight) | |
| return lightconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) | |
| def reorder_incremental_state(self, incremental_state, new_order): | |
| input_buffer = self._get_input_buffer(incremental_state) | |
| if input_buffer is not None: | |
| input_buffer = input_buffer.index_select(1, new_order) | |
| self._set_input_buffer(incremental_state, input_buffer) | |
| def _get_input_buffer(self, incremental_state): | |
| return utils.get_incremental_state(self, incremental_state, "input_buffer") | |
| def _set_input_buffer(self, incremental_state, new_buffer): | |
| return utils.set_incremental_state( | |
| self, incremental_state, "input_buffer", new_buffer | |
| ) | |
| def half(self): | |
| return self._apply(lambda t: t.half() if t.is_floating_point() else t) | |