diff --git a/ggml/src/ggml-cuda/allreduce.cu b/ggml/src/ggml-cuda/allreduce.cu index 434689abd..d56129a22 100644 --- a/ggml/src/ggml-cuda/allreduce.cu +++ b/ggml/src/ggml-cuda/allreduce.cu @@ -184,13 +184,15 @@ static __global__ void ggml_cuda_ar_kernel( #pragma unroll for (int k = 0; k < ELEMS_PER_VEC; ++k) { const T_wire d_low = ggml_cuda_cast(sendbuf[off + k]); - recvbuf[off + k] = ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k]); + recvbuf[off + k] = ggml_cuda_cast( + ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k])); } } if (bid == 0 && tid < count - tail) { const T_wire d_low = ggml_cuda_cast(sendbuf[tail + tid]); - recvbuf[tail + tid] = - ggml_cuda_cast(d_low) + ggml_cuda_cast(host_other[tail + tid]); + recvbuf[tail + tid] = ggml_cuda_cast( + ggml_cuda_cast(d_low) + + ggml_cuda_cast(host_other[tail + tid])); } } } @@ -210,7 +212,8 @@ static __global__ void ggml_cuda_ar_add_kernel( const int nt = gridDim.x * blockDim.x; for (int i = tid; i < count; i += nt) { const T_src d_low = ggml_cuda_cast(dst[i]); - dst[i] = ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i]); + dst[i] = ggml_cuda_cast( + ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i])); } }