From 13133ab299e94a413fed015841a424adec149b1c Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Thu, 14 May 2026 09:31:36 -0700 Subject: [PATCH] ggml-webgpu: makes the flash attn vec path subgroup-aware (llama/23040) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-webgpu: makes the flash attn vec path compile and size its split/reduce work from the device’s reported subgroup range instead of assuming 32 subgroup size. * ggml-webgpu: remove the extra max_wg_size >= max_subgroup_size guard. Remove hardcoded 32 when determine the value of reduce_wg_size and vec_nwg_cap --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 13 +++++++++---- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 12 +++++++----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 62a523365..4c4eda1cb 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -770,9 +770,14 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(K->type); + const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && + context.src2->ne[0] % kv_vec_head_align == 0; + // Compile with enough invocations to cover the largest reported subgroup. + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && + kv_vec_head_dims_aligned && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (context.src2->type == K->type); const bool tile_can_dispatch_all_q_rows = context.max_subgroup_size > 0 && @@ -808,7 +813,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.q_tile = 1u; decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + decisions.wg_size = context.max_subgroup_size; if (decisions.kv_direct) { decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 401c75c12..78cb02be0 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1832,7 +1832,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size; uint32_t nwg = 1u; const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { @@ -1953,8 +1953,11 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector reduce_params; std::vector reduce_entries; if (use_vec_reduce) { - const uint32_t reduce_wg_size = std::max( - 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size; + const uint32_t reduce_wg_size = + std::max(reduce_sg_size, (uint32_t) std::min( + (uint64_t) nwg * reduce_sg_size, + ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; reduce_shader_ctx.max_wg_size = reduce_wg_size; reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); @@ -3542,8 +3545,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { const uint32_t kv_tile = decisions.kv_tile; - const uint32_t vec_nwg_cap = std::max( - 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); + const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; uint32_t nwg = 1u; const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {