Formatting

This commit is contained in:
Reese Levine
2026-04-16 08:39:10 -07:00
parent fe744e031d
commit ec783c1513
2 changed files with 27 additions and 34 deletions

View File

@@ -1083,8 +1083,7 @@ class ggml_webgpu_shader_lib {
std::string type_upper = type_str;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
switch (key.src_type)
{
switch (key.src_type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
@@ -1104,9 +1103,9 @@ class ggml_webgpu_shader_lib {
break;
}
default:
{
defines.push_back(std::string("SRC_TYPE=") + type_str);
}
{
defines.push_back(std::string("SRC_TYPE=") + type_str);
}
}
defines.push_back("BYTE_HELPERS");
@@ -1586,8 +1585,7 @@ class ggml_webgpu_shader_lib {
std::string type_upper = src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
switch (context.src0->type)
{
switch (context.src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
@@ -1607,9 +1605,9 @@ class ggml_webgpu_shader_lib {
break;
}
default:
{
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
}
{
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
}
}
defines.push_back("BYTE_HELPERS");

View File

@@ -596,19 +596,17 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context &
#ifdef GGML_WEBGPU_GPU_PROFILE
for (size_t i = 0; i < dispatches.size(); i++) {
GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT);
const uint32_t query_begin = ctx->profile_timestamp_query_count++;
const uint32_t query_end = ctx->profile_timestamp_query_count++;
const uint32_t query_begin = ctx->profile_timestamp_query_count++;
const uint32_t query_end = ctx->profile_timestamp_query_count++;
wgpu::PassTimestampWrites ts_writes = {};
ts_writes.querySet = ctx->profile_timestamp_query_set;
ts_writes.beginningOfPassWriteIndex = query_begin;
ts_writes.endOfPassWriteIndex = query_end;
wgpu::ComputePassDescriptor pass_desc = {};
pass_desc.timestampWrites = &ts_writes;
wgpu::PassTimestampWrites ts_writes = {};
ts_writes.querySet = ctx->profile_timestamp_query_set;
ts_writes.beginningOfPassWriteIndex = query_begin;
ts_writes.endOfPassWriteIndex = query_end;
wgpu::ComputePassDescriptor pass_desc = {};
pass_desc.timestampWrites = &ts_writes;
wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc);
wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc);
pass.SetPipeline(dispatches[i].pipeline.pipeline);
pass.SetBindGroup(0, bind_groups[i]);
@@ -1644,20 +1642,17 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
if (has_mask) {
dispatches.push_back({
blk_pipeline,
std::move(blk_params),
std::move(blk_entries),
{ blk_nblk0, blk_nblk1 * blk_batch_count }
});
}
dispatches.push_back({
pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u }
blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count }
});
if (use_vec_reduce) {
dispatches.push_back({
reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u }
});
}
}
dispatches.push_back({
pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u }
});
if (use_vec_reduce) {
dispatches.push_back({
reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u }
});
}
return ggml_backend_webgpu_build_multi(ctx, dispatches);
}