Compare commits

...

7 Commits

Author SHA1 Message Date
0cc4m
dba1cbceb3 tune for RDNA3 2025-11-16 20:21:22 +01:00
0cc4m
94e2c4d2b3 fix warptile 2025-11-16 20:06:04 +01:00
0cc4m
c19b3c378c device tuning 2025-11-16 18:37:04 +00:00
0cc4m
7e8eb9ba0a vulkan: allow MMQ bk_step tuning 2025-11-16 17:00:00 +01:00
Georgi Gerganov
416e7c7f47 metal : remove obosolete asserts (#17295) 2025-11-16 09:50:26 +02:00
Georgi Gerganov
5b2093becc server : handle context overflow during decode (#17267)
* server : handle context overflow during decode

* server : minor refactor
2025-11-16 09:23:37 +02:00
lhez
52e5d421f1 opencl: fix rms_norm_mul (#17250)
* opencl: use subgrroup reduce for reduction in rms_norm_mul

* opencl: add comment about workgroup size
2025-11-15 17:40:14 -08:00
6 changed files with 298 additions and 97 deletions

View File

@@ -2191,8 +2191,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (has_mask) {
@@ -2222,8 +2220,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
}
if (need_sync) {
@@ -2363,8 +2359,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (need_sync) {

View File

@@ -5705,7 +5705,7 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs, NULL));
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}

View File

@@ -134,6 +134,15 @@ kernel void kernel_rms_norm_mul(
src1 = src1 + offset1;
dst = dst + offsetd;
// The size of sum is sizeof(float)*subgroup_size.
// Each subgroup writes its partial sum to this array.
// So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.
// This is generally true -
// for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).
if (get_sub_group_id() == 0) {
sum[get_sub_group_local_id()] = 0.0f;
}
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
@@ -148,24 +157,30 @@ kernel void kernel_rms_norm_mul(
sumf += dot(x[i00], x[i00]);
}
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if (get_sub_group_local_id() == 0) {
sum[get_sub_group_id()] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
if (get_local_id(0) < i) {
sum[get_local_id(0)] += sum[get_local_id(0) + i];
}
}
if (get_local_id(0) == 0) {
sum[0] /= ne00;
}
//for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
// if (get_local_id(0) < i) {
// sum[get_local_id(0)] += sum[get_local_id(0) + i];
// }
//}
//if (get_local_id(0) == 0) {
// sum[0] /= ne00;
//}
barrier(CLK_LOCAL_MEM_FENCE);
//barrier(CLK_LOCAL_MEM_FENCE);
float mean = sum[0];
sumf = sum[get_sub_group_local_id()];
sumf = sub_group_reduce_add(sumf);
float mean = sumf / ne00;
float scale = 1.0f/sqrt(mean + eps);
global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);

View File

@@ -38,6 +38,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
#include <mutex>
#include <future>
#include <thread>
#include <bitset>
#if defined(_MSC_VER)
# define NOMINMAX 1
@@ -561,12 +562,19 @@ struct vk_device_struct {
size_t idx;
bool mul_mat_l[GGML_TYPE_COUNT];
bool mul_mat_m[GGML_TYPE_COUNT];
bool mul_mat_s[GGML_TYPE_COUNT];
bool mul_mat_id_l[GGML_TYPE_COUNT];
bool mul_mat_id_m[GGML_TYPE_COUNT];
bool mul_mat_id_s[GGML_TYPE_COUNT];
std::bitset<GGML_TYPE_COUNT> mul_mat_l;
std::bitset<GGML_TYPE_COUNT> mul_mat_m;
std::bitset<GGML_TYPE_COUNT> mul_mat_s;
std::bitset<GGML_TYPE_COUNT> mul_mat_id_l;
std::bitset<GGML_TYPE_COUNT> mul_mat_id_m;
std::bitset<GGML_TYPE_COUNT> mul_mat_id_s;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_l;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_m;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_s;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_id_l;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_id_m;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_id_s;
vk::DescriptorSetLayout dsl;
@@ -2526,7 +2534,38 @@ static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
}
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
static uint32_t mmq_shmem_struct_size(const vk_device& device, ggml_type type) {
const uint32_t float_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
constexpr uint32_t int32_size = sizeof(uint32_t);
switch (type) {
case GGML_TYPE_Q4_0:
return 4 * int32_size + float_size;
case GGML_TYPE_Q4_1:
return 4 * int32_size + 2 * float_size;
case GGML_TYPE_Q5_0:
return 5 * int32_size + float_size;
case GGML_TYPE_Q5_1:
return 5 * int32_size + 2 * float_size;
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
return 8 * int32_size + float_size;
case GGML_TYPE_Q8_1:
return 8 * int32_size + 2 * float_size;
case GGML_TYPE_Q2_K:
return 2 * int32_size + 2 + 2 * float_size;
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
return 4 * int32_size + 2 * float_size;
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return 8 * int32_size + 2 * float_size;
default:
return 0;
}
}
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type, ggml_type src1_type, uint32_t bk_step) {
uint32_t lut_size = 0;
switch (src0_type) {
@@ -2559,11 +2598,18 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
}
// Needs to be kept up to date on shader changes
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
const uint32_t warps = warptile[0] / warptile[10];
uint32_t load_bufs;
if (src1_type != GGML_TYPE_Q8_1) {
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
} else {
load_bufs = (warptile[1] + warptile[2]) * bk_step * mmq_shmem_struct_size(device, src0_type);
}
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
@@ -2644,6 +2690,70 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev
return 0; // If no matching configuration is found
}
static bool is_k_quant(ggml_type type) {
switch (type) {
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return true;
default:
return false;
}
}
static uint32_t get_default_bk_step(const vk_device& device, ggml_type src0_type, bool mul_mat_id) {
const uint32_t bk_struct_size = mmq_shmem_struct_size(device, src0_type);
const uint32_t q5_0_struct_size = mmq_shmem_struct_size(device, GGML_TYPE_Q5_0);
const bool kq = is_k_quant(src0_type);
if (device->architecture == vk_device_architecture::AMD_GCN) {
if (mul_mat_id) {
return kq ? 1 : 4;
}
return 4;
} else if (device->vendor_id == VK_VENDOR_ID_AMD) {
if (mul_mat_id) {
return src0_type == GGML_TYPE_Q8_0 ? 1 : 2;
}
if (kq) {
if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K) {
return 4;
}
return 2;
}
return 4;
}
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
if (mul_mat_id) {
if (kq) {
return 1;
}
return src0_type != GGML_TYPE_Q8_0 ? 4 : 1;
}
if (kq) {
return src0_type == GGML_TYPE_Q4_K ? 4 : 1;
}
return src0_type != GGML_TYPE_Q8_0 ? 4 : 1;
}
// Nvidia/Generic case
if (!mul_mat_id && !kq) {
return 1;
}
return 4;
}
static void ggml_vk_load_shaders(vk_device& device) {
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
@@ -2676,6 +2786,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_s;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_m;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_l;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_id_s;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_id_m;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_id_l;
uint32_t l_align, m_align, s_align;
if (device->coopmat2) {
// spec constants and tile sizes for non-quant matmul/matmul_id
@@ -2734,14 +2851,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
// Integer MMQ has a smaller shared memory profile, but heavier register use
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
l_warptile_mmq_int = { 128, 128, 128, 0, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 0, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 0, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
// K-quants use even more registers, mitigate by setting WMITER to 1
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
l_warptile_mmq_int_k = { 128, 128, 128, 0, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int_k = { 128, 64, 64, 0, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 0, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
@@ -2751,13 +2868,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int = { 128, 128, 128, 0, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
m_warptile_mmqid_int = { 128, 64, 64, 0, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 0, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
l_warptile_mmqid_int_k = { 128, 128, 128, 0, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
m_warptile_mmqid_int_k = { 128, 64, 64, 0, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 0, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
// chip specific tuning
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
@@ -2773,31 +2890,81 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_align = 32;
for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
ggml_type t = (ggml_type)i;
const ggml_type t = (ggml_type)i;
// Disable medium and large matrix multiplication if not enough shared memory is available
// Check mmq warptiles as the largest configuration
// Throw an error if not enough for any matrix multiplication is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t, GGML_TYPE_F32, 1)) {
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t, GGML_TYPE_F32, 1)) {
device->mul_mat_m[i] = false;
device->mul_mat_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t, GGML_TYPE_F32, 1)) {
device->mul_mat_l[i] = false;
}
// Disable mul_mat_id if not enough shared memory is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) {
device->mul_mat_id_s[i] = false;
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) {
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) {
device->mul_mat_id_l[i] = false;
}
// Integer dot matmul has different shared memory requirements
// bk_step is how many blocks in k-direction are stored in shared memory at once
// Each value is initialized to 4, then reduced until it fits, or 0 if it is not supported
mul_mat_int_bk_step_s[i] = device->mul_mat_int_s[i] ? get_default_bk_step(device, t, false) : 0;
mul_mat_int_bk_step_m[i] = device->mul_mat_int_m[i] ? get_default_bk_step(device, t, false) : 0;
mul_mat_int_bk_step_l[i] = device->mul_mat_int_l[i] ? get_default_bk_step(device, t, false) : 0;
mul_mat_int_bk_step_id_s[i] = device->mul_mat_int_id_s[i] ? get_default_bk_step(device, t, true) : 0;
mul_mat_int_bk_step_id_m[i] = device->mul_mat_int_id_m[i] ? get_default_bk_step(device, t, true) : 0;
mul_mat_int_bk_step_id_l[i] = device->mul_mat_int_id_l[i] ? get_default_bk_step(device, t, true) : 0;
for (uint32_t bk_step : { 4, 2, 1 }) {
if (mul_mat_int_bk_step_s[i] == bk_step && !ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_s[i] >>= 1;
mul_mat_int_bk_step_m[i] >>= 1;
mul_mat_int_bk_step_l[i] >>= 1;
} else if (mul_mat_int_bk_step_m[i] == bk_step && !ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_m[i] >>= 1;
mul_mat_int_bk_step_l[i] >>= 1;
} else if (mul_mat_int_bk_step_l[i] == bk_step && !ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_l[i] >>= 1;
}
if (mul_mat_int_bk_step_id_s[i] == bk_step && !ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_id_s[i] >>= 1;
mul_mat_int_bk_step_id_m[i] >>= 1;
mul_mat_int_bk_step_id_l[i] >>= 1;
} else if (mul_mat_int_bk_step_id_m[i] == bk_step && !ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_id_m[i] >>= 1;
mul_mat_int_bk_step_id_l[i] >>= 1;
} else if (mul_mat_int_bk_step_id_l[i] == bk_step && !ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_id_l[i] >>= 1;
}
}
// std::cerr << "ggml_vulkan: Info: Integer dot-product matmul support for type " << ggml_type_name(t) << ": "
// << "small bk_step=" << (int)mul_mat_int_bk_step_s[i] << ", "
// << "medium bk_step=" << (int)mul_mat_int_bk_step_m[i] << ", "
// << "large bk_step=" << (int)mul_mat_int_bk_step_l[i] << "; "
// << "matmul_id small bk_step=" << (int)mul_mat_int_bk_step_id_s[i] << ", "
// << "medium bk_step=" << (int)mul_mat_int_bk_step_id_m[i] << ", "
// << "large bk_step=" << (int)mul_mat_int_bk_step_id_l[i] << std::endl;
device->mul_mat_int_s[i] = mul_mat_int_bk_step_s[i] > 0;
device->mul_mat_int_m[i] = mul_mat_int_bk_step_m[i] > 0;
device->mul_mat_int_l[i] = mul_mat_int_bk_step_l[i] > 0;
device->mul_mat_int_id_s[i] = mul_mat_int_bk_step_id_s[i] > 0;
device->mul_mat_int_id_m[i] = mul_mat_int_bk_step_id_m[i] > 0;
device->mul_mat_int_id_l[i] = mul_mat_int_bk_step_id_l[i] > 0;
}
}
@@ -2865,6 +3032,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
align, disable_robustness, require_full_subgroups, required_subgroup_size);
};
auto const &s_mmq_warptile_bk_step = [&](const std::vector<uint32_t>& warptile, ggml_type type) -> std::vector<uint32_t> {
std::vector<uint32_t> warptile_copy = warptile;
warptile_copy[3] = mul_mat_int_bk_step_s[type];
return warptile_copy;
};
auto const &m_mmq_warptile_bk_step = [&](const std::vector<uint32_t>& warptile, ggml_type type) -> std::vector<uint32_t> {
std::vector<uint32_t> warptile_copy = warptile;
warptile_copy[3] = mul_mat_int_bk_step_m[type];
return warptile_copy;
};
auto const &l_mmq_warptile_bk_step = [&](const std::vector<uint32_t>& warptile, ggml_type type) -> std::vector<uint32_t> {
std::vector<uint32_t> warptile_copy = warptile;
warptile_copy[3] = mul_mat_int_bk_step_l[type];
return warptile_copy;
};
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
};
@@ -3150,14 +3333,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat_int ## ID ## _l[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_mmq_warptile_bk_step(l_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _m[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat_int ## ID ## _m[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_mmq_warptile_bk_step(m_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _s[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat_int ## ID ## _s[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_mmq_warptile_bk_step(s_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
// Create 2 variants, {f16,f32} accumulator
@@ -3321,12 +3504,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
if (device->mul_mat_int ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_mmq_warptile_bk_step(l_ ## WARPTILE, TYPE), 1); \
if (device->mul_mat_int ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_mmq_warptile_bk_step(m_ ## WARPTILE, TYPE), 1); \
if (device->mul_mat_int ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_mmq_warptile_bk_step(s_ ## WARPTILE, TYPE), 1); \
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
@@ -4693,8 +4876,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->mul_mat_id_s[i] = true;
break;
}
}
device->mul_mat_int_l[i] = device->mul_mat_l[i];
device->mul_mat_int_m[i] = device->mul_mat_m[i];
device->mul_mat_int_s[i] = device->mul_mat_s[i];
device->mul_mat_int_id_l[i] = device->mul_mat_id_l[i];
device->mul_mat_int_id_m[i] = device->mul_mat_id_m[i];
device->mul_mat_int_id_s[i] = device->mul_mat_id_s[i];
}
std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
@@ -6130,6 +6319,16 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
return aligned ? mmp->a_s : mmp->s;
}
if (src1_type == GGML_TYPE_Q8_1) {
if ((ctx->device->mul_mat_int_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_int_m[src0_type] && !ctx->device->mul_mat_int_l[src0_type])) {
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_int_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_int_l[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
return aligned ? mmp->a_s : mmp->s;
}

View File

@@ -67,7 +67,7 @@ layout (push_constant) uniform parameter
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
// layout (constant_id = 3) const uint BK = 32;
layout (constant_id = 3) const uint BK_STEP = 1; // Amount of quant blocks stored in shared memory
layout (constant_id = 4) const uint WM = 32;
layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2;
@@ -82,14 +82,6 @@ layout (constant_id = 10) const uint WARP = 32;
#include "mul_mmq_shmem_types.glsl"
#ifdef MUL_MAT_ID
#define BK_STEP 1
#else
#ifndef BK_STEP
#define BK_STEP 4
#endif
#endif
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
shared block_b_cache buf_b[BN * BK_STEP];

View File

@@ -1686,14 +1686,13 @@ struct server_slot {
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
}
void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
bool res = prompt_cache.load(prompt, tokens, ctx, id);
if (!res) {
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
prompt.tokens.clear();
}
return res;
}
std::vector<common_adapter_lora_info> lora;
@@ -2339,7 +2338,6 @@ struct server_context {
llama_batch batch {};
bool clean_kv_cache = true;
bool add_bos_token = true;
int32_t n_ctx; // total context for all clients / slots
@@ -2702,7 +2700,10 @@ struct server_context {
const int64_t t_start = ggml_time_us();
ret->prompt_save(*prompt_cache);
ret->prompt_load(*prompt_cache, task.tokens);
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
clear_slot(*ret);
}
prompt_cache->update();
@@ -2713,12 +2714,21 @@ struct server_context {
return ret;
}
// return true if at least one slot has been purged
void clear_slot(server_slot & slot) const {
GGML_ASSERT(!slot.is_processing());
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
slot.prompt.tokens.clear();
}
// return true if at least one slot has been cleared
// TODO: improve logic
// - smarter decision which slot to purge (LRU or longest prompt?)
// - smarter decision which slot to clear (LRU or longest prompt?)
// - move slot to level 2 cache instead of removing?
// - instead of purging, try to store and resume later?
bool try_purge_idle_slots() {
bool try_clear_idle_slots() {
bool res = false;
if (!params_base.kv_unified) {
@@ -2733,12 +2743,11 @@ struct server_context {
if (slot.prompt.n_tokens() > 0) {
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
slot.prompt.tokens.clear();
clear_slot(slot);
res = true;
// purge slots one by one
// clear slots one by one
break;
}
}
@@ -2848,14 +2857,6 @@ struct server_context {
return true;
}
void kv_cache_clear() {
SRV_DBG("%s", "clearing KV cache\n");
// clear the entire KV cache
llama_memory_clear(llama_get_memory(ctx), true);
clean_kv_cache = false;
}
bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = result.text_to_send;
@@ -3443,8 +3444,8 @@ struct server_context {
// Erase token cache
const size_t n_erased = slot->prompt.tokens.size();
llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1);
slot->prompt.tokens.clear();
clear_slot(*slot);
auto res = std::make_unique<server_task_result_slot_erase>();
res->id = task.id;
@@ -3477,9 +3478,6 @@ struct server_context {
if (all_idle) {
SRV_INF("%s", "all slots are idle\n");
if (clean_kv_cache) {
kv_cache_clear();
}
return;
}
@@ -3873,12 +3871,11 @@ struct server_context {
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
clear_slot(slot);
// there is no common part left
slot.n_prompt_tokens_cache = 0;
slot.prompt.tokens.clear();
}
// check if we should process the image
@@ -4108,6 +4105,10 @@ struct server_context {
if (slot.is_processing()) {
send_error(slot, err);
slot.release();
// note: it's complicated to keep track of how much of the current batch has been
// processed before the error occurred, so we simply clear the entire context
clear_slot(slot);
}
}
@@ -4116,7 +4117,7 @@ struct server_context {
}
// retry with half the batch size to try to find a free slot in the KV cache
if (!try_purge_idle_slots()) {
if (!try_clear_idle_slots()) {
n_batch /= 2;
}