vulkan: don't hold the device mutex while compiling pipelines (llama/23641)

* vulkan: don't hold the device mutex while compiling pipelines

We need to hold a lock while we traverse all pipelines and lazily initialize
them, but we don't need to hold it while the pipeline is being compiled. And
it doesn't need to be the same lock as the device mutex. We call load_shaders
each time a pipeline is needed, so we only need to compile that one pipeline
(and, for example, don't want to end up compiling a pipeline that another
thread should be compiling).

* remove 'needed'
This commit is contained in:
Jeff Bolz 2026-06-01 07:04:01 -05:00 committed by Georgi Gerganov
parent c471bcce1b
commit 71d80aa49e
1 changed files with 99 additions and 45 deletions

View File

@ -65,6 +65,7 @@ typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV {
#include <shared_mutex>
#include <mutex>
#include <future>
#include <condition_variable>
#include <thread>
#if defined(_MSC_VER)
@ -159,8 +160,9 @@ struct vk_pipeline_struct {
uint32_t align;
// true if fields have been set by ggml_vk_create_pipeline
bool initialized {};
// set to true to request the pipeline is compiled
std::atomic<bool> needed {};
// true while a compile is in flight, used to dedupe concurrent claims.
// Protected by device->compile_mutex.
bool compile_pending {};
// set to true when the shader has been compiled
std::atomic<bool> compiled {};
// number of registers used, extracted from pipeline executable properties
@ -621,6 +623,13 @@ struct vk_device_struct {
std::recursive_mutex mutex;
mutable std::shared_mutex pinned_memory_mutex;
// Guards compile_pending, all_pipelines, and the dynamic pipeline maps
// (flash_attn, fa_mask_opt, solve_tri, conv2d, etc). The actual compile
// runs with no lock held, so different pipelines can compile in parallel.
// Lock order is device->mutex -> compile_mutex, never the reverse.
std::mutex compile_mutex;
std::condition_variable compile_cv;
vk::PhysicalDevice physical_device;
vk::PhysicalDeviceProperties properties;
std::string name;
@ -1729,7 +1738,7 @@ struct ggml_vk_garbage_collector {
};
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
static void ggml_vk_load_shaders(vk_device& device);
static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr);
static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
static bool vk_memory_logger_enabled = false;
@ -2196,11 +2205,6 @@ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
ctx->device->device.resetFences({ ctx->fence });
}
// variables to track number of compiles in progress
static uint32_t compile_count = 0;
static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond;
static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367;
static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447;
static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4;
@ -2495,7 +2499,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
throw e;
}
pipeline->compiled = true;
if (vk_instance.debug_utils_support) {
vk::DebugUtilsObjectNameInfoEXT duoni;
@ -2544,14 +2547,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
}
}
device->all_pipelines.push_back(pipeline);
{
std::lock_guard<std::mutex> guard(compile_count_mutex);
assert(compile_count > 0);
compile_count--;
std::lock_guard<std::mutex> guard(device->compile_mutex);
device->all_pipelines.push_back(pipeline);
pipeline->compiled = true;
pipeline->compile_pending = false;
}
compile_count_cond.notify_all();
device->compile_cv.notify_all();
}
static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
@ -2567,8 +2569,7 @@ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx,
VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
ctx->pipeline_descriptor_set_requirements += n;
if (!pipeline->compiled) {
pipeline->needed = true;
ggml_vk_load_shaders(ctx->device);
ggml_vk_load_shaders(ctx->device, pipeline);
}
ggml_pipeline_allocate_descriptor_sets(ctx);
}
@ -3567,10 +3568,26 @@ static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type
#endif
}
static void ggml_vk_load_shaders(vk_device& device) {
// load_shaders walks the pipeline list under compile_mutex and either claims
// the requested pipeline for compilation or, if another thread is already
// compiling it, drops the lock and waits on compile_cv. Compiles themselves
// run unlocked.
struct CompileTask {
vk_pipeline pipeline;
size_t spv_size;
const void * spv_data;
std::string entrypoint;
uint32_t parameter_count;
std::array<uint32_t, 3> wg_denoms;
std::vector<uint32_t> specialization_constants;
bool disable_robustness;
bool require_full_subgroups;
uint32_t required_subgroup_size;
};
static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
std::lock_guard<std::recursive_mutex> guard(device->mutex);
// some shaders have a minimum subgroup size
const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
@ -3600,6 +3617,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
uint32_t l_align, m_align, s_align;
vk_pipeline wait_pipeline;
CompileTask claimed_task {};
bool has_claimed_task = false;
// The rest of the walk reads and writes shared device state, so hold the
// lock until we're done deciding what to compile.
std::unique_lock<std::mutex> compile_lock(device->compile_mutex);
if (device->coopmat2) {
// spec constants and tile sizes for non-quant matmul/matmul_id
l_warptile = { 256, 128, 256, 64, 1 };
@ -3785,7 +3811,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
}
std::vector<std::future<void>> compiles;
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
@ -3819,23 +3844,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
}
if (!pipeline->needed || pipeline->compiled) {
// We only care about the pipeline this call asked for; the rest
// (including the 64-bit indexing variant) are handled by their
// own request_descriptor_sets / load_shaders calls.
if (pipeline.get() != requested.get()) {
continue;
}
// TODO: We're no longer benefitting from the async compiles (shaders are
// compiled individually, as needed) and this complexity can be removed.
{
// wait until fewer than N compiles are in progress
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
std::unique_lock<std::mutex> guard(compile_count_mutex);
while (compile_count >= N) {
compile_count_cond.wait(guard);
}
compile_count++;
if (pipeline->compiled) {
continue;
}
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
wait_pipeline = pipeline;
if (!pipeline->compile_pending) {
pipeline->compile_pending = true;
claimed_task.pipeline = pipeline;
claimed_task.spv_size = spv_size;
claimed_task.spv_data = spv_data;
claimed_task.entrypoint = entrypoint;
claimed_task.parameter_count = parameter_count;
claimed_task.wg_denoms = wg_denoms;
claimed_task.specialization_constants = specialization_constants;
claimed_task.disable_robustness = disable_robustness;
claimed_task.require_full_subgroups = require_full_subgroups;
claimed_task.required_subgroup_size = required_subgroup_size;
has_claimed_task = true;
}
}
};
@ -5332,8 +5367,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
}
for (auto &c : compiles) {
c.wait();
// Drop compile_mutex so other threads can walk while we compile.
compile_lock.unlock();
// Compile what we claimed; create_pipeline_func reacquires compile_mutex
// at the end to flip compile_pending/compiled and notify waiters.
if (has_claimed_task) {
auto & task = claimed_task;
ggml_vk_create_pipeline_func(device, task.pipeline, task.spv_size, task.spv_data,
task.entrypoint, task.parameter_count, task.wg_denoms,
task.specialization_constants, task.disable_robustness,
task.require_full_subgroups, task.required_subgroup_size);
}
// Another thread may be compiling the pipeline we need; block on it here.
if (wait_pipeline) {
std::unique_lock<std::mutex> wait_lock(device->compile_mutex);
device->compile_cv.wait(wait_lock, [&] {
return wait_pipeline->compiled.load();
});
}
}
@ -9722,7 +9774,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_pipeline pipeline = nullptr;
{
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16;
auto it = pipelines.find(fa_pipeline_state);
if (it != pipelines.end()) {
@ -9786,13 +9838,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_pipeline pipeline_fa_mask_opt = nullptr;
if (use_mask_opt) {
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
auto &pipelines = ctx->device->pipeline_fa_mask_opt;
auto it = pipelines.find({Br, Bc});
if (it != pipelines.end()) {
pipeline_fa_mask_opt = it->second;
} else {
pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
{
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto &pipelines = ctx->device->pipeline_fa_mask_opt;
auto it = pipelines.find({Br, Bc});
if (it != pipelines.end()) {
pipeline_fa_mask_opt = it->second;
} else {
pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
}
}
assert(pipeline_fa_mask_opt);
ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
@ -10326,7 +10380,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
vk_pipeline pipeline = nullptr;
{
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
if (it != ctx->device->pipeline_solve_tri_f32.end()) {
pipeline = it->second;
@ -10485,7 +10539,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
vk_pipeline pipeline = nullptr;
{
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto it = pipelines->find(conv2d_pipeline_state);
if (it != pipelines->end()) {
pipeline = it->second;