Spaces:
Runtime error
Runtime error
| /** | |
| * Copyright (c) Facebook, Inc. and its affiliates. | |
| * | |
| * This source code is licensed under the MIT license found in the | |
| * LICENSE file in the root directory of this source tree. | |
| */ | |
| std::vector<at::Tensor> | |
| lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l); | |
| std::vector<at::Tensor> lightconv_cuda_backward( | |
| at::Tensor gradOutput, | |
| int padding_l, | |
| at::Tensor input, | |
| at::Tensor filters); | |
| std::vector<at::Tensor> | |
| lightconv_forward(at::Tensor input, at::Tensor filters, int padding_l) { | |
| CHECK_INPUT(input); | |
| CHECK_INPUT(filters); | |
| return lightconv_cuda_forward(input, filters, padding_l); | |
| } | |
| std::vector<at::Tensor> lightconv_backward( | |
| at::Tensor gradOutput, | |
| int padding_l, | |
| at::Tensor input, | |
| at::Tensor filters) { | |
| CHECK_INPUT(gradOutput); | |
| CHECK_INPUT(input); | |
| CHECK_INPUT(filters); | |
| return lightconv_cuda_backward(gradOutput, padding_l, input, filters); | |
| } | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); | |
| m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); | |
| } | |