HIP: RDNA3 mma FA, faster AMD transpose, tune AMD (llama/22880)
Adds RDNA3 support to the CUDA mma FA kernel. To make the RDNA3 tensor cores work with the FP16 accumulation for VKQ the tiles they need to be 32 logical units long in direction of the attention head; for head sizes 80 and 112 that are not exactly divided by 32 the regular length of 16 with FP32 accumulation is used instead. The longer tiles also enable more efficient transposition for a warp size of 32 which is why it's also used for RDNA4. However, this scrambles the data layout of the accumulators along the attention head dimension. To prevent accidental misuse I added another entry to ggml_cuda_mma::data_layout. I also tuned the kernel parameters for RDNA3, RDNA4, and CDNA1 in general, during which I discovered that the kernel can be made to work for head sizes up to 256 for CDNA. For RDNA3/4 I was not able to get better performance that the tile kernel for head sizes > 128.
This commit is contained in:
parent
13133ab299
commit
e62d5893f4
|
|
@ -125,61 +125,107 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
|||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 64, 2, 32, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 64, 2, 32, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 64, 2, 32, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 64, 2, 32, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 64, 2, 32, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 64, 2, 32, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 1, true);
|
||||
|
||||
// TODO tune specifically for RDNA
|
||||
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 64, 2, 32, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 64, 2, 32, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 2, 32, 96, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 2, 32, 96, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 64, 96, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 64, 96, 64, 64, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 2, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 2, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 32, 160, 128, 128, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 128, 3, 64, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 128, 3, 64, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 128, 2, 32, 128, 128, 128, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 128, 3, 64, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 128, 3, 64, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 128, 2, 32, 160, 128, 128, 1, true);
|
||||
|
||||
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
|
||||
// Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 1, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 4, 64, 32, 32, 32, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 256, 1, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 256, 1, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 256, 1, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 512, 1, 64, 64, 64, 64, 1, true);
|
||||
|
||||
// Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy
|
||||
// compile-time static_asserts even though the kernel guard prevents runtime execution.
|
||||
// nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.
|
||||
return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 256, 1, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 256, 1, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 256, 1, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 512, 1, 64, 128, 128, 64, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 256, 1, 64, 160, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 64, 160, 128, 128, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 256, 1, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 256, 1, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 256, 1, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 64, 128, 128, 128, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 256, 1, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 256, 1, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 256, 1, 64, 160, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 64, 160, 128, 128, 1, true);
|
||||
|
||||
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
|
||||
}
|
||||
|
||||
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
||||
|
|
@ -510,7 +556,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
const int jt,
|
||||
const int kb0,
|
||||
const int k_VKQ_sup) {
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
|
|
@ -712,6 +758,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
#pragma unroll
|
||||
for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {
|
||||
const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;
|
||||
|
||||
// The mask is stored as 16 bit half values, loading them as 32 bit half2 values is preferred in terms of speed.
|
||||
// However, this is not possible for RDNA3 where 2 consecutive l indices are not consecutive in the mask memory layout.
|
||||
#ifdef RDNA3
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
const int i = i0 + T_C_KQ::get_j(l);
|
||||
const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l)) / ncols2;
|
||||
|
||||
KQ_C[i00/(np*T_C_KQ::J)].x[l] += __half2float(tile_mask[j*(nbatch_fa + 8) + i]);
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {
|
||||
const int i = (i0 + T_C_KQ::get_j(l0)) / 2;
|
||||
|
|
@ -721,6 +779,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;
|
||||
KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;
|
||||
}
|
||||
#endif // RDNA3
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -827,13 +886,23 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
const half2 KQ_max_scale_h2 = make_half2(
|
||||
KQ_max_scale[0], KQ_max_scale[0]);
|
||||
if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) {
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
||||
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
||||
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
||||
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type");
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV/T_C_VKQ::J; ++i) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
||||
VKQ_C[i].x[l] *= KQ_max_scale[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else // Volta
|
||||
|
|
@ -901,9 +970,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
|
||||
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += T_A_VKQ::I) {
|
||||
static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size");
|
||||
#pragma unroll
|
||||
for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {
|
||||
|
|
@ -912,15 +980,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
|
||||
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
if constexpr (T_B_KQ::I == 8) {
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
||||
mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]);
|
||||
} else {
|
||||
// Wide version of VKQ_C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
// AMD matrix C is column-major.
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
||||
mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
|
||||
mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], B[k00/(np*T_A_VKQ::J)], A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
|
|
@ -953,11 +1021,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
tile_Q, tile_K, tile_V, tile_mask,
|
||||
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
template<int ncols> struct mma_tile_sizes {
|
||||
template<int DV, int ncols> struct mma_tile_sizes {
|
||||
using T_A_KQ = tile<16, 8, half2>; // row-major
|
||||
using T_B_KQ = tile<16, 8, half2>; // column-major
|
||||
using T_C_KQ = tile<16, 16, float>; // column-major
|
||||
|
|
@ -965,7 +1033,7 @@ template<int ncols> struct mma_tile_sizes {
|
|||
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
||||
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
||||
};
|
||||
template<> struct mma_tile_sizes<8> {
|
||||
template<int DV> struct mma_tile_sizes<DV, 8> {
|
||||
using T_A_KQ = tile<16, 8, half2>; // row-major
|
||||
using T_B_KQ = tile< 8, 8, half2>; // column-major
|
||||
using T_C_KQ = tile<16, 8, float>; // row-major
|
||||
|
|
@ -973,8 +1041,60 @@ template<> struct mma_tile_sizes<8> {
|
|||
using T_B_VKQ = tile< 8, 8, half2>; // column-major
|
||||
using T_C_VKQ = tile<16, 4, half2>; // row-major
|
||||
};
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
template<int ncols> struct mma_tile_sizes {
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#ifdef RDNA3
|
||||
template<int DV, int ncols> struct mma_tile_sizes {
|
||||
using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
||||
using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
||||
using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
||||
using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
||||
using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
||||
using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
||||
};
|
||||
template<int ncols> struct mma_tile_sizes<80, ncols> {
|
||||
using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
||||
using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
||||
using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
||||
using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
||||
using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
||||
using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
||||
};
|
||||
template<int ncols> struct mma_tile_sizes<112, ncols> {
|
||||
using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
||||
using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
||||
using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
||||
using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
||||
using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
||||
using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
||||
};
|
||||
#else
|
||||
template<int DV, int ncols> struct mma_tile_sizes {
|
||||
using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major
|
||||
using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
||||
using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
||||
using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major
|
||||
using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
||||
using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED>; // column-major
|
||||
};
|
||||
template<int ncols> struct mma_tile_sizes<80, ncols> {
|
||||
using T_A_KQ = tile<16, 8, half2>; // row-major
|
||||
using T_B_KQ = tile<16, 8, half2>; // column-major
|
||||
using T_C_KQ = tile<16, 16, float>; // column-major
|
||||
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
||||
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
||||
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
||||
};
|
||||
template<int ncols> struct mma_tile_sizes<112, ncols> {
|
||||
using T_A_KQ = tile<16, 8, half2>; // row-major
|
||||
using T_B_KQ = tile<16, 8, half2>; // column-major
|
||||
using T_C_KQ = tile<16, 16, float>; // column-major
|
||||
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
||||
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
||||
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
||||
};
|
||||
#endif // RDNA3
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
template<int DV, int ncols> struct mma_tile_sizes {
|
||||
using T_A_KQ = tile<16, 8, half2>; // row-major
|
||||
using T_B_KQ = tile<16, 8, half2>; // column-major
|
||||
using T_C_KQ = tile<16, 16, float>; // column-major
|
||||
|
|
@ -983,7 +1103,7 @@ template<int ncols> struct mma_tile_sizes {
|
|||
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
||||
};
|
||||
#else // Volta
|
||||
template<int ncols> struct mma_tile_sizes {
|
||||
template<int DV, int ncols> struct mma_tile_sizes {
|
||||
using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
||||
using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
||||
using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
||||
|
|
@ -1018,17 +1138,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
const int zt_gqa,
|
||||
const int kb0_start,
|
||||
const int kb0_stop) {
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
|
||||
using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
|
||||
using T_C_KQ = typename mma_tile_sizes<ncols>::T_C_KQ;
|
||||
using T_A_VKQ = typename mma_tile_sizes<ncols>::T_A_VKQ;
|
||||
using T_B_VKQ = typename mma_tile_sizes<ncols>::T_B_VKQ;
|
||||
using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
|
||||
using T_A_KQ = typename mma_tile_sizes<DV, ncols>::T_A_KQ;
|
||||
using T_B_KQ = typename mma_tile_sizes<DV, ncols>::T_B_KQ;
|
||||
using T_C_KQ = typename mma_tile_sizes<DV, ncols>::T_C_KQ;
|
||||
using T_A_VKQ = typename mma_tile_sizes<DV, ncols>::T_A_VKQ;
|
||||
using T_B_VKQ = typename mma_tile_sizes<DV, ncols>::T_B_VKQ;
|
||||
using T_C_VKQ = typename mma_tile_sizes<DV, ncols>::T_C_VKQ;
|
||||
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = get_cols_per_thread();
|
||||
|
|
@ -1061,6 +1181,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
|
||||
#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3)
|
||||
T_C_VKQ VKQ_C[DV % 32 != 0 ? DV/T_C_VKQ::J : DV/(2*T_C_VKQ::J)];
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
||||
#else // Volta
|
||||
|
|
@ -1269,12 +1391,23 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
|
||||
if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) {
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
||||
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
||||
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
||||
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type");
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV/T_C_VKQ::J; ++i) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
||||
VKQ_C[i].x[l] *= KQ_max_scale[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else // Volta
|
||||
|
|
@ -1433,6 +1566,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
#pragma unroll
|
||||
for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
|
||||
if constexpr (cols_per_warp == 8) {
|
||||
static_assert(std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>, "bad VKQ type");
|
||||
const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {
|
||||
|
|
@ -1447,14 +1581,45 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
}
|
||||
} else {
|
||||
const int j0 = threadIdx.y*cols_per_warp;
|
||||
if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) {
|
||||
if constexpr (T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR) {
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
|
||||
for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
||||
const int j = j0 + T_C_VKQ::get_i(l);
|
||||
const int k = k1 + T_C_VKQ::get_j(l);
|
||||
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
||||
const int j = j0 + T_C_VKQ::get_i(l);
|
||||
const int k = k1 + T_C_VKQ::get_j(l);
|
||||
|
||||
tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
|
||||
tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR_SCRAMBLED, "bad T_C_VKQ data layout");
|
||||
using T_C_VKQ_us = tile<T_C_VKQ::I, T_C_VKQ::J, half2, DATA_LAYOUT_I_MAJOR>; // us == unscrambled
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
|
||||
const T_C_VKQ_us VKQ_C_us = unscramble(VKQ_C[(k00 + k1)/T_C_VKQ::J]);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_VKQ_us::ne; ++l) {
|
||||
const int j = j0 + T_C_VKQ_us::get_i(l);
|
||||
const int k = k1 + T_C_VKQ_us::get_j(l);
|
||||
|
||||
tile_Q[j*tile_stride + k] = VKQ_C_us.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type");
|
||||
half * tile_Q_h = (half *) tile_Q;
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J/2) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
||||
const int j = j0 + T_C_VKQ::get_i(l);
|
||||
const int k = 2*k1 + T_C_VKQ::get_j(l);
|
||||
|
||||
tile_Q_h[j*(2*tile_stride) + k] = VKQ_C[(k00 + k1)/(T_C_VKQ::J/2)].x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1532,7 +1697,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
|
||||
jt, kb0_start, kb0_stop);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
|
||||
|
|
@ -1559,7 +1724,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE))
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) {
|
||||
|
|
@ -1585,14 +1750,14 @@ static __global__ void flash_attn_ext_f16(
|
|||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
|
||||
if (ncols1*ncols2 < 16 || ncols2 == 1 || DKQ > 128) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {
|
||||
if (ncols1*ncols2 < 16 || DKQ > 256) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
|
@ -1715,7 +1880,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE))
|
||||
}
|
||||
|
||||
template <int DKQ, int DV, int ncols1, int ncols2>
|
||||
|
|
|
|||
|
|
@ -19,13 +19,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
|||
}
|
||||
|
||||
if constexpr (ncols2 <= 16) {
|
||||
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
|
||||
if (Q->ne[1] <= 16/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
|
||||
if (Q->ne[1] <= 32/ncols2 || (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) ||
|
||||
(GGML_CUDA_CC_IS_AMD(cc) && DKQ > 256)) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
|
@ -477,12 +478,13 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
const int ncols2_max = Q->ne[0] == 320 ? 32 : ((Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8);
|
||||
int gqa_ratio_eff = 1;
|
||||
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
|
||||
gqa_ratio_eff *= 2;
|
||||
}
|
||||
|
||||
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
int gqa_ratio_eff = 1;
|
||||
const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8;
|
||||
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
|
||||
gqa_ratio_eff *= 2;
|
||||
}
|
||||
if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
|
|
@ -500,41 +502,22 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
return BEST_FATTN_KERNEL_WMMA_F16;
|
||||
}
|
||||
|
||||
if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (Q->ne[1] == 1) {
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
}
|
||||
int gqa_ratio_eff = 1;
|
||||
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
|
||||
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
|
||||
gqa_ratio_eff *= 2;
|
||||
}
|
||||
if (Q->ne[1] * gqa_ratio_eff <= 8) {
|
||||
return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.
|
||||
}
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
// Use MFMA flash attention for CDNA (MI100+):
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
|
||||
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
|
||||
// MMA vs tile crossover benchmarked on MI300X @ d32768:
|
||||
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)
|
||||
// hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%)
|
||||
if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) {
|
||||
// AMD MFMA needs a certain minimum batch size to outscale the tile kernel for large head sizes.
|
||||
if ((amd_mfma_available(cc) && Q->ne[0] <= 256) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
if ((Q->ne[0] <= 64 && Q->ne[1] * gqa_ratio_eff > 8)) {
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
// Fall through to tile kernel for small effective batch sizes.
|
||||
if ((Q->ne[0] <= 128 && Q->ne[1] * gqa_ratio_eff > 16)) {
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
if ((Q->ne[0] <= 256 && Q->ne[1] * gqa_ratio_eff > 64)) {
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
}
|
||||
|
||||
// AMD WMMA is always faster than the tile kernel if the full tile width of 16 can be utilized.
|
||||
if ((amd_wmma_available(cc) && gqa_opt_applies && Q->ne[0] <= 128) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[1] * gqa_ratio_eff > 8) {
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
// If there are no tensor cores available, use the generic tile kernel:
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ namespace ggml_cuda_mma {
|
|||
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
|
||||
DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
|
||||
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
|
||||
DATA_LAYOUT_I_MAJOR_SCRAMBLED = 40, // Scrambled matrix C for faster transposition (RDNA4/CDNA), convert to float to unscramble.
|
||||
};
|
||||
// Implemented mma combinations are:
|
||||
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
||||
|
|
@ -312,13 +313,19 @@ namespace ggml_cuda_mma {
|
|||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 16 && J == 16) return true;
|
||||
if (I == 32 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return threadIdx.x % 16;
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return (threadIdx.x % 16) * 2 + l / (ne/2);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
|
|
@ -327,7 +334,15 @@ namespace ggml_cuda_mma {
|
|||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
return (threadIdx.x / 16) * ne + l;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
#ifdef RDNA3
|
||||
return l*2 + (threadIdx.x / 16);
|
||||
#else
|
||||
return (threadIdx.x / 16) * ne + l;
|
||||
#endif // RDNA3
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return (threadIdx.x / 16) * (ne/2) + l % (ne/2);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
|
|
@ -338,13 +353,19 @@ namespace ggml_cuda_mma {
|
|||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 16 && J == 16) return true;
|
||||
if (I == 32 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return threadIdx.x % 16;
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return (threadIdx.x % 16) * 2 + l / (ne/2);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
|
|
@ -353,7 +374,11 @@ namespace ggml_cuda_mma {
|
|||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
return (threadIdx.x / 16) * ne + l;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return (threadIdx.x / 16) * ne + l;
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return (threadIdx.x / 16) * (ne/2) + l % (ne/2);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
|
|
@ -516,12 +541,15 @@ namespace ggml_cuda_mma {
|
|||
if (I == 16 && J == 16) return true;
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 16 && J == 4) return true;
|
||||
if (I == 32 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int /*l*/) {
|
||||
if constexpr (supported()) {
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16) {
|
||||
return threadIdx.x % 16;
|
||||
} else if constexpr (I == 32) {
|
||||
return (threadIdx.x % 16) * 2 + l / (ne/2);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
|
|
@ -529,8 +557,10 @@ namespace ggml_cuda_mma {
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (supported()) {
|
||||
if constexpr (I == 16) {
|
||||
return l;
|
||||
} else if constexpr (I == 32) {
|
||||
return l % (ne/2);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
|
|
@ -644,6 +674,40 @@ namespace ggml_cuda_mma {
|
|||
}
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_SCRAMBLED;
|
||||
|
||||
static constexpr int ne = I * J / ggml_cuda_get_physical_warp_size();
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 16) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
||||
}
|
||||
};
|
||||
|
||||
static __device__ __forceinline__ tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> unscramble(const tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & t) {
|
||||
#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> ret;
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < t.ne/2; ++l0) {
|
||||
ret.x[2*l0 + 0] = __lows2half2(t.x[l0], t.x[l0 + t.ne/2]);
|
||||
ret.x[2*l0 + 1] = __highs2half2(t.x[l0], t.x[l0 + t.ne/2]);
|
||||
}
|
||||
return ret;
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
GGML_UNUSED(t);
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
}
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
template <int I, int J>
|
||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||
|
|
@ -660,6 +724,21 @@ namespace ggml_cuda_mma {
|
|||
ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
|
||||
ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
|
||||
|
||||
return ret;
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3)
|
||||
static __device__ __forceinline__ tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> get_half2(
|
||||
const tile<16, 16, float, DATA_LAYOUT_I_MAJOR> & tile_float) {
|
||||
tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> ret;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_float.ne; ++l) {
|
||||
float tmp[2];
|
||||
int i = threadIdx.x / 16;
|
||||
tmp[i] = tile_float.x[l];
|
||||
i ^= 1;
|
||||
tmp[i] = __shfl_xor_sync(0xFFFFFFFF, tile_float.x[l], 16, WARP_SIZE);
|
||||
ret.x[l] = make_half2(tmp[0], tmp[1]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
|
|
@ -802,21 +881,35 @@ namespace ggml_cuda_mma {
|
|||
#endif // defined(VOLTA_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <int I, typename T, data_layout dl>
|
||||
static __device__ __forceinline__ void load_ldmatrix_trans(
|
||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
tile<I, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#ifdef TURING_MMA_AVAILABLE
|
||||
static_assert(I == 16, "bad tile width");
|
||||
static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout");
|
||||
int * xi = (int *) t.x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
|
||||
: "l"(xs));
|
||||
#elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
half * xh = (half *) t.x;
|
||||
static_assert(dl == DATA_LAYOUT_I_MAJOR || dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout");
|
||||
if constexpr (I == 32) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)];
|
||||
xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)];
|
||||
for (int l0 = 0; l0 < t.ne/2; ++l0) {
|
||||
const half2 tmp0 = xs0[(2*t.get_j(l0) + 0)*stride + t.get_i(l0)/2];
|
||||
const half2 tmp1 = xs0[(2*t.get_j(l0) + 1)*stride + t.get_i(l0)/2];
|
||||
|
||||
t.x[l0] = __lows2half2(tmp0, tmp1);
|
||||
t.x[l0 + t.ne/2] = __highs2half2(tmp0, tmp1);
|
||||
}
|
||||
} else {
|
||||
half * xh = (half *) t.x;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)];
|
||||
xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)];
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(t, xs0, stride);
|
||||
|
|
@ -972,6 +1065,20 @@ namespace ggml_cuda_mma {
|
|||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR> & A,
|
||||
const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR> & B) {
|
||||
#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
tile<16, 8, half2> * D16 = (tile<16, 8, half2> *) &D;
|
||||
const tile<16, 8, half2> * A16 = (const tile<16, 8, half2> *) &A;
|
||||
mma(D16[0], A16[0], B);
|
||||
mma(D16[1], A16[1], B);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)
|
||||
}
|
||||
|
||||
template <data_layout dl_ab, data_layout dl_d>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
|
||||
|
|
@ -1296,6 +1403,22 @@ namespace ggml_cuda_mma {
|
|||
#endif // defined(VOLTA_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & A,
|
||||
const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
|
||||
#if defined(AMD_WMMA_AVAILABLE) && defined(RDNA3)
|
||||
using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
|
||||
halfx16_t * xD = (halfx16_t *) D.x;
|
||||
const halfx16_t * xA = (const halfx16_t *) A.x;
|
||||
const halfx16_t * xB = (const halfx16_t *) B.x;
|
||||
xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[0], xB[0], xD[0], /*opsel =*/ 0);
|
||||
xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[1], xB[0], xD[0], /*opsel =*/ 1);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <data_layout dl_d, data_layout dl_ab>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue