vulkan: Support asymmetric FA in coopmat2 path (llama/21753)

* vulkan: Support asymmetric FA in coopmat2 path

There has been some recent interest/experimentation with mixed quantization
types for FA. I had originally designed the cm2 FA shader with this in mind
(because I didn't realize it wasn't supported at the time!), this change
adds the missing pieces and enables it.

Also support Q1_0 since people have been trying that out (seems crazy, but
who knows).

We should be able to do similar things in the coopmat1/scalar path, but
there's another change open against the scalar path and I don't want to
conflict.

* reorder cases
This commit is contained in:
Jeff Bolz 2026-05-01 15:28:32 +02:00 committed by Georgi Gerganov
parent 35cb684129
commit 95053f68e4
4 changed files with 185 additions and 87 deletions

View File

@ -440,10 +440,12 @@ struct vk_fa_pipeline_state {
bool f32acc;
uint32_t flags;
uint32_t limit_occupancy_shmem;
ggml_type k_type;
ggml_type v_type;
bool operator<(const vk_fa_pipeline_state &b) const {
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) <
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type);
}
};
@ -3041,7 +3043,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device
return result;
}
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
GGML_UNUSED(n_kv);
GGML_UNUSED(f32acc);
@ -3055,7 +3057,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
if (small_rows) {
result.block_rows = 32;
result.block_cols = 32;
} else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
} else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) {
result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
result.block_cols = 32;
} else {
@ -3069,7 +3071,13 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
return result;
}
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
// Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1.
if (k_type != v_type) {
GGML_ASSERT(device->coopmat2);
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
}
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
@ -3081,7 +3089,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
if (path == FA_COOPMAT1) {
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
if (!shape_ok || !shmem_ok) {
@ -3094,20 +3102,25 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
path = FA_SCALAR;
}
// Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it.
if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) {
path = FA_COOPMAT2;
}
switch (path) {
case FA_SCALAR:
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
case FA_COOPMAT1:
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
case FA_COOPMAT2:
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
default:
throw std::runtime_error("unsupported FaCodePath");
}
}
static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) {
const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
(device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
@ -3118,12 +3131,32 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const
const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type};
}
static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
const auto fa_block_bytes = [](ggml_type t) -> uint32_t {
// decodeBufF32 uses a block of vec4s for a better memory access pattern.
return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t);
};
return {
/* 0 WorkGroupSize */ state.workgroup_size,
/* 1 Br */ state.Br,
/* 2 Bc */ state.Bc,
/* 3 HSK */ state.HSK,
/* 4 HSV */ state.HSV,
/* 5 Clamp */ static_cast<uint32_t>(!state.aligned),
/* 6 D_split */ state.D_split,
/* 7 row_split */ state.row_split,
/* 8 SubGroupSize */ state.subgroup_size,
/* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u,
/*10 Flags */ state.flags,
/*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem,
/*12 FaTypeK */ static_cast<uint32_t>(state.k_type),
/*13 FaTypeV */ static_cast<uint32_t>(state.v_type),
/*14 FaBlockBytesK */ fa_block_bytes(state.k_type),
/*15 FaBlockBytesV */ fa_block_bytes(state.v_type),
};
}
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
@ -3578,16 +3611,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
#define CREATE_FA_CM2_MIXED() \
for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \
for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \
FaCodePath path = fa.first.path; \
uint32_t Br = fa.first.Br; \
uint32_t Bc = fa.first.Bc; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
if (path == FA_COOPMAT2) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
} \
} \
} \
} \
}
if (device->coopmat2) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
CREATE_FA_CM2_MIXED();
}
#undef CREATE_FA_CM2_MIXED
#endif
#undef CREATE_FA
@ -9042,8 +9094,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
assert(dst->type == GGML_TYPE_F32);
assert(q->type == GGML_TYPE_F32);
assert(k->type == v->type);
uint32_t gqa_ratio = 1;
uint32_t qk_ratio = neq2 / nek2;
uint32_t workgroups_x = (uint32_t)neq1;
@ -9054,7 +9104,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc);
const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
@ -9067,7 +9117,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
workgroups_y /= gqa_ratio;
}
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);
if (tuning_params.path != FA_COOPMAT2) {
GGML_ASSERT(k->type == v->type);
}
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
@ -9106,7 +9160,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
mask != nullptr, use_mask_opt, logit_softcap != 0);
mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type);
vk_pipeline pipeline = nullptr;
@ -15590,38 +15644,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
// It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) {
// mismatching K/V type is currently supported for coopmat2 only.
if (op->src[1]->type != op->src[2]->type && !coopmat2) {
return false;
}
switch (op->src[1]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_IQ4_NL:
// supported in scalar and coopmat2 paths
break;
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
//case GGML_TYPE_Q2_K:
//case GGML_TYPE_Q3_K:
//case GGML_TYPE_Q4_K:
//case GGML_TYPE_Q5_K:
//case GGML_TYPE_Q6_K:
//case GGML_TYPE_IQ1_S:
//case GGML_TYPE_IQ1_M:
//case GGML_TYPE_IQ2_XXS:
//case GGML_TYPE_IQ2_XS:
//case GGML_TYPE_IQ2_S:
//case GGML_TYPE_IQ3_XXS:
//case GGML_TYPE_IQ3_S:
//case GGML_TYPE_IQ4_XS:
default:
auto fa_kv_ok = [coopmat2](ggml_type t) {
switch (t) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_0:
return true;
case GGML_TYPE_Q1_0:
return coopmat2;
default:
return false;
}
};
if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) {
return false;
}
if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {

View File

@ -13,6 +13,12 @@ layout (constant_id = 8) const uint32_t SubGroupSize = 32;
layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0;
layout (constant_id = 10) const uint32_t Flags = 0;
layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
// ggml_type enumerant for K/V
layout (constant_id = 12) const uint32_t FaTypeK = 0;
layout (constant_id = 13) const uint32_t FaTypeV = 0;
// sizeof(decode buffer): quants -> ggml block size; F32 -> 16 (decodeBufF32 vec4).
layout (constant_id = 14) const uint32_t FaBlockBytesK = 2;
layout (constant_id = 15) const uint32_t FaBlockBytesV = 2;
const bool USE_MASK_OPT = (Flags & 1) != 0;
const bool MASK_ENABLE = (Flags & 2) != 0;

View File

@ -17,8 +17,57 @@
#extension GL_EXT_null_initializer : enable
#include "types.glsl"
#include "dequant_funcs_cm2.glsl"
#include "flash_attn_base.glsl"
#include "dequant_funcs_cm2.glsl"
// buffer_reference stride = sizeof(struct) = FaBlockBytesK/V.
layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K {
uint8_t raw[FaBlockBytesK];
};
layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_V {
uint8_t raw[FaBlockBytesV];
};
uint fa_block_elems(uint ty) {
switch (ty) {
case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32)
case 1u: return 1u; // GGML_TYPE_F16
case 2u: return uint(QUANT_K_Q4_0);
case 3u: return uint(QUANT_K_Q4_1);
case 6u: return uint(QUANT_K_Q5_0);
case 7u: return uint(QUANT_K_Q5_1);
case 8u: return uint(QUANT_K_Q8_0);
case 41u: return uint(QUANT_K_Q1_0);
default:
return 1u;
}
}
float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
switch (FaTypeK) {
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
default: return float16_t(0);
}
}
float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
switch (FaTypeV) {
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
default: return float16_t(0);
}
}
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
@ -55,12 +104,6 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
return max(elem0, elem1);
}
#if BLOCK_SIZE > 1
#define DECODEFUNC , DEQUANTFUNC
#else
#define DECODEFUNC
#endif
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
@ -95,10 +138,6 @@ ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c
}
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
init_indices();
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
@ -107,10 +146,10 @@ void main() {
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if BLOCK_SIZE > 1
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif
const uint bs_k = fa_block_elems(FaTypeK);
const uint bs_v = fa_block_elems(FaTypeV);
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, bs_k);
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, bs_v);
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
@ -120,10 +159,12 @@ void main() {
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{
q_stride &= ~7;
#if BLOCK_SIZE == 1
k_stride &= ~7;
v_stride &= ~7;
#endif
if (bs_k == 1u) {
k_stride &= ~7;
}
if (bs_v == 1u) {
v_stride &= ~7;
}
m_stride &= ~7;
}
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
@ -230,7 +271,13 @@ void main() {
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
// F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128.
const bool k_use_decode = (bs_k > 1u);
if (k_use_decode) {
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK);
} else {
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose);
}
S = coopMatMulAdd(Qf16, K_T, S);
if (LOGIT_SOFTCAP) {
@ -291,7 +338,12 @@ void main() {
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
const bool v_use_decode = (bs_v > 1u);
if (v_use_decode) {
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV);
} else {
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad));
}
L = eM*L + rowsum;

View File

@ -641,20 +641,17 @@ void process_shaders() {
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
}
if (fp16) {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
#endif
}
for (const auto& tname : type_names) {
if (tname == "bf16") continue;
if (fp16) {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
} else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",