diff --git a/ggml/src/ggml-cuda/allreduce.cu b/ggml/src/ggml-cuda/allreduce.cu index 03d88968c..434689abd 100644 --- a/ggml/src/ggml-cuda/allreduce.cu +++ b/ggml/src/ggml-cuda/allreduce.cu @@ -105,20 +105,6 @@ static constexpr int GGML_CUDA_AR_KERNEL_BLOCKS = 8; // blocks. Tail elements (the leftover < ELEMS_PER_VEC at the end) are // handled only by block 0 to avoid cross-block writes to the same slots. // --------------------------------------------------------------------------- - -// CUDA 11.8 does not expose operator+ for half/bfloat16 below sm_530, -// so use the explicit intrinsics to avoid ambiguous implicit conversions. -template -static __device__ inline T ar_add(T a, T b) { - if constexpr (std::is_same_v) { - return __hadd(a, b); - } else if constexpr (std::is_same_v) { - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); - } else { - return a + b; - } -} - template static __global__ void ggml_cuda_ar_kernel( const T_dst * sendbuf, @@ -198,13 +184,13 @@ static __global__ void ggml_cuda_ar_kernel( #pragma unroll for (int k = 0; k < ELEMS_PER_VEC; ++k) { const T_wire d_low = ggml_cuda_cast(sendbuf[off + k]); - recvbuf[off + k] = ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(wire[k])); + recvbuf[off + k] = ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k]); } } if (bid == 0 && tid < count - tail) { const T_wire d_low = ggml_cuda_cast(sendbuf[tail + tid]); recvbuf[tail + tid] = - ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(host_other[tail + tid])); + ggml_cuda_cast(d_low) + ggml_cuda_cast(host_other[tail + tid]); } } } @@ -224,7 +210,7 @@ static __global__ void ggml_cuda_ar_add_kernel( const int nt = gridDim.x * blockDim.x; for (int i = tid; i < count; i += nt) { const T_src d_low = ggml_cuda_cast(dst[i]); - dst[i] = ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(src[i])); + dst[i] = ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i]); } }