vulkan: add fwht support for Intel with shmem reduction (llama/23964)
* vulkan: add fwht support for Intel with shmem reduction * don't use N as workgroup size * disable subgroup shuffle on MoltenVK AMD * disable fwht shader on Intel Windows due to driver bug
This commit is contained in:
parent
facb02c4c3
commit
5a1feed8ca
|
|
@ -5084,6 +5084,14 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
|||
}
|
||||
++idx;
|
||||
}
|
||||
} else if (device->driver_id != vk::DriverId::eIntelProprietaryWindows) {
|
||||
// Disabled on Intel Windows due to a driver bug: https://github.com/ggml-org/llama.cpp/pull/23964#issuecomment-4598226147
|
||||
int idx = 0;
|
||||
for (uint32_t n : {64, 128, 256, 512}) {
|
||||
const uint32_t block_size = std::min(device->subgroup_size, n);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_shmem_f32", fwht_shmem_f32_len, fwht_shmem_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { block_size, n }, 1);
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
|
||||
|
|
@ -5630,6 +5638,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
#endif
|
||||
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
||||
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
|
||||
#ifdef __APPLE__
|
||||
if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
||||
device->subgroup_shuffle = false;
|
||||
}
|
||||
#endif
|
||||
device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
||||
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#ifndef FWHT_SHMEM
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
#endif
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout(constant_id = 1) const uint N = 128;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint N = 128;
|
||||
|
||||
layout(push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
|
|
@ -20,35 +22,72 @@ layout(push_constant) uniform parameter
|
|||
layout(binding = 0, std430) readonly buffer A { float data_a[]; };
|
||||
layout(binding = 1, std430) writeonly buffer D { float data_d[]; };
|
||||
|
||||
const uint EL_W = N / WARP_SIZE;
|
||||
const uint EL_W = N / BLOCK_SIZE;
|
||||
|
||||
#ifdef FWHT_SHMEM
|
||||
shared float shmem[4 * N];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
const uint lane = gl_SubgroupInvocationID;
|
||||
for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
|
||||
row < n_rows;
|
||||
row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) {
|
||||
#ifdef FWHT_SHMEM
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint shmem_base = gl_LocalInvocationID.y * N;
|
||||
const uint row_id = gl_LocalInvocationID.y;
|
||||
#else
|
||||
const uint tid = gl_SubgroupInvocationID;
|
||||
const uint row_id = gl_SubgroupID;
|
||||
#endif
|
||||
|
||||
for (uint base_row = gl_WorkGroupID.x * gl_WorkGroupSize.y;
|
||||
base_row < n_rows;
|
||||
base_row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) {
|
||||
const uint row = base_row + row_id;
|
||||
const uint row_offset = row * N;
|
||||
|
||||
#ifndef FWHT_SHMEM
|
||||
if (row >= n_rows) {
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
||||
float reg[EL_W];
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < EL_W; ++i) {
|
||||
reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane] * scale;
|
||||
reg[i] = row < n_rows ? data_a[src_offset + row_offset + i * BLOCK_SIZE + tid] * scale : 0.0;
|
||||
}
|
||||
|
||||
#ifdef FWHT_SHMEM
|
||||
[[unroll]]
|
||||
for (uint h = 1; h < WARP_SIZE; h <<= 1) {
|
||||
for (uint h = 1; h < BLOCK_SIZE; h <<= 1) {
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < EL_W; ++i) {
|
||||
shmem[shmem_base + i * BLOCK_SIZE + tid] = reg[i];
|
||||
}
|
||||
barrier();
|
||||
[[unroll]]
|
||||
for (uint j = 0; j < EL_W; ++j) {
|
||||
const float val = reg[j];
|
||||
const float other = shmem[shmem_base + j * BLOCK_SIZE + (tid ^ h)];
|
||||
reg[j] = (tid & h) == 0 ? val + other : other - val;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#else
|
||||
[[unroll]]
|
||||
for (uint h = 1; h < BLOCK_SIZE; h <<= 1) {
|
||||
[[unroll]]
|
||||
for (uint j = 0; j < EL_W; ++j) {
|
||||
const float val = reg[j];
|
||||
const float val2 = subgroupShuffleXor(val, h);
|
||||
reg[j] = (lane & h) == 0 ? val + val2 : val2 - val;
|
||||
reg[j] = (tid & h) == 0 ? val + val2 : val2 - val;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
[[unroll]]
|
||||
for (uint h = WARP_SIZE; h < N; h <<= 1) {
|
||||
const uint step = h / WARP_SIZE;
|
||||
for (uint h = BLOCK_SIZE; h < N; h <<= 1) {
|
||||
const uint step = h / BLOCK_SIZE;
|
||||
[[unroll]]
|
||||
for (uint j = 0; j < EL_W; j += 2 * step) {
|
||||
[[unroll]]
|
||||
|
|
@ -61,9 +100,16 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < EL_W; ++i) {
|
||||
data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i];
|
||||
#ifdef FWHT_SHMEM
|
||||
if (row < n_rows) {
|
||||
#endif
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < EL_W; ++i) {
|
||||
data_d[dst_offset + row_offset + i * BLOCK_SIZE + tid] = reg[i];
|
||||
}
|
||||
#ifdef FWHT_SHMEM
|
||||
}
|
||||
barrier();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -957,6 +957,7 @@ void process_shaders() {
|
|||
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("fwht_f32", "fwht.comp", {});
|
||||
string_to_spv("fwht_shmem_f32", "fwht.comp", {{"FWHT_SHMEM", "1"}});
|
||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
|
|
|||
Loading…
Reference in New Issue