diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index adba4d522..8d557092b 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,16 +1,18 @@ #include "concat.cuh" +#include + // contiguous kernels -template -static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont(const float * x, - const float * y, - float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne0, - int64_t ne1, - int64_t ne2) { +template +static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_cont(const T * x, + const T * y, + T * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2) { static_assert(dim >= 0 && dim <= 2, "dim must be in [0, 2]"); const int64_t n = ne0 * ne1 * ne2; @@ -50,37 +52,37 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont } } -static void concat_f32_cuda(const float * x, - const float * y, - float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int dim, - cudaStream_t stream) { +template +static void concat_cont_cuda(const T * x, + const T * y, + T * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int dim, + cudaStream_t stream) { const int64_t n = ne0 * ne1 * ne2; const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; if (dim == 0) { const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream); - ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + ggml_cuda_kernel_launch(concat_cont, launch_params, x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { - concat_f32_cont<1> - <<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + concat_cont<<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } - concat_f32_cont<2><<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + concat_cont<<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); } // non-contiguous kernel (slow) -template +template static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) - concat_f32_non_cont( + concat_non_cont( const char * src0, const char * src1, char * dst, @@ -107,61 +109,49 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) uint64_t nb0, uint64_t nb1, uint64_t nb2, - uint64_t nb3){ + uint64_t nb3) { static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]"); const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; - const float * x; + const T * x; for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + x = (const T *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); } else { if constexpr (dim == 0) { - x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10); + x = (const T *)(src1 + i3*nb13 + i2*nb12 + i1*nb11 + (i0 - ne00)*nb10); } else if constexpr (dim == 1) { - x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10); + x = (const T *)(src1 + i3*nb13 + i2*nb12 + (i1 - ne01)*nb11 + i0*nb10); } else if constexpr (dim == 2) { - x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10); + x = (const T *)(src1 + i3*nb13 + (i2 - ne02)*nb12 + i1*nb11 + i0*nb10); } else if constexpr (dim == 3) { - x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10); + x = (const T *)(src1 + (i3 - ne03)*nb13 + i2*nb12 + i1*nb11 + i0*nb10); } } - float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + T * y = (T *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); *y = *x; } } - -void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - cudaStream_t stream = ctx.stream(); - - const int32_t dim = ((int32_t *) dst->op_params)[0]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - +template +static void concat_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, int dim, cudaStream_t stream) { if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - - float * dst_d = (float *)dst->data; + const T * src0_d = (const T *) src0->data; + const T * src1_d = (const T *) src1->data; + T * dst_d = (T *) dst->data; if (dim != 3) { - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda( - src0_d + i3 * (src0->nb[3] / 4), - src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * ( dst->nb[3] / 4), + for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) { + concat_cont_cuda( + src0_d + i3*(src0->nb[3] / sizeof(T)), + src1_d + i3*(src1->nb[3] / sizeof(T)), + dst_d + i3*( dst->nb[3] / sizeof(T)), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); } @@ -169,13 +159,13 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const size_t size0 = ggml_nbytes(src0); const size_t size1 = ggml_nbytes(src1); - CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *) dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *) dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream)); } } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); auto launch_kernel = [&](auto dim) { - concat_f32_non_cont<<>>( + concat_non_cont<<>>( (const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], @@ -203,3 +193,35 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } } + +void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + cudaStream_t stream = ctx.stream(); + + const int32_t dim = ((int32_t *) dst->op_params)[0]; + + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT(dst->type == src0->type); + GGML_ASSERT(!ggml_is_quantized(src0->type)); + GGML_ASSERT(ggml_blck_size(src0->type) == 1); + + switch (ggml_type_size(src0->type)) { + case 1: + concat_cuda(src0, src1, dst, dim, stream); + break; + case 2: + concat_cuda(src0, src1, dst, dim, stream); + break; + case 4: + concat_cuda(src0, src1, dst, dim, stream); + break; + case 8: + concat_cuda(src0, src1, dst, dim, stream); + break; + default: + GGML_ABORT("Unsupported type size: %zu", ggml_type_size(src0->type)); + break; + } +} diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e779a9be9..61041bdc1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -5345,7 +5345,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; - return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + ggml_type src1_type = op->src[1]->type; + return src0_type == src1_type && + src0_type == op->type && + !ggml_is_quantized(src0_type) && + ggml_blck_size(src0_type) == 1 && + (ggml_type_size(src0_type) == 1 || + ggml_type_size(src0_type) == 2 || + ggml_type_size(src0_type) == 4 || + ggml_type_size(src0_type) == 8); } break; case GGML_OP_CONV_TRANSPOSE_1D: { diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 05d7f4305..d583bd6ef 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1120,8 +1120,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_VIEW: case GGML_OP_TRANSPOSE: case GGML_OP_PERMUTE: - case GGML_OP_CONCAT: return true; + case GGML_OP_CONCAT: + { + // kernel_concat copies one float-sized value per element. + // Other scalar types need a type-generic copy kernel first. + const enum ggml_type src0_type = op->src[0]->type; + const enum ggml_type src1_type = op->src[1]->type; + return src0_type == src1_type && + src0_type == op->type && + (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_I32); + } case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: