diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c6cfb0bbb..94a108dfa 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -621,10 +621,11 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, uint32_t value, size_t offset, size_t size) { - std::vector params = { (uint32_t) offset, (uint32_t) size, value }; - std::vector entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; - size_t bytes_per_wg = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread; - uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + std::vector params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; + size_t bytes_per_wg = + ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); @@ -1362,7 +1363,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2169,8 +2170,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, @@ -2244,8 +2247,10 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, } } - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx, @@ -2673,8 +2678,10 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, @@ -3751,7 +3758,8 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * + ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled ctx->capabilities.memset_bytes_per_thread = CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl index 605de7aa7..f262c4a8f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -130,10 +130,13 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) { } @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - let src0_i = params.offset_src0 + src0_index(gid.x); - let src1_i = params.offset_src1 + src1_index(gid.x); - update(params.offset_dst + gid.x, src0_i, src1_i); +fn main(@builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let threads_per_group = u32(WG_SIZE); + let i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (i < params.ne) { + let src0_i = params.offset_src0 + src0_index(i); + let src1_i = params.offset_src1 + src1_index(i); + update(params.offset_dst + i, src0_i, src1_i); } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl index 3b70a876d..6c76ed69e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl @@ -43,12 +43,14 @@ struct Params { var src: array; @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let threads_per_group = u32(WG_SIZE); + var i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (i >= params.ne) { return; } - - var i = gid.x; let i3 = i / (params.ne2 * params.ne1 * params.ne0); i = i % (params.ne2 * params.ne1 * params.ne0); let i2 = i / (params.ne1 * params.ne0); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 8e34e1c9c..cb342c472 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -66,11 +66,14 @@ fn erf_approx(x: TYPE) -> TYPE { } @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { +fn main(@builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let threads_per_group = u32(WG_SIZE); + let flat_i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (flat_i >= params.ne) { return; } - var i = gid.x; + var i = flat_i; let ne2 = params.ne2; #ifdef DIAG let ne1 = params.ne0; @@ -205,6 +208,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { #ifdef INPLACE src[params.offset_src + src_idx] = res; #else - dst[params.offset_dst + gid.x] = res; + dst[params.offset_dst + flat_i] = res; #endif }