CUDA: fuse SSM_CONV + ADD(bias) + SILU (llama/22478)
This commit is contained in:
parent
9f2cec1840
commit
aec8e69c2f
|
|
@ -3556,6 +3556,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
|||
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
|
||||
const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * silu = cgraph->nodes[node_idx+1];
|
||||
if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
|
|
@ -3564,6 +3567,31 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
|||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD
|
||||
&& ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
|
||||
const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * add = cgraph->nodes[node_idx+1];
|
||||
const ggml_tensor * silu = cgraph->nodes[node_idx+2];
|
||||
if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias.
|
||||
const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0];
|
||||
if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
|
||||
return false;
|
||||
}
|
||||
if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
|
||||
&& unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
|
||||
const ggml_tensor * unary = cgraph->nodes[node_idx];
|
||||
|
|
@ -3966,8 +3994,13 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
|
|||
return 1;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
|
||||
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
|
||||
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1]);
|
||||
ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
template <bool apply_silu, size_t split_d_inner, size_t d_conv>
|
||||
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
|
||||
const float * __restrict__ bias,
|
||||
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
|
||||
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
|
||||
const int64_t n_t) {
|
||||
|
|
@ -27,6 +28,8 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
|
|||
w[j] = w_block[tid * stride_w + j];
|
||||
}
|
||||
|
||||
float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f;
|
||||
|
||||
for (int64_t i = 0; i < n_t; i++) {
|
||||
float sumf = 0.0f;
|
||||
|
||||
|
|
@ -42,12 +45,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
|
|||
for (size_t j = 0; j < d_conv; j++) {
|
||||
sumf += x[(i + j) % d_conv] * w[j];
|
||||
}
|
||||
sumf += b;
|
||||
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t>
|
||||
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
|
||||
const float * __restrict__ bias,
|
||||
const int src0_nb0, const int src0_nb1, const int src0_nb2,
|
||||
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
|
||||
const int dst_nb1, const int dst_nb2, const int64_t n_t) {
|
||||
|
|
@ -97,6 +102,8 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
|||
w[j] = w_block[tid * stride_w + j];
|
||||
}
|
||||
|
||||
float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f;
|
||||
|
||||
// Compute from shared memory
|
||||
for (int64_t i = 0; i < local_n_t; i++) {
|
||||
float sumf = 0.0f;
|
||||
|
|
@ -104,12 +111,13 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
|||
for (size_t j = 0; j < d_conv; j++) {
|
||||
sumf += smem[tid * n_cols + i + j] * w[j];
|
||||
}
|
||||
sumf += b;
|
||||
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool apply_silu>
|
||||
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
|
||||
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1,
|
||||
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
|
||||
const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
|
||||
const int64_t n_s, cudaStream_t stream) {
|
||||
|
|
@ -120,14 +128,14 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
|||
constexpr int kNC = decltype(NC)::value;
|
||||
if (n_t <= 32) {
|
||||
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
|
||||
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
const int64_t split_n_t = 32;
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float);
|
||||
ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -140,11 +148,18 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
|
||||
const bool fuse_bias = bias_add_node != nullptr;
|
||||
const bool fuse_silu = silu_dst != nullptr;
|
||||
|
||||
// bias always comes with silu.
|
||||
GGML_ASSERT(!fuse_bias || fuse_silu);
|
||||
|
||||
// The bias (when fused) is the non-conv operand of the ADD node.
|
||||
const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src[0] == dst ? bias_add_node->src[1] : bias_add_node->src[0]) : nullptr;
|
||||
|
||||
// When fusing, write to silu_dst (the node downstream references).
|
||||
const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;
|
||||
|
||||
|
|
@ -160,16 +175,23 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g
|
|||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr;
|
||||
float * dst_d = (float *) out->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(out->type == GGML_TYPE_F32);
|
||||
if (fuse_bias) {
|
||||
GGML_ASSERT(bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(bias));
|
||||
GGML_ASSERT(ggml_nelements(bias) == nr);
|
||||
}
|
||||
|
||||
if (fuse_silu) {
|
||||
ssm_conv_f32_cuda<true>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
ssm_conv_f32_cuda<true>(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
out->nb[2], nc, nr, n_t, n_s, stream);
|
||||
} else {
|
||||
ssm_conv_f32_cuda<false>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
ssm_conv_f32_cuda<false>(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
out->nb[2], nc, nr, n_t, n_s, stream);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr);
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node = nullptr, ggml_tensor * silu_dst = nullptr);
|
||||
|
|
|
|||
Loading…
Reference in New Issue