ggml-webgpu: makes the flash attn vec path subgroup-aware (llama/23040)
* ggml-webgpu: makes the flash attn vec path compile and size its split/reduce work from the device’s reported subgroup range instead of assuming 32 subgroup size. * ggml-webgpu: remove the extra max_wg_size >= max_subgroup_size guard. Remove hardcoded 32 when determine the value of reduce_wg_size and vec_nwg_cap
This commit is contained in:
parent
592a8cd15d
commit
13133ab299
|
|
@ -770,9 +770,14 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
|
||||
const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
|
||||
(uint32_t) ggml_blck_size(K->type);
|
||||
const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 &&
|
||||
context.src2->ne[0] % kv_vec_head_align == 0;
|
||||
// Compile with enough invocations to cover the largest reported subgroup.
|
||||
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) &&
|
||||
kv_vec_head_dims_aligned && kv_vec_type_supported &&
|
||||
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
|
||||
(context.src2->type == K->type);
|
||||
const bool tile_can_dispatch_all_q_rows =
|
||||
context.max_subgroup_size > 0 &&
|
||||
|
|
@ -808,7 +813,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||
decisions.q_tile = 1u;
|
||||
decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile));
|
||||
decisions.kv_tile = (decisions.kv_tile / 8u) * 8u;
|
||||
decisions.wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
||||
decisions.wg_size = context.max_subgroup_size;
|
||||
if (decisions.kv_direct) {
|
||||
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||
|
|
|
|||
|
|
@ -1832,7 +1832,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||
uint32_t blk_nblk1 = 0;
|
||||
uint32_t blk_batch_count = 0;
|
||||
|
||||
const uint32_t vec_nwg_cap = std::max(1u, std::min<uint32_t>(32u, ctx->global_ctx->capabilities.max_subgroup_size));
|
||||
const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size;
|
||||
uint32_t nwg = 1u;
|
||||
const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile);
|
||||
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
|
||||
|
|
@ -1953,8 +1953,11 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||
std::vector<uint32_t> reduce_params;
|
||||
std::vector<wgpu::BindGroupEntry> reduce_entries;
|
||||
if (use_vec_reduce) {
|
||||
const uint32_t reduce_wg_size = std::max(
|
||||
32u, std::min<uint32_t>(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
|
||||
const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
const uint32_t reduce_wg_size =
|
||||
std::max(reduce_sg_size, (uint32_t) std::min<uint64_t>(
|
||||
(uint64_t) nwg * reduce_sg_size,
|
||||
ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
|
||||
ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx;
|
||||
reduce_shader_ctx.max_wg_size = reduce_wg_size;
|
||||
reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
|
||||
|
|
@ -3542,8 +3545,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
|||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
const uint32_t kv_tile = decisions.kv_tile;
|
||||
|
||||
const uint32_t vec_nwg_cap = std::max(
|
||||
1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
|
||||
const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
|
||||
uint32_t nwg = 1u;
|
||||
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
|
||||
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue