ggml-webgpu: makes the flash attn vec path subgroup-aware (llama/23040)

* ggml-webgpu: makes the flash attn vec path compile and size its split/reduce work from the device’s reported subgroup range instead of assuming 32 subgroup size.

* ggml-webgpu: remove the extra max_wg_size >= max_subgroup_size guard. Remove hardcoded 32 when determine the value of reduce_wg_size and vec_nwg_cap
This commit is contained in:
Zheyuan Chen 2026-05-14 09:31:36 -07:00 committed by Georgi Gerganov
parent 592a8cd15d
commit 13133ab299
2 changed files with 16 additions and 9 deletions

View File

@ -770,9 +770,14 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
(uint32_t) ggml_blck_size(K->type);
const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 &&
context.src2->ne[0] % kv_vec_head_align == 0;
// Compile with enough invocations to cover the largest reported subgroup.
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) &&
kv_vec_head_dims_aligned && kv_vec_type_supported &&
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
(context.src2->type == K->type);
const bool tile_can_dispatch_all_q_rows =
context.max_subgroup_size > 0 &&
@ -808,7 +813,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
decisions.q_tile = 1u;
decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile));
decisions.kv_tile = (decisions.kv_tile / 8u) * 8u;
decisions.wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
decisions.wg_size = context.max_subgroup_size;
if (decisions.kv_direct) {
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {

View File

@ -1832,7 +1832,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
uint32_t blk_nblk1 = 0;
uint32_t blk_batch_count = 0;
const uint32_t vec_nwg_cap = std::max(1u, std::min<uint32_t>(32u, ctx->global_ctx->capabilities.max_subgroup_size));
const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size;
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile);
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
@ -1953,8 +1953,11 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
std::vector<uint32_t> reduce_params;
std::vector<wgpu::BindGroupEntry> reduce_entries;
if (use_vec_reduce) {
const uint32_t reduce_wg_size = std::max(
32u, std::min<uint32_t>(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size;
const uint32_t reduce_wg_size =
std::max(reduce_sg_size, (uint32_t) std::min<uint64_t>(
(uint64_t) nwg * reduce_sg_size,
ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx;
reduce_shader_ctx.max_wg_size = reduce_wg_size;
reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
@ -3542,8 +3545,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
const uint32_t kv_tile = decisions.kv_tile;
const uint32_t vec_nwg_cap = std::max(
1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {