From 5a1feed8ca57b70d002ca0df2abd3db9328a1daa Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Fri, 5 Jun 2026 19:44:40 +0200 Subject: [PATCH] vulkan: add fwht support for Intel with shmem reduction (llama/23964) * vulkan: add fwht support for Intel with shmem reduction * don't use N as workgroup size * disable subgroup shuffle on MoltenVK AMD * disable fwht shader on Intel Windows due to driver bug --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 ++++ ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp | 78 +++++++++++++++---- .../vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 3 files changed, 76 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e7d04634b..df410368a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5084,6 +5084,14 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } ++idx; } + } else if (device->driver_id != vk::DriverId::eIntelProprietaryWindows) { + // Disabled on Intel Windows due to a driver bug: https://github.com/ggml-org/llama.cpp/pull/23964#issuecomment-4598226147 + int idx = 0; + for (uint32_t n : {64, 128, 256, 512}) { + const uint32_t block_size = std::min(device->subgroup_size, n); + ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_shmem_f32", fwht_shmem_f32_len, fwht_shmem_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { block_size, n }, 1); + ++idx; + } } const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; @@ -5630,6 +5638,11 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); +#ifdef __APPLE__ + if (device->vendor_id == VK_VENDOR_ID_AMD) { + device->subgroup_shuffle = false; + } +#endif device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp index 72059d4af..a2069964a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp @@ -1,14 +1,16 @@ #version 450 #extension GL_EXT_control_flow_attributes : require +#ifndef FWHT_SHMEM #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_shuffle : enable +#endif + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(constant_id = 1) const uint N = 128; layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; -layout(constant_id = 0) const uint WARP_SIZE = 32; -layout(constant_id = 1) const uint N = 128; - layout(push_constant) uniform parameter { uint n_rows; @@ -20,35 +22,72 @@ layout(push_constant) uniform parameter layout(binding = 0, std430) readonly buffer A { float data_a[]; }; layout(binding = 1, std430) writeonly buffer D { float data_d[]; }; -const uint EL_W = N / WARP_SIZE; +const uint EL_W = N / BLOCK_SIZE; + +#ifdef FWHT_SHMEM +shared float shmem[4 * N]; +#endif void main() { - const uint lane = gl_SubgroupInvocationID; - for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID; - row < n_rows; - row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) { +#ifdef FWHT_SHMEM + const uint tid = gl_LocalInvocationID.x; + const uint shmem_base = gl_LocalInvocationID.y * N; + const uint row_id = gl_LocalInvocationID.y; +#else + const uint tid = gl_SubgroupInvocationID; + const uint row_id = gl_SubgroupID; +#endif + + for (uint base_row = gl_WorkGroupID.x * gl_WorkGroupSize.y; + base_row < n_rows; + base_row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) { + const uint row = base_row + row_id; const uint row_offset = row * N; +#ifndef FWHT_SHMEM + if (row >= n_rows) { + continue; + } +#endif + float reg[EL_W]; [[unroll]] for (uint i = 0; i < EL_W; ++i) { - reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane] * scale; + reg[i] = row < n_rows ? data_a[src_offset + row_offset + i * BLOCK_SIZE + tid] * scale : 0.0; } +#ifdef FWHT_SHMEM [[unroll]] - for (uint h = 1; h < WARP_SIZE; h <<= 1) { + for (uint h = 1; h < BLOCK_SIZE; h <<= 1) { + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + shmem[shmem_base + i * BLOCK_SIZE + tid] = reg[i]; + } + barrier(); + [[unroll]] + for (uint j = 0; j < EL_W; ++j) { + const float val = reg[j]; + const float other = shmem[shmem_base + j * BLOCK_SIZE + (tid ^ h)]; + reg[j] = (tid & h) == 0 ? val + other : other - val; + } + barrier(); + } +#else + [[unroll]] + for (uint h = 1; h < BLOCK_SIZE; h <<= 1) { [[unroll]] for (uint j = 0; j < EL_W; ++j) { const float val = reg[j]; const float val2 = subgroupShuffleXor(val, h); - reg[j] = (lane & h) == 0 ? val + val2 : val2 - val; + reg[j] = (tid & h) == 0 ? val + val2 : val2 - val; } } +#endif [[unroll]] - for (uint h = WARP_SIZE; h < N; h <<= 1) { - const uint step = h / WARP_SIZE; + for (uint h = BLOCK_SIZE; h < N; h <<= 1) { + const uint step = h / BLOCK_SIZE; [[unroll]] for (uint j = 0; j < EL_W; j += 2 * step) { [[unroll]] @@ -61,9 +100,16 @@ void main() { } } - [[unroll]] - for (uint i = 0; i < EL_W; ++i) { - data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i]; +#ifdef FWHT_SHMEM + if (row < n_rows) { +#endif + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + data_d[dst_offset + row_offset + i * BLOCK_SIZE + tid] = reg[i]; + } +#ifdef FWHT_SHMEM } + barrier(); +#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index de7dbec2c..d65cd12b2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -957,6 +957,7 @@ void process_shaders() { string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("fwht_f32", "fwht.comp", {}); + string_to_spv("fwht_shmem_f32", "fwht.comp", {{"FWHT_SHMEM", "1"}}); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));