unify matmul_id shader selection

This commit is contained in:
Ruben Ortlam
2026-03-12 14:55:12 +01:00
parent 664dfc7730
commit 7ded1269ab

View File

@@ -7069,27 +7069,21 @@ static void ggml_vk_matmul(
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
if (ctx->device->coopmat2) {
// Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l;
}
// Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
}
const bool mm_l = ctx->device->mul_mat_id_l[src0_type];
const bool mm_m = ctx->device->mul_mat_id_m[src0_type];
const bool mm_s = ctx->device->mul_mat_id_s[src0_type];
if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
return aligned ? mmp->a_s : mmp->s;
// Use large shader when the N dimension is greater than the medium shader's tile size
const uint32_t crossover_large = mm_m ? mmp->m->wg_denoms[1] : (mm_s ? mmp->s->wg_denoms[1] : 0);
if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) {
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
// Use medium shader when the N dimension is greater than the small shader's tile size
const uint32_t crossover_medium = mm_s ? mmp->s->wg_denoms[1] : 0;
if ((mm_m && (n > crossover_medium)) || !mm_s) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
return aligned ? mmp->a_s : mmp->s;
}
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {