From ec0c6619500e86d2c2f290d1c95a3f022397fbcf Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Mon, 1 Jun 2026 14:53:53 +0800 Subject: [PATCH] Support Q4_1, Q5_0, Q5_1 in Flash-attention (llama/23812) * support Q4_1, Q5_0, Q5_1 * update ut case --- ggml/src/ggml-sycl/common.hpp | 1 + ggml/src/ggml-sycl/fattn-common.hpp | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 31e26ff48..d8bb3638d 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -45,6 +45,7 @@ namespace syclexp = sycl::ext::oneapi::experimental; #define GGML_COMMON_IMPL_SYCL #define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building. #define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building +#define GGML_SYCL_FA_ALL_QUANTS //define it to enable all quantization types in flash attention. undefine it to only support F16, Q4_0 and Q8_0 in flash attention. /* suppress warning spam */ #pragma clang diagnostic push diff --git a/ggml/src/ggml-sycl/fattn-common.hpp b/ggml/src/ggml-sycl/fattn-common.hpp index 03f0c2623..c6cc13cfb 100644 --- a/ggml/src/ggml-sycl/fattn-common.hpp +++ b/ggml/src/ggml-sycl/fattn-common.hpp @@ -1031,7 +1031,7 @@ void launch_fattn( auto KV_max_ptr_ct1 = KV_max.ptr; cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_mask_to_KV_max( mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33, @@ -1149,7 +1149,7 @@ void launch_fattn( auto K_ne_ct6 = K->ne[2]; cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_stream_k_fixup(KQV_data_ct0, dst_tmp_meta_ptr_ct1, Q_ne_ct2, Q_ne_ct3, Q_ne_ct4, @@ -1169,7 +1169,7 @@ void launch_fattn( auto KQV_data_ct2 = (float *) KQV->data; cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_combine_results( dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks,