diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b92a208705..e25be3592f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3929,10 +3929,25 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph // closure check: the trailing add must read the same x as the leading mul const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; - const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16); + // Kernel iterates over total = T * C, so x and add must be 2D and + // a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled. + const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) && + (add->ne[2] == 1 && add->ne[3] == 1) && + (a->ne[2] == 1 && a->ne[3] == 1); const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1]; - if (type_ok && shape_ok && x_in_add == x && add->type == x->type) { + // x must be in the supported whitelist and every operand / intermediate + // result must share x's type, since launch_snake casts a / inv_b as + // float and templates the kernel on a single T. Mixed precision chains + // fall back to the naive path. + const ggml_tensor * sin1 = cgraph->nodes[i + 1]; + const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) && + (a->type == x->type) && (inv_b->type == x->type) && + (mul0->type == x->type) && (sin1->type == x->type) && + (sqr->type == x->type) && (mul1->type == x->type) && + (add->type == x->type); + + if (types_ok && shape_ok && dim_ok && x_in_add == x) { ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add); return 4; } @@ -5291,12 +5306,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - case GGML_OP_ADD: case GGML_OP_ADD_ID: case GGML_OP_ADD1: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -5305,6 +5316,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CLAMP: case GGML_OP_LOG: return true; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SSM_SCAN: { if (op->src[3]->ne[0] == 1) { // Mamba2 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 922ad493a3..3331194866 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3561,7 +3561,7 @@ struct test_relu_sqr : public test_case { // and dispatches a single fused kernel. struct test_snake_fuse : public test_case { const ggml_type type; - const std::array ne; // [T, C] + const std::array ne; // [T, C, D2, D3] std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); @@ -3586,11 +3586,11 @@ struct test_snake_fuse : public test_case { } test_snake_fuse(ggml_type type = GGML_TYPE_F32, - std::array ne = {256, 192}) + std::array ne = {256, 192, 1, 1}) : type(type), ne(ne) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); ggml_set_name(x, "x"); ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]); @@ -7558,11 +7558,15 @@ static std::vector> make_test_cases_eval() { // SNAKE activation fusion: x + sin(a*x)^2 * inv_b for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { - test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block - test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary - test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride - test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two - test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish + test_cases.emplace_back(new test_snake_fuse(type, { 5, 7, 1, 1})); // primes sub-block + test_cases.emplace_back(new test_snake_fuse(type, { 33, 32, 1, 1})); // boundary + test_cases.emplace_back(new test_snake_fuse(type, {1025, 13, 1, 1})); // large prime, grid-stride + test_cases.emplace_back(new test_snake_fuse(type, { 128, 16, 1, 1})); // power-of-two + test_cases.emplace_back(new test_snake_fuse(type, { 256, 192, 1, 1})); // BigVGAN-ish + // higher-rank shapes: matcher must reject fusion, fallback to naive chain + test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 1})); // ne[2] > 1 + test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 1, 2})); // ne[3] > 1 + test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 3})); // ne[2] > 1 and ne[3] > 1 } // glu ops @@ -9093,9 +9097,9 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); // SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192) - test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192})); - test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192})); - test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192})); + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192, 1, 1})); + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192, 1, 1})); + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192, 1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));