ggml: support concat for scalar types at cuda backend (llama/24011)

* cuda: support concat for scalar types

* Update concat.cu

* fix metal ci issue
This commit is contained in:
ZihaoMu 2026-06-12 14:32:44 +08:00 committed by Georgi Gerganov
parent 2dcfd49d59
commit 882736f886
3 changed files with 101 additions and 62 deletions

View File

@ -1,16 +1,18 @@
#include "concat.cuh"
#include <stdint.h>
// contiguous kernels
template <int dim>
static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont(const float * x,
const float * y,
float * dst,
int64_t ne00,
int64_t ne01,
int64_t ne02,
int64_t ne0,
int64_t ne1,
int64_t ne2) {
template <typename T, int dim>
static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_cont(const T * x,
const T * y,
T * dst,
int64_t ne00,
int64_t ne01,
int64_t ne02,
int64_t ne0,
int64_t ne1,
int64_t ne2) {
static_assert(dim >= 0 && dim <= 2, "dim must be in [0, 2]");
const int64_t n = ne0 * ne1 * ne2;
@ -50,37 +52,37 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont
}
}
static void concat_f32_cuda(const float * x,
const float * y,
float * dst,
int64_t ne00,
int64_t ne01,
int64_t ne02,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int dim,
cudaStream_t stream) {
template <typename T>
static void concat_cont_cuda(const T * x,
const T * y,
T * dst,
int64_t ne00,
int64_t ne01,
int64_t ne02,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int dim,
cudaStream_t stream) {
const int64_t n = ne0 * ne1 * ne2;
const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
if (dim == 0) {
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
ggml_cuda_kernel_launch(concat_cont<T, 0>, launch_params, x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
return;
}
if (dim == 1) {
concat_f32_cont<1>
<<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
concat_cont<T, 1><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
return;
}
concat_f32_cont<2><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
concat_cont<T, 2><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
}
// non-contiguous kernel (slow)
template <int dim>
template <typename T, int dim>
static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
concat_f32_non_cont(
concat_non_cont(
const char * src0,
const char * src1,
char * dst,
@ -107,61 +109,49 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
uint64_t nb0,
uint64_t nb1,
uint64_t nb2,
uint64_t nb3){
uint64_t nb3) {
static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]");
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
const float * x;
const T * x;
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
x = (const T *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
} else {
if constexpr (dim == 0) {
x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);
x = (const T *)(src1 + i3*nb13 + i2*nb12 + i1*nb11 + (i0 - ne00)*nb10);
} else if constexpr (dim == 1) {
x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);
x = (const T *)(src1 + i3*nb13 + i2*nb12 + (i1 - ne01)*nb11 + i0*nb10);
} else if constexpr (dim == 2) {
x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);
x = (const T *)(src1 + i3*nb13 + (i2 - ne02)*nb12 + i1*nb11 + i0*nb10);
} else if constexpr (dim == 3) {
x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);
x = (const T *)(src1 + (i3 - ne03)*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
}
}
float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
T * y = (T *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*y = *x;
}
}
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
cudaStream_t stream = ctx.stream();
const int32_t dim = ((int32_t *) dst->op_params)[0];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
template <typename T>
static void concat_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, int dim, cudaStream_t stream) {
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
const T * src0_d = (const T *) src0->data;
const T * src1_d = (const T *) src1->data;
T * dst_d = (T *) dst->data;
if (dim != 3) {
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda(
src0_d + i3 * (src0->nb[3] / 4),
src1_d + i3 * (src1->nb[3] / 4),
dst_d + i3 * ( dst->nb[3] / 4),
for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) {
concat_cont_cuda(
src0_d + i3*(src0->nb[3] / sizeof(T)),
src1_d + i3*(src1->nb[3] / sizeof(T)),
dst_d + i3*( dst->nb[3] / sizeof(T)),
src0->ne[0], src0->ne[1], src0->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
}
@ -169,13 +159,13 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const size_t size0 = ggml_nbytes(src0);
const size_t size1 = ggml_nbytes(src1);
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync((char *) dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync((char *) dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream));
}
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
auto launch_kernel = [&](auto dim) {
concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
concat_non_cont<T, dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
(const char *) src0->data, (const char *) src1->data, (char *) dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
@ -203,3 +193,35 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
}
}
}
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
cudaStream_t stream = ctx.stream();
const int32_t dim = ((int32_t *) dst->op_params)[0];
GGML_ASSERT(src0->type == src1->type);
GGML_ASSERT(dst->type == src0->type);
GGML_ASSERT(!ggml_is_quantized(src0->type));
GGML_ASSERT(ggml_blck_size(src0->type) == 1);
switch (ggml_type_size(src0->type)) {
case 1:
concat_cuda<uint8_t>(src0, src1, dst, dim, stream);
break;
case 2:
concat_cuda<uint16_t>(src0, src1, dst, dim, stream);
break;
case 4:
concat_cuda<uint32_t>(src0, src1, dst, dim, stream);
break;
case 8:
concat_cuda<uint64_t>(src0, src1, dst, dim, stream);
break;
default:
GGML_ABORT("Unsupported type size: %zu", ggml_type_size(src0->type));
break;
}
}

View File

@ -5345,7 +5345,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONCAT:
{
ggml_type src0_type = op->src[0]->type;
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
ggml_type src1_type = op->src[1]->type;
return src0_type == src1_type &&
src0_type == op->type &&
!ggml_is_quantized(src0_type) &&
ggml_blck_size(src0_type) == 1 &&
(ggml_type_size(src0_type) == 1 ||
ggml_type_size(src0_type) == 2 ||
ggml_type_size(src0_type) == 4 ||
ggml_type_size(src0_type) == 8);
} break;
case GGML_OP_CONV_TRANSPOSE_1D:
{

View File

@ -1120,8 +1120,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
case GGML_OP_CONCAT:
return true;
case GGML_OP_CONCAT:
{
// kernel_concat copies one float-sized value per element.
// Other scalar types need a type-generic copy kernel first.
const enum ggml_type src0_type = op->src[0]->type;
const enum ggml_type src1_type = op->src[1]->type;
return src0_type == src1_type &&
src0_type == op->type &&
(src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_I32);
}
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL: