ggml-cuda : add ar_add() to avoid ambiguous operator+ for half/bfloat16 in CUDA 11.8

This commit is contained in:
Daniel Bevenius 2026-05-12 08:30:00 +02:00
parent a2839b4404
commit 5cd228494a
1 changed files with 17 additions and 3 deletions

View File

@ -105,6 +105,20 @@ static constexpr int GGML_CUDA_AR_KERNEL_BLOCKS = 8;
// blocks. Tail elements (the leftover < ELEMS_PER_VEC at the end) are
// handled only by block 0 to avoid cross-block writes to the same slots.
// ---------------------------------------------------------------------------
// CUDA 11.8 does not expose operator+ for half/bfloat16 below sm_530,
// so use the explicit intrinsics to avoid ambiguous implicit conversions.
template<typename T>
static __device__ inline T ar_add(T a, T b) {
if constexpr (std::is_same_v<T, half>) {
return __hadd(a, b);
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b));
} else {
return a + b;
}
}
template <typename T_dst, typename T_wire>
static __global__ void ggml_cuda_ar_kernel(
const T_dst * sendbuf,
@ -184,13 +198,13 @@ static __global__ void ggml_cuda_ar_kernel(
#pragma unroll
for (int k = 0; k < ELEMS_PER_VEC; ++k) {
const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[off + k]);
recvbuf[off + k] = ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(wire[k]);
recvbuf[off + k] = ar_add(ggml_cuda_cast<T_dst>(d_low), ggml_cuda_cast<T_dst>(wire[k]));
}
}
if (bid == 0 && tid < count - tail) {
const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[tail + tid]);
recvbuf[tail + tid] =
ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(host_other[tail + tid]);
ar_add(ggml_cuda_cast<T_dst>(d_low), ggml_cuda_cast<T_dst>(host_other[tail + tid]));
}
}
}
@ -210,7 +224,7 @@ static __global__ void ggml_cuda_ar_add_kernel(
const int nt = gridDim.x * blockDim.x;
for (int i = tid; i < count; i += nt) {
const T_src d_low = ggml_cuda_cast<T_src>(dst[i]);
dst[i] = ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(src[i]);
dst[i] = ar_add(ggml_cuda_cast<T_dst>(d_low), ggml_cuda_cast<T_dst>(src[i]));
}
}