mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-08 01:54:10 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d8359f5fde | ||
|
|
6a2c6145a0 |
@@ -438,21 +438,35 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_libr
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const ggml_type tsrc0 = op->src[0]->type;
|
||||
const ggml_type tsrc1 = op->src[1]->type;
|
||||
|
||||
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
|
||||
const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_pipeline_set_smem(res, 8192);
|
||||
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
||||
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
|
||||
ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -659,19 +673,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const ggml_type tsrc0 = op->src[0]->type;
|
||||
const ggml_type tsrc1 = op->src[1]->type;
|
||||
|
||||
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_bci=%d", base, bc_inp);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
ggml_metal_pipeline_set_smem(res, 8192);
|
||||
|
||||
|
||||
@@ -115,10 +115,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
@@ -717,8 +717,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return true;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return has_simdgroup_reduction &&
|
||||
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
|
||||
@@ -76,6 +76,7 @@
|
||||
#define FC_FLASH_ATTN_EXT_VEC 200
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
|
||||
#define FC_MUL_MV 400
|
||||
#define FC_MUL_MM 500
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
|
||||
@@ -1520,22 +1520,20 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
!ggml_is_transposed(op->src[1]) &&
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
props_dev->has_simdgroup_mm &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||
props_dev->has_simdgroup_mm && ne00 >= 64 &&
|
||||
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
|
||||
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
default: break;
|
||||
}
|
||||
//switch (op->src[0]->type) {
|
||||
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// default: break;
|
||||
//}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op->src[0]->type, op->src[1]->type);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
|
||||
|
||||
ggml_metal_kargs_mul_mm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1655,8 +1653,6 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(!ggml_is_transposed(op->src[0]));
|
||||
GGML_ASSERT(!ggml_is_transposed(op->src[1]));
|
||||
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
|
||||
@@ -1674,19 +1670,15 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
// ne21 = n_rows (batch size)
|
||||
const int ne21_mm_id_min = 32;
|
||||
|
||||
if (props_dev->has_simdgroup_mm &&
|
||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||
(ne21 >= ne21_mm_id_min)) {
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
default: break;
|
||||
}
|
||||
//switch (op->src[0]->type) {
|
||||
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// default: break;
|
||||
//}
|
||||
|
||||
// extra buffers for intermediate id mapping
|
||||
ggml_metal_buffer_id bid_tpe = bid_dst;
|
||||
@@ -1730,7 +1722,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op->src[0]->type, GGML_TYPE_F16);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
|
||||
@@ -33,6 +33,7 @@ using namespace metal;
|
||||
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
typedef matrix<bfloat, 4, 4> bfloat4x4;
|
||||
typedef matrix<bfloat, 2, 4> bfloat2x4;
|
||||
#endif
|
||||
|
||||
constexpr constant static float kvalues_iq4nl_f[16] = {
|
||||
@@ -7856,6 +7857,9 @@ kernel void kernel_set_rows_f(
|
||||
}
|
||||
}
|
||||
|
||||
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
|
||||
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
|
||||
|
||||
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
||||
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
||||
#define BLOCK_SIZE_K 32
|
||||
@@ -7868,7 +7872,7 @@ kernel void kernel_set_rows_f(
|
||||
#define SG_MAT_ROW 8
|
||||
|
||||
// each block_q contains 16*nl weights
|
||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm(
|
||||
constant ggml_metal_kargs_mul_mm & args,
|
||||
device const char * src0,
|
||||
@@ -7879,8 +7883,8 @@ kernel void kernel_mul_mm(
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup T * sa = (threadgroup T *)(shmem);
|
||||
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
const int r0 = tgpig.y;
|
||||
const int r1 = tgpig.x;
|
||||
@@ -7894,8 +7898,9 @@ kernel void kernel_mul_mm(
|
||||
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
|
||||
simdgroup_T8x8 ma[4];
|
||||
simdgroup_float8x8 mb[2];
|
||||
S0_8x8 ma[4];
|
||||
S1_8x8 mb[2];
|
||||
|
||||
simdgroup_float8x8 mc[8];
|
||||
|
||||
for (short i = 0; i < 8; i++){
|
||||
@@ -7913,27 +7918,45 @@ kernel void kernel_mul_mm(
|
||||
device const block_q * x = (device const block_q *)(src0
|
||||
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
||||
|
||||
device const float * y = (device const float *)(src1
|
||||
const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
|
||||
|
||||
device const T1 * y = (device const T1 *)(src1
|
||||
+ args.nb13*i13
|
||||
+ args.nb12*i12
|
||||
+ args.nb11*(r1*BLOCK_SIZE_N + thread_col)
|
||||
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
+ args.nb10*iy);
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
||||
// load data and store to threadgroup memory
|
||||
T4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// no need for dequantization
|
||||
for (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
#pragma unroll(16)
|
||||
for (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
*(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
|
||||
if (FC_mul_mm_bc_inp) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0;
|
||||
}
|
||||
} else {
|
||||
*(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
@@ -7942,8 +7965,8 @@ kernel void kernel_mul_mm(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||
threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||
threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||
threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
||||
@@ -7971,7 +7994,8 @@ kernel void kernel_mul_mm(
|
||||
}
|
||||
}
|
||||
|
||||
if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
|
||||
if (!FC_mul_mm_bc_out || ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1)) {
|
||||
// if no bounds checks on the output are needed, we can directly write to device memory
|
||||
device float * C = (device float *) dst +
|
||||
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
|
||||
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
||||
@@ -8076,7 +8100,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
|
||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
constant ggml_metal_kargs_mul_mm_id & args,
|
||||
device const char * src0,
|
||||
@@ -8090,8 +8114,8 @@ kernel void kernel_mul_mm_id(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup T * sa = (threadgroup T *)(shmem);
|
||||
threadgroup half * sb = (threadgroup half *)(shmem + 4096);
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
const int r0 = tgpig.y;
|
||||
const int r1 = tgpig.x;
|
||||
@@ -8114,8 +8138,9 @@ kernel void kernel_mul_mm_id(
|
||||
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
|
||||
simdgroup_T8x8 ma[4];
|
||||
simdgroup_half8x8 mb[2];
|
||||
S0_8x8 ma[4];
|
||||
S1_8x8 mb[2];
|
||||
|
||||
simdgroup_float8x8 mc[8];
|
||||
|
||||
for (short i = 0; i < 8; i++){
|
||||
@@ -8136,27 +8161,45 @@ kernel void kernel_mul_mm_id(
|
||||
device const block_q * x = (device const block_q *)(src0
|
||||
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
||||
|
||||
device const float * y = (device const float *)(src1
|
||||
const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
|
||||
|
||||
device const T1 * y = (device const T1 *)(src1
|
||||
+ args.nb13*i13
|
||||
+ args.nb12*i12
|
||||
+ args.nb11*i11
|
||||
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
+ args.nb10*iy);
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
||||
// load data and store to threadgroup memory
|
||||
T4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// no need for dequantization
|
||||
for (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
#pragma unroll(16)
|
||||
for (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
|
||||
if (FC_mul_mm_bc_inp) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0;
|
||||
}
|
||||
} else {
|
||||
*(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
@@ -8165,8 +8208,8 @@ kernel void kernel_mul_mm_id(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||
threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||
threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||
threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
||||
@@ -8299,66 +8342,117 @@ template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kerne
|
||||
// matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
|
||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id;
|
||||
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
|
||||
|
||||
//
|
||||
// matrix-vector multiplication
|
||||
|
||||
@@ -408,6 +408,8 @@ struct vk_device_struct {
|
||||
bool subgroup_ballot;
|
||||
bool subgroup_clustered;
|
||||
bool multi_add;
|
||||
bool shader_int64;
|
||||
bool buffer_device_address;
|
||||
|
||||
bool add_rms_fusion;
|
||||
uint32_t partials_binding_alignment;
|
||||
@@ -655,6 +657,7 @@ struct vk_buffer_struct {
|
||||
vk::MemoryPropertyFlags memory_property_flags;
|
||||
void * ptr;
|
||||
size_t size = 0;
|
||||
vk::DeviceAddress bda_addr {};
|
||||
|
||||
vk_device device;
|
||||
|
||||
@@ -987,6 +990,7 @@ struct vk_op_argsort_push_constants {
|
||||
};
|
||||
|
||||
struct vk_op_im2col_push_constants {
|
||||
uint64_t dst_addr;
|
||||
uint32_t batch_offset; uint32_t offset_delta;
|
||||
uint32_t IC;
|
||||
uint32_t IW; uint32_t IH;
|
||||
@@ -1000,6 +1004,7 @@ struct vk_op_im2col_push_constants {
|
||||
};
|
||||
|
||||
struct vk_op_im2col_3d_push_constants {
|
||||
uint64_t dst_addr;
|
||||
uint32_t nb10;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
@@ -2012,10 +2017,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
||||
return buf;
|
||||
}
|
||||
|
||||
vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
|
||||
vk::MemoryAllocateFlags mem_flags {};
|
||||
if (device->buffer_device_address) {
|
||||
usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
|
||||
mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
|
||||
}
|
||||
|
||||
vk::BufferCreateInfo buffer_create_info{
|
||||
vk::BufferCreateFlags(),
|
||||
size,
|
||||
vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
|
||||
usage_flags,
|
||||
vk::SharingMode::eExclusive,
|
||||
0,
|
||||
nullptr,
|
||||
@@ -2027,6 +2039,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
||||
|
||||
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
|
||||
|
||||
const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
|
||||
|
||||
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
|
||||
const auto & req_flags = *it;
|
||||
|
||||
@@ -2038,7 +2052,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
||||
buf->memory_property_flags = req_flags;
|
||||
|
||||
try {
|
||||
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
|
||||
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
|
||||
break;
|
||||
} catch (const vk::SystemError& e) {
|
||||
// loop and retry
|
||||
@@ -2066,6 +2080,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
||||
buf->device = device;
|
||||
buf->size = size;
|
||||
|
||||
if (device->buffer_device_address) {
|
||||
const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
|
||||
buf->bda_addr = device->device.getBufferAddress(addressInfo);
|
||||
}
|
||||
|
||||
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
||||
device->memory_logger->log_allocation(buf, size);
|
||||
#endif
|
||||
@@ -3532,14 +3551,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
||||
#define IM2COL(bda) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
if (device->float_controls_rte_fp16) { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
}
|
||||
if (device->shader_int64 && device->buffer_device_address) {
|
||||
IM2COL(_bda)
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
||||
IM2COL()
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
||||
@@ -4017,6 +4042,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->vendor_id != VK_VENDOR_ID_INTEL &&
|
||||
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
|
||||
|
||||
device->shader_int64 = device_features2.features.shaderInt64;
|
||||
device->buffer_device_address = vk12_features.bufferDeviceAddress;
|
||||
|
||||
if (device->subgroup_size_control) {
|
||||
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
|
||||
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
|
||||
@@ -8635,6 +8663,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
|
||||
if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
|
||||
// buffer device address path doesn't use dst buffer
|
||||
d_sz = 1;
|
||||
}
|
||||
// im2col uses only src1 and dst buffers
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (op == GGML_OP_COUNT_EQUAL) {
|
||||
@@ -9486,7 +9518,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
|
||||
const uint32_t pelements = OW * KW * KH;
|
||||
|
||||
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
|
||||
|
||||
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
|
||||
|
||||
ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
|
||||
dst_addr,
|
||||
batch_offset, offset_delta,
|
||||
IC, IW, IH, OW, OH, KW, KH,
|
||||
pelements,
|
||||
@@ -9522,8 +9560,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
const int64_t OH = ne2;
|
||||
const int64_t OW = ne1;
|
||||
|
||||
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
|
||||
|
||||
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
|
||||
|
||||
vk_op_im2col_3d_push_constants pc {};
|
||||
|
||||
pc.dst_addr = dst_addr;
|
||||
pc.nb10 = nb10 / ggml_type_size(src1->type);
|
||||
pc.nb11 = nb11 / ggml_type_size(src1->type);
|
||||
pc.nb12 = nb12 / ggml_type_size(src1->type);
|
||||
|
||||
@@ -5,8 +5,11 @@
|
||||
|
||||
#include "rte.comp"
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
BDA_STORAGE_T dst_addr;
|
||||
uint batch_offset; uint offset_delta;
|
||||
uint IC;
|
||||
uint IW; uint IH;
|
||||
@@ -19,8 +22,6 @@ layout (push_constant) uniform parameter
|
||||
int d0; int d1;
|
||||
} p;
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
const uint NUM_ITER = 512 / BLOCK_SIZE;
|
||||
@@ -30,6 +31,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#if BDA
|
||||
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
const uint gidx = gl_GlobalInvocationID.x;
|
||||
|
||||
@@ -38,7 +43,7 @@ void main() {
|
||||
const uint ic = gl_GlobalInvocationID.z % p.IC;
|
||||
|
||||
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
|
||||
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
|
||||
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
|
||||
const int oh_s1 = int(oh) * p.s1;
|
||||
const uint ksize = p.OW * p.KH;
|
||||
|
||||
@@ -50,7 +55,7 @@ void main() {
|
||||
uint current_ix = rem % p.OW;
|
||||
|
||||
A_TYPE values[NUM_ITER];
|
||||
uint offset_dst[NUM_ITER];
|
||||
BDA_OFFSET_T offset_dst[NUM_ITER];
|
||||
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||
values[idx] = A_TYPE(0);
|
||||
}
|
||||
@@ -66,7 +71,7 @@ void main() {
|
||||
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
|
||||
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
|
||||
|
||||
offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
|
||||
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
|
||||
|
||||
if ((iih < p.IH) && (iiw < p.IW)) {
|
||||
values[idx] = data_a[src_base + iih * p.IW + iiw];
|
||||
@@ -89,7 +94,11 @@ void main() {
|
||||
continue;
|
||||
}
|
||||
|
||||
#if BDA
|
||||
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
|
||||
dst_addr.d = D_TYPE(values[idx]);
|
||||
#else
|
||||
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -6,8 +6,11 @@
|
||||
|
||||
#include "rte.comp"
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
BDA_STORAGE_T dst_addr;
|
||||
uint32_t nb10;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
@@ -38,8 +41,6 @@ layout (push_constant) uniform parameter
|
||||
uint32_t misalign_offsets;
|
||||
} p;
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
@@ -50,6 +51,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#if BDA
|
||||
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
const uint32_t i = gl_GlobalInvocationID.x;
|
||||
|
||||
@@ -100,13 +105,22 @@ void main() {
|
||||
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
|
||||
const uint32_t iid = iod * s2 + ikd * d2 - p2;
|
||||
|
||||
const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
|
||||
const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
|
||||
|
||||
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
|
||||
#if BDA
|
||||
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst);
|
||||
if (iih >= IH || iiw >= IW || iid >= ID) {
|
||||
dst_addr.d = D_TYPE(0.0f);
|
||||
} else {
|
||||
dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]);
|
||||
}
|
||||
#else
|
||||
if (iih >= IH || iiw >= IW || iid >= ID) {
|
||||
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
|
||||
} else {
|
||||
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
|
||||
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1447,4 +1447,19 @@ float e8m0_to_fp32(uint8_t x) {
|
||||
return uintBitsToFloat(bits);
|
||||
}
|
||||
|
||||
#if BDA
|
||||
|
||||
#extension GL_EXT_buffer_reference : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable
|
||||
|
||||
#define BDA_STORAGE_T uint64_t
|
||||
#define BDA_OFFSET_T uint64_t
|
||||
|
||||
#else
|
||||
|
||||
#define BDA_STORAGE_T uvec2
|
||||
#define BDA_OFFSET_T uint
|
||||
|
||||
#endif
|
||||
|
||||
#endif // !defined(GGML_TYPES_COMP)
|
||||
|
||||
@@ -775,13 +775,15 @@ void process_shaders() {
|
||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||
|
||||
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
||||
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
|
||||
|
||||
string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
||||
string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
|
||||
for (std::string dim_str : {"", "_3d"}) {
|
||||
for (bool bda : {false, true}) {
|
||||
std::string bda_str = bda ? "_bda" : "";
|
||||
std::string bda_def = bda ? "1" : "0";
|
||||
string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
|
||||
string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
|
||||
string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
|
||||
}
|
||||
}
|
||||
|
||||
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
|
||||
@@ -5753,6 +5753,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
// >4GB im2col destination. Too slow to run by default.
|
||||
// Test cases taken from Wan2.1 T2V 1.3B.
|
||||
test_cases.emplace_back(new test_im2col (GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {832, 480, 192, 4}, {3, 3, 192, 96}, 1, 1, 1, 1, 1, 1, true));
|
||||
test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {834, 482, 6, 96}, {3, 3,3, 9216}, 96, 1, 1, 1, 0, 0, 0, 1, 1, 1, false));
|
||||
#endif
|
||||
|
||||
// im2col 1D
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
||||
@@ -6301,6 +6308,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1, 1}, {1, 1}, {0, 1, 2, 3}, true, 3));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
|
||||
|
||||
#if 0
|
||||
// test the mat-mat path for Metal
|
||||
for (int k = 1; k < 512; ++k) {
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 127, k, {12,1}, {1,1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 127, k, {12,1}, {1,1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 77, k, {12,1}, {1,1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, k, {12,1}, {1,1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 128, k, {12,1}, {1,1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 128, k, {12,1}, {1,1}));
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 50, 200, k));
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, true, 50, 200, k));
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F32, GGML_TYPE_F32, 16, 16, false, 50, 200, k));
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F32, GGML_TYPE_F32, 16, 16, true, 50, 200, k));
|
||||
}
|
||||
#endif
|
||||
|
||||
for (auto bs2 : {1,3}) {
|
||||
for (auto bs : {1,2,4,8}) {
|
||||
for (auto nr : {1,4}) {
|
||||
|
||||
Reference in New Issue
Block a user