diff --git a/ggml/src/ggml-cpu/binary-ops.cpp b/ggml/src/ggml-cpu/binary-ops.cpp index 14f5b43ae0..75e3829001 100644 --- a/ggml/src/ggml-cpu/binary-ops.cpp +++ b/ggml/src/ggml-cpu/binary-ops.cpp @@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds GGML_ASSERT(nb00 == sizeof(src0_t)); const auto [ir0, ir1] = get_thread_range(params, src0); - const bool is_src1_contiguous = (nb10 == sizeof(src1_t)); - - if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - } + const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1); #ifdef GGML_USE_ACCELERATE vDSP_fn_t vDSP_op = nullptr; @@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - if (is_src1_contiguous) { + if (is_src1_contiguous_rows) { // src1 is broadcastable across src0 and dst in i1, i2, i3 const int64_t nr0 = ne00 / ne10; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 82e8928e96..5852f53d19 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2964,11 +2964,12 @@ struct test_bin_bcast : public test_case { const std::array ne; const std::array nr; int nf; // number of fused ops, nf == 1 -> single op (no fusion) + bool perm1; // permute src1? - bool run_whole_graph() override { return nf > 1; } + bool run_whole_graph() override { return nf > 1 || perm1; } std::string vars() override { - return VARS_TO_STR4(type, ne, nr, nf); + return VARS_TO_STR5(type, ne, nr, nf, perm1); } size_t op_size(ggml_tensor * t) override { @@ -2978,8 +2979,9 @@ struct test_bin_bcast : public test_case { test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 1, 1}, std::array nr = {1, 2, 1, 1}, - int nf = 1) - : op(op), type(type), ne(ne), nr(nr), nf(nf) {} + int nf = 1, + bool perm1 = false) + : op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {} ggml_tensor * build_graph(ggml_context * ctx) override { GGML_ASSERT(nf <= 16); @@ -2989,12 +2991,24 @@ struct test_bin_bcast : public test_case { ggml_tensor * b[16]; for (int i = 0; i < nf; ++i) { - b[i] = ggml_new_tensor(ctx, type, 4, ne.data()); + if (perm1) { + const int p[4] = { 1, 2, 0, 3 }; // hardcoded for now + + int64_t ne_b[4]; + ne_b[0] = ne[p[0]]; + ne_b[1] = ne[p[1]]; + ne_b[2] = ne[p[2]]; + ne_b[3] = ne[p[3]]; + b[i] = ggml_new_tensor_4d(ctx, type, ne_b[0], ne_b[1], ne_b[2], ne_b[3]); + b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]); + } else { + b[i] = ggml_new_tensor(ctx, type, 4, ne.data()); + } ggml_set_name(b[i], (std::string("b") + std::to_string(i)).c_str()); } // The backward pass supports broadcasting only for GGML_ADD: - const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1; + const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1 && !perm1; if (grad_supported) { ggml_set_param(a); ggml_set_param(b[0]); @@ -7477,25 +7491,27 @@ static std::vector> make_test_cases_eval() { } } - auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr) { + auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr, bool perm1 = false) { for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) { - test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr)); + test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1)); } }; for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { - add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1}); - add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}); + for (bool perm1 : {false, true}) { + add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}, perm1); + } // test case for k_bin_bcast_unravel in CUDA backend add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});