Ggml/cuda snake fusion hardening (#22912)

* cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan)

* cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review)

* cuda: merge type_ok and types_ok into a single types_ok (address am17an review)

* cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16

bin_bcast only dispatches F32/F16 type triplets, mirror the
vulkan filter so unsupported types fall back through cpy
instead of aborting.

* test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases
This commit is contained in:
Pascal
2026-05-11 18:42:08 +02:00
committed by GitHub
parent ef22b3e4ac
commit e936660760
2 changed files with 39 additions and 17 deletions

View File

@@ -3929,10 +3929,25 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
// closure check: the trailing add must read the same x as the leading mul
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16);
// Kernel iterates over total = T * C, so x and add must be 2D and
// a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled.
const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) &&
(add->ne[2] == 1 && add->ne[3] == 1) &&
(a->ne[2] == 1 && a->ne[3] == 1);
const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];
if (type_ok && shape_ok && x_in_add == x && add->type == x->type) {
// x must be in the supported whitelist and every operand / intermediate
// result must share x's type, since launch_snake casts a / inv_b as
// float and templates the kernel on a single T. Mixed precision chains
// fall back to the naive path.
const ggml_tensor * sin1 = cgraph->nodes[i + 1];
const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) &&
(a->type == x->type) && (inv_b->type == x->type) &&
(mul0->type == x->type) && (sin1->type == x->type) &&
(sqr->type == x->type) && (mul1->type == x->type) &&
(add->type == x->type);
if (types_ok && shape_ok && dim_ok && x_in_add == x) {
ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
return 4;
}
@@ -5291,12 +5306,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
@@ -5305,6 +5316,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CLAMP:
case GGML_OP_LOG:
return true;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_SSM_SCAN: {
if (op->src[3]->ne[0] == 1) {
// Mamba2

View File

@@ -3561,7 +3561,7 @@ struct test_relu_sqr : public test_case {
// and dispatches a single fused kernel.
struct test_snake_fuse : public test_case {
const ggml_type type;
const std::array<int64_t, 2> ne; // [T, C]
const std::array<int64_t, 4> ne; // [T, C, D2, D3]
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
@@ -3586,11 +3586,11 @@ struct test_snake_fuse : public test_case {
}
test_snake_fuse(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 2> ne = {256, 192})
std::array<int64_t, 4> ne = {256, 192, 1, 1})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_name(x, "x");
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]);
@@ -7558,11 +7558,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// SNAKE activation fusion: x + sin(a*x)^2 * inv_b
for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7, 1, 1})); // primes sub-block
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32, 1, 1})); // boundary
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13, 1, 1})); // large prime, grid-stride
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16, 1, 1})); // power-of-two
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192, 1, 1})); // BigVGAN-ish
// higher-rank shapes: matcher must reject fusion, fallback to naive chain
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 1})); // ne[2] > 1
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 1, 2})); // ne[3] > 1
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 3})); // ne[2] > 1 and ne[3] > 1
}
// glu ops
@@ -9093,9 +9097,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
// SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192)
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192, 1, 1}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192, 1, 1}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192, 1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));