cuda : fix KQ mask offset integer overflow in fattn MMA kernel (llama/23610)
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
parent
8e40325876
commit
60e420ff6a
|
|
@ -472,7 +472,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|||
|
||||
const int i = 8 * (threadIdx.x % (nbatch_fa/8));
|
||||
|
||||
cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
|
||||
cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + int64_t(j_vram)*stride_mask + i);
|
||||
}
|
||||
} else if constexpr (oob_check) {
|
||||
#pragma unroll
|
||||
|
|
@ -488,7 +488,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|||
for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
|
||||
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[int64_t(j_vram)*stride_mask + i] : half(0.0f);
|
||||
}
|
||||
}
|
||||
} else if constexpr (nbatch_fa < 2*warp_size) {
|
||||
|
|
@ -505,7 +505,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|||
|
||||
const int i = threadIdx.x % (warp_size/cols_per_warp);
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + int64_t(j_vram)*stride_mask + 2*i);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
|
|
@ -521,7 +521,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|||
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
|
||||
const int i = i0 + 2*threadIdx.x;
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + int64_t(j_vram)*stride_mask + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue