From 41a95b8ba798cece458644337c08f82772a8942d Mon Sep 17 00:00:00 2001 From: Aadeshveer Singh Date: Wed, 17 Dec 2025 09:17:01 +0530 Subject: [PATCH] ggml : use WARP_SIZE/2 for argmax reduction offset (llama/18092) --- ggml/src/ggml-cuda/argmax.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index 5340eedc..51967c66 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest } #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { + for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) { const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) { @@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest argmax = shared_argmax[lane_id]; } #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { + for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) { const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) {