From 58e68df0f91dd16ff56423ee5ef44062ed73bdfc Mon Sep 17 00:00:00 2001 From: Pascal Date: Fri, 8 May 2026 11:44:09 +0200 Subject: [PATCH] cuda: fuse snake activation (mul, sin, sqr, mul, add) (#22667) * cuda: fuse snake activation (mul, sin, sqr, mul, add) Add ggml_cuda_op_snake_fused with F32 / F16 / BF16 templates. The matcher recognizes the naive 5 op decomposition emitted by audio decoders (BigVGAN, Vocos) for snake activation y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise kernel. Add test_snake_fuse comparing CPU naive vs CUDA fused across F32 / F16 / BF16. * cuda: address review feedback from @am17an Use ggml_cuda_cast for F32/F16/BF16 conversions and rename kernel_snake to snake_kernel to match upstream conventions. * cuda: snake fusion fastdiv on T_len, Suggested-by: @am17an * Update tests/test-backend-ops.cpp Co-authored-by: Aman Gupta * cuda: snake fusion check add->type matches x->type Address review feedback from @am17an * cuda: snake fusion check add->type matches x->type Moved for readability (equivalent) Address review feedback from @am17an --------- Co-authored-by: Aman Gupta --- ggml/src/ggml-cuda/ggml-cuda.cu | 30 ++++++++++++ ggml/src/ggml-cuda/snake.cu | 72 +++++++++++++++++++++++++++++ ggml/src/ggml-cuda/snake.cuh | 8 ++++ tests/test-backend-ops.cpp | 81 +++++++++++++++++++++++++++++++++ 4 files changed, 191 insertions(+) create mode 100644 ggml/src/ggml-cuda/snake.cu create mode 100644 ggml/src/ggml-cuda/snake.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 925a9ffe04..4df1b93088 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -39,6 +39,7 @@ #include "ggml-cuda/rope.cuh" #include "ggml-cuda/roll.cuh" #include "ggml-cuda/scale.cuh" +#include "ggml-cuda/snake.cuh" #include "ggml-cuda/softcap.cuh" #include "ggml-cuda/softmax.cuh" #include "ggml-cuda/ssm-conv.cuh" @@ -3757,6 +3758,35 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph return 2; } + // Snake activation: y = x + sin(a*x)^2 * inv_b + // Naive 5-op decomposition emitted by frontends: mul -> sin -> sqr -> mul -> add + if (ggml_can_fuse_subgraph(cgraph, i, + { GGML_OP_MUL, GGML_OP_SIN, GGML_OP_SQR, GGML_OP_MUL, GGML_OP_ADD }, + { i + 4 })) { + const ggml_tensor * mul0 = cgraph->nodes[i]; + const ggml_tensor * sqr = cgraph->nodes[i + 2]; + const ggml_tensor * mul1 = cgraph->nodes[i + 3]; + ggml_tensor * add = cgraph->nodes[i + 4]; + + // x carries the full activation shape, a is the broadcast operand + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + // mul1 reads sqr and inv_b in either operand order + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + + // closure check: the trailing add must read the same x as the leading mul + const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; + + const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16); + const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1]; + + if (type_ok && shape_ok && x_in_add == x && add->type == x->type) { + ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add); + return 4; + } + } + // multi-(add or mul) if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { int n_fuse = 0; diff --git a/ggml/src/ggml-cuda/snake.cu b/ggml/src/ggml-cuda/snake.cu new file mode 100644 index 0000000000..384638c1f4 --- /dev/null +++ b/ggml/src/ggml-cuda/snake.cu @@ -0,0 +1,72 @@ +#include "snake.cuh" +#include "convert.cuh" + +// Fused Snake activation: y = x + sin^2(a * x) * inv_b +// x: [T, C] (T contiguous), a: [1, C], inv_b: [1, C] +// Supports F32, F16, BF16 data with F32 compute. + +template +static __global__ void snake_kernel( + const T * __restrict__ x, + const float * __restrict__ a, + const float * __restrict__ inv_b, + T * __restrict__ dst, + const int total, + const uint3 T_len_fastdiv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + + const int c = (int) fastdiv((uint32_t) idx, T_len_fastdiv); + + const float xi = ggml_cuda_cast(x[idx]); + const float s = sinf(a[c] * xi); + dst[idx] = ggml_cuda_cast(xi + s * s * inv_b[c]); +} + +// Internal launcher with explicit x/a/inv_b/dst tensors. +// Shared by the public op (reads dst->src) and the fusion path (explicit args). +static void launch_snake(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst) { + const float * a_d = (const float *)a->data; + const float * inv_b_d = (const float *)inv_b->data; + + const int T = (int)x->ne[0]; + const int C = (int)x->ne[1]; + const int total = T * C; + const uint3 T_len_fastdiv = init_fastdiv_values((uint64_t) T); + + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + cudaStream_t stream = ctx.stream(); + + switch (x->type) { + case GGML_TYPE_F32: { + snake_kernel<<>>( + (const float *)x->data, a_d, inv_b_d, (float *)dst->data, total, T_len_fastdiv); + } break; + case GGML_TYPE_F16: { + snake_kernel<<>>( + (const half *)x->data, a_d, inv_b_d, (half *)dst->data, total, T_len_fastdiv); + } break; + case GGML_TYPE_BF16: { + snake_kernel<<>>( + (const nv_bfloat16 *)x->data, a_d, inv_b_d, (nv_bfloat16 *)dst->data, total, T_len_fastdiv); + } break; + default: + GGML_ABORT("snake: unsupported type"); + } +} + +// Fusion entry: caller supplies x/a/inv_b explicitly from the matched +// mul -> sin -> sqr -> mul -> add pattern. The dst is the trailing add output. +void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst) { + launch_snake(ctx, x, a, inv_b, dst); +} diff --git a/ggml/src/ggml-cuda/snake.cuh b/ggml/src/ggml-cuda/snake.cuh new file mode 100644 index 0000000000..7f6f1cb3b4 --- /dev/null +++ b/ggml/src/ggml-cuda/snake.cuh @@ -0,0 +1,8 @@ +#include "common.cuh" + +// Fusion entry point. Caller supplies x/a/inv_b explicitly. +void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8db0b6b3d5..a55b5b4c23 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3556,6 +3556,73 @@ struct test_relu_sqr : public test_case { } }; +// SNAKE activation fusion: y = x + sin(a*x)^2 * inv_b +// CUDA backend matches the naive 5-op chain (mul, sin, sqr, mul, add) +// and dispatches a single fused kernel. +struct test_snake_fuse : public test_case { + const ggml_type type; + const std::array ne; // [T, C] + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "SNAKE_FUSE"; + } + + bool run_whole_graph() override { return true; } + + double max_nmse_err() override { + // BF16 epsilon ~ 7.8e-3, F16 epsilon ~ 9.7e-4: relax tolerance to match + // the natural roundoff drift between the naive CPU chain and the fused + // CUDA kernel. F32 keeps the default tight bound. + switch (type) { + case GGML_TYPE_BF16: return 5e-3; + case GGML_TYPE_F16: return 5e-5; + default: return 1e-7; + } + } + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_snake_fuse(ggml_type type = GGML_TYPE_F32, + std::array ne = {256, 192}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); + ggml_set_name(x, "x"); + + ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]); + ggml_set_name(a, "a"); + + ggml_tensor * inv_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]); + ggml_set_name(inv_b, "inv_b"); + + // exact 5-op chain that BigVGAN / Vocos frontends emit + ggml_tensor * ax = ggml_mul(ctx, x, a); + ggml_tensor * sin_ax = ggml_sin(ctx, ax); + ggml_tensor * sin_sq = ggml_sqr(ctx, sin_ax); + ggml_tensor * scaled = ggml_mul(ctx, sin_sq, inv_b); + ggml_tensor * out = ggml_add(ctx, x, scaled); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + // x in [-pi, pi] to exercise sin periodicity, params in default [-1, 1] + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + const std::string name = ggml_get_name(t); + if (name == "x") { + init_tensor_uniform(t, -3.14159f, 3.14159f); + } else { + init_tensor_uniform(t); + } + } + } +}; + // GGML_OP_SSM_CONV struct test_ssm_conv : public test_case { const ggml_type type; @@ -7489,6 +7556,15 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_relu_sqr(type, { 5, 7, 11, 13 })); } + // SNAKE activation fusion: x + sin(a*x)^2 * inv_b + for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { + test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block + test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary + test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride + test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two + test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish + } + // glu ops for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (int v : {0, 1}) { @@ -9014,6 +9090,11 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1})); test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); + // SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192) + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192})); + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192})); + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));