Spaces:
Runtime error
Runtime error
| #include "common.cuh" | |
| #include "fattn-common.cuh" | |
| template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size | |
| #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) | |
| __launch_bounds__(D, 1) | |
| #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) | |
| static __global__ void flash_attn_vec_ext_f32( | |
| const char * __restrict__ Q, | |
| const char * __restrict__ K, | |
| const char * __restrict__ V, | |
| const char * __restrict__ mask, | |
| float * __restrict__ dst, | |
| float2 * __restrict__ dst_meta, | |
| const float scale, | |
| const float max_bias, | |
| const float m0, | |
| const float m1, | |
| const uint32_t n_head_log2, | |
| const float logit_softcap, | |
| const int ne00, | |
| const int ne01, | |
| const int ne02, | |
| const int ne03, | |
| const int ne10, | |
| const int ne11, | |
| const int ne12, | |
| const int ne13, | |
| const int ne31, | |
| const int nb31, | |
| const int nb01, | |
| const int nb02, | |
| const int nb03, | |
| const int nb11, | |
| const int nb12, | |
| const int nb13, | |
| const int nb21, | |
| const int nb22, | |
| const int nb23, | |
| const int ne0, | |
| const int ne1, | |
| const int ne2, | |
| const int ne3) { | |
| // Skip unused kernel variants for faster compilation: | |
| if (use_logit_softcap && !(D == 128 || D == 256)) { | |
| NO_DEVICE_CODE; | |
| return; | |
| } | |
| //In this kernel Q, K, V are matrices while i, j, k are matrix indices. | |
| constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K); | |
| constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; | |
| constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V); | |
| const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. | |
| const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. | |
| const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. | |
| Q += nb02* blockIdx.y + nb01*ic0; | |
| K += nb12*(blockIdx.y / gqa_ratio); | |
| V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape | |
| const half * maskh = (const half *) mask + ne11*ic0; | |
| const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); | |
| static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); | |
| constexpr int nwarps = D / WARP_SIZE; | |
| const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; | |
| __builtin_assume(tid < D); | |
| __shared__ float KQ[ncols*D]; | |
| #pragma unroll | |
| for (int j = 0; j < ncols; ++j) { | |
| KQ[j*D + tid] = -FLT_MAX/2.0f; | |
| } | |
| float kqmax[ncols]; | |
| #pragma unroll | |
| for (int j = 0; j < ncols; ++j) { | |
| kqmax[j] = -FLT_MAX/2.0f; | |
| } | |
| float kqsum[ncols] = {0.0f}; | |
| __shared__ float kqmax_shared[ncols][WARP_SIZE]; | |
| __shared__ float kqsum_shared[ncols][WARP_SIZE]; | |
| #pragma unroll | |
| for (int j = 0; j < ncols; ++j) { | |
| if (threadIdx.y == 0) { | |
| kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f; | |
| kqsum_shared[j][threadIdx.x] = 0.0f; | |
| } | |
| } | |
| __syncthreads(); | |
| // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: | |
| float2 Q_f2[ncols][D/(2*WARP_SIZE)]; | |
| int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)]; | |
| float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; | |
| if (Q_q8_1) { | |
| #pragma unroll | |
| for (int j0 = 0; j0 < ncols; j0 += nwarps) { | |
| const int j = j0 + threadIdx.y; | |
| if (j0 + nwarps > ncols && j >= ncols) { | |
| break; | |
| } | |
| // Reuse KQ as temporary storage for converting Q to q8_1: | |
| int * tmp_q_i32 = (int *) &KQ[j*D]; | |
| float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); | |
| // Set memory to zero if out of bounds: | |
| if (ncols > 2 && ic0 + j >= ne01) { | |
| #pragma unroll | |
| for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { | |
| const int i = i0 + threadIdx.x; | |
| tmp_q_i32[i] = 0; | |
| } | |
| if (threadIdx.x < D/QK8_1) { | |
| tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); | |
| } | |
| continue; | |
| } | |
| const float * Q_f = (const float *) (Q + j*nb01); | |
| #pragma unroll | |
| for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { | |
| quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); | |
| } | |
| } | |
| __syncthreads(); | |
| #pragma unroll | |
| for (int j = 0; j < ncols; ++j) { | |
| int * tmp_q_i32 = (int *) &KQ[j*D]; | |
| float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); | |
| #pragma unroll | |
| for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { | |
| const int i = i0 + threadIdx.x; | |
| Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; | |
| Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; | |
| } | |
| } | |
| __syncthreads(); | |
| } else { | |
| #pragma unroll | |
| for (int j = 0; j < ncols; ++j) { | |
| const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); | |
| #pragma unroll | |
| for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { | |
| const int i = i0 + threadIdx.x; | |
| Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); | |
| Q_f2[j][i0/WARP_SIZE].x *= scale; | |
| Q_f2[j][i0/WARP_SIZE].y *= scale; | |
| } | |
| } | |
| } | |
| float VKQ[ncols] = {0.0f}; | |
| const int k_start = parallel_blocks == 1 ? 0 : ip*D; | |
| for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { | |
| // Calculate KQ tile and keep track of new maximum KQ values: | |
| float kqmax_new_arr[ncols]; | |
| #pragma unroll | |
| for (int j = 0; j < ncols; ++j) { | |
| kqmax_new_arr[j] = kqmax[j]; | |
| } | |
| #pragma unroll | |
| for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { | |
| const int i_KQ = i_KQ_0 + threadIdx.y; | |
| if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { | |
| break; | |
| } | |
| #pragma unroll | |
| for (int j = 0; j < ncols; ++j) { | |
| float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); | |
| sum = warp_reduce_sum(sum); | |
| if (use_logit_softcap) { | |
| sum = logit_softcap*tanhf(sum); | |
| } | |
| sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; | |
| kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); | |
| if (threadIdx.x == 0) { | |
| KQ[j*D + i_KQ] = sum; | |
| } | |
| } | |
| } | |
| #pragma unroll | |
| for (int j = 0; j < ncols; ++j) { | |
| float kqmax_new_j = kqmax_new_arr[j]; | |
| kqmax_new_j = warp_reduce_max(kqmax_new_j); | |
| if (threadIdx.x == 0) { | |
| kqmax_shared[j][threadIdx.y] = kqmax_new_j; | |
| } | |
| } | |
| __syncthreads(); | |
| #pragma unroll | |
| for (int j = 0; j < ncols; ++j) { | |
| float kqmax_new_j = kqmax_shared[j][threadIdx.x]; | |
| kqmax_new_j = warp_reduce_max(kqmax_new_j); | |
| const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); | |
| kqmax[j] = kqmax_new_j; | |
| const float val = expf(KQ[j*D + tid] - kqmax[j]); | |
| kqsum[j] = kqsum[j]*KQ_max_scale + val; | |
| KQ[j*D + tid] = val; | |
| VKQ[j] *= KQ_max_scale; | |
| } | |
| __syncthreads(); | |
| #pragma unroll | |
| for (int k = 0; k < D; ++k) { | |
| if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) { | |
| break; | |
| } | |
| const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); | |
| #pragma unroll | |
| for (int j = 0; j < ncols; ++j) { | |
| VKQ[j] += V_ki*KQ[j*D + k]; | |
| } | |
| } | |
| __syncthreads(); | |
| } | |
| #pragma unroll | |
| for (int j = 0; j < ncols; ++j) { | |
| kqsum[j] = warp_reduce_sum(kqsum[j]); | |
| if (threadIdx.x == 0) { | |
| kqsum_shared[j][threadIdx.y] = kqsum[j]; | |
| } | |
| } | |
| __syncthreads(); | |
| #pragma unroll | |
| for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { | |
| if (ncols > 2 && ic0 + j_VKQ >= ne01) { | |
| break; | |
| } | |
| kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; | |
| kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); | |
| float dst_val = VKQ[j_VKQ]; | |
| if (parallel_blocks == 1) { | |
| dst_val /= kqsum[j_VKQ]; | |
| } | |
| const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; | |
| dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; | |
| } | |
| if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { | |
| dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); | |
| } | |
| } | |
| template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> | |
| void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
| constexpr int nwarps = D/WARP_SIZE; | |
| fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>; | |
| constexpr bool need_f16_K = D != 128; | |
| constexpr bool need_f16_V = D != 128 && D != 64; | |
| launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); | |
| } | |
| template <int D, ggml_type type_K, ggml_type type_V> | |
| void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
| const ggml_tensor * KQV = dst; | |
| const ggml_tensor * Q = dst->src[0]; | |
| const ggml_tensor * K = dst->src[1]; | |
| const ggml_tensor * V = dst->src[2]; | |
| GGML_ASSERT(K->type == type_K); | |
| GGML_ASSERT(V->type == type_V); | |
| float logit_softcap; | |
| memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); | |
| if (Q->ne[1] == 1) { | |
| constexpr int cols_per_block = 1; | |
| constexpr int parallel_blocks = 4; | |
| if (logit_softcap == 0.0f) { | |
| constexpr bool use_logit_softcap = false; | |
| ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst); | |
| } else { | |
| constexpr bool use_logit_softcap = true; | |
| ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst); | |
| } | |
| return; | |
| } | |
| if (Q->ne[1] == 2) { | |
| constexpr int cols_per_block = 2; | |
| constexpr int parallel_blocks = 4; | |
| if (logit_softcap == 0.0f) { | |
| constexpr bool use_logit_softcap = false; | |
| ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst); | |
| } else { | |
| constexpr bool use_logit_softcap = true; | |
| ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst); | |
| } | |
| return; | |
| } | |
| if (Q->ne[1] <= 4) { | |
| constexpr int cols_per_block = 4; | |
| constexpr int parallel_blocks = 4; | |
| if (logit_softcap == 0.0f) { | |
| constexpr bool use_logit_softcap = false; | |
| ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst); | |
| } else { | |
| constexpr bool use_logit_softcap = true; | |
| ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst); | |
| } | |
| return; | |
| } | |
| if (Q->ne[1] <= 8) { | |
| constexpr int cols_per_block = 8; | |
| constexpr int parallel_blocks = 4; | |
| if (logit_softcap == 0.0f) { | |
| constexpr bool use_logit_softcap = false; | |
| ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst); | |
| } else { | |
| constexpr bool use_logit_softcap = true; | |
| ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst); | |
| } | |
| return; | |
| } | |
| constexpr int cols_per_block = 8; | |
| constexpr int parallel_blocks = 1; | |
| if (logit_softcap == 0.0f) { | |
| constexpr bool use_logit_softcap = false; | |
| ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst); | |
| } else { | |
| constexpr bool use_logit_softcap = true; | |
| ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst); | |
| } | |
| } | |
| #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ | |
| template void ggml_cuda_flash_attn_ext_vec_f32_case \ | |
| <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ | |
| extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); | |
| extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); | |
| extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); | |
| extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); | |
| extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); | |
| extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); | |
| extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); | |
| extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); | |