ggml-webgpu: only use subgroup-matrix path when head dims are divisible by sg_mat_k / sg_mat_n (llama/23020)
This commit is contained in:
parent
97371e9285
commit
e4ce42e55f
|
|
@ -777,7 +777,10 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||
const bool tile_can_dispatch_all_q_rows =
|
||||
context.max_subgroup_size > 0 &&
|
||||
context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size;
|
||||
const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
|
||||
const bool use_subgroup_matrix =
|
||||
context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
|
||||
context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0;
|
||||
const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 &&
|
||||
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
|
||||
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
|
|
@ -785,7 +788,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||
|
||||
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
|
||||
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
|
||||
context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
|
||||
use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue