mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-03-17 16:44:07 +00:00
use scalar sums
This commit is contained in:
@@ -31,6 +31,7 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
|
||||
#endif
|
||||
layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
layout (binding = 2) writeonly buffer D4 {D_TYPE_VEC4 data_dv4[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
@@ -94,7 +95,7 @@ shared float16_t buf_b_d[BN];
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
shared ivec4 coopmat_stage[TM * TN * NUM_WARPS / 4];
|
||||
|
||||
#include "mul_mm_id_funcs.glsl"
|
||||
#include "mul_mmq_cm1_funcs.glsl"
|
||||
@@ -204,17 +205,17 @@ void main() {
|
||||
|
||||
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
||||
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> int_result;
|
||||
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result[cms_per_row * cms_per_col];
|
||||
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> scales_a;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> scales_b[cms_per_col];
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> scales[cms_per_row * cms_per_col];
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||
const uint accs_per_thread = (WM * WN) / WARP / 4;
|
||||
ACC_TYPE_VEC4 sums[accs_per_thread];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
[[unroll]] for (uint i = 0; i < accs_per_thread; i++) {
|
||||
sums[i] = ACC_TYPE_VEC4(0.0f);
|
||||
}
|
||||
|
||||
const uint chunks_per_thread_per_tile = (TM * TN) / (WARP * 4);
|
||||
|
||||
for (uint block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
|
||||
const uint buf_ib = loadc_a + l;
|
||||
@@ -242,16 +243,8 @@ void main() {
|
||||
pos_a_ib += 1;
|
||||
pos_b_ib += 1;
|
||||
|
||||
// Precompute scales
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatLoad(scales_b[cm_col], buf_b_d, warp_c*WN + cm_col*TN, 0, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
coopMatLoad(scales_a, buf_a_d, warp_r*WM + cm_row*TM, 0, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
scales[cm_col * cms_per_row + cm_row] = coopMatMulAdd(scales_a, scales_b[cm_col], coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0));
|
||||
}
|
||||
[[unroll]] for (uint idx = 0; idx < cms_per_row * cms_per_col; idx++) {
|
||||
cm_result[idx] = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
|
||||
}
|
||||
|
||||
// Calculate quants
|
||||
@@ -262,8 +255,37 @@ void main() {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatLoad(cache_b, buf_b_qs, (warp_c * WN + cm_col * TN) * shmem_stride + i / 4, shmem_stride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
int_result = coopMatMulAdd(cache_a, cache_b, coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0));
|
||||
sums[cm_col * cms_per_row + cm_row] += scales[cm_col * cms_per_row + cm_row] * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(int_result);
|
||||
cm_result[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, cm_result[cm_col * cms_per_row + cm_row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store to shmem
|
||||
const uint subgroup_vec_stride = (TM * TN) / 4;
|
||||
const uint subgroup_offset = warp_i * subgroup_vec_stride;
|
||||
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const uint tile_idx = cm_col * cms_per_row + cm_row;
|
||||
coopMatStore(cm_result[tile_idx], coopmat_stage, subgroup_offset, TM / 4, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||
|
||||
// Each thread grabs chunks and applies the scales
|
||||
[[unroll]] for (uint chunk = 0; chunk < chunks_per_thread_per_tile; chunk++) {
|
||||
const uint local_chunk = chunk * WARP + tiw;
|
||||
const uint col_local = local_chunk / (TM / 4);
|
||||
const uint row_group = local_chunk % (TM / 4);
|
||||
const uint row0_local = row_group * 4;
|
||||
const ivec4 qs = coopmat_stage[subgroup_offset + col_local * (TM / 4) + row_group];
|
||||
|
||||
const uint a_row0 = warp_r * WM + cm_row * TM + row0_local;
|
||||
const uint b_col = warp_c * WN + cm_col * TN + col_local;
|
||||
|
||||
const ACC_TYPE_VEC4 da = ACC_TYPE_VEC4(buf_a_d[a_row0], buf_a_d[a_row0+1], buf_a_d[a_row0+2], buf_a_d[a_row0+3]);
|
||||
const ACC_TYPE db = ACC_TYPE(buf_b_d[b_col]);
|
||||
|
||||
sums[tile_idx * chunks_per_thread_per_tile + chunk] += ACC_TYPE_VEC4(qs) * da * db;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -274,50 +296,92 @@ void main() {
|
||||
const uint dr = ir * BM + warp_r * WM;
|
||||
const uint dc = ic * BN + warp_c * WN;
|
||||
|
||||
const bool is_aligned = p.stride_d % 4 == 0;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
const uint tile_idx = cm_col * cms_per_row + cm_row;
|
||||
[[unroll]] for (uint chunk = 0; chunk < chunks_per_thread_per_tile; chunk++) {
|
||||
const uint local_chunk = chunk * WARP + tiw;
|
||||
const uint col_local = local_chunk / (TM / 4);
|
||||
const uint row_group = local_chunk % (TM / 4);
|
||||
const uint row0_local = row_group * 4;
|
||||
|
||||
const uint row_i = dc + cm_col * TN + col_local;
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const uint row0_g = dr + cm_row * TM + row0_local;
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
const uint store_offset = row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + row0_g;
|
||||
const uint acc_idx = tile_idx * chunks_per_thread_per_tile + chunk;
|
||||
|
||||
if (dr + cm_row * TM + store_r < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
if (row0_g + 3 < p.M && is_aligned && (store_offset % 4) == 0) {
|
||||
data_dv4[store_offset / 4] = D_TYPE_VEC4(sums[acc_idx]);
|
||||
} else if (row0_g + 3 < p.M) {
|
||||
const ACC_TYPE_VEC4 vals = sums[acc_idx];
|
||||
data_d[store_offset ] = D_TYPE(vals.x);
|
||||
data_d[store_offset + 1] = D_TYPE(vals.y);
|
||||
data_d[store_offset + 2] = D_TYPE(vals.z);
|
||||
data_d[store_offset + 3] = D_TYPE(vals.w);
|
||||
} else if (row0_g + 2 < p.M) {
|
||||
const ACC_TYPE_VEC4 vals = sums[acc_idx];
|
||||
data_d[store_offset ] = D_TYPE(vals.x);
|
||||
data_d[store_offset + 1] = D_TYPE(vals.y);
|
||||
data_d[store_offset + 2] = D_TYPE(vals.z);
|
||||
} else if (row0_g + 1 < p.M) {
|
||||
const ACC_TYPE_VEC4 vals = sums[acc_idx];
|
||||
data_d[store_offset ] = D_TYPE(vals.x);
|
||||
data_d[store_offset + 1] = D_TYPE(vals.y);
|
||||
} else if (row0_g < p.M) {
|
||||
const ACC_TYPE_VEC4 vals = sums[acc_idx];
|
||||
data_d[store_offset] = D_TYPE(vals.x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
|
||||
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
||||
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
||||
const uint tile_idx = cm_col * cms_per_row + cm_row;
|
||||
[[unroll]] for (uint chunk = 0; chunk < chunks_per_thread_per_tile; chunk++) {
|
||||
const uint local_chunk = chunk * WARP + tiw;
|
||||
const uint col_local = local_chunk / (TM / 4);
|
||||
const uint row_group = local_chunk % (TM / 4);
|
||||
const uint row0_local = row_group * 4;
|
||||
|
||||
if (is_aligned && is_in_bounds) {
|
||||
// Full coopMat is within bounds and stride_d is aligned with 16B
|
||||
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
||||
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
} else if (is_in_bounds) {
|
||||
// Full coopMat is within bounds, but stride_d is not aligned
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
const uint col_g = dc + cm_col * TN + col_local;
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
||||
// Partial coopMat is within bounds
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
if (col_g >= p.N) break;
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
const uint row0_g = dr + cm_row * TM + row0_local;
|
||||
|
||||
const uint store_offset = offsets + col_g * p.stride_d + row0_g;
|
||||
const uint acc_idx = tile_idx * chunks_per_thread_per_tile + chunk;
|
||||
|
||||
if (row0_g + 3 < p.M && is_aligned && (store_offset % 4) == 0) {
|
||||
data_dv4[store_offset / 4] = D_TYPE_VEC4(sums[acc_idx]);
|
||||
} else if (row0_g + 3 < p.M) {
|
||||
const ACC_TYPE_VEC4 vals = sums[acc_idx];
|
||||
data_d[store_offset ] = D_TYPE(vals.x);
|
||||
data_d[store_offset + 1] = D_TYPE(vals.y);
|
||||
data_d[store_offset + 2] = D_TYPE(vals.z);
|
||||
data_d[store_offset + 3] = D_TYPE(vals.w);
|
||||
} else if (row0_g + 2 < p.M) {
|
||||
const ACC_TYPE_VEC4 vals = sums[acc_idx];
|
||||
data_d[store_offset ] = D_TYPE(vals.x);
|
||||
data_d[store_offset + 1] = D_TYPE(vals.y);
|
||||
data_d[store_offset + 2] = D_TYPE(vals.z);
|
||||
} else if (row0_g + 1 < p.M) {
|
||||
const ACC_TYPE_VEC4 vals = sums[acc_idx];
|
||||
data_d[store_offset ] = D_TYPE(vals.x);
|
||||
data_d[store_offset + 1] = D_TYPE(vals.y);
|
||||
} else if (row0_g < p.M) {
|
||||
const ACC_TYPE_VEC4 vals = sums[acc_idx];
|
||||
data_d[store_offset] = D_TYPE(vals.x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
|
||||
if (iqs == 0) {
|
||||
// Divide by TK for matmul scale application
|
||||
buf_b_d[buf_ib] = data_b[ib_outer].ds[ib_inner].x / float16_t(TK);
|
||||
buf_b_d[buf_ib] = data_b[ib_outer].ds[ib_inner].x;
|
||||
}
|
||||
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
|
||||
@@ -447,6 +447,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
|
||||
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
|
||||
base_dict["ACC_TYPE_VEC4"] = f16acc ? "f16vec4" : "vec4";
|
||||
if (f16acc) {
|
||||
base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
}
|
||||
@@ -593,7 +594,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
#endif
|
||||
|
||||
if (coopmat && tname == "q4_0") {
|
||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq_cm1.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq_cm1.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"}, {"D_TYPE_VEC4", "vec4"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user