vulkan: improve topk perf for large k, fix overflow in unit tests (llama/17582)

This commit is contained in:
Jeff Bolz 2025-11-29 01:39:57 -06:00 committed by Georgi Gerganov
parent 463003e76c
commit dbf8766ffa
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 3 additions and 1 deletions

View File

@ -10239,7 +10239,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
// But if K is larger, then we need a larger workgroup
uint32_t max_pipeline = num_topk_pipelines - 3;
uint32_t max_pipeline = num_topk_pipelines - 1;
uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);
max_pipeline = std::min(preferred_pipeline, max_pipeline);
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
// require full subgroup
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);