cuda: fuse snake activation (mul, sin, sqr, mul, add) (#22667)

* cuda: fuse snake activation (mul, sin, sqr, mul, add)

Add ggml_cuda_op_snake_fused with F32 / F16 / BF16 templates. The
matcher recognizes the naive 5 op decomposition emitted by audio
decoders (BigVGAN, Vocos) for snake activation
y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise
kernel.

Add test_snake_fuse comparing CPU naive vs CUDA fused across
F32 / F16 / BF16.

* cuda: address review feedback from @am17an

Use ggml_cuda_cast for F32/F16/BF16 conversions and rename
kernel_snake to snake_kernel to match upstream conventions.

* cuda: snake fusion fastdiv on T_len, Suggested-by: @am17an

* Update tests/test-backend-ops.cpp

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cuda: snake fusion check add->type matches x->type

Address review feedback from @am17an

* cuda: snake fusion check add->type matches x->type

Moved for readability (equivalent)
Address review feedback from @am17an

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
This commit is contained in:
Pascal
2026-05-08 11:44:09 +02:00
committed by GitHub
parent 9b2925e1e0
commit 58e68df0f9
4 changed files with 191 additions and 0 deletions

View File

@@ -39,6 +39,7 @@
#include "ggml-cuda/rope.cuh"
#include "ggml-cuda/roll.cuh"
#include "ggml-cuda/scale.cuh"
#include "ggml-cuda/snake.cuh"
#include "ggml-cuda/softcap.cuh"
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
@@ -3757,6 +3758,35 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
return 2;
}
// Snake activation: y = x + sin(a*x)^2 * inv_b
// Naive 5-op decomposition emitted by frontends: mul -> sin -> sqr -> mul -> add
if (ggml_can_fuse_subgraph(cgraph, i,
{ GGML_OP_MUL, GGML_OP_SIN, GGML_OP_SQR, GGML_OP_MUL, GGML_OP_ADD },
{ i + 4 })) {
const ggml_tensor * mul0 = cgraph->nodes[i];
const ggml_tensor * sqr = cgraph->nodes[i + 2];
const ggml_tensor * mul1 = cgraph->nodes[i + 3];
ggml_tensor * add = cgraph->nodes[i + 4];
// x carries the full activation shape, a is the broadcast operand
const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
// mul1 reads sqr and inv_b in either operand order
const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
// closure check: the trailing add must read the same x as the leading mul
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16);
const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];
if (type_ok && shape_ok && x_in_add == x && add->type == x->type) {
ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
return 4;
}
}
// multi-(add or mul)
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
int n_fuse = 0;

View File

