Check batch_compute_passes before sending passes when not doing GPU profiling (llama/23457)

* Only run webgpu CI on my fork

* Add webgpu only workflow

* refactor batch_compute_passes to a per-thread variable, and submit individual passes when it is set to false and no GPU profiling is enabled

* restore build.yml
This commit is contained in:
Nikhil Jain 2026-05-25 20:32:49 -07:00 committed by Georgi Gerganov
parent 2307712d32
commit bc77933c2d
1 changed files with 22 additions and 13 deletions

View File

@ -259,6 +259,7 @@ struct webgpu_context_struct {
wgpu::Buffer set_rows_host_error_buf;
wgpu::CommandEncoder active_command_encoder;
wgpu::ComputePassEncoder active_compute_pass;
bool batch_compute_passes = true;
size_t memset_bytes_per_thread;
@ -590,9 +591,18 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context &
}
#else
for (size_t i = 0; i < dispatches.size(); i++) {
ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline);
ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]);
ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1);
if (ctx->batch_compute_passes) {
ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline);
ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]);
ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second,
1);
} else {
wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass();
pass.SetPipeline(dispatches[i].pipeline.pipeline);
pass.SetBindGroup(0, bind_groups[i]);
pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1);
pass.End();
}
}
#endif
@ -1956,10 +1966,10 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
std::vector<wgpu::BindGroupEntry> reduce_entries;
if (use_vec_reduce) {
const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size;
const uint32_t reduce_wg_size =
std::max(reduce_sg_size, (uint32_t) std::min<uint64_t>(
(uint64_t) nwg * reduce_sg_size,
ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
const uint32_t reduce_wg_size = std::max(
reduce_sg_size,
(uint32_t) std::min<uint64_t>((uint64_t) nwg * reduce_sg_size,
ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx;
reduce_shader_ctx.max_wg_size = reduce_wg_size;
reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
@ -3110,18 +3120,16 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
uint32_t num_batched_kernels = 0;
uint32_t num_inflight_batches = 0;
bool contains_set_rows = false;
bool batch_compute_passes = true;
int num_encoded_ops = 1;
int node_idx = 0;
#ifdef GGML_WEBGPU_GPU_PROFILE
ctx->profile_timestamp_query_count = 0;
batch_compute_passes = false;
std::vector<std::string> profile_pipeline_names;
#endif
ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder();
if (batch_compute_passes) {
if (ctx->batch_compute_passes) {
ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass();
}
@ -3148,7 +3156,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
// reset state for next batch
ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder();
if (batch_compute_passes) {
if (ctx->batch_compute_passes) {
ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass();
}
ctx->param_arena.reset();
@ -3548,8 +3556,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
const uint32_t kv_tile = decisions.kv_tile;
const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
nwg <<= 1;
}
@ -3839,6 +3847,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
#ifdef GGML_WEBGPU_GPU_PROFILE
webgpu_ctx->batch_compute_passes = false;
ggml_webgpu_create_buffer(
webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_dev_buf, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, "profile_timestamp_dev_buf");