From 60e420ff6ac28ae5bc5af42b4a77bc98dca760e6 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Thu, 28 May 2026 10:55:42 +0200 Subject: [PATCH] cuda : fix KQ mask offset integer overflow in fattn MMA kernel (llama/23610) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Stanisław Szymczyk --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 4871b90df..3c8b6eaaf 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -472,7 +472,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const int i = 8 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i); + cp_async_cg_16(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + int64_t(j_vram)*stride_mask + i); } } else if constexpr (oob_check) { #pragma unroll @@ -488,7 +488,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) { const int i = i0 + threadIdx.x; - tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); + tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[int64_t(j_vram)*stride_mask + i] : half(0.0f); } } } else if constexpr (nbatch_fa < 2*warp_size) { @@ -505,7 +505,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const int i = threadIdx.x % (warp_size/cols_per_warp); - ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + int64_t(j_vram)*stride_mask + 2*i); } } else { #pragma unroll @@ -521,7 +521,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) { const int i = i0 + 2*threadIdx.x; - ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + int64_t(j_vram)*stride_mask + i); } } }