diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a7a3f208..38b4c3d9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -585,6 +585,8 @@ struct vk_device_struct { bool multi_add; bool shader_int64; bool buffer_device_address; + // Not needed for Vulkan 1.2+ where it's a core function + PFN_vkGetBufferDeviceAddressKHR pfn_vkGetBufferDeviceAddress = nullptr; bool vulkan_memory_model; bool add_rms_fusion; @@ -2585,8 +2587,13 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std buf->size = size; if (device->buffer_device_address) { - const vk::BufferDeviceAddressInfo addressInfo(buf->buffer); - buf->bda_addr = device->device.getBufferAddress(addressInfo); + if (device->pfn_vkGetBufferDeviceAddress){ + vk::BufferDeviceAddressInfo addressInfo(buf->buffer); + buf->bda_addr = device->pfn_vkGetBufferDeviceAddress(device->device, &static_cast(addressInfo)); + } else { + const vk::BufferDeviceAddressInfo addressInfo(buf->buffer); + buf->bda_addr = device->device.getBufferAddress(addressInfo); + } } device->memory_logger->log_allocation(buf, size); @@ -5268,6 +5275,15 @@ static vk_device ggml_vk_get_device(size_t idx) { device_create_info.setPNext(&device_features2); device->device = device->physical_device.createDevice(device_create_info); + if (!device_is_vulkan_12 && device->buffer_device_address) { + device->pfn_vkGetBufferDeviceAddress = (PFN_vkGetBufferDeviceAddressKHR) + vkGetDeviceProcAddr(device->device, "vkGetBufferDeviceAddressKHR"); + + if (!device->pfn_vkGetBufferDeviceAddress) { + throw std::runtime_error("Failed to load vkGetBufferDeviceAddressKHR"); + } + } + // Queues ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);