vulkan: Support asymmetric FA in coopmat2 path (llama/21753)
* vulkan: Support asymmetric FA in coopmat2 path There has been some recent interest/experimentation with mixed quantization types for FA. I had originally designed the cm2 FA shader with this in mind (because I didn't realize it wasn't supported at the time!), this change adds the missing pieces and enables it. Also support Q1_0 since people have been trying that out (seems crazy, but who knows). We should be able to do similar things in the coopmat1/scalar path, but there's another change open against the scalar path and I don't want to conflict. * reorder cases
This commit is contained in:
parent
35cb684129
commit
95053f68e4
|
|
@ -440,10 +440,12 @@ struct vk_fa_pipeline_state {
|
|||
bool f32acc;
|
||||
uint32_t flags;
|
||||
uint32_t limit_occupancy_shmem;
|
||||
ggml_type k_type;
|
||||
ggml_type v_type;
|
||||
|
||||
bool operator<(const vk_fa_pipeline_state &b) const {
|
||||
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
|
||||
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
|
||||
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) <
|
||||
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -3041,7 +3043,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device
|
|||
return result;
|
||||
}
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
||||
GGML_UNUSED(n_kv);
|
||||
GGML_UNUSED(f32acc);
|
||||
|
||||
|
|
@ -3055,7 +3057,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
|
|||
if (small_rows) {
|
||||
result.block_rows = 32;
|
||||
result.block_cols = 32;
|
||||
} else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
|
||||
} else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) {
|
||||
result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
|
||||
result.block_cols = 32;
|
||||
} else {
|
||||
|
|
@ -3069,7 +3071,13 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
|
|||
return result;
|
||||
}
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
||||
// Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1.
|
||||
if (k_type != v_type) {
|
||||
GGML_ASSERT(device->coopmat2);
|
||||
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
}
|
||||
|
||||
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
|
||||
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
||||
|
||||
|
|
@ -3081,7 +3089,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
|||
if (path == FA_COOPMAT1) {
|
||||
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
|
||||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
|
||||
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
|
||||
|
||||
if (!shape_ok || !shmem_ok) {
|
||||
|
|
@ -3094,20 +3102,25 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
|||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
// Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it.
|
||||
if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) {
|
||||
path = FA_COOPMAT2;
|
||||
}
|
||||
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
case FA_COOPMAT1:
|
||||
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
case FA_COOPMAT2:
|
||||
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
default:
|
||||
throw std::runtime_error("unsupported FaCodePath");
|
||||
}
|
||||
}
|
||||
|
||||
static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
|
||||
bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
|
||||
bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) {
|
||||
const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
|
||||
(device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
|
||||
|
||||
|
|
@ -3118,12 +3131,32 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const
|
|||
|
||||
const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
|
||||
|
||||
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
|
||||
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type};
|
||||
}
|
||||
|
||||
static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
|
||||
return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
|
||||
state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
|
||||
const auto fa_block_bytes = [](ggml_type t) -> uint32_t {
|
||||
// decodeBufF32 uses a block of vec4s for a better memory access pattern.
|
||||
return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t);
|
||||
};
|
||||
return {
|
||||
/* 0 WorkGroupSize */ state.workgroup_size,
|
||||
/* 1 Br */ state.Br,
|
||||
/* 2 Bc */ state.Bc,
|
||||
/* 3 HSK */ state.HSK,
|
||||
/* 4 HSV */ state.HSV,
|
||||
/* 5 Clamp */ static_cast<uint32_t>(!state.aligned),
|
||||
/* 6 D_split */ state.D_split,
|
||||
/* 7 row_split */ state.row_split,
|
||||
/* 8 SubGroupSize */ state.subgroup_size,
|
||||
/* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u,
|
||||
/*10 Flags */ state.flags,
|
||||
/*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem,
|
||||
/*12 FaTypeK */ static_cast<uint32_t>(state.k_type),
|
||||
/*13 FaTypeV */ static_cast<uint32_t>(state.v_type),
|
||||
/*14 FaBlockBytesK */ fa_block_bytes(state.k_type),
|
||||
/*15 FaBlockBytesV */ fa_block_bytes(state.v_type),
|
||||
};
|
||||
}
|
||||
|
||||
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
||||
|
|
@ -3578,16 +3611,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
}
|
||||
#endif
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
#define CREATE_FA_CM2_MIXED() \
|
||||
for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \
|
||||
FaCodePath path = fa.first.path; \
|
||||
uint32_t Br = fa.first.Br; \
|
||||
uint32_t Bc = fa.first.Bc; \
|
||||
bool aligned = fa.first.aligned; \
|
||||
bool f32acc = fa.first.f32acc; \
|
||||
if (path == FA_COOPMAT2) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
if (device->coopmat2) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA_CM2_MIXED();
|
||||
}
|
||||
#undef CREATE_FA_CM2_MIXED
|
||||
#endif
|
||||
#undef CREATE_FA
|
||||
|
||||
|
|
@ -9042,8 +9094,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
assert(dst->type == GGML_TYPE_F32);
|
||||
assert(q->type == GGML_TYPE_F32);
|
||||
assert(k->type == v->type);
|
||||
|
||||
uint32_t gqa_ratio = 1;
|
||||
uint32_t qk_ratio = neq2 / nek2;
|
||||
uint32_t workgroups_x = (uint32_t)neq1;
|
||||
|
|
@ -9054,7 +9104,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
||||
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
|
||||
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc);
|
||||
const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
|
||||
|
||||
if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||
|
|
@ -9067,7 +9117,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
workgroups_y /= gqa_ratio;
|
||||
}
|
||||
|
||||
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
|
||||
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);
|
||||
|
||||
if (tuning_params.path != FA_COOPMAT2) {
|
||||
GGML_ASSERT(k->type == v->type);
|
||||
}
|
||||
|
||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
|
|
@ -9106,7 +9160,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
|
||||
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
|
||||
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
|
||||
mask != nullptr, use_mask_opt, logit_softcap != 0);
|
||||
mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type);
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
|
||||
|
|
@ -15590,38 +15644,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
// It's straightforward to support different K/V dequant, but would
|
||||
// significantly increase the number of pipelines
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
// mismatching K/V type is currently supported for coopmat2 only.
|
||||
if (op->src[1]->type != op->src[2]->type && !coopmat2) {
|
||||
return false;
|
||||
}
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
// supported in scalar and coopmat2 paths
|
||||
break;
|
||||
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
||||
//case GGML_TYPE_Q2_K:
|
||||
//case GGML_TYPE_Q3_K:
|
||||
//case GGML_TYPE_Q4_K:
|
||||
//case GGML_TYPE_Q5_K:
|
||||
//case GGML_TYPE_Q6_K:
|
||||
//case GGML_TYPE_IQ1_S:
|
||||
//case GGML_TYPE_IQ1_M:
|
||||
//case GGML_TYPE_IQ2_XXS:
|
||||
//case GGML_TYPE_IQ2_XS:
|
||||
//case GGML_TYPE_IQ2_S:
|
||||
//case GGML_TYPE_IQ3_XXS:
|
||||
//case GGML_TYPE_IQ3_S:
|
||||
//case GGML_TYPE_IQ4_XS:
|
||||
|
||||
default:
|
||||
auto fa_kv_ok = [coopmat2](ggml_type t) {
|
||||
switch (t) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
case GGML_TYPE_Q1_0:
|
||||
return coopmat2;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) {
|
||||
return false;
|
||||
}
|
||||
if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,12 @@ layout (constant_id = 8) const uint32_t SubGroupSize = 32;
|
|||
layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0;
|
||||
layout (constant_id = 10) const uint32_t Flags = 0;
|
||||
layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
|
||||
// ggml_type enumerant for K/V
|
||||
layout (constant_id = 12) const uint32_t FaTypeK = 0;
|
||||
layout (constant_id = 13) const uint32_t FaTypeV = 0;
|
||||
// sizeof(decode buffer): quants -> ggml block size; F32 -> 16 (decodeBufF32 vec4).
|
||||
layout (constant_id = 14) const uint32_t FaBlockBytesK = 2;
|
||||
layout (constant_id = 15) const uint32_t FaBlockBytesV = 2;
|
||||
|
||||
const bool USE_MASK_OPT = (Flags & 1) != 0;
|
||||
const bool MASK_ENABLE = (Flags & 2) != 0;
|
||||
|
|
|
|||
|
|
@ -17,8 +17,57 @@
|
|||
#extension GL_EXT_null_initializer : enable
|
||||
|
||||
#include "types.glsl"
|
||||
#include "dequant_funcs_cm2.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
#include "dequant_funcs_cm2.glsl"
|
||||
|
||||
// buffer_reference stride = sizeof(struct) = FaBlockBytesK/V.
|
||||
layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K {
|
||||
uint8_t raw[FaBlockBytesK];
|
||||
};
|
||||
layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_V {
|
||||
uint8_t raw[FaBlockBytesV];
|
||||
};
|
||||
|
||||
uint fa_block_elems(uint ty) {
|
||||
switch (ty) {
|
||||
case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32)
|
||||
case 1u: return 1u; // GGML_TYPE_F16
|
||||
case 2u: return uint(QUANT_K_Q4_0);
|
||||
case 3u: return uint(QUANT_K_Q4_1);
|
||||
case 6u: return uint(QUANT_K_Q5_0);
|
||||
case 7u: return uint(QUANT_K_Q5_1);
|
||||
case 8u: return uint(QUANT_K_Q8_0);
|
||||
case 41u: return uint(QUANT_K_Q1_0);
|
||||
default:
|
||||
return 1u;
|
||||
}
|
||||
}
|
||||
|
||||
float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeK) {
|
||||
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
|
||||
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return float16_t(0);
|
||||
}
|
||||
}
|
||||
|
||||
float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeV) {
|
||||
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
|
||||
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return float16_t(0);
|
||||
}
|
||||
}
|
||||
|
||||
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
||||
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
||||
|
|
@ -55,12 +104,6 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
|||
return max(elem0, elem1);
|
||||
}
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
#define DECODEFUNC , DEQUANTFUNC
|
||||
#else
|
||||
#define DECODEFUNC
|
||||
#endif
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
|
|
@ -95,10 +138,6 @@ ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c
|
|||
}
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
init_indices();
|
||||
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
|
|
@ -107,10 +146,10 @@ void main() {
|
|||
|
||||
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
|
||||
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
||||
#endif
|
||||
const uint bs_k = fa_block_elems(FaTypeK);
|
||||
const uint bs_v = fa_block_elems(FaTypeV);
|
||||
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, bs_k);
|
||||
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, bs_v);
|
||||
|
||||
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
|
||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
|
||||
|
|
@ -120,10 +159,12 @@ void main() {
|
|||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||
{
|
||||
q_stride &= ~7;
|
||||
#if BLOCK_SIZE == 1
|
||||
k_stride &= ~7;
|
||||
v_stride &= ~7;
|
||||
#endif
|
||||
if (bs_k == 1u) {
|
||||
k_stride &= ~7;
|
||||
}
|
||||
if (bs_v == 1u) {
|
||||
v_stride &= ~7;
|
||||
}
|
||||
m_stride &= ~7;
|
||||
}
|
||||
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
||||
|
|
@ -230,7 +271,13 @@ void main() {
|
|||
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||
|
||||
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
|
||||
// F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128.
|
||||
const bool k_use_decode = (bs_k > 1u);
|
||||
if (k_use_decode) {
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK);
|
||||
} else {
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose);
|
||||
}
|
||||
S = coopMatMulAdd(Qf16, K_T, S);
|
||||
|
||||
if (LOGIT_SOFTCAP) {
|
||||
|
|
@ -291,7 +338,12 @@ void main() {
|
|||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
|
||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
|
||||
const bool v_use_decode = (bs_v > 1u);
|
||||
if (v_use_decode) {
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV);
|
||||
} else {
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad));
|
||||
}
|
||||
|
||||
L = eM*L + rowsum;
|
||||
|
||||
|
|
|
|||
|
|
@ -641,20 +641,17 @@ void process_shaders() {
|
|||
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
}
|
||||
|
||||
if (fp16) {
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
if (fp16) {
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
|
||||
} else {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
|
|
|
|||
Loading…
Reference in New Issue