CUDA: batch out_prod inner loop with cublasSgemmStridedBatched (#22651)

* CUDA: batch out_prod inner loop with cublasSgemmStridedBatched

* CUDA: batch out_prod inner loop with cublasSgemmStridedBatched

* CUDA: add cublasSgemmStridedBatched mapping for HIP and MUSA backends
This commit is contained in:
leonardHONG
2026-05-08 03:59:29 +08:00
committed by GitHub
parent aaf4a4d5e0
commit 05ff59cb57
4 changed files with 31 additions and 7 deletions

View File

@@ -54,15 +54,31 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
// TODO batched matrix multiplication
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
if (dps2 == 1 && ne2 > 1) {
// src0 has uniform stride s02 along dim 2; batch the inner loop with a strided GEMM
GGML_ASSERT(ne2 <= std::numeric_limits<int>::max());
const int batch_count = (int) ne2;
for (int64_t i3 = 0; i3 < ne3; ++i3) {
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
cublasSgemmStridedBatched(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
src1_d + i3 *s13 + i2 *s12, ldb,
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
&alpha, src0_d + (i3/dps3)*s03, lda, s02,
src1_d + i3 *s13, ldb, s12,
&beta, dst_d + i3 *s3, ldc, s2,
batch_count));
}
} else {
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
src1_d + i3 *s13 + i2 *s12, ldb,
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
}
}
}
}

View File

@@ -48,6 +48,7 @@
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasSgemmStridedBatched hipblasSgemmStridedBatched
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch

View File

@@ -32,6 +32,7 @@
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasSgemmStridedBatched mublasSgemmStridedBatched
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasGetStatusString

View File

@@ -8385,6 +8385,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
// ne2 sweep to cover the cublasSgemmStridedBatched path (dps2 == 1, ne2 > 1)
for (int64_t ne2 : {1, 8, 16, 32}) {
test_cases.emplace_back(new test_out_prod(GGML_TYPE_F32, GGML_TYPE_F32,
256, 16, 16, {ne2, 1}, {1, 1}));
}
// add_id
for (ggml_type type_a : {GGML_TYPE_F32}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {