ggml-webgpu: only use subgroup-matrix path when head dims are divisible by sg_mat_k / sg_mat_n (llama/23020)

This commit is contained in:
Zheyuan Chen 2026-05-13 15:12:40 -07:00 committed by Georgi Gerganov
parent 97371e9285
commit e4ce42e55f
1 changed files with 5 additions and 2 deletions

View File

@ -777,7 +777,10 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
const bool tile_can_dispatch_all_q_rows =
context.max_subgroup_size > 0 &&
context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size;
const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
const bool use_subgroup_matrix =
context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0;
const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 &&
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
@ -785,7 +788,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {