Support Q4_1, Q5_0, Q5_1 in Flash-attention (llama/23812)

* support Q4_1, Q5_0, Q5_1

* update ut case
This commit is contained in:
Neo Zhang 2026-06-01 14:53:53 +08:00 committed by Georgi Gerganov
parent 20323e48c4
commit ec0c661950
2 changed files with 4 additions and 3 deletions

View File

@ -45,6 +45,7 @@ namespace syclexp = sycl::ext::oneapi::experimental;
#define GGML_COMMON_IMPL_SYCL
#define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building.
#define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building
#define GGML_SYCL_FA_ALL_QUANTS //define it to enable all quantization types in flash attention. undefine it to only support F16, Q4_0 and Q8_0 in flash attention.
/* suppress warning spam */
#pragma clang diagnostic push

View File

@ -1031,7 +1031,7 @@ void launch_fattn(
auto KV_max_ptr_ct1 = KV_max.ptr;
cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max),
[=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
GGML_UNUSED(item_ct1);
flash_attn_mask_to_KV_max<ncols1, warp_size>(
mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33,
@ -1149,7 +1149,7 @@ void launch_fattn(
auto K_ne_ct6 = K->ne[2];
cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
[=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
GGML_UNUSED(item_ct1);
flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1,
Q_ne_ct2, Q_ne_ct3, Q_ne_ct4,
@ -1169,7 +1169,7 @@ void launch_fattn(
auto KQV_data_ct2 = (float *) KQV->data;
cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
[=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
GGML_UNUSED(item_ct1);
flash_attn_combine_results<DV>(
dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks,