vulkan: add `v_dot2_f32_f16` support in matrix-matrix multiplication and Flash Attention (llama/24123)
* vulkan: add support for valve fp16 dot2 extension * use macro for dot2 path choice * properly check for the feature * add dot_product abstraction to reduce preprocessor branching
This commit is contained in:
parent
28c7ed3db7
commit
686bc802d1
|
|
@ -113,6 +113,21 @@ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
|
|||
} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
|
||||
#endif
|
||||
|
||||
#if !defined(VK_VALVE_shader_mixed_float_dot_product)
|
||||
#define VK_VALVE_shader_mixed_float_dot_product 1
|
||||
#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_SPEC_VERSION 1
|
||||
#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_EXTENSION_NAME "VK_VALVE_shader_mixed_float_dot_product"
|
||||
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE ((VkStructureType)1000673000)
|
||||
typedef struct VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE {
|
||||
VkStructureType sType;
|
||||
void* pNext;
|
||||
VkBool32 shaderMixedFloatDotProductFloat16AccFloat32;
|
||||
VkBool32 shaderMixedFloatDotProductFloat16AccFloat16;
|
||||
VkBool32 shaderMixedFloatDotProductBFloat16Acc;
|
||||
VkBool32 shaderMixedFloatDotProductFloat8AccFloat32;
|
||||
} VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE;
|
||||
#endif
|
||||
|
||||
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
||||
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
||||
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
||||
|
|
@ -705,6 +720,8 @@ struct vk_device_struct {
|
|||
bool coopmat2_bf16_support {};
|
||||
bool coopmat2_decode_vector;
|
||||
|
||||
bool dot2_f16 {};
|
||||
|
||||
bool pipeline_executable_properties_support {};
|
||||
|
||||
size_t idx;
|
||||
|
|
@ -3920,8 +3937,13 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
|||
name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
|
||||
} else {
|
||||
if (device->fp16) {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
|
||||
if (device->dot2_f16) {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_dot2_data; spv_size = flash_attn_f32_f16_dot2_len; }
|
||||
else { spv_data = flash_attn_f32_f16_dot2_f16acc_data; spv_size = flash_attn_f32_f16_dot2_f16acc_len; }
|
||||
} else {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
|
||||
}
|
||||
} else {
|
||||
spv_data = flash_attn_f32_f16_fp32_data;
|
||||
spv_size = flash_attn_f32_f16_fp32_len;
|
||||
|
|
@ -4215,7 +4237,23 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
|||
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->fp16) {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
// bf16 scalar path promotes to f32, no dot2 variant
|
||||
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
|
|
@ -4250,7 +4288,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
|||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
||||
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
||||
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
|
@ -4258,7 +4296,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
|||
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
||||
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
|
@ -4298,8 +4335,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
|||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
||||
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
||||
|
||||
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
||||
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
||||
|
|
@ -4344,8 +4380,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
|||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
|
||||
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
|
|
@ -4390,6 +4425,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
|||
#undef CREATE_MM2
|
||||
#undef CREATE_MMQ
|
||||
#undef CREATE_MM
|
||||
#undef CREATE_MM_NODOT2
|
||||
} else {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
|
|
@ -5453,6 +5489,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
device->integer_dot_product = false;
|
||||
device->shader_64b_indexing = false;
|
||||
bool bfloat16_support = false;
|
||||
bool dot2_f16_support = false;
|
||||
|
||||
for (const auto& properties : ext_props) {
|
||||
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
||||
|
|
@ -5495,6 +5532,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
||||
bfloat16_support = true;
|
||||
#endif
|
||||
} else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_DOT2")) {
|
||||
dot2_f16_support = true;
|
||||
} else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
|
||||
pipeline_executable_properties_support = true;
|
||||
} else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
|
||||
|
|
@ -5802,6 +5842,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
|
||||
}
|
||||
|
||||
VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {};
|
||||
dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE;
|
||||
if (dot2_f16_support) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&dot2_features;
|
||||
last_struct = (VkBaseOutStructure *)&dot2_features;
|
||||
device_extensions.push_back("VK_VALVE_shader_mixed_float_dot_product");
|
||||
}
|
||||
|
||||
VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {};
|
||||
pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR;
|
||||
if (pipeline_executable_properties_support) {
|
||||
|
|
@ -5836,6 +5884,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
device->bf16 = false;
|
||||
#endif
|
||||
|
||||
device->dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32;
|
||||
|
||||
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
|
||||
|
||||
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
|
||||
|
|
@ -6250,6 +6300,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|||
bool coopmat2_decode_vector_support = false;
|
||||
bool integer_dot_product = false;
|
||||
bool bfloat16_support = false;
|
||||
bool dot2_f16_support = false;
|
||||
|
||||
for (auto properties : ext_props) {
|
||||
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
||||
|
|
@ -6279,6 +6330,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|||
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
||||
bfloat16_support = true;
|
||||
#endif
|
||||
} else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_DOT2")) {
|
||||
dot2_f16_support = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -6369,6 +6423,13 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|||
last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
||||
}
|
||||
|
||||
VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {};
|
||||
dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE;
|
||||
if (dot2_f16_support) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&dot2_features;
|
||||
last_struct = (VkBaseOutStructure *)&dot2_features;
|
||||
}
|
||||
|
||||
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
||||
|
||||
fp16 = fp16 && vk12_features.shaderFloat16;
|
||||
|
|
@ -6415,9 +6476,12 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|||
: coopmat_support ? "KHR_coopmat"
|
||||
: "none";
|
||||
|
||||
bool dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32;
|
||||
const char *fp16_str = fp16 ? (dot2_f16 ? "dot2" : "1") : "0";
|
||||
|
||||
std::string device_name = props2.properties.deviceName.data();
|
||||
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
|
||||
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
|
||||
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %s | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
|
||||
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16_str, bf16, subgroup_size,
|
||||
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
|
||||
|
||||
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
#ifdef DOT2_F16
|
||||
#extension GL_EXT_spirv_intrinsics : require
|
||||
|
||||
spirv_instruction(extensions = ["SPV_VALVE_mixed_float_dot_product"],
|
||||
capabilities = [6912], id = 6916)
|
||||
float v_dot2_f32_f16(f16vec2 a, f16vec2 b, float acc);
|
||||
|
||||
ACC_TYPE dot_product(f16vec4 a, f16vec4 b, ACC_TYPE acc) {
|
||||
return ACC_TYPE(v_dot2_f32_f16(a.zw, b.zw, v_dot2_f32_f16(a.xy, b.xy, float(acc))));
|
||||
}
|
||||
|
||||
ACC_TYPE dot_product(f16vec2 a, f16vec2 b, ACC_TYPE acc) {
|
||||
return ACC_TYPE(v_dot2_f32_f16(a, b, float(acc)));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
ACC_TYPE dot_product(FLOAT_TYPEV4 a, FLOAT_TYPEV4 b, ACC_TYPE acc) {
|
||||
return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y),
|
||||
fma(ACC_TYPE(a.z), ACC_TYPE(b.z), fma(ACC_TYPE(a.w), ACC_TYPE(b.w), acc))));
|
||||
}
|
||||
|
||||
ACC_TYPE dot_product(FLOAT_TYPEV2 a, FLOAT_TYPEV2 b, ACC_TYPE acc) {
|
||||
return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), acc));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
@ -21,6 +21,7 @@
|
|||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
|
||||
#include "types.glsl"
|
||||
#include "dot_product_funcs.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
#include "flash_attn_dequant.glsl"
|
||||
|
||||
|
|
@ -318,7 +319,7 @@ void main() {
|
|||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
|
||||
Sf[r][c] = dot_product(Q_cache[r], K_Tf, Sf[r][c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -341,7 +342,7 @@ void main() {
|
|||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
|
||||
Sf[r][c] = dot_product(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf, Sf[r][c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@
|
|||
#endif
|
||||
|
||||
#include "types.glsl"
|
||||
#include "dot_product_funcs.glsl"
|
||||
|
||||
#ifndef LOAD_VEC_A
|
||||
#define LOAD_VEC_A 1
|
||||
|
|
@ -329,15 +330,8 @@ void main() {
|
|||
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
|
||||
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
|
||||
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y),
|
||||
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));
|
||||
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),
|
||||
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));
|
||||
#else
|
||||
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
|
||||
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
|
||||
#endif
|
||||
sums[sums_idx].x = dot_product(cache_a[wsir * TM + 2 * cr ], cache_b, sums[sums_idx].x);
|
||||
sums[sums_idx].y = dot_product(cache_a[wsir * TM + 2 * cr + 1], cache_b, sums[sums_idx].y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -336,7 +336,8 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
|
|||
// disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734
|
||||
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
|
||||
// disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
|
||||
if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) {
|
||||
// disable spirv-opt for dot2 shaders (spirv-opt doesn't recognize SPV_VALVE_mixed_float_dot_product capability)
|
||||
if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos && name.find("_dot2") == std::string::npos) {
|
||||
cmd.push_back("-O");
|
||||
}
|
||||
|
||||
|
|
@ -427,10 +428,11 @@ void string_to_spv(std::string name, const std::string& source, const std::map<s
|
|||
generate_dep_file = false;
|
||||
}
|
||||
|
||||
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
|
||||
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc, bool dot2 = false) {
|
||||
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
|
||||
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
||||
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
||||
std::string dot2_sfx = dot2 ? "_dot2" : "";
|
||||
|
||||
std::map<std::string, std::string> base_dict;
|
||||
std::string shader_name = "matmul";
|
||||
|
|
@ -463,6 +465,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
|||
}
|
||||
#endif
|
||||
|
||||
if (dot2) {
|
||||
base_dict["DOT2_F16"] = "1";
|
||||
}
|
||||
|
||||
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
||||
|
||||
auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string {
|
||||
|
|
@ -528,11 +534,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
|||
};
|
||||
|
||||
// Shaders with f16 B_TYPE
|
||||
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
// bf16
|
||||
{
|
||||
|
|
@ -553,8 +559,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
|||
if (!(coopmat || coopmat2))
|
||||
#endif
|
||||
{
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
if (!dot2) {
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -584,18 +592,18 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
|||
|
||||
// don't generate f32 variants for coopmat2
|
||||
if (!coopmat2) {
|
||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
if (tname != "f16" && tname != "f32") {
|
||||
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
// Integer dot mmq performs better with f32 accumulators
|
||||
if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
|
||||
// Integer dot mmq performs better with f32 accumulators (different shader, skip for dot2)
|
||||
if (!f16acc && !coopmat && !coopmat2 && !dot2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
|
||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
#endif
|
||||
|
|
@ -613,6 +621,10 @@ void process_shaders() {
|
|||
matmul_shaders(true, matmul_id_type, false, false, false);
|
||||
matmul_shaders(true, matmul_id_type, false, false, true);
|
||||
|
||||
// dot2 variants (scalar fp16 only)
|
||||
matmul_shaders(true, matmul_id_type, false, false, false, true);
|
||||
matmul_shaders(true, matmul_id_type, false, false, true, true);
|
||||
|
||||
if (matmul_id_type != MatMulIdType::DEFAULT) {
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
// Coopmat, fp32acc and fp16acc
|
||||
|
|
@ -660,6 +672,12 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
|
||||
|
||||
if (fp16) {
|
||||
string_to_spv("flash_attn_f32_f16_dot2", "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DOT2_F16", "1"}}), fp16, false, false, f16acc);
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8");
|
||||
|
|
|
|||
Loading…
Reference in New Issue