From 029c30fda49e1941b95decd6e196eafbd679b974 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 10 Feb 2026 12:44:50 +0200 Subject: [PATCH] cont : extend bin support --- ggml/src/ggml-cuda/binbcast.cu | 64 ++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 0e6d777b1e..00cddff974 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + const int s0, + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { @@ -72,14 +75,14 @@ static __global__ void k_bin_bcast(const src0_t * src0, for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { const uint32_t i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } - dst_row[i0] = (dst_t) result; + dst_row[i0*s0] = (dst_t) result; } } @@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + const int s0, + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { @@ -135,14 +141,14 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const int i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } - dst_row[i0] = (dst_t) result; + dst_row[i0*s0] = (dst_t) result; } template @@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * cnb[3] *= cne[3]; }; - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { for (int i = 0; i < 4; i++) { if (nr[i] != 1) { break; @@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s00 == 1); - GGML_ASSERT(s10 == 1); - const int block_size = 128; int64_t hne0 = std::max(ne0 / 2LL, 1LL); @@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * k_bin_bcast_unravel<<>>( src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + s0, s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast_unravel <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + s0, s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); } } else { const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); if constexpr (sizeof...(I) > 0) { k_bin_bcast<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + s0, s1, s2, s3, + s00,s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + s0, s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); } } }