cuda : fix KQ mask offset integer overflow in fattn MMA kernel (llama/23610)

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
fairydreaming 2026-05-28 10:55:42 +02:00 committed by Georgi Gerganov
parent 8e40325876
commit 60e420ff6a
1 changed files with 4 additions and 4 deletions

View File

@ -472,7 +472,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
const int i = 8 * (threadIdx.x % (nbatch_fa/8));
cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + int64_t(j_vram)*stride_mask + i);
}
} else if constexpr (oob_check) {
#pragma unroll
@ -488,7 +488,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
const int i = i0 + threadIdx.x;
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[int64_t(j_vram)*stride_mask + i] : half(0.0f);
}
}
} else if constexpr (nbatch_fa < 2*warp_size) {
@ -505,7 +505,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
const int i = threadIdx.x % (warp_size/cols_per_warp);
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + int64_t(j_vram)*stride_mask + 2*i);
}
} else {
#pragma unroll
@ -521,7 +521,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
const int i = i0 + 2*threadIdx.x;
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + int64_t(j_vram)*stride_mask + i);
}
}
}