mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-12 03:54:06 +00:00
cuda: fuse snake activation (mul, sin, sqr, mul, add) (#22667)
* cuda: fuse snake activation (mul, sin, sqr, mul, add) Add ggml_cuda_op_snake_fused with F32 / F16 / BF16 templates. The matcher recognizes the naive 5 op decomposition emitted by audio decoders (BigVGAN, Vocos) for snake activation y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise kernel. Add test_snake_fuse comparing CPU naive vs CUDA fused across F32 / F16 / BF16. * cuda: address review feedback from @am17an Use ggml_cuda_cast for F32/F16/BF16 conversions and rename kernel_snake to snake_kernel to match upstream conventions. * cuda: snake fusion fastdiv on T_len, Suggested-by: @am17an * Update tests/test-backend-ops.cpp Co-authored-by: Aman Gupta <amangupta052@gmail.com> * cuda: snake fusion check add->type matches x->type Address review feedback from @am17an * cuda: snake fusion check add->type matches x->type Moved for readability (equivalent) Address review feedback from @am17an --------- Co-authored-by: Aman Gupta <amangupta052@gmail.com>
This commit is contained in:
@@ -39,6 +39,7 @@
|
||||
#include "ggml-cuda/rope.cuh"
|
||||
#include "ggml-cuda/roll.cuh"
|
||||
#include "ggml-cuda/scale.cuh"
|
||||
#include "ggml-cuda/snake.cuh"
|
||||
#include "ggml-cuda/softcap.cuh"
|
||||
#include "ggml-cuda/softmax.cuh"
|
||||
#include "ggml-cuda/ssm-conv.cuh"
|
||||
@@ -3757,6 +3758,35 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Snake activation: y = x + sin(a*x)^2 * inv_b
|
||||
// Naive 5-op decomposition emitted by frontends: mul -> sin -> sqr -> mul -> add
|
||||
if (ggml_can_fuse_subgraph(cgraph, i,
|
||||
{ GGML_OP_MUL, GGML_OP_SIN, GGML_OP_SQR, GGML_OP_MUL, GGML_OP_ADD },
|
||||
{ i + 4 })) {
|
||||
const ggml_tensor * mul0 = cgraph->nodes[i];
|
||||
const ggml_tensor * sqr = cgraph->nodes[i + 2];
|
||||
const ggml_tensor * mul1 = cgraph->nodes[i + 3];
|
||||
ggml_tensor * add = cgraph->nodes[i + 4];
|
||||
|
||||
// x carries the full activation shape, a is the broadcast operand
|
||||
const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
|
||||
const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
|
||||
|
||||
// mul1 reads sqr and inv_b in either operand order
|
||||
const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
|
||||
|
||||
// closure check: the trailing add must read the same x as the leading mul
|
||||
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
|
||||
|
||||
const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16);
|
||||
const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];
|
||||
|
||||
if (type_ok && shape_ok && x_in_add == x && add->type == x->type) {
|
||||
ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
// multi-(add or mul)
|
||||
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
|
||||
int n_fuse = 0;
|
||||
|
||||
72
ggml/src/ggml-cuda/snake.cu
Normal file
72
ggml/src/ggml-cuda/snake.cu
Normal file
@@ -0,0 +1,72 @@
|
||||
#include "snake.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
// Fused Snake activation: y = x + sin^2(a * x) * inv_b
|
||||
// x: [T, C] (T contiguous), a: [1, C], inv_b: [1, C]
|
||||
// Supports F32, F16, BF16 data with F32 compute.
|
||||
|
||||
template <typename T>
|
||||
static __global__ void snake_kernel(
|
||||
const T * __restrict__ x,
|
||||
const float * __restrict__ a,
|
||||
const float * __restrict__ inv_b,
|
||||
T * __restrict__ dst,
|
||||
const int total,
|
||||
const uint3 T_len_fastdiv) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= total) return;
|
||||
|
||||
const int c = (int) fastdiv((uint32_t) idx, T_len_fastdiv);
|
||||
|
||||
const float xi = ggml_cuda_cast<float>(x[idx]);
|
||||
const float s = sinf(a[c] * xi);
|
||||
dst[idx] = ggml_cuda_cast<T>(xi + s * s * inv_b[c]);
|
||||
}
|
||||
|
||||
// Internal launcher with explicit x/a/inv_b/dst tensors.
|
||||
// Shared by the public op (reads dst->src) and the fusion path (explicit args).
|
||||
static void launch_snake(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * x,
|
||||
const ggml_tensor * a,
|
||||
const ggml_tensor * inv_b,
|
||||
ggml_tensor * dst) {
|
||||
const float * a_d = (const float *)a->data;
|
||||
const float * inv_b_d = (const float *)inv_b->data;
|
||||
|
||||
const int T = (int)x->ne[0];
|
||||
const int C = (int)x->ne[1];
|
||||
const int total = T * C;
|
||||
const uint3 T_len_fastdiv = init_fastdiv_values((uint64_t) T);
|
||||
|
||||
const int block_size = 256;
|
||||
const int grid_size = (total + block_size - 1) / block_size;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
switch (x->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
snake_kernel<<<grid_size, block_size, 0, stream>>>(
|
||||
(const float *)x->data, a_d, inv_b_d, (float *)dst->data, total, T_len_fastdiv);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
snake_kernel<<<grid_size, block_size, 0, stream>>>(
|
||||
(const half *)x->data, a_d, inv_b_d, (half *)dst->data, total, T_len_fastdiv);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
snake_kernel<<<grid_size, block_size, 0, stream>>>(
|
||||
(const nv_bfloat16 *)x->data, a_d, inv_b_d, (nv_bfloat16 *)dst->data, total, T_len_fastdiv);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("snake: unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
// Fusion entry: caller supplies x/a/inv_b explicitly from the matched
|
||||
// mul -> sin -> sqr -> mul -> add pattern. The dst is the trailing add output.
|
||||
void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * x,
|
||||
const ggml_tensor * a,
|
||||
const ggml_tensor * inv_b,
|
||||
ggml_tensor * dst) {
|
||||
launch_snake(ctx, x, a, inv_b, dst);
|
||||
}
|
||||
8
ggml/src/ggml-cuda/snake.cuh
Normal file
8
ggml/src/ggml-cuda/snake.cuh
Normal file
@@ -0,0 +1,8 @@
|
||||
#include "common.cuh"
|
||||
|
||||
// Fusion entry point. Caller supplies x/a/inv_b explicitly.
|
||||
void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * x,
|
||||
const ggml_tensor * a,
|
||||
const ggml_tensor * inv_b,
|
||||
ggml_tensor * dst);
|
||||
@@ -3556,6 +3556,73 @@ struct test_relu_sqr : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// SNAKE activation fusion: y = x + sin(a*x)^2 * inv_b
|
||||
// CUDA backend matches the naive 5-op chain (mul, sin, sqr, mul, add)
|
||||
// and dispatches a single fused kernel.
|
||||
struct test_snake_fuse : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 2> ne; // [T, C]
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "SNAKE_FUSE";
|
||||
}
|
||||
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
double max_nmse_err() override {
|
||||
// BF16 epsilon ~ 7.8e-3, F16 epsilon ~ 9.7e-4: relax tolerance to match
|
||||
// the natural roundoff drift between the naive CPU chain and the fused
|
||||
// CUDA kernel. F32 keeps the default tight bound.
|
||||
switch (type) {
|
||||
case GGML_TYPE_BF16: return 5e-3;
|
||||
case GGML_TYPE_F16: return 5e-5;
|
||||
default: return 1e-7;
|
||||
}
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
}
|
||||
|
||||
test_snake_fuse(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 2> ne = {256, 192})
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
|
||||
ggml_set_name(x, "x");
|
||||
|
||||
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * inv_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]);
|
||||
ggml_set_name(inv_b, "inv_b");
|
||||
|
||||
// exact 5-op chain that BigVGAN / Vocos frontends emit
|
||||
ggml_tensor * ax = ggml_mul(ctx, x, a);
|
||||
ggml_tensor * sin_ax = ggml_sin(ctx, ax);
|
||||
ggml_tensor * sin_sq = ggml_sqr(ctx, sin_ax);
|
||||
ggml_tensor * scaled = ggml_mul(ctx, sin_sq, inv_b);
|
||||
ggml_tensor * out = ggml_add(ctx, x, scaled);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
// x in [-pi, pi] to exercise sin periodicity, params in default [-1, 1]
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
|
||||
const std::string name = ggml_get_name(t);
|
||||
if (name == "x") {
|
||||
init_tensor_uniform(t, -3.14159f, 3.14159f);
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SSM_CONV
|
||||
struct test_ssm_conv : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -7489,6 +7556,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_relu_sqr(type, { 5, 7, 11, 13 }));
|
||||
}
|
||||
|
||||
// SNAKE activation fusion: x + sin(a*x)^2 * inv_b
|
||||
for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary
|
||||
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish
|
||||
}
|
||||
|
||||
// glu ops
|
||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
for (int v : {0, 1}) {
|
||||
@@ -9014,6 +9090,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));
|
||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
||||
|
||||
// SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192)
|
||||
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192}));
|
||||
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192}));
|
||||
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user