Compare commits

...

5 Commits
b6504 ... b6509

Author SHA1 Message Date
Sigbjørn Skjæret
ad6bd9083b cuda : add missing F32<->I32 entries in ggml_cuda_cpy_fn (#16060) 2025-09-18 13:28:22 +02:00
Radoslav Gerganov
2b6b55a59f server : include usage statistics only when user request them (#16052)
* server : include usage statistics only when user request them

When serving the OpenAI compatible API, we should check if
{"stream_options": {"include_usage": true} is set in the request when
deciding whether we should send usage statistics

closes: #16048

* add unit test
2025-09-18 10:36:57 +00:00
Georgi Gerganov
e58174cecb llama : bump max seq limit from 64 to 256 (#15916)
ggml-ci
2025-09-18 12:47:56 +03:00
Georgi Gerganov
b213fce89b metal : improve F32, F16 and BF16 mat-vec multiplication (#16057)
* metal : improve F32, F16 and BF16 mat-vec multiplication

ggml-ci

* metal : make the NSG a function constant in mul_mv kernels

ggml-ci
2025-09-18 12:33:45 +03:00
Jhen-Jie Hong
e00f3fd8ff metal : avoid call free for non-owned buffer (#16067) 2025-09-18 10:06:48 +03:00
10 changed files with 401 additions and 316 deletions

View File

@@ -441,6 +441,10 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));

View File

@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
}
void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
if (!ppls) {
return;
}
for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
ggml_metal_pipeline_free(it->second);
}
@@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
// use custom matrix x vector kernel
switch (tsrc0) {
case GGML_TYPE_F32:
{
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
nsg = 1;
nr0 = 1;
nr1 = 4;
if (ne00 == 4) {
nr0 = 32;
suffix = "_c4";
}
} break;
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
nsg = 1;
nr0 = 1;
if (op->src[1]->type == GGML_TYPE_F32) {
if (ne00 == 4) {
nr0 = 32;
nr1 = 4;
suffix = "_c4";
} else if (ne11 * ne12 < 4) {
suffix = "_1row";
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
suffix = "_l4";
nr1 = ne11;
} else {
nr1 = 4;
}
} else {
if (ne00 == 4) {
nsg = 1;
nr0 = 32;
nr1 = 4;
suffix = "_c4";
} else if (ne00 % 4 == 0) {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
suffix = "_4";
} else {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
}
} break;
case GGML_TYPE_Q4_0:
@@ -623,7 +615,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
ggml_metal_pipeline_set_nr0 (res, nr0);
ggml_metal_pipeline_set_nr1 (res, nr1);
@@ -689,25 +687,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
const ggml_type tsrc0 = op->src[0]->type;
const ggml_type tsrc1 = op->src[1]->type;
const char * suffix = "";
// use custom matrix x vector kernel
switch (tsrc0) {
case GGML_TYPE_F32:
{
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
nsg = 1;
nr0 = 1;
} break;
case GGML_TYPE_F16:
{
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
nsg = 1;
nr0 = 1;
} break;
case GGML_TYPE_BF16:
{
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
nsg = 1;
nr0 = 1;
if (ne00 % 4 == 0) {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
suffix = "_4";
} else {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
}
} break;
case GGML_TYPE_Q4_0:
{
@@ -824,7 +823,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
}
};
snprintf(base, 256, "kernel_mul_mv_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
@@ -832,7 +831,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
ggml_metal_pipeline_set_nr0 (res, nr0);
ggml_metal_pipeline_set_nr1 (res, nr1);

View File

@@ -22,6 +22,7 @@ typedef struct ggml_metal_cv * ggml_metal_cv_t;
ggml_metal_cv_t ggml_metal_cv_init(void);
void ggml_metal_cv_free(ggml_metal_cv_t cv);
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx);

View File

@@ -51,6 +51,10 @@ void ggml_metal_cv_free(ggml_metal_cv_t cv) {
free(cv);
}
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) {
[cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx];
}
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) {
[cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx];
}
@@ -824,6 +828,7 @@ struct ggml_metal_buffer {
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
bool is_shared;
bool owned;
// multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
int n_buffers;
@@ -956,6 +961,7 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
if (shared) {
res->all_data = ggml_metal_host_malloc(size_aligned);
res->is_shared = true;
res->owned = true;
} else {
// dummy, non-NULL value - we'll populate this after creating the Metal buffer below
res->all_data = (void *) 0x000000400ULL;
@@ -1014,6 +1020,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
res->all_size = size;
res->is_shared = true;
res->owned = false;
res->n_buffers = 0;
@@ -1107,7 +1114,7 @@ void ggml_metal_buffer_free(ggml_metal_buffer_t buf) {
ggml_metal_buffer_rset_free(buf);
if (buf->is_shared) {
if (buf->is_shared && buf->owned) {
#if TARGET_OS_OSX
vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)buf->all_data, buf->all_size);
#else

View File

@@ -8,6 +8,9 @@
//
// TODO: for optimal performance, become function of the device and work size
#define N_R0_F 2
#define N_SG_F 4
#define N_R0_Q4_0 4
#define N_SG_Q4_0 2
@@ -72,6 +75,7 @@
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
#define FC_MUL_MV 400
// kernel argument structs
//

View File

@@ -1564,7 +1564,10 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
if (op->src[0]->type == GGML_TYPE_Q8_0) {
if (op->src[0]->type == GGML_TYPE_F32 ||
op->src[0]->type == GGML_TYPE_F16 ||
op->src[0]->type == GGML_TYPE_BF16 ||
op->src[0]->type == GGML_TYPE_Q8_0) {
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
} else {
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
@@ -1772,7 +1775,10 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
if (op->src[0]->type == GGML_TYPE_Q8_0) {
if (op->src[0]->type == GGML_TYPE_F32 ||
op->src[0]->type == GGML_TYPE_F16 ||
op->src[0]->type == GGML_TYPE_BF16 ||
op->src[0]->type == GGML_TYPE_Q8_0) {
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
} else {
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);

View File

@@ -2883,7 +2883,9 @@ static inline void helper_mv_reduce_and_write(
}
}
template<typename block_q_type, short NR0, short NSG, short NW, typename args_t>
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
template<typename block_q_type, short NR0, short NW, typename args_t>
void mul_vec_q_n_f32_impl(
args_t args,
device const char * src0,
@@ -2893,6 +2895,8 @@ void mul_vec_q_n_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
constexpr short NQ = 16;
const int nb = args.ne00/QK4_0;
@@ -2977,7 +2981,7 @@ kernel void kernel_mul_mv_q4_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
@@ -2989,7 +2993,7 @@ kernel void kernel_mul_mv_q4_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
@@ -3001,7 +3005,7 @@ kernel void kernel_mul_mv_q5_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
@@ -3013,10 +3017,10 @@ kernel void kernel_mul_mv_q5_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<short NR0, short NSG, short NW, typename args_t>
template<short NR0, short NW, typename args_t>
void kernel_mul_mv_q8_0_f32_impl(
args_t args,
device const char * src0,
@@ -3026,6 +3030,8 @@ void kernel_mul_mv_q8_0_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
constexpr short NQ = 8;
const int nb = args.ne00/QK8_0;
@@ -3097,7 +3103,7 @@ kernel void kernel_mul_mv_q8_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
// mat-vec kernel processing in chunks of float4
@@ -3404,104 +3410,215 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
#define N_MV_T_T 4
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
void kernel_mul_mv_impl(
template<typename T0, typename T1, short NR0, short NW, typename args_t>
void kernel_mul_mv_t_t_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg) {
const int r0 = tgpig.x;
const int rb = tgpig.y*N_MV_T_T;
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
constexpr short NB = 32;
constexpr short NF = 8;
const int nb = args.ne00/NB;
const int r0 = tgpig.x*NR0;
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T0 * x = (device const T0 *) (src0 + offset0);
//device const T0 * x = (device const T0 *) (src0 + offset0);
device const T1 * y = (device const T1 *) (src1 + offset1);
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
// pointers to src0 rows
device const T0 * ax [NR0];
FOR_UNROLL (short row = 0; row < NR0; ++row) {
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
if (args.ne00 < 128) {
for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row;
if (r1 >= args.ne11) {
break;
}
ax[row] = (device const T0 *) ((device char *) src0 + offset0);
}
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
float sumf[NR0] = { 0.f };
device const T1 * y = (device const T1 *) (src1 + offset1);
const short ix = tiisg/(NW/NF);
const short il = tiisg%(NW/NF);
float sumf = 0;
for (int i = tiisg; i < args.ne00; i += 32) {
sumf += (T0) x[i] * (T1) y[i];
}
const int ib0 = sgitg*NF + ix;
float sum_all = simd_sum(sumf);
if (tiisg == 0) {
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
}
T1 yl[NF];
device const T1 * yb = y + (ib0*NB + il*NF);
for (int ib = ib0; ib < nb; ib += NSG*NF) {
for (short i = 0; i < NF; ++i) {
yl[i] = yb[i];
}
} else {
device const T04 * x4 = (device const T04 *) x;
for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row;
if (r1 >= args.ne11) {
break;
for (short row = 0; row < NR0; row++) {
device const T0 * xb = ax[row] + (ib*NB + il*NF);
float sumq = 0.f;
FOR_UNROLL (short i = 0; i < NF; ++i) {
sumq += xb[i] * yl[i];
}
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
sumf[row] += sumq;
}
device const T1 * y = (device const T1 *) (src1 + offset1);
device const T14 * y4 = (device const T14 *) y;
yb += NSG*NF*NW;
}
float sumf = 0;
for (int i = tiisg; i < args.ne00/4; i += 32) {
sumf += dot((float4) x4[i], (float4) y4[i]);
}
float sum_all = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
}
for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
for (short row = 0; row < NR0; row++) {
sumf[row] += ax[row][i] * y[i];
}
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
}
template<typename T0, typename T04, typename T1, typename T14>
kernel void kernel_mul_mv(
template<typename T0, typename T1, short NR0, short NW>
kernel void kernel_mul_mv_t_t(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
args,
src0,
src1,
dst,
tgpig,
tiisg);
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_t_t_impl<T0, T1, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t;
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F, N_SIMDWIDTH>;
#endif
template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW, typename args_t>
void kernel_mul_mv_t_t_4_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
constexpr short NB = 32;
constexpr short NF = 16;
constexpr short NF4 = NF/4;
const int nb = args.ne00/NB;
const int r0 = tgpig.x*NR0;
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T1 * y = (device const T1 *) (src1 + offset1);
device const T14 * y4 = (device const T14 *) (src1 + offset1);
// pointers to src0 rows
device const T0 * ax [NR0];
device const T04 * ax4[NR0];
FOR_UNROLL (short row = 0; row < NR0; ++row) {
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax [row] = (device const T0 *) ((device char *) src0 + offset0);
ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
}
float sumf[NR0] = { 0.f };
const short ix = tiisg/(NW/NF);
const short il = tiisg%(NW/NF);
const int ib0 = sgitg*NF + ix;
T14 yl4[NF4];
device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4;
for (int ib = ib0; ib < nb; ib += NSG*NF) {
for (short i = 0; i < NF4; ++i) {
yl4[i] = yb4[i];
}
for (short row = 0; row < NR0; row++) {
device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4;
float sumq = 0.f;
FOR_UNROLL (short i = 0; i < NF4; ++i) {
sumq += dot(float4(xb4[i]), float4(yl4[i]));
}
sumf[row] += sumq;
}
yb4 += NSG*NF*NW/4;
}
for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
for (short row = 0; row < NR0; row++) {
sumf[row] += ax[row][i] * y[i];
}
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
}
template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW>
kernel void kernel_mul_mv_t_t_4(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t_4;
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F, N_SIMDWIDTH>;
#endif
#define N_MV_T_T 4
template<typename T04, typename T14, typename args_t>
void kernel_mul_mv_c4_impl(
args_t args,
@@ -3562,112 +3679,10 @@ typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, half4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
#endif
template<typename T, typename T4>
kernel void kernel_mul_mv_1row(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T * x = (device const T *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
float sumf = 0;
if (args.ne00 < 128) {
for (int i = tiisg; i < args.ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}
float sum_all = simd_sum(sumf);
if (tiisg == 0) {
dst_f32[r0] = sum_all;
}
} else {
device const T4 * x4 = (device const T4 *) x;
device const float4 * y4 = (device const float4 *) y;
for (int i = tiisg; i < args.ne00/4; i += 32) {
sumf += dot((float4) x4[i], y4[i]);
}
float sum_all = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
dst_f32[r0] = sum_all;
}
}
}
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
#endif
// Assumes row size (ne00) is a multiple of 4
template<typename T, typename T4>
kernel void kernel_mul_mv_l4(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
const int nrows = args.ne11;
const int r0 = tgpig.x;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
device const T4 * x4 = (device const T4 *) (src0 + offset0);
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
for (int r1 = 0; r1 < nrows; ++r1) {
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const float4 * y4 = (device const float4 *) (src1 + offset1);
float sumf = 0;
for (int i = tiisg; i < args.ne00/4; i += 32) {
sumf += dot((float4) x4[i], y4[i]);
}
float sum_all = simd_sum(sumf);
if (tiisg == 0) {
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
}
}
}
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, bfloat4>;
#endif
static float rope_yarn_ramp(const float low, const float high, const int i0) {
@@ -5951,7 +5966,7 @@ kernel void kernel_concat(
}
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_q2_K_f32_impl(
args_t args,
device const char * src0,
@@ -5961,13 +5976,15 @@ void kernel_mul_mv_q2_K_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -6051,10 +6068,10 @@ kernel void kernel_mul_mv_q2_K_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_q3_K_f32_impl(
args_t args,
device const char * src0,
@@ -6064,6 +6081,7 @@ void kernel_mul_mv_q3_K_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
@@ -6071,7 +6089,7 @@ void kernel_mul_mv_q3_K_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -6215,10 +6233,10 @@ kernel void kernel_mul_mv_q3_K_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_q4_K_f32_impl(
args_t args,
device const char * src0,
@@ -6228,6 +6246,8 @@ void kernel_mul_mv_q4_K_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
@@ -6243,7 +6263,7 @@ void kernel_mul_mv_q4_K_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -6337,10 +6357,10 @@ kernel void kernel_mul_mv_q4_K_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_q5_K_f32_impl(
args_t args,
device const char * src0,
@@ -6350,6 +6370,7 @@ void kernel_mul_mv_q5_K_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
@@ -6357,7 +6378,7 @@ void kernel_mul_mv_q5_K_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -6468,10 +6489,10 @@ kernel void kernel_mul_mv_q5_K_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_q6_K_f32_impl(
args_t args,
device const char * src0,
@@ -6481,6 +6502,7 @@ void kernel_mul_mv_q6_K_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
@@ -6493,7 +6515,7 @@ void kernel_mul_mv_q6_K_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -6577,12 +6599,12 @@ kernel void kernel_mul_mv_q6_K_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
// ======================= "True" 2-bit
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq2_xxs_f32_impl(
args_t args,
device const char * src0,
@@ -6592,13 +6614,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -6685,10 +6709,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq2_xs_f32_impl(
args_t args,
device const char * src0,
@@ -6698,13 +6722,15 @@ void kernel_mul_mv_iq2_xs_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -6802,10 +6828,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq3_xxs_f32_impl(
args_t args,
device const char * src0,
@@ -6815,13 +6841,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -6912,10 +6940,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq3_s_f32_impl(
args_t args,
device const char * src0,
@@ -6925,13 +6953,15 @@ void kernel_mul_mv_iq3_s_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -7022,10 +7052,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq2_s_f32_impl(
args_t args,
device const char * src0,
@@ -7035,13 +7065,15 @@ void kernel_mul_mv_iq2_s_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -7133,10 +7165,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq1_s_f32_impl(
args_t args,
device const char * src0,
@@ -7146,13 +7178,15 @@ void kernel_mul_mv_iq1_s_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -7230,10 +7264,10 @@ kernel void kernel_mul_mv_iq1_s_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq1_m_f32_impl(
args_t args,
device const char * src0,
@@ -7243,6 +7277,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK_K;
@@ -7250,7 +7285,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -7338,10 +7373,10 @@ kernel void kernel_mul_mv_iq1_m_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq4_nl_f32_impl(
args_t args,
device const char * src0,
@@ -7351,6 +7386,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK4_NL;
@@ -7359,7 +7395,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -7444,10 +7480,10 @@ kernel void kernel_mul_mv_iq4_nl_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_iq4_xs_f32_impl(
args_t args,
device const char * src0,
@@ -7457,13 +7493,15 @@ void kernel_mul_mv_iq4_xs_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -7549,10 +7587,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
template<int nr0, int nw, typename args_t>
void kernel_mul_mv_mxfp4_f32_impl(
args_t args,
device const char * src0,
@@ -7562,6 +7600,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_MXFP4;
@@ -7570,7 +7609,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -7638,7 +7677,7 @@ kernel void kernel_mul_mv_mxfp4_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -8314,7 +8353,7 @@ void mmv_fn(
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, N_SIMDWIDTH, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
template<mul_mv_impl_fn_t impl_fn>
kernel void kernel_mul_mv_id(
@@ -8379,36 +8418,44 @@ kernel void kernel_mul_mv_id(
sgitg);
}
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_t;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_4_t;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F, N_SIMDWIDTH>>>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F, N_SIMDWIDTH>>>;
#endif
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
#endif
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SIMDWIDTH>>>;
kernel void kernel_pool_2d_max_f32(
constant ggml_metal_kargs_pool_2d & args,

View File

@@ -4,7 +4,7 @@
#include <cstdint>
#define LLAMA_MAX_SEQ 64
#define LLAMA_MAX_SEQ 256
struct llama_cparams {
uint32_t n_ctx; // context size used during inference

View File

@@ -111,6 +111,7 @@ static bool server_task_type_need_logits(server_task_type task_type) {
struct slot_params {
bool stream = true;
bool include_usage = false;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
bool return_progress = false;
@@ -310,17 +311,19 @@ struct server_task {
params.verbose = params_base.verbosity > 9;
params.timings_per_token = json_value(data, "timings_per_token", false);
params.stream = json_value(data, "stream", false);
params.cache_prompt = json_value(data, "cache_prompt", true);
params.return_tokens = json_value(data, "return_tokens", false);
params.return_progress = json_value(data, "return_progress", false);
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
params.stream = json_value(data, "stream", false);
auto stream_opt = json_value(data, "stream_options", json::object());
params.include_usage = json_value(stream_opt, "include_usage", false);
params.cache_prompt = json_value(data, "cache_prompt", true);
params.return_tokens = json_value(data, "return_tokens", false);
params.return_progress = json_value(data, "return_progress", false);
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
@@ -775,6 +778,7 @@ struct server_task_result_cmpl_final : server_task_result {
llama_tokens tokens;
bool stream;
bool include_usage;
result_timings timings;
std::string prompt;
@@ -982,21 +986,23 @@ struct server_task_result_cmpl_final : server_task_result {
{"object", "chat.completion.chunk"},
});
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
deltas.push_back({
{"choices", json::array()},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}},
});
if (include_usage) {
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
deltas.push_back({
{"choices", json::array()},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}},
});
}
if (timings.prompt_n >= 0) {
deltas.back().push_back({"timings", timings.to_json()});
@@ -2815,6 +2821,7 @@ struct server_context {
res->verbose = slot.params.verbose;
res->stream = slot.params.stream;
res->include_usage = slot.params.include_usage;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;

View File

@@ -271,8 +271,10 @@ def test_chat_completion_with_timings_per_token():
"max_tokens": 10,
"messages": [{"role": "user", "content": "test"}],
"stream": True,
"stream_options": {"include_usage": True},
"timings_per_token": True,
})
stats_received = False
for i, data in enumerate(res):
if i == 0:
# Check first role message for stream=True
@@ -288,6 +290,8 @@ def test_chat_completion_with_timings_per_token():
assert "predicted_per_second" in data["timings"]
assert "predicted_n" in data["timings"]
assert data["timings"]["predicted_n"] <= 10
stats_received = True
assert stats_received
def test_logprobs():