diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 72a686951..59582e4f0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -911,8 +911,8 @@ struct vk_device_struct { vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; - // [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128 - vk_pipeline pipeline_gated_delta_net[3][2]; + // [size_idx][kda] where size_idx: 0=d16, 1=d32, 2=d64, 3=d128 + vk_pipeline pipeline_gated_delta_net[4][2]; vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; @@ -5231,14 +5231,14 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); { - const uint32_t gdn_sizes[] = {32, 64, 128}; + const uint32_t gdn_sizes[] = {16, 32, 64, 128}; const char * gdn_names[][2] = { + {"gated_delta_net_f32_d16", "gated_delta_net_f32_d16_kda"}, {"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"}, {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"}, {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"}, }; - const bool use_subgroup_reduce = device->subgroup_arithmetic; - for (uint32_t si = 0; si < 3; si++) { + for (uint32_t si = 0; si < 4; si++) { const uint32_t S_V = gdn_sizes[si]; GGML_ASSERT(is_pow2(S_V)); @@ -5252,10 +5252,29 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { lanes_per_column = std::min(S_V, device->subgroup_size); } - const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size); + // gated_delta_net.comp relies on S_V % COLS_PER_WG == 0 and + // S_V % LANES_PER_COLUMN == 0 to avoid bounds checks. + while (lanes_per_column > 1u) { + const bool valid_lanes = (device->subgroup_size % lanes_per_column) == 0 && + (S_V % lanes_per_column) == 0; + const uint32_t cols_per_wg = valid_lanes ? device->subgroup_size / lanes_per_column : 0; + if (valid_lanes && cols_per_wg > 0 && (S_V % cols_per_wg) == 0) { + break; + } + lanes_per_column >>= 1u; + } + + GGML_ASSERT((device->subgroup_size % lanes_per_column) == 0); + GGML_ASSERT((S_V % lanes_per_column) == 0); + GGML_ASSERT((S_V % (device->subgroup_size / lanes_per_column)) == 0); + + const bool need_partial_subgroup_reduce = lanes_per_column != 1u && lanes_per_column < device->subgroup_size; + const bool use_clustered_reduce = device->subgroup_arithmetic && device->subgroup_clustered && need_partial_subgroup_reduce; + const bool use_subgroup_reduce = device->subgroup_arithmetic && !need_partial_subgroup_reduce; + const bool use_subgroup_ops = use_clustered_reduce || use_subgroup_reduce; size_t gdn_len; const void * gdn_data; - if (use_subgroup_reduce && need_clustered_shader) { + if (use_clustered_reduce) { gdn_len = gated_delta_net_f32_len; gdn_data = (const void *)gated_delta_net_f32_data; } else if (use_subgroup_reduce) { @@ -5272,7 +5291,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { for (uint32_t kda = 0; kda < 2; kda++) { ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda], gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), - wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size); + wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_ops, device->subgroup_size); } } } @@ -10746,9 +10765,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0; uint32_t si; switch (S_v) { - case 32: si = 0; break; - case 64: si = 1; break; - case 128: si = 2; break; + case 16: si = 0; break; + case 32: si = 1; break; + case 64: si = 2; break; + case 128: si = 3; break; default: return nullptr; } return ctx->device->pipeline_gated_delta_net[si][kda]; @@ -17193,7 +17213,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_GATED_DELTA_NET: { const uint32_t S_v = op->src[2]->ne[0]; - if (S_v != 32 && S_v != 64 && S_v != 128) { + if (S_v != 16 && S_v != 32 && S_v != 64 && S_v != 128) { return false; } for (int i = 0; i < 6; i++) {