From 75b9543856158561584afe59772713ded1e82e95 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Mon, 30 Mar 2026 16:20:00 +0200 Subject: [PATCH] CUDA : Fix CUB's argsort when nrows % block_size == 0 CCCL < 3.1 (llama/21181) * CUDA: Fix CUB's argsort when nrows % block_size == 0 CCCL < 3.1 We wrongly calculated offset_grid as `ceildiv(nrows, block_size)`, while it must be `ceildiv(nrows + 1, block_size)`. As a consequence, we had uninitialized values in `offset_iterator[nrows]` for the case when `nrows % block_size == 0`. Fixes #21162 * Reduce nrows in test case to 256, don't need 768 --- ggml/src/ggml-cuda/argsort.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 4896669c..38fdf367 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -47,9 +47,11 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, #ifdef STRIDED_ITERATOR_AVAILABLE auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols); #else - ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); + // offset_iterator needs to populate nrows + 1 elements, so we also have to ceildiv nrows + 1 by block_size + const int nrows_offset = nrows + 1; + ggml_cuda_pool_alloc offsets_alloc(pool, nrows_offset); int * offset_iterator = offsets_alloc.get(); - const dim3 offset_grid((nrows + block_size - 1) / block_size); + const dim3 offset_grid((nrows_offset + block_size - 1) / block_size); init_offsets<<>>(offset_iterator, ncols, nrows); #endif CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));