cuda: fuse snake activation (mul, sin, sqr, mul, add) (llama/22667)

* cuda: fuse snake activation (mul, sin, sqr, mul, add)

Add ggml_cuda_op_snake_fused with F32 / F16 / BF16 templates. The
matcher recognizes the naive 5 op decomposition emitted by audio
decoders (BigVGAN, Vocos) for snake activation
y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise
kernel.

Add test_snake_fuse comparing CPU naive vs CUDA fused across
F32 / F16 / BF16.

* cuda: address review feedback from @am17an

Use ggml_cuda_cast for F32/F16/BF16 conversions and rename
kernel_snake to snake_kernel to match upstream conventions.

* cuda: snake fusion fastdiv on T_len, Suggested-by: @am17an

* Update tests/test-backend-ops.cpp

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cuda: snake fusion check add->type matches x->type

Address review feedback from @am17an

* cuda: snake fusion check add->type matches x->type

Moved for readability (equivalent)
Address review feedback from @am17an

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
This commit is contained in:
Pascal 2026-05-08 11:44:09 +02:00 committed by Georgi Gerganov
parent acb484d776
commit a0c421f7ab
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 110 additions and 0 deletions

View File

@ -39,6 +39,7 @@
#include "ggml-cuda/rope.cuh"
#include "ggml-cuda/roll.cuh"
#include "ggml-cuda/scale.cuh"
#include "ggml-cuda/snake.cuh"
#include "ggml-cuda/softcap.cuh"
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
@ -3757,6 +3758,35 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
return 2;
}
// Snake activation: y = x + sin(a*x)^2 * inv_b
// Naive 5-op decomposition emitted by frontends: mul -> sin -> sqr -> mul -> add
if (ggml_can_fuse_subgraph(cgraph, i,
{ GGML_OP_MUL, GGML_OP_SIN, GGML_OP_SQR, GGML_OP_MUL, GGML_OP_ADD },
{ i + 4 })) {
const ggml_tensor * mul0 = cgraph->nodes[i];
const ggml_tensor * sqr = cgraph->nodes[i + 2];
const ggml_tensor * mul1 = cgraph->nodes[i + 3];
ggml_tensor * add = cgraph->nodes[i + 4];
// x carries the full activation shape, a is the broadcast operand
const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
// mul1 reads sqr and inv_b in either operand order
const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
// closure check: the trailing add must read the same x as the leading mul
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16);
const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];
if (type_ok && shape_ok && x_in_add == x && add->type == x->type) {
ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
return 4;
}
}
// multi-(add or mul)
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
int n_fuse = 0;

View File

@ -0,0 +1,72 @@
#include "snake.cuh"
#include "convert.cuh"
// Fused Snake activation: y = x + sin^2(a * x) * inv_b
// x: [T, C] (T contiguous), a: [1, C], inv_b: [1, C]
// Supports F32, F16, BF16 data with F32 compute.
template <typename T>
static __global__ void snake_kernel(
const T * __restrict__ x,
const float * __restrict__ a,
const float * __restrict__ inv_b,
T * __restrict__ dst,
const int total,
const uint3 T_len_fastdiv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total) return;
const int c = (int) fastdiv((uint32_t) idx, T_len_fastdiv);
const float xi = ggml_cuda_cast<float>(x[idx]);
const float s = sinf(a[c] * xi);
dst[idx] = ggml_cuda_cast<T>(xi + s * s * inv_b[c]);
}
// Internal launcher with explicit x/a/inv_b/dst tensors.
// Shared by the public op (reads dst->src) and the fusion path (explicit args).
static void launch_snake(ggml_backend_cuda_context & ctx,
const ggml_tensor * x,
const ggml_tensor * a,
const ggml_tensor * inv_b,
ggml_tensor * dst) {
const float * a_d = (const float *)a->data;
const float * inv_b_d = (const float *)inv_b->data;
const int T = (int)x->ne[0];
const int C = (int)x->ne[1];
const int total = T * C;
const uint3 T_len_fastdiv = init_fastdiv_values((uint64_t) T);
const int block_size = 256;
const int grid_size = (total + block_size - 1) / block_size;
cudaStream_t stream = ctx.stream();
switch (x->type) {
case GGML_TYPE_F32: {
snake_kernel<<<grid_size, block_size, 0, stream>>>(
(const float *)x->data, a_d, inv_b_d, (float *)dst->data, total, T_len_fastdiv);
} break;
case GGML_TYPE_F16: {
snake_kernel<<<grid_size, block_size, 0, stream>>>(
(const half *)x->data, a_d, inv_b_d, (half *)dst->data, total, T_len_fastdiv);
} break;
case GGML_TYPE_BF16: {
snake_kernel<<<grid_size, block_size, 0, stream>>>(
(const nv_bfloat16 *)x->data, a_d, inv_b_d, (nv_bfloat16 *)dst->data, total, T_len_fastdiv);
} break;
default:
GGML_ABORT("snake: unsupported type");
}
}
// Fusion entry: caller supplies x/a/inv_b explicitly from the matched
// mul -> sin -> sqr -> mul -> add pattern. The dst is the trailing add output.
void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx,
const ggml_tensor * x,
const ggml_tensor * a,
const ggml_tensor * inv_b,
ggml_tensor * dst) {
launch_snake(ctx, x, a, inv_b, dst);
}

View File

@ -0,0 +1,8 @@
#include "common.cuh"
// Fusion entry point. Caller supplies x/a/inv_b explicitly.
void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx,
const ggml_tensor * x,
const ggml_tensor * a,
const ggml_tensor * inv_b,
ggml_tensor * dst);