Spaces:
Runtime error
Runtime error
| /** | |
| * Copyright 2017-present, Facebook, Inc. | |
| * All rights reserved. | |
| * | |
| * This source code is licensed under the license found in the | |
| * LICENSE file in the root directory of this source tree. | |
| */ | |
| template <typename scalar_t> | |
| __global__ void generate_deletion_label_kernel( | |
| const scalar_t* __restrict__ source, | |
| const size_t source_size, | |
| const size_t operation_size, | |
| int* __restrict__ operations, | |
| int* __restrict__ labels) { | |
| const int index = blockIdx.x; | |
| const int offset = index * operation_size; | |
| const int offset_label = index * source_size; | |
| for (int i = 0; i < source_size; i++) { | |
| labels[offset_label + i] = 0; | |
| } | |
| int k = 0; | |
| for (int i = 0; i < operation_size; i++) { | |
| if (operations[offset + i] == 0) { | |
| break; | |
| } else if (operations[offset + i] == 1) { | |
| continue; | |
| } else { | |
| labels[offset_label + k] = 3 - operations[offset + i]; | |
| k++; | |
| } | |
| } | |
| } | |
| template <typename scalar_t> | |
| __global__ void generate_insertion_label_kernel( | |
| const scalar_t* __restrict__ target, | |
| const size_t target_size, | |
| const size_t operation_size, | |
| int* __restrict__ operations, | |
| int* __restrict__ labels, | |
| int* __restrict__ masks) { | |
| const int index = blockIdx.x; | |
| const int offset = index * operation_size; | |
| const int offset_label = index * target_size; | |
| int k = 0; | |
| int u = 0; | |
| int m = 0; | |
| for (int i = 0; i < target_size; i++) { | |
| labels[offset_label + i] = 0; | |
| masks[offset_label + i] = 0; | |
| } | |
| for (int i = 0; i < operation_size - 1; i++) { | |
| if (operations[offset + i] == 0) { | |
| break; | |
| } else if (operations[offset + i] == 2) { | |
| continue; | |
| } else if (operations[offset + i] == 1) { | |
| masks[offset_label + m] = 1; | |
| u++; | |
| m++; | |
| } else { | |
| labels[offset_label + k] = u; | |
| masks[offset_label + m] = 0; | |
| k++; | |
| m++; | |
| u = 0; | |
| } | |
| } | |
| } | |
| template <typename scalar_t> | |
| __global__ void levenshtein_distance_kernel( | |
| const scalar_t* __restrict__ source, | |
| const scalar_t* __restrict__ target, | |
| const int* __restrict__ source_length, | |
| const int* __restrict__ target_length, | |
| const size_t source_size, | |
| const size_t target_size, | |
| int* __restrict__ operations, | |
| int* __restrict__ errors_curr) { | |
| const int index = blockIdx.x; | |
| const int offset = index * (source_size + target_size); | |
| const int d = index * (source_size + 1) * (target_size + 1); | |
| const int t = target_size + 1; | |
| auto err_idx = [d, t](int i, int j) { return d + i * t + j; }; | |
| auto opt_idx = [offset](int k) { return offset + k; }; | |
| const int hyp_len = source_length[index]; | |
| const int ref_len = target_length[index]; | |
| const scalar_t* hyp_begin = source + index * source_size; | |
| const scalar_t* ref_begin = target + index * target_size; | |
| // dynamic programming | |
| for (int i = 0; i <= hyp_len; i++) { | |
| errors_curr[err_idx(i, 0)] = i; | |
| } | |
| for (int j = 0; j <= ref_len; j++) { | |
| errors_curr[err_idx(0, j)] = j; | |
| } | |
| for (int i = 1; i <= hyp_len; i++) { | |
| for (int j = 1; j <= ref_len; j++) { | |
| errors_curr[err_idx(i, j)] = min( | |
| min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) + | |
| 1, | |
| errors_curr[err_idx(i - 1, j - 1)] + | |
| 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1)); | |
| } | |
| } | |
| // back-tracing | |
| int i = hyp_len; | |
| int j = ref_len; | |
| int o = hyp_len + ref_len; | |
| for (int k = 0; k < source_size + target_size; k++) { | |
| operations[opt_idx(k)] = 0; | |
| } | |
| while ((i >= 0) && (j >= 0)) { | |
| if ((i == 0) && (j == 0)) { | |
| break; | |
| } | |
| if ((j > 0) && | |
| (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) { | |
| o--; | |
| operations[opt_idx(o)] = 1; | |
| j--; // insertion | |
| } else if ( | |
| (i > 0) && | |
| (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) { | |
| o--; | |
| operations[opt_idx(o)] = 2; | |
| i--; // deletion | |
| } else { | |
| o--; | |
| operations[opt_idx(o)] = 3; | |
| i--; | |
| j--; // do nothing | |
| } | |
| } | |
| // moving to the left | |
| for (int k = 0; k < hyp_len + ref_len; k++) { | |
| if (k + o < hyp_len + ref_len) { | |
| operations[opt_idx(k)] = operations[opt_idx(k + o)]; | |
| } else { | |
| operations[opt_idx(k)] = 0; // padding | |
| } | |
| } | |
| } | |
| template <typename scalar_t> | |
| __global__ void faster_levenshtein_distance_kernel( | |
| const scalar_t* __restrict__ source, | |
| const scalar_t* __restrict__ target, | |
| const int* __restrict__ source_length, | |
| const int* __restrict__ target_length, | |
| const size_t source_size, | |
| const size_t target_size, | |
| int* __restrict__ operations) { | |
| extern __shared__ short errors[]; | |
| auto errors_curr = errors; | |
| const int index = blockIdx.x; | |
| const int offset = index * (source_size + target_size); | |
| const int t = target_size + 1; | |
| auto err_idx = [t](int i, int j) { return i * t + j; }; | |
| auto opt_idx = [offset](int k) { return offset + k; }; | |
| const int hyp_len = source_length[index]; | |
| const int ref_len = target_length[index]; | |
| const scalar_t* hyp_begin = source + index * source_size; | |
| const scalar_t* ref_begin = target + index * target_size; | |
| // dynamic programming | |
| for (int i = 0; i <= hyp_len; i++) { | |
| errors_curr[err_idx(i, 0)] = i; | |
| } | |
| for (int j = 0; j <= ref_len; j++) { | |
| errors_curr[err_idx(0, j)] = j; | |
| } | |
| for (int i = 1; i <= hyp_len; i++) { | |
| for (int j = 1; j <= ref_len; j++) { | |
| errors_curr[err_idx(i, j)] = min( | |
| min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) + | |
| 1, | |
| errors_curr[err_idx(i - 1, j - 1)] + | |
| 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1)); | |
| } | |
| } | |
| // back-tracing | |
| int i = hyp_len; | |
| int j = ref_len; | |
| int o = hyp_len + ref_len; | |
| for (int k = 0; k < source_size + target_size; k++) { | |
| operations[opt_idx(k)] = 0; | |
| } | |
| while ((i >= 0) && (j >= 0)) { | |
| if ((i == 0) && (j == 0)) { | |
| break; | |
| } | |
| if ((j > 0) && | |
| (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) { | |
| o--; | |
| operations[opt_idx(o)] = 1; | |
| j--; // insertion | |
| } else if ( | |
| (i > 0) && | |
| (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) { | |
| o--; | |
| operations[opt_idx(o)] = 2; | |
| i--; // deletion | |
| } else { | |
| o--; | |
| operations[opt_idx(o)] = 3; | |
| i--; | |
| j--; // do nothing | |
| } | |
| } | |
| // moving to the left | |
| for (int k = 0; k < hyp_len + ref_len; k++) { | |
| if (k + o < hyp_len + ref_len) { | |
| operations[opt_idx(k)] = operations[opt_idx(k + o)]; | |
| } else { | |
| operations[opt_idx(k)] = 0; // padding | |
| } | |
| } | |
| } | |
| torch::Tensor GenerateDeletionLabelCuda( | |
| torch::Tensor source, | |
| torch::Tensor operations) { | |
| const auto batch_size = source.size(0); | |
| at::TensorOptions options(source.device()); | |
| options = options.dtype(at::ScalarType::Int); | |
| auto labels = torch::empty({batch_size, source.size(1)}, options); | |
| auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); | |
| AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] { | |
| generate_deletion_label_kernel<scalar_t> | |
| <<<batch_size, 1, 0, stream>>>( | |
| source.data_ptr<scalar_t>(), | |
| source.size(1), | |
| operations.size(1), | |
| operations.data_ptr<int>(), | |
| labels.data_ptr<int>()); | |
| })); | |
| return labels; | |
| } | |
| std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda( | |
| torch::Tensor target, | |
| torch::Tensor operations) { | |
| const auto batch_size = target.size(0); | |
| at::TensorOptions options(target.device()); | |
| options = options.dtype(at::ScalarType::Int); | |
| auto labels = torch::empty({batch_size, target.size(1)}, options); | |
| auto masks = torch::empty({batch_size, target.size(1)}, options); | |
| auto stream = at::cuda::getCurrentCUDAStream(target.device().index()); | |
| AT_DISPATCH_ALL_TYPES( | |
| target.scalar_type(), "generate_insertion_labels", ([&] { | |
| generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>( | |
| target.data_ptr<scalar_t>(), | |
| target.size(1), | |
| operations.size(1), | |
| operations.data_ptr<int>(), | |
| labels.data_ptr<int>(), | |
| masks.data_ptr<int>()); | |
| })); | |
| return std::make_pair(labels, masks); | |
| } | |
| torch::Tensor LevenshteinDistanceCuda( | |
| torch::Tensor source, | |
| torch::Tensor target, | |
| torch::Tensor source_length, | |
| torch::Tensor target_length) { | |
| const auto batch_size = source.size(0); | |
| const auto shared_size = | |
| (source.size(1) + 1) * (target.size(1) + 1) * sizeof(short); | |
| at::TensorOptions options(source.device()); | |
| options = options.dtype(at::ScalarType::Int); | |
| auto operations = | |
| torch::empty({batch_size, source.size(1) + target.size(1)}, options); | |
| auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); | |
| if (shared_size > 40000) { | |
| auto distances = torch::empty( | |
| {batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options); | |
| AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] { | |
| levenshtein_distance_kernel<scalar_t> | |
| <<<batch_size, 1, 0, stream>>>( | |
| source.data_ptr<scalar_t>(), | |
| target.data_ptr<scalar_t>(), | |
| source_length.data_ptr<int>(), | |
| target_length.data_ptr<int>(), | |
| source.size(1), | |
| target.size(1), | |
| operations.data_ptr<int>(), | |
| distances.data_ptr<int>()); | |
| })); | |
| } else { | |
| AT_DISPATCH_ALL_TYPES( | |
| source.scalar_type(), "faster_levenshtein_distance", ([&] { | |
| faster_levenshtein_distance_kernel<scalar_t> | |
| <<<batch_size, 1, shared_size, stream>>>( | |
| source.data_ptr<scalar_t>(), | |
| target.data_ptr<scalar_t>(), | |
| source_length.data_ptr<int>(), | |
| target_length.data_ptr<int>(), | |
| source.size(1), | |
| target.size(1), | |
| operations.data_ptr<int>()); | |
| })); | |
| } | |
| return operations; | |
| } | |