vulkan: Support gated_delta_net with S_v=16 (llama/24581)

This commit is contained in:
Jeff Bolz 2026-06-16 02:26:57 -05:00 committed by Georgi Gerganov
parent fa204c82f4
commit 93c02083bd
1 changed files with 32 additions and 12 deletions

View File

@ -911,8 +911,8 @@ struct vk_device_struct {
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
// [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
vk_pipeline pipeline_gated_delta_net[3][2];
// [size_idx][kda] where size_idx: 0=d16, 1=d32, 2=d64, 3=d128
vk_pipeline pipeline_gated_delta_net[4][2];
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
@ -5231,14 +5231,14 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
{
const uint32_t gdn_sizes[] = {32, 64, 128};
const uint32_t gdn_sizes[] = {16, 32, 64, 128};
const char * gdn_names[][2] = {
{"gated_delta_net_f32_d16", "gated_delta_net_f32_d16_kda"},
{"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"},
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
};
const bool use_subgroup_reduce = device->subgroup_arithmetic;
for (uint32_t si = 0; si < 3; si++) {
for (uint32_t si = 0; si < 4; si++) {
const uint32_t S_V = gdn_sizes[si];
GGML_ASSERT(is_pow2(S_V));
@ -5252,10 +5252,29 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
lanes_per_column = std::min(S_V, device->subgroup_size);
}
const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
// gated_delta_net.comp relies on S_V % COLS_PER_WG == 0 and
// S_V % LANES_PER_COLUMN == 0 to avoid bounds checks.
while (lanes_per_column > 1u) {
const bool valid_lanes = (device->subgroup_size % lanes_per_column) == 0 &&
(S_V % lanes_per_column) == 0;
const uint32_t cols_per_wg = valid_lanes ? device->subgroup_size / lanes_per_column : 0;
if (valid_lanes && cols_per_wg > 0 && (S_V % cols_per_wg) == 0) {
break;
}
lanes_per_column >>= 1u;
}
GGML_ASSERT((device->subgroup_size % lanes_per_column) == 0);
GGML_ASSERT((S_V % lanes_per_column) == 0);
GGML_ASSERT((S_V % (device->subgroup_size / lanes_per_column)) == 0);
const bool need_partial_subgroup_reduce = lanes_per_column != 1u && lanes_per_column < device->subgroup_size;
const bool use_clustered_reduce = device->subgroup_arithmetic && device->subgroup_clustered && need_partial_subgroup_reduce;
const bool use_subgroup_reduce = device->subgroup_arithmetic && !need_partial_subgroup_reduce;
const bool use_subgroup_ops = use_clustered_reduce || use_subgroup_reduce;
size_t gdn_len;
const void * gdn_data;
if (use_subgroup_reduce && need_clustered_shader) {
if (use_clustered_reduce) {
gdn_len = gated_delta_net_f32_len;
gdn_data = (const void *)gated_delta_net_f32_data;
} else if (use_subgroup_reduce) {
@ -5272,7 +5291,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
for (uint32_t kda = 0; kda < 2; kda++) {
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size);
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_ops, device->subgroup_size);
}
}
}
@ -10746,9 +10765,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0;
uint32_t si;
switch (S_v) {
case 32: si = 0; break;
case 64: si = 1; break;
case 128: si = 2; break;
case 16: si = 0; break;
case 32: si = 1; break;
case 64: si = 2; break;
case 128: si = 3; break;
default: return nullptr;
}
return ctx->device->pipeline_gated_delta_net[si][kda];
@ -17193,7 +17213,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_GATED_DELTA_NET:
{
const uint32_t S_v = op->src[2]->ne[0];
if (S_v != 32 && S_v != 64 && S_v != 128) {
if (S_v != 16 && S_v != 32 && S_v != 64 && S_v != 128) {
return false;
}
for (int i = 0; i < 6; i++) {