Delete transpose_kernel.h
Browse files- transpose_kernel.h +0 -120
transpose_kernel.h
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
// Implementation of transpose kernel.
|
| 2 |
-
#pragma once
|
| 3 |
-
|
| 4 |
-
#include <hip/amd_detail/amd_hip_runtime.h>
|
| 5 |
-
#include <hip/amd_detail/amd_warp_functions.h>
|
| 6 |
-
#include "../include/gpu_libs.h"
|
| 7 |
-
#include "../include/gpu_types.h"
|
| 8 |
-
#include "../src/utils/arithmetic.h"
|
| 9 |
-
#include "../include/clangd_workaround.h"
|
| 10 |
-
|
| 11 |
-
DEVICE_CODE_BELOW
|
| 12 |
-
|
| 13 |
-
namespace transpose_kernel {
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
template <typename Elem, int M, int N, int TILE_DIM, int BLOCK_SIZE, int VEC_SIZE>
|
| 18 |
-
__launch_bounds__(BLOCK_SIZE)
|
| 19 |
-
__global__ void transpose_kernel(Elem *odata, const Elem *idata) {
|
| 20 |
-
constexpr auto TBLOCK_X = TILE_DIM / VEC_SIZE;
|
| 21 |
-
constexpr auto TBLOCK_Y = BLOCK_SIZE / TBLOCK_X;
|
| 22 |
-
|
| 23 |
-
// avoid read bank conflict
|
| 24 |
-
// VEC_SIZE * (TILE_DIM + d) * sizeof(Elem) = TBLOCK_Y / (BLOCK_SIZE / WARP_SIZE) * sizeof(Elem) + 128k
|
| 25 |
-
// each warp read row = TILE_DIM (in VEC_SIZE reads), col = TBLOCK_Y / (BLOCK_SIZE / WARP_SIZE)
|
| 26 |
-
// warp 0 warp 1
|
| 27 |
-
// t0 t16 t32 t48 ...
|
| 28 |
-
// ...
|
| 29 |
-
// t1
|
| 30 |
-
// ...
|
| 31 |
-
// t15
|
| 32 |
-
// don't know why padding to d as described above is not working, maybe gpu could merge contigious ds_read_u8 and
|
| 33 |
-
// cause padding to be TBLOCK_Y / (BLOCK_SIZE / WARP_SIZE)
|
| 34 |
-
constexpr auto PADDING = TBLOCK_Y / (BLOCK_SIZE / warpSize);
|
| 35 |
-
__shared__ Elem tile[TILE_DIM][TILE_DIM + PADDING];
|
| 36 |
-
|
| 37 |
-
int x = blockIdx.x * TILE_DIM + threadIdx.x * VEC_SIZE;
|
| 38 |
-
int y = blockIdx.y * TILE_DIM + threadIdx.y;
|
| 39 |
-
|
| 40 |
-
// Load tile
|
| 41 |
-
#pragma unroll
|
| 42 |
-
for (int i = 0; i < TILE_DIM; i += TBLOCK_Y) {
|
| 43 |
-
#pragma unroll
|
| 44 |
-
for (int v = 0; v < VEC_SIZE; v++) {
|
| 45 |
-
tile[threadIdx.y + i][threadIdx.x * VEC_SIZE + v] = idata[(y + i) * N + x + v];
|
| 46 |
-
}
|
| 47 |
-
}
|
| 48 |
-
|
| 49 |
-
__syncthreads();
|
| 50 |
-
|
| 51 |
-
// Transpose indices
|
| 52 |
-
x = blockIdx.y * TILE_DIM + threadIdx.x * VEC_SIZE;
|
| 53 |
-
y = blockIdx.x * TILE_DIM + threadIdx.y;
|
| 54 |
-
|
| 55 |
-
// Write tile
|
| 56 |
-
#pragma unroll
|
| 57 |
-
for (int i = 0; i < TILE_DIM; i += TBLOCK_Y) {
|
| 58 |
-
#pragma unroll
|
| 59 |
-
for (int v = 0; v < VEC_SIZE; v++) {
|
| 60 |
-
odata[(y + i) * M + x + v] = tile[threadIdx.x * VEC_SIZE + v][threadIdx.y + i];
|
| 61 |
-
}
|
| 62 |
-
}
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
template <typename Elem, int M, int N, int TILE_DIM, int BLOCK_SIZE, int VEC_SIZE>
|
| 66 |
-
void launch_transpose(Elem *out, const Elem *in, hipStream_t stream = 0) {
|
| 67 |
-
static_assert(TILE_DIM % VEC_SIZE == 0);
|
| 68 |
-
constexpr auto TBLOCK_X = TILE_DIM / VEC_SIZE;
|
| 69 |
-
static_assert(BLOCK_SIZE % TBLOCK_X == 0);
|
| 70 |
-
constexpr auto TBLOCK_Y = BLOCK_SIZE / TBLOCK_X;
|
| 71 |
-
static_assert(M % TILE_DIM == 0 && N % TILE_DIM == 0);
|
| 72 |
-
hipLaunchKernelGGL(
|
| 73 |
-
HIP_KERNEL_NAME(transpose_kernel<Elem, M, N, TILE_DIM, BLOCK_SIZE, VEC_SIZE>),
|
| 74 |
-
dim3(N / TILE_DIM, M / TILE_DIM), dim3(TBLOCK_X, TBLOCK_Y), 0, stream,
|
| 75 |
-
out, in);
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
#define DISPATCH_TRANSPOSE(DIM_0, DIM_1, TILE_DIM, BLOCK_SIZE, VEC_SIZE) else if constexpr(IN_DIM_0 == DIM_0 && IN_DIM_1 == DIM_1) launch_transpose<__FP8_TYPE, IN_DIM_0, IN_DIM_1, TILE_DIM, BLOCK_SIZE, VEC_SIZE>(out, in, stream)
|
| 79 |
-
|
| 80 |
-
template <int DIM0, int DIM1>
|
| 81 |
-
struct unsupported_config {
|
| 82 |
-
static_assert(DIM0 == -1, "Unsupported transpose configuration - check template parameters");
|
| 83 |
-
};
|
| 84 |
-
|
| 85 |
-
// Selecte best parameters for tranpose kernel.
|
| 86 |
-
template <int IN_DIM_0, int IN_DIM_1>
|
| 87 |
-
void transpose_fp8(__FP8_TYPE *out, const __FP8_TYPE *in, hipStream_t stream = 0) {
|
| 88 |
-
if constexpr (false /* dummy*/ ) static_assert(true);
|
| 89 |
-
DISPATCH_TRANSPOSE( 256, 1024, 64, 256, 4); // Optimized: 2.71 µs (193.46 GB/s)
|
| 90 |
-
DISPATCH_TRANSPOSE( 256, 6144, 64, 256, 4); // Optimized: 2.72 µs (1157.37 GB/s)
|
| 91 |
-
DISPATCH_TRANSPOSE( 256, 7168, 64, 256, 8); // Optimized: 2.99 µs (1225.38 GB/s)
|
| 92 |
-
DISPATCH_TRANSPOSE( 512, 1024, 64, 512, 4); // Optimized: 2.55 µs (411.21 GB/s)
|
| 93 |
-
DISPATCH_TRANSPOSE( 512, 4096, 64, 256, 4); // Optimized: 3.01 µs (1394.85 GB/s)
|
| 94 |
-
DISPATCH_TRANSPOSE( 512, 6144, 64, 512, 4); // Optimized: 3.58 µs (1755.43 GB/s)
|
| 95 |
-
DISPATCH_TRANSPOSE( 1536, 1024, 64, 1024, 4); // Optimized: 2.78 µs (1130.74 GB/s)
|
| 96 |
-
DISPATCH_TRANSPOSE( 1536, 3072, 64, 512, 4); // Optimized: 3.57 µs (2641.99 GB/s)
|
| 97 |
-
DISPATCH_TRANSPOSE( 1536, 6144, 128, 1024, 8); // Optimized: 7.09 µs (2661.36 GB/s)
|
| 98 |
-
DISPATCH_TRANSPOSE( 2048, 1024, 64, 1024, 4); // Optimized: 2.84 µs (1477.91 GB/s)
|
| 99 |
-
DISPATCH_TRANSPOSE( 2048, 6144, 128, 512, 8); // Optimized: 8.94 µs (2816.23 GB/s)
|
| 100 |
-
DISPATCH_TRANSPOSE( 2048, 7168, 128, 512, 8); // Optimized: 9.56 µs (3070.50 GB/s)
|
| 101 |
-
DISPATCH_TRANSPOSE( 2304, 1024, 64, 1024, 4); // Optimized: 3.08 µs (1532.51 GB/s)
|
| 102 |
-
DISPATCH_TRANSPOSE( 2304, 6144, 128, 512, 8); // Optimized: 9.30 µs (3043.93 GB/s)
|
| 103 |
-
DISPATCH_TRANSPOSE( 2304, 7168, 128, 512, 8); // Optimized: 10.39 µs (3179.95 GB/s)
|
| 104 |
-
DISPATCH_TRANSPOSE( 7168, 512, 64, 512, 4); // Optimized: 3.25 µs (2257.78 GB/s)
|
| 105 |
-
DISPATCH_TRANSPOSE( 7168, 576, 64, 512, 4); // Optimized: 3.44 µs (2403.24 GB/s)
|
| 106 |
-
DISPATCH_TRANSPOSE( 7168, 1024, 64, 256, 4); // Optimized: 5.07 µs (2892.62 GB/s)
|
| 107 |
-
DISPATCH_TRANSPOSE( 7168, 1536, 128, 1024, 8); // Optimized: 7.72 µs (2851.97 GB/s)
|
| 108 |
-
DISPATCH_TRANSPOSE( 7168, 4608, 128, 512, 8); // Optimized: 16.87 µs (3915.84 GB/s)
|
| 109 |
-
DISPATCH_TRANSPOSE( 7168, 6144, 128, 256, 8); // Optimized: 21.59 µs (4079.12 GB/s)
|
| 110 |
-
else static_assert(false);
|
| 111 |
-
}
|
| 112 |
-
|
| 113 |
-
} // namespace transpose_kernel
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
#ifndef PARAMETERIZE_LIBRARY
|
| 119 |
-
int main() {}
|
| 120 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|