mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 20:44:09 +00:00
ggml-webgpu: Enables running gpt-oss-20b (#22906)
* Enable to run gpt-oss-20b and refactor mulmat-q * disable test-backend-ops in ubuntu-24-webgpu
This commit is contained in:
committed by
GitHub
parent
239a497e5f
commit
927dada6c9
3
.github/workflows/build.yml
vendored
3
.github/workflows/build.yml
vendored
@@ -456,7 +456,8 @@ jobs:
|
||||
run: |
|
||||
cd build
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 900
|
||||
# test-backend-ops is too slow on llvmpipe, skip it
|
||||
ctest -L main -E test-backend-ops --verbose --timeout 900
|
||||
|
||||
ubuntu-24-webgpu-wasm:
|
||||
runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
|
||||
@@ -18,7 +18,7 @@ Legend:
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
@@ -71,7 +71,7 @@ Legend:
|
||||
| MUL_MAT_HADAMARD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
|
||||
@@ -118,5 +118,5 @@ Legend:
|
||||
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
||||
11465
docs/ops/WebGPU.csv
11465
docs/ops/WebGPU.csv
File diff suppressed because it is too large
Load Diff
@@ -495,6 +495,22 @@ struct ggml_webgpu_binary_pipeline_key_hash {
|
||||
}
|
||||
};
|
||||
|
||||
/* Add_Id */
|
||||
|
||||
struct ggml_webgpu_add_id_pipeline_key {
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_add_id_pipeline_key & other) const { return inplace == other.inplace; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_add_id_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_add_id_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Unary **/
|
||||
|
||||
struct ggml_webgpu_unary_pipeline_key {
|
||||
@@ -1058,7 +1074,9 @@ class ggml_webgpu_shader_lib {
|
||||
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
|
||||
pad_pipelines; // circular/non-circular
|
||||
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
||||
binary_pipelines; // type/op/inplace/overlap
|
||||
binary_pipelines; // type/op/inplace/overlap/src_overlap
|
||||
std::unordered_map<ggml_webgpu_add_id_pipeline_key, webgpu_pipeline, ggml_webgpu_add_id_pipeline_key_hash>
|
||||
add_id_pipelines; // inplace
|
||||
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
|
||||
concat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
|
||||
@@ -1433,6 +1451,7 @@ class ggml_webgpu_shader_lib {
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
{
|
||||
// Quantized types using u32 buffers for portability.
|
||||
defines.push_back("SRC_TYPE=u32");
|
||||
@@ -1451,6 +1470,7 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(type_upper + "_SCALE_MIN");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
defines.push_back(type_upper + "_LUT");
|
||||
|
||||
variant += "_";
|
||||
variant += type_str;
|
||||
@@ -1460,7 +1480,7 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.src_type == GGML_TYPE_Q1_0) {
|
||||
defines.push_back("BLOCK_SIZE=128u");
|
||||
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
key.src_type == GGML_TYPE_IQ4_NL) {
|
||||
key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) {
|
||||
defines.push_back("BLOCK_SIZE=32u");
|
||||
} else if (key.src_type >= GGML_TYPE_Q2_K) {
|
||||
defines.push_back("BLOCK_SIZE=256u");
|
||||
@@ -1774,6 +1794,9 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
defines.push_back(type_upper + "_LUT");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -1908,6 +1931,9 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
defines.push_back(type_upper + "_LUT");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -2042,6 +2068,7 @@ class ggml_webgpu_shader_lib {
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
{
|
||||
// Quantized types using u32 buffers for portability.
|
||||
defines.push_back("SRC0_TYPE=u32");
|
||||
@@ -2169,6 +2196,9 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
defines.push_back(type_upper + "_LUT");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -2286,6 +2316,9 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
defines.push_back(type_upper + "_LUT");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -2503,6 +2536,37 @@ class ggml_webgpu_shader_lib {
|
||||
return binary_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_add_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_add_id_pipeline_key key = {};
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
|
||||
auto it = add_id_pipelines.find(key);
|
||||
if (it != add_id_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "add_id";
|
||||
const char * shader_src = wgsl_add_id;
|
||||
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(shader_src, defines);
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
pipeline_decisions->wg_size = context.max_wg_size;
|
||||
pipeline_decisions->inplace = key.inplace;
|
||||
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = pipeline_decisions;
|
||||
add_id_pipelines[key] = pipeline;
|
||||
return pipeline;
|
||||
}
|
||||
|
||||
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_concat_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
|
||||
@@ -1411,8 +1411,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q1_0:
|
||||
use_fast = true;
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
@@ -1422,6 +1420,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_MXFP4:
|
||||
use_fast = true;
|
||||
break;
|
||||
default:
|
||||
@@ -2145,6 +2144,56 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * src2,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.src2 = src2;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_add_id_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src2->nb[0] / ggml_type_size(src2->type)),
|
||||
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
|
||||
(uint32_t) dst->ne[0],
|
||||
(uint32_t) dst->ne[1],
|
||||
(uint32_t) dst->ne[2],
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2));
|
||||
|
||||
if (!decisions->inplace) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, dst));
|
||||
}
|
||||
|
||||
uint32_t wg_x = 1;
|
||||
uint32_t wg_y = 1;
|
||||
uint32_t total_wg = ggml_nrows(dst);
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
@@ -2918,6 +2967,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
return ggml_webgpu_binary_op(ctx, src0, src1, node);
|
||||
case GGML_OP_ADD_ID:
|
||||
return ggml_webgpu_add_id(ctx, src0, src1, src2, node);
|
||||
case GGML_OP_CONCAT:
|
||||
return ggml_webgpu_concat(ctx, src0, src1, node);
|
||||
case GGML_OP_REPEAT:
|
||||
@@ -3867,6 +3918,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_MXFP4:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -3905,6 +3957,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
|
||||
(src1->type == op->type);
|
||||
break;
|
||||
case GGML_OP_ADD_ID:
|
||||
supports_op = src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
|
||||
break;
|
||||
@@ -3962,6 +4017,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_MXFP4:
|
||||
supports_op = true;
|
||||
break;
|
||||
default:
|
||||
@@ -4001,6 +4057,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_MXFP4:
|
||||
supports_op = true;
|
||||
break;
|
||||
default:
|
||||
|
||||
64
ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl
Normal file
64
ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl
Normal file
@@ -0,0 +1,64 @@
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_ids: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
nb01: u32,
|
||||
nb02: u32,
|
||||
nb11: u32,
|
||||
nb20: u32,
|
||||
nb21: u32,
|
||||
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // [n_embd, n_experts_used, n_token]
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // [n_embd, n_experts]
|
||||
@group(0) @binding(2) var<storage, read_write> ids: array<i32>; // [n_experts_used, n_token]
|
||||
|
||||
#ifdef INPLACE
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#else
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#endif
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>) {
|
||||
|
||||
let wg_linear = wg_id.x + wg_id.y * num_wg.x;
|
||||
|
||||
if (wg_linear < params.ne1 * params.ne2) {
|
||||
let thread_id = local_id.x;
|
||||
let i2 = wg_linear / params.ne1;
|
||||
let i1 = wg_linear % params.ne1;
|
||||
|
||||
let i11 = u32(ids[params.offset_ids + i1 * params.nb20 + i2 * params.nb21]);
|
||||
|
||||
let src0_row = params.offset_src0 + i1 * params.nb01 + i2 * params.nb02;
|
||||
let src1_row = params.offset_src1 + i11 * params.nb11;
|
||||
let dst_row = params.offset_dst + i1 * params.ne0 + i2 * (params.ne0 * params.ne1);
|
||||
|
||||
for (var i = thread_id;i < params.ne0; i += WG_SIZE) {
|
||||
#ifdef INPLACE
|
||||
src0[src0_row + i] = src0[src0_row + i] + src1[src1_row + i];
|
||||
#else
|
||||
dst[dst_row + i] = src0[src0_row + i] + src1[src1_row + i];
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -896,3 +896,10 @@ const kvalues_iq4nl = array<i32, 16>(
|
||||
);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef MXFP4_LUT
|
||||
const kvalues_mxfp4 = array<i32, 16>(
|
||||
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12
|
||||
);
|
||||
#endif
|
||||
|
||||
|
||||
@@ -652,6 +652,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MXFP4
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_byte_base = (src_base + offset) * 17;
|
||||
let eu8 = get_byte(load_u32_at_src(block_byte_base), 0);
|
||||
let d = ldexp(1.0, i32(eu8) - 128);
|
||||
for (var j: u32 = 0u; j < 4; j++) {
|
||||
let q_byte_offset = block_byte_base + 1 + j * 4;
|
||||
let q_packed = load_u32_at_src(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * d;
|
||||
let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * d;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<SRC_TYPE>;
|
||||
|
||||
|
||||
@@ -100,34 +100,37 @@ const BLOCK_SIZE_BYTES = 18u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
// store NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
|
||||
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
shmem[shmem_idx + j * 2 + k] = q_lo;
|
||||
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -141,35 +144,38 @@ const BLOCK_SIZE_BYTES = 20u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let m = load_f16_at_src0(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
// store NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
|
||||
let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_lo = f16(q_byte & 0xF) * d + m;
|
||||
let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
|
||||
shmem[shmem_idx + j * 2 + k] = q_lo;
|
||||
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -178,52 +184,49 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
#endif // INIT_SRC0_SHMEM_Q4_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_0
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 22u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let qh_packed = load_u32_at_src0(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
|
||||
// store NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let byte_idx = block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP + k;
|
||||
let qh_hi = (qh_packed >> (byte_idx + 12u)) & 0x10;
|
||||
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let qh_lo = ((qh_packed >> byte_idx) << 4) & 0x10;
|
||||
let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
|
||||
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
|
||||
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -232,54 +235,49 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
#endif // INIT_SRC0_SHMEM_Q5_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_1
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 24u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let m = load_f16_at_src0(block_byte_base + 2u);
|
||||
let qh_packed = load_u32_at_src0(block_byte_base + 4u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
|
||||
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
|
||||
// store NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
|
||||
|
||||
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
|
||||
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
|
||||
let byte_idx = block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP + k;
|
||||
let qh_hi = (qh_packed >> (byte_idx + 12u)) & 0x10;
|
||||
let q_hi = f16(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
|
||||
let qh_lo = ((qh_packed >> byte_idx) << 4) & 0x10;
|
||||
let q_lo = f16((q_byte & 0xF) | qh_lo) * d + m;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -293,33 +291,34 @@ const BLOCK_SIZE_BYTES = 34u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
|
||||
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
// store NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
let q_val = f16(q_byte) * d;
|
||||
shmem[shmem_idx + j * 2 + k] = q_val;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -333,34 +332,35 @@ const BLOCK_SIZE_BYTES = 36u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
|
||||
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let m = load_f16_at_src0(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
// store NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
let q_val = f16(q_byte) * d + m;
|
||||
shmem[shmem_idx + j * 2 + k] = q_val;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1163,3 +1163,48 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ3_S
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_MXFP4
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 17u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const BYTES_PER_THREAD = 8u; // NQ(16) weights uses 8 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0);
|
||||
let e = ldexp(1.0, i32(eu8) - 128);
|
||||
|
||||
// store NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
|
||||
let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e;
|
||||
let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo);
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_MXFP4
|
||||
|
||||
@@ -1389,3 +1389,45 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_MXFP4
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCK_SIZE_BYTES 17
|
||||
#define THREADS_PER_BLOCK 4
|
||||
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
|
||||
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
let thread_within_block = thread_id % 4;
|
||||
for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
|
||||
var x_block: array<f32, ELEMS_PER_THREAD>;
|
||||
for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
x_block[i + 4] = f32(src1[x_base + i + 16]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0);
|
||||
let e = ldexp(1.0, i32(eu8) - 128);
|
||||
var row_sum = 0.0;
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 1u + 4u * thread_within_block);
|
||||
for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
|
||||
let q_byte = get_byte(q_packed, byte_idx);
|
||||
let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e;
|
||||
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e;
|
||||
row_sum += q_lo * x_block[byte_idx];
|
||||
row_sum += q_hi * x_block[byte_idx + 4u];
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user