Spaces:
Runtime error
Runtime error
| /** | |
| * Copyright (c) Facebook, Inc. and its affiliates. | |
| * | |
| * This source code is licensed under the MIT license found in the | |
| * LICENSE file in the root directory of this source tree. | |
| */ | |
| template <int FS, int SB, int padding_l, typename scalar_t> | |
| __global__ void lightconv_forward_kernel( | |
| const scalar_t* input, | |
| const scalar_t* filters, | |
| int minibatch, | |
| int sequenceLength, | |
| int numFeatures, | |
| int numFiltersInBlock, | |
| scalar_t* output); | |
| template <int FS, int SB, int padding_l, typename scalar_t> | |
| __global__ void lightconv_grad_wrt_input_kernel( | |
| const scalar_t* input, | |
| const scalar_t* filters, | |
| int minibatch, | |
| int sequenceLength, | |
| int numFeatures, | |
| int numFiltersInBlock, | |
| scalar_t* output); | |
| template <int FS, int SB, int padding_l, typename scalar_t> | |
| __global__ void lightconv_grad_wrt_weights_firstpass_short_kernel( | |
| const scalar_t* input, | |
| const scalar_t* gradInput, | |
| int minibatch, | |
| int sequenceLength, | |
| int numFeatures, | |
| int numFiltersInBlock, | |
| int numHeads, | |
| float* output); | |
| template <int FS, int SB, typename scalar_t> | |
| __global__ void lightconv_grad_wrt_weights_secondpass_short_kernel( | |
| const float* input, | |
| const int minibatch, | |
| const int numFiltersInBlock, | |
| scalar_t* output); | |
| template <int FS, int SB, int padding_l, typename scalar_t> | |
| __global__ void lightconv_grad_wrt_weights_firstpass_kernel( | |
| const scalar_t* input, | |
| const scalar_t* gradInput, | |
| int minibatch, | |
| int sequenceLength, | |
| int numFeatures, | |
| int numFiltersInBlock, | |
| float* output); | |
| template <int FS, int SB, typename scalar_t> | |
| __global__ void lightconv_grad_wrt_weights_secondpass_kernel( | |
| const float* input, | |
| const int minibatch, | |
| const int numFiltersInBlock, | |
| scalar_t* output); | |