diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ec6611c2..49b268ea 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4483,6 +4483,9 @@ static vk_device ggml_vk_get_device(size_t idx) { device->physical_device = physical_devices[dev_num]; const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + vk::PhysicalDeviceProperties device_props = device->physical_device.getProperties(); + const bool device_is_vulkan_12 = device_props.apiVersion >= VK_API_VERSION_1_2; + device->architecture = get_device_architecture(device->physical_device); const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); @@ -4509,6 +4512,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = false; device->shader_64b_indexing = false; bool bfloat16_support = false; + bool shader_float_controls_khr = false; for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { @@ -4559,6 +4563,8 @@ static vk_device ggml_vk_get_device(size_t idx) { } else if (strcmp("VK_EXT_shader_64bit_indexing", properties.extensionName) == 0) { device->shader_64b_indexing = true; #endif + } else if (strcmp("VK_KHR_shader_float_controls", properties.extensionName) == 0) { + shader_float_controls_khr = true; } } @@ -4571,6 +4577,7 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; vk::PhysicalDeviceVulkan11Properties vk11_props; vk::PhysicalDeviceVulkan12Properties vk12_props; + vk::PhysicalDeviceFloatControlsProperties float_controls_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props; @@ -4578,10 +4585,21 @@ static vk_device ggml_vk_get_device(size_t idx) { props2.pNext = &props3; props3.pNext = &subgroup_props; subgroup_props.pNext = &driver_props; - driver_props.pNext = &vk11_props; - vk11_props.pNext = &vk12_props; - VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; + VkBaseOutStructure * last_struct; + + if (device_is_vulkan_12) { + driver_props.pNext = &vk11_props; + vk11_props.pNext = &vk12_props; + last_struct = (VkBaseOutStructure *)&vk12_props; + } else { + if (shader_float_controls_khr) { + driver_props.pNext = &float_controls_props; + last_struct = (VkBaseOutStructure *)&float_controls_props; + } else { + last_struct = (VkBaseOutStructure *)&driver_props; + } + } if (maintenance4_support) { last_struct->pNext = (VkBaseOutStructure *)&props4; @@ -4679,7 +4697,11 @@ static vk_device ggml_vk_get_device(size_t idx) { } else { device->shader_core_count = 0; } - device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + if (device_is_vulkan_12) { + device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + } else { + device->float_controls_rte_fp16 = shader_float_controls_khr ? float_controls_props.shaderRoundingModeRTEFloat16 : false; + } device->subgroup_basic = (subgroup_props.supportedStages & vk::ShaderStageFlagBits::eCompute) && (subgroup_props.supportedOperations & vk::SubgroupFeatureFlagBits::eBasic);