@@ -0,0 +1,72 @@
#include "snake.cuh"
#include "convert.cuh"
// Fused Snake activation: y = x + sin^2(a * x) * inv_b
// x: [T, C] (T contiguous), a: [1, C], inv_b: [1, C]
// Supports F32, F16, BF16 data with F32 compute.
template <typename T>
static __global__ void snake_kernel(
const T * __restrict__ x,
const float * __restrict__ a,
const float * __restrict__ inv_b,
T * __restrict__ dst,
const int total,
const uint3 T_len_fastdiv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total) return;
const int c = (int) fastdiv((uint32_t) idx, T_len_fastdiv);
const float xi = ggml_cuda_cast<float>(x[idx]);
const float s = sinf(a[c] * xi);
dst[idx] = ggml_cuda_cast<T>(xi + s * s * inv_b[c]);
}
// Internal launcher with explicit x/a/inv_b/dst tensors.
// Shared by the public op (reads dst->src) and the fusion path (explicit args).
static void launch_snake(ggml_backend_cuda_context & ctx,
const ggml_tensor * x,
const ggml_tensor * a,
const ggml_tensor * inv_b,
ggml_tensor * dst) {
const float * a_d = (const float *)a->data;
const float * inv_b_d = (const float *)inv_b->data;
const int T = (int)x->ne[0];
const int C = (int)x->ne[1];
const int total = T * C;
const uint3 T_len_fastdiv = init_fastdiv_values((uint64_t) T);
const int block_size = 256;
const int grid_size = (total + block_size - 1) / block_size;
cudaStream_t stream = ctx.stream();
switch (x->type) {
case GGML_TYPE_F32: {
snake_kernel<<<grid_size, block_size, 0, stream>>>(
(const float *)x->data, a_d, inv_b_d, (float *)dst->data, total, T_len_fastdiv);
} break;
case GGML_TYPE_F16: {
snake_kernel<<<grid_size, block_size, 0, stream>>>(
(const half *)x->data, a_d, inv_b_d, (half *)dst->data, total, T_len_fastdiv);
} break;
case GGML_TYPE_BF16: {
snake_kernel<<<grid_size, block_size, 0, stream>>>(
(const nv_bfloat16 *)x->data, a_d, inv_b_d, (nv_bfloat16 *)dst->data, total, T_len_fastdiv);
} break;
default:
GGML_ABORT("snake: unsupported type");
}
}
// Fusion entry: caller supplies x/a/inv_b explicitly from the matched
// mul -> sin -> sqr -> mul -> add pattern. The dst is the trailing add output.
void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx,
const ggml_tensor * x,
const ggml_tensor * a,
const ggml_tensor * inv_b,
ggml_tensor * dst) {
launch_snake(ctx, x, a, inv_b, dst);
}

View File

@@ -0,0 +1,8 @@
#include "common.cuh"
// Fusion entry point. Caller supplies x/a/inv_b explicitly.
void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx,
const ggml_tensor * x,
const ggml_tensor * a,
const ggml_tensor * inv_b,
ggml_tensor * dst);

View File

@@ -3556,6 +3556,73 @@ struct test_relu_sqr : public test_case {
}
};
// SNAKE activation fusion: y = x + sin(a*x)^2 * inv_b
// CUDA backend matches the naive 5-op chain (mul, sin, sqr, mul, add)
// and dispatches a single fused kernel.
struct test_snake_fuse : public test_case {
const ggml_type type;
const std::array<int64_t, 2> ne; // [T, C]
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "SNAKE_FUSE";
}
bool run_whole_graph() override { return true; }
double max_nmse_err() override {
// BF16 epsilon ~ 7.8e-3, F16 epsilon ~ 9.7e-4: relax tolerance to match
// the natural roundoff drift between the naive CPU chain and the fused
// CUDA kernel. F32 keeps the default tight bound.
switch (type) {
case GGML_TYPE_BF16: return 5e-3;
case GGML_TYPE_F16: return 5e-5;
default: return 1e-7;
}
}
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_snake_fuse(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 2> ne = {256, 192})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
ggml_set_name(x, "x");
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]);
ggml_set_name(a, "a");
ggml_tensor * inv_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]);
ggml_set_name(inv_b, "inv_b");
// exact 5-op chain that BigVGAN / Vocos frontends emit
ggml_tensor * ax = ggml_mul(ctx, x, a);
ggml_tensor * sin_ax = ggml_sin(ctx, ax);
ggml_tensor * sin_sq = ggml_sqr(ctx, sin_ax);
ggml_tensor * scaled = ggml_mul(ctx, sin_sq, inv_b);
ggml_tensor * out = ggml_add(ctx, x, scaled);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
// x in [-pi, pi] to exercise sin periodicity, params in default [-1, 1]
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
const std::string name = ggml_get_name(t);
if (name == "x") {
init_tensor_uniform(t, -3.14159f, 3.14159f);
} else {
init_tensor_uniform(t);
}
}
}
};
// GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case {
const ggml_type type;
@@ -7489,6 +7556,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_relu_sqr(type, { 5, 7, 11, 13 }));
}
// SNAKE activation fusion: x + sin(a*x)^2 * inv_b
for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish
}
// glu ops
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int v : {0, 1}) {
@@ -9014,6 +9090,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
// SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192)
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));