medmekk HF Staff commited on
Commit
2b92228
·
verified ·
1 Parent(s): f6342f5

Delete transpose_kernel.h

Browse files
Files changed (1) hide show
  1. transpose_kernel.h +0 -120
transpose_kernel.h DELETED
@@ -1,120 +0,0 @@
1
- // Implementation of transpose kernel.
2
- #pragma once
3
-
4
- #include <hip/amd_detail/amd_hip_runtime.h>
5
- #include <hip/amd_detail/amd_warp_functions.h>
6
- #include "../include/gpu_libs.h"
7
- #include "../include/gpu_types.h"
8
- #include "../src/utils/arithmetic.h"
9
- #include "../include/clangd_workaround.h"
10
-
11
- DEVICE_CODE_BELOW
12
-
13
- namespace transpose_kernel {
14
-
15
-
16
-
17
- template <typename Elem, int M, int N, int TILE_DIM, int BLOCK_SIZE, int VEC_SIZE>
18
- __launch_bounds__(BLOCK_SIZE)
19
- __global__ void transpose_kernel(Elem *odata, const Elem *idata) {
20
- constexpr auto TBLOCK_X = TILE_DIM / VEC_SIZE;
21
- constexpr auto TBLOCK_Y = BLOCK_SIZE / TBLOCK_X;
22
-
23
- // avoid read bank conflict
24
- // VEC_SIZE * (TILE_DIM + d) * sizeof(Elem) = TBLOCK_Y / (BLOCK_SIZE / WARP_SIZE) * sizeof(Elem) + 128k
25
- // each warp read row = TILE_DIM (in VEC_SIZE reads), col = TBLOCK_Y / (BLOCK_SIZE / WARP_SIZE)
26
- // warp 0 warp 1
27
- // t0 t16 t32 t48 ...
28
- // ...
29
- // t1
30
- // ...
31
- // t15
32
- // don't know why padding to d as described above is not working, maybe gpu could merge contigious ds_read_u8 and
33
- // cause padding to be TBLOCK_Y / (BLOCK_SIZE / WARP_SIZE)
34
- constexpr auto PADDING = TBLOCK_Y / (BLOCK_SIZE / warpSize);
35
- __shared__ Elem tile[TILE_DIM][TILE_DIM + PADDING];
36
-
37
- int x = blockIdx.x * TILE_DIM + threadIdx.x * VEC_SIZE;
38
- int y = blockIdx.y * TILE_DIM + threadIdx.y;
39
-
40
- // Load tile
41
- #pragma unroll
42
- for (int i = 0; i < TILE_DIM; i += TBLOCK_Y) {
43
- #pragma unroll
44
- for (int v = 0; v < VEC_SIZE; v++) {
45
- tile[threadIdx.y + i][threadIdx.x * VEC_SIZE + v] = idata[(y + i) * N + x + v];
46
- }
47
- }
48
-
49
- __syncthreads();
50
-
51
- // Transpose indices
52
- x = blockIdx.y * TILE_DIM + threadIdx.x * VEC_SIZE;
53
- y = blockIdx.x * TILE_DIM + threadIdx.y;
54
-
55
- // Write tile
56
- #pragma unroll
57
- for (int i = 0; i < TILE_DIM; i += TBLOCK_Y) {
58
- #pragma unroll
59
- for (int v = 0; v < VEC_SIZE; v++) {
60
- odata[(y + i) * M + x + v] = tile[threadIdx.x * VEC_SIZE + v][threadIdx.y + i];
61
- }
62
- }
63
- }
64
-
65
- template <typename Elem, int M, int N, int TILE_DIM, int BLOCK_SIZE, int VEC_SIZE>
66
- void launch_transpose(Elem *out, const Elem *in, hipStream_t stream = 0) {
67
- static_assert(TILE_DIM % VEC_SIZE == 0);
68
- constexpr auto TBLOCK_X = TILE_DIM / VEC_SIZE;
69
- static_assert(BLOCK_SIZE % TBLOCK_X == 0);
70
- constexpr auto TBLOCK_Y = BLOCK_SIZE / TBLOCK_X;
71
- static_assert(M % TILE_DIM == 0 && N % TILE_DIM == 0);
72
- hipLaunchKernelGGL(
73
- HIP_KERNEL_NAME(transpose_kernel<Elem, M, N, TILE_DIM, BLOCK_SIZE, VEC_SIZE>),
74
- dim3(N / TILE_DIM, M / TILE_DIM), dim3(TBLOCK_X, TBLOCK_Y), 0, stream,
75
- out, in);
76
- }
77
-
78
- #define DISPATCH_TRANSPOSE(DIM_0, DIM_1, TILE_DIM, BLOCK_SIZE, VEC_SIZE) else if constexpr(IN_DIM_0 == DIM_0 && IN_DIM_1 == DIM_1) launch_transpose<__FP8_TYPE, IN_DIM_0, IN_DIM_1, TILE_DIM, BLOCK_SIZE, VEC_SIZE>(out, in, stream)
79
-
80
- template <int DIM0, int DIM1>
81
- struct unsupported_config {
82
- static_assert(DIM0 == -1, "Unsupported transpose configuration - check template parameters");
83
- };
84
-
85
- // Selecte best parameters for tranpose kernel.
86
- template <int IN_DIM_0, int IN_DIM_1>
87
- void transpose_fp8(__FP8_TYPE *out, const __FP8_TYPE *in, hipStream_t stream = 0) {
88
- if constexpr (false /* dummy*/ ) static_assert(true);
89
- DISPATCH_TRANSPOSE( 256, 1024, 64, 256, 4); // Optimized: 2.71 µs (193.46 GB/s)
90
- DISPATCH_TRANSPOSE( 256, 6144, 64, 256, 4); // Optimized: 2.72 µs (1157.37 GB/s)
91
- DISPATCH_TRANSPOSE( 256, 7168, 64, 256, 8); // Optimized: 2.99 µs (1225.38 GB/s)
92
- DISPATCH_TRANSPOSE( 512, 1024, 64, 512, 4); // Optimized: 2.55 µs (411.21 GB/s)
93
- DISPATCH_TRANSPOSE( 512, 4096, 64, 256, 4); // Optimized: 3.01 µs (1394.85 GB/s)
94
- DISPATCH_TRANSPOSE( 512, 6144, 64, 512, 4); // Optimized: 3.58 µs (1755.43 GB/s)
95
- DISPATCH_TRANSPOSE( 1536, 1024, 64, 1024, 4); // Optimized: 2.78 µs (1130.74 GB/s)
96
- DISPATCH_TRANSPOSE( 1536, 3072, 64, 512, 4); // Optimized: 3.57 µs (2641.99 GB/s)
97
- DISPATCH_TRANSPOSE( 1536, 6144, 128, 1024, 8); // Optimized: 7.09 µs (2661.36 GB/s)
98
- DISPATCH_TRANSPOSE( 2048, 1024, 64, 1024, 4); // Optimized: 2.84 µs (1477.91 GB/s)
99
- DISPATCH_TRANSPOSE( 2048, 6144, 128, 512, 8); // Optimized: 8.94 µs (2816.23 GB/s)
100
- DISPATCH_TRANSPOSE( 2048, 7168, 128, 512, 8); // Optimized: 9.56 µs (3070.50 GB/s)
101
- DISPATCH_TRANSPOSE( 2304, 1024, 64, 1024, 4); // Optimized: 3.08 µs (1532.51 GB/s)
102
- DISPATCH_TRANSPOSE( 2304, 6144, 128, 512, 8); // Optimized: 9.30 µs (3043.93 GB/s)
103
- DISPATCH_TRANSPOSE( 2304, 7168, 128, 512, 8); // Optimized: 10.39 µs (3179.95 GB/s)
104
- DISPATCH_TRANSPOSE( 7168, 512, 64, 512, 4); // Optimized: 3.25 µs (2257.78 GB/s)
105
- DISPATCH_TRANSPOSE( 7168, 576, 64, 512, 4); // Optimized: 3.44 µs (2403.24 GB/s)
106
- DISPATCH_TRANSPOSE( 7168, 1024, 64, 256, 4); // Optimized: 5.07 µs (2892.62 GB/s)
107
- DISPATCH_TRANSPOSE( 7168, 1536, 128, 1024, 8); // Optimized: 7.72 µs (2851.97 GB/s)
108
- DISPATCH_TRANSPOSE( 7168, 4608, 128, 512, 8); // Optimized: 16.87 µs (3915.84 GB/s)
109
- DISPATCH_TRANSPOSE( 7168, 6144, 128, 256, 8); // Optimized: 21.59 µs (4079.12 GB/s)
110
- else static_assert(false);
111
- }
112
-
113
- } // namespace transpose_kernel
114
-
115
-
116
-
117
-
118
- #ifndef PARAMETERIZE_LIBRARY
119
- int main() {}
120
- #endif