mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-11 11:34:10 +00:00
CUDA: batch out_prod inner loop with cublasSgemmStridedBatched (#22651)
* CUDA: batch out_prod inner loop with cublasSgemmStridedBatched * CUDA: batch out_prod inner loop with cublasSgemmStridedBatched * CUDA: add cublasSgemmStridedBatched mapping for HIP and MUSA backends
This commit is contained in:
@@ -54,15 +54,31 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int64_t dps2 = ne2 / ne02;
|
||||
const int64_t dps3 = ne3 / ne03;
|
||||
|
||||
// TODO batched matrix multiplication
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
if (dps2 == 1 && ne2 > 1) {
|
||||
// src0 has uniform stride s02 along dim 2; batch the inner loop with a strided GEMM
|
||||
GGML_ASSERT(ne2 <= std::numeric_limits<int>::max());
|
||||
const int batch_count = (int) ne2;
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
cublasSgemmStridedBatched(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
&alpha, src0_d + (i3/dps3)*s03, lda, s02,
|
||||
src1_d + i3 *s13, ldb, s12,
|
||||
&beta, dst_d + i3 *s3, ldc, s2,
|
||||
batch_count));
|
||||
}
|
||||
} else {
|
||||
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
|
||||
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
1
ggml/src/ggml-cuda/vendors/hip.h
vendored
1
ggml/src/ggml-cuda/vendors/hip.h
vendored
@@ -48,6 +48,7 @@
|
||||
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetStream hipblasSetStream
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasSgemmStridedBatched hipblasSgemmStridedBatched
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
|
||||
|
||||
1
ggml/src/ggml-cuda/vendors/musa.h
vendored
1
ggml/src/ggml-cuda/vendors/musa.h
vendored
@@ -32,6 +32,7 @@
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasSgemmStridedBatched mublasSgemmStridedBatched
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cublasOperation_t mublasOperation_t
|
||||
#define cublasGetStatusString mublasGetStatusString
|
||||
|
||||
@@ -8385,6 +8385,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
// ne2 sweep to cover the cublasSgemmStridedBatched path (dps2 == 1, ne2 > 1)
|
||||
for (int64_t ne2 : {1, 8, 16, 32}) {
|
||||
test_cases.emplace_back(new test_out_prod(GGML_TYPE_F32, GGML_TYPE_F32,
|
||||
256, 16, 16, {ne2, 1}, {1, 1}));
|
||||
}
|
||||
|
||||
// add_id
|
||||
for (ggml_type type_a : {GGML_TYPE_F32}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
|
||||
Reference in New Issue
Block a user