From c613b87ad4f73454143b88693de63c674b56848d Mon Sep 17 00:00:00 2001 From: Thomas Guillem Date: Wed, 21 Jan 2026 06:29:31 +0100 Subject: [PATCH] vulkan: use pfn_vkGetBufferDeviceAddress on Vulkan 1.1 --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a7a3f208..38b4c3d9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -585,6 +585,8 @@ struct vk_device_struct { bool multi_add; bool shader_int64; bool buffer_device_address; + // Not needed for Vulkan 1.2+ where it's a core function + PFN_vkGetBufferDeviceAddressKHR pfn_vkGetBufferDeviceAddress = nullptr; bool vulkan_memory_model; bool add_rms_fusion; @@ -2585,8 +2587,13 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std buf->size = size; if (device->buffer_device_address) { - const vk::BufferDeviceAddressInfo addressInfo(buf->buffer); - buf->bda_addr = device->device.getBufferAddress(addressInfo); + if (device->pfn_vkGetBufferDeviceAddress){ + vk::BufferDeviceAddressInfo addressInfo(buf->buffer); + buf->bda_addr = device->pfn_vkGetBufferDeviceAddress(device->device, &static_cast(addressInfo)); + } else { + const vk::BufferDeviceAddressInfo addressInfo(buf->buffer); + buf->bda_addr = device->device.getBufferAddress(addressInfo); + } } device->memory_logger->log_allocation(buf, size); @@ -5268,6 +5275,15 @@ static vk_device ggml_vk_get_device(size_t idx) { device_create_info.setPNext(&device_features2); device->device = device->physical_device.createDevice(device_create_info); + if (!device_is_vulkan_12 && device->buffer_device_address) { + device->pfn_vkGetBufferDeviceAddress = (PFN_vkGetBufferDeviceAddressKHR) + vkGetDeviceProcAddr(device->device, "vkGetBufferDeviceAddressKHR"); + + if (!device->pfn_vkGetBufferDeviceAddress) { + throw std::runtime_error("Failed to load vkGetBufferDeviceAddressKHR"); + } + } + // Queues ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);