diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c75a98a8d..6f877f15c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -644,7 +644,8 @@ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) { const uint32_t offset_elems = - (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type)); + (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / + ggml_type_size(K->type)); return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u; } @@ -655,8 +656,10 @@ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); } -inline bool ggml_webgpu_flash_attn_kv_direct( - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) { +inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V, + uint32_t kv_direct_align) { return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); } @@ -671,10 +674,10 @@ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_co key.dst_type = context.dst->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; - key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align); - key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); - key.has_mask = context.src3 != nullptr; - key.has_sinks = context.src4 != nullptr; + key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align); + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); + key.has_mask = context.src3 != nullptr; + key.has_sinks = context.src4 != nullptr; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; return key; } @@ -1727,7 +1730,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 79d513802..538e587bb 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -4253,9 +4253,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const const uint32_t q_tile = use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u; - const bool kv_direct = use_subgroup_matrix ? - ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) : - false; + const bool kv_direct = use_subgroup_matrix ? + ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) : + false; const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct);