|
|
|
|
|
#pragma once |
|
|
|
|
|
#include <hip/amd_detail/amd_hip_runtime.h> |
|
|
#include <hip/amd_detail/amd_warp_functions.h> |
|
|
#include "../include/gpu_libs.h" |
|
|
#include "../include/gpu_types.h" |
|
|
#include "../src/utils/arithmetic.h" |
|
|
#include "../include/clangd_workaround.h" |
|
|
|
|
|
DEVICE_CODE_BELOW |
|
|
|
|
|
namespace transpose_kernel { |
|
|
|
|
|
|
|
|
|
|
|
template <typename Elem, int M, int N, int TILE_DIM, int BLOCK_SIZE, int VEC_SIZE> |
|
|
__launch_bounds__(BLOCK_SIZE) |
|
|
__global__ void transpose_kernel(Elem *odata, const Elem *idata) { |
|
|
constexpr auto TBLOCK_X = TILE_DIM / VEC_SIZE; |
|
|
constexpr auto TBLOCK_Y = BLOCK_SIZE / TBLOCK_X; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
constexpr auto PADDING = TBLOCK_Y / (BLOCK_SIZE / warpSize); |
|
|
__shared__ Elem tile[TILE_DIM][TILE_DIM + PADDING]; |
|
|
|
|
|
int x = blockIdx.x * TILE_DIM + threadIdx.x * VEC_SIZE; |
|
|
int y = blockIdx.y * TILE_DIM + threadIdx.y; |
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
for (int i = 0; i < TILE_DIM; i += TBLOCK_Y) { |
|
|
#pragma unroll |
|
|
for (int v = 0; v < VEC_SIZE; v++) { |
|
|
tile[threadIdx.y + i][threadIdx.x * VEC_SIZE + v] = idata[(y + i) * N + x + v]; |
|
|
} |
|
|
} |
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
x = blockIdx.y * TILE_DIM + threadIdx.x * VEC_SIZE; |
|
|
y = blockIdx.x * TILE_DIM + threadIdx.y; |
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
for (int i = 0; i < TILE_DIM; i += TBLOCK_Y) { |
|
|
#pragma unroll |
|
|
for (int v = 0; v < VEC_SIZE; v++) { |
|
|
odata[(y + i) * M + x + v] = tile[threadIdx.x * VEC_SIZE + v][threadIdx.y + i]; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
template <typename Elem, int M, int N, int TILE_DIM, int BLOCK_SIZE, int VEC_SIZE> |
|
|
void launch_transpose(Elem *out, const Elem *in, hipStream_t stream = 0) { |
|
|
static_assert(TILE_DIM % VEC_SIZE == 0); |
|
|
constexpr auto TBLOCK_X = TILE_DIM / VEC_SIZE; |
|
|
static_assert(BLOCK_SIZE % TBLOCK_X == 0); |
|
|
constexpr auto TBLOCK_Y = BLOCK_SIZE / TBLOCK_X; |
|
|
static_assert(M % TILE_DIM == 0 && N % TILE_DIM == 0); |
|
|
hipLaunchKernelGGL( |
|
|
HIP_KERNEL_NAME(transpose_kernel<Elem, M, N, TILE_DIM, BLOCK_SIZE, VEC_SIZE>), |
|
|
dim3(N / TILE_DIM, M / TILE_DIM), dim3(TBLOCK_X, TBLOCK_Y), 0, stream, |
|
|
out, in); |
|
|
} |
|
|
|
|
|
#define DISPATCH_TRANSPOSE(DIM_0, DIM_1, TILE_DIM, BLOCK_SIZE, VEC_SIZE) else if constexpr(IN_DIM_0 == DIM_0 && IN_DIM_1 == DIM_1) launch_transpose<__FP8_TYPE, IN_DIM_0, IN_DIM_1, TILE_DIM, BLOCK_SIZE, VEC_SIZE>(out, in, stream) |
|
|
|
|
|
template <int DIM0, int DIM1> |
|
|
struct unsupported_config { |
|
|
static_assert(DIM0 == -1, "Unsupported transpose configuration - check template parameters"); |
|
|
}; |
|
|
|
|
|
|
|
|
template <int IN_DIM_0, int IN_DIM_1> |
|
|
void transpose_fp8(__FP8_TYPE *out, const __FP8_TYPE *in, hipStream_t stream = 0) { |
|
|
if constexpr (false ) static_assert(true); |
|
|
DISPATCH_TRANSPOSE( 256, 1024, 64, 256, 4); |
|
|
DISPATCH_TRANSPOSE( 256, 6144, 64, 256, 4); |
|
|
DISPATCH_TRANSPOSE( 256, 7168, 64, 256, 8); |
|
|
DISPATCH_TRANSPOSE( 512, 1024, 64, 512, 4); |
|
|
DISPATCH_TRANSPOSE( 512, 4096, 64, 256, 4); |
|
|
DISPATCH_TRANSPOSE( 512, 6144, 64, 512, 4); |
|
|
DISPATCH_TRANSPOSE( 1536, 1024, 64, 1024, 4); |
|
|
DISPATCH_TRANSPOSE( 1536, 3072, 64, 512, 4); |
|
|
DISPATCH_TRANSPOSE( 1536, 6144, 128, 1024, 8); |
|
|
DISPATCH_TRANSPOSE( 2048, 1024, 64, 1024, 4); |
|
|
DISPATCH_TRANSPOSE( 2048, 6144, 128, 512, 8); |
|
|
DISPATCH_TRANSPOSE( 2048, 7168, 128, 512, 8); |
|
|
DISPATCH_TRANSPOSE( 2304, 1024, 64, 1024, 4); |
|
|
DISPATCH_TRANSPOSE( 2304, 6144, 128, 512, 8); |
|
|
DISPATCH_TRANSPOSE( 2304, 7168, 128, 512, 8); |
|
|
DISPATCH_TRANSPOSE( 7168, 512, 64, 512, 4); |
|
|
DISPATCH_TRANSPOSE( 7168, 576, 64, 512, 4); |
|
|
DISPATCH_TRANSPOSE( 7168, 1024, 64, 256, 4); |
|
|
DISPATCH_TRANSPOSE( 7168, 1536, 128, 1024, 8); |
|
|
DISPATCH_TRANSPOSE( 7168, 4608, 128, 512, 8); |
|
|
DISPATCH_TRANSPOSE( 7168, 6144, 128, 256, 8); |
|
|
else static_assert(false); |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef PARAMETERIZE_LIBRARY |
|
|
int main() {} |
|
|
#endif |