diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c823742eb2..7d9a4403fa 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1083,8 +1083,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = type_str; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (key.src_type) - { + switch (key.src_type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1104,9 +1103,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC_TYPE=") + type_str); - } + { + defines.push_back(std::string("SRC_TYPE=") + type_str); + } } defines.push_back("BYTE_HELPERS"); @@ -1586,8 +1585,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (context.src0->type) - { + switch (context.src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1607,9 +1605,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC0_TYPE=") + src0_name); - } + { + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + } } defines.push_back("BYTE_HELPERS"); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index d631846edc..e7bda817a2 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -596,19 +596,17 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & #ifdef GGML_WEBGPU_GPU_PROFILE for (size_t i = 0; i < dispatches.size(); i++) { GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT); - const uint32_t query_begin = ctx->profile_timestamp_query_count++; - const uint32_t query_end = ctx->profile_timestamp_query_count++; + const uint32_t query_begin = ctx->profile_timestamp_query_count++; + const uint32_t query_end = ctx->profile_timestamp_query_count++; - wgpu::PassTimestampWrites ts_writes = {}; - ts_writes.querySet = ctx->profile_timestamp_query_set; - ts_writes.beginningOfPassWriteIndex = query_begin; - ts_writes.endOfPassWriteIndex = query_end; - wgpu::ComputePassDescriptor pass_desc = {}; - pass_desc.timestampWrites = &ts_writes; + wgpu::PassTimestampWrites ts_writes = {}; + ts_writes.querySet = ctx->profile_timestamp_query_set; + ts_writes.beginningOfPassWriteIndex = query_begin; + ts_writes.endOfPassWriteIndex = query_end; + wgpu::ComputePassDescriptor pass_desc = {}; + pass_desc.timestampWrites = &ts_writes; - - - wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); pass.SetPipeline(dispatches[i].pipeline.pipeline); pass.SetBindGroup(0, bind_groups[i]); @@ -1644,20 +1642,17 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, if (has_mask) { dispatches.push_back({ - blk_pipeline, - std::move(blk_params), - std::move(blk_entries), - { blk_nblk0, blk_nblk1 * blk_batch_count } - }); - } - dispatches.push_back({ - pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } }); - if (use_vec_reduce) { - dispatches.push_back({ - reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } - }); - } + } + dispatches.push_back({ + pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + }); + if (use_vec_reduce) { + dispatches.push_back({ + reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } + }); + } return ggml_backend_webgpu_build_multi(ctx, dispatches); }