mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 20:44:09 +00:00
Ggml/cuda snake fusion hardening (#22912)
* cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan) * cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review) * cuda: merge type_ok and types_ok into a single types_ok (address am17an review) * cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16 bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting. * test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases
This commit is contained in:
@@ -3929,10 +3929,25 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
|
||||
// closure check: the trailing add must read the same x as the leading mul
|
||||
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
|
||||
|
||||
const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16);
|
||||
// Kernel iterates over total = T * C, so x and add must be 2D and
|
||||
// a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled.
|
||||
const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) &&
|
||||
(add->ne[2] == 1 && add->ne[3] == 1) &&
|
||||
(a->ne[2] == 1 && a->ne[3] == 1);
|
||||
const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];
|
||||
|
||||
if (type_ok && shape_ok && x_in_add == x && add->type == x->type) {
|
||||
// x must be in the supported whitelist and every operand / intermediate
|
||||
// result must share x's type, since launch_snake casts a / inv_b as
|
||||
// float and templates the kernel on a single T. Mixed precision chains
|
||||
// fall back to the naive path.
|
||||
const ggml_tensor * sin1 = cgraph->nodes[i + 1];
|
||||
const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) &&
|
||||
(a->type == x->type) && (inv_b->type == x->type) &&
|
||||
(mul0->type == x->type) && (sin1->type == x->type) &&
|
||||
(sqr->type == x->type) && (mul1->type == x->type) &&
|
||||
(add->type == x->type);
|
||||
|
||||
if (types_ok && shape_ok && dim_ok && x_in_add == x) {
|
||||
ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
|
||||
return 4;
|
||||
}
|
||||
@@ -5291,12 +5306,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
@@ -5305,6 +5316,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_LOG:
|
||||
return true;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
|
||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
||||
case GGML_OP_SSM_SCAN: {
|
||||
if (op->src[3]->ne[0] == 1) {
|
||||
// Mamba2
|
||||
|
||||
@@ -3561,7 +3561,7 @@ struct test_relu_sqr : public test_case {
|
||||
// and dispatches a single fused kernel.
|
||||
struct test_snake_fuse : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 2> ne; // [T, C]
|
||||
const std::array<int64_t, 4> ne; // [T, C, D2, D3]
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
@@ -3586,11 +3586,11 @@ struct test_snake_fuse : public test_case {
|
||||
}
|
||||
|
||||
test_snake_fuse(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 2> ne = {256, 192})
|
||||
std::array<int64_t, 4> ne = {256, 192, 1, 1})
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
|
||||
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
|
||||
ggml_set_name(x, "x");
|
||||
|
||||
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]);
|
||||
@@ -7558,11 +7558,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
|
||||
// SNAKE activation fusion: x + sin(a*x)^2 * inv_b
|
||||
for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary
|
||||
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7, 1, 1})); // primes sub-block
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32, 1, 1})); // boundary
|
||||
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13, 1, 1})); // large prime, grid-stride
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16, 1, 1})); // power-of-two
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192, 1, 1})); // BigVGAN-ish
|
||||
// higher-rank shapes: matcher must reject fusion, fallback to naive chain
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 1})); // ne[2] > 1
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 1, 2})); // ne[3] > 1
|
||||
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 3})); // ne[2] > 1 and ne[3] > 1
|
||||
}
|
||||
|
||||
// glu ops
|
||||
@@ -9093,9 +9097,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
||||
|
||||
// SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192)
|
||||
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192}));
|
||||
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192}));
|
||||
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192}));
|
||||
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192, 1, 1}));
|
||||
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192, 1, 1}));
|
||||
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192, 1, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
|
||||
|
||||
Reference in New Issue
Block a user