ggml-webgpu: Add MMVQ path for Q4/Q8/Q2_K/Q4_K and clean up legacy MUL_MAT pipeline (llama/23594)
* ggml-webgpu: Add MMVQ path for Q4/Q8/Q2_K/Q4_K * Fix to editorconfig checking pass * Remove mul-mat-legacy pipeline * Fix to use vendor name as is and add dot_product/vendor to shader_lib_ctx
This commit is contained in:
parent
bc77933c2d
commit
00a5110b19
|
|
@ -52,7 +52,7 @@
|
|||
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
|
||||
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4
|
||||
|
||||
// default size for legacy matrix multiplication
|
||||
// default size for reg-tile matrix multiplication
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
||||
|
||||
// Same hash combine function as in boost
|
||||
|
|
@ -93,6 +93,8 @@ struct ggml_webgpu_shader_lib_context {
|
|||
uint32_t sg_mat_k = 0;
|
||||
uint32_t min_subgroup_size = 0;
|
||||
uint32_t max_subgroup_size = 0;
|
||||
bool supports_dot_product = false;
|
||||
std::string vendor;
|
||||
};
|
||||
|
||||
struct webgpu_pipeline {
|
||||
|
|
@ -850,31 +852,15 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||
|
||||
/** Matrix Multiplication **/
|
||||
|
||||
struct ggml_webgpu_legacy_mul_mat_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
|
||||
bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_mul_mat_vec_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
int vectorized;
|
||||
bool use_mmvq;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized;
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
||||
use_mmvq == other.use_mmvq;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -884,6 +870,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
|||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
ggml_webgpu_hash_combine(seed, key.use_mmvq);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
|
@ -894,6 +881,20 @@ struct ggml_webgpu_mul_mat_vec_shader_decisions {
|
|||
uint32_t vec_size;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_quantize_q8_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
|
||||
bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_quantize_q8_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_mul_mat_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
|
|
@ -1051,6 +1052,36 @@ struct ggml_webgpu_soft_max_pipeline_key_hash {
|
|||
}
|
||||
};
|
||||
|
||||
/** MMVQ **/
|
||||
|
||||
inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
bool supports_dot_product,
|
||||
const std::string & vendor) {
|
||||
if (src1->ne[1] == 1) {
|
||||
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
|
||||
if (supports_dp4a && supports_dot_product) {
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
return src0->ne[0] % 4 == 0;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
class ggml_webgpu_shader_lib {
|
||||
wgpu::Device device;
|
||||
pre_wgsl::Preprocessor preprocessor;
|
||||
|
|
@ -1099,14 +1130,12 @@ class ggml_webgpu_shader_lib {
|
|||
webgpu_pipeline,
|
||||
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
|
||||
flash_attn_blk_pipelines;
|
||||
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
|
||||
mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec)
|
||||
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
|
||||
mul_mat_vec_pipelines; // fast mat-vec (n==1)
|
||||
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
|
||||
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
|
||||
std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash>
|
||||
quantize_q8_pipelines;
|
||||
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
|
||||
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
|
||||
mul_mat_id_pipelines; // src0_type/src1_type
|
||||
|
|
@ -1631,7 +1660,7 @@ class ggml_webgpu_shader_lib {
|
|||
key.type = context.dst->type;
|
||||
key.d_state = (int) context.src0->ne[0];
|
||||
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
|
||||
auto it = ssm_scan_pipelines.find(key);
|
||||
if (it != ssm_scan_pipelines.end()) {
|
||||
|
|
@ -1744,6 +1773,44 @@ class ggml_webgpu_shader_lib {
|
|||
return pad_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_quantize_q8_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
|
||||
auto it = quantize_q8_pipelines.find(key);
|
||||
if (it != quantize_q8_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
const char * shader_src = wgsl_quantize_q8;
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "quantize_q8";
|
||||
|
||||
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
|
||||
defines.push_back("SRC1_INNER_TYPE=f32");
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
std::string src0_name = src0_traits->type_name;
|
||||
std::string type_upper = src0_name;
|
||||
variant += "_" + src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
defines.push_back("Q8_1_T");
|
||||
|
||||
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
|
||||
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
|
||||
|
||||
auto processed = preprocessor.preprocess(shader_src, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
quantize_q8_pipelines[key] = pipeline;
|
||||
return quantize_q8_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
|
|
@ -1752,6 +1819,8 @@ class ggml_webgpu_shader_lib {
|
|||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
||||
|
||||
auto it = mul_mat_vec_pipelines.find(key);
|
||||
if (it != mul_mat_vec_pipelines.end()) {
|
||||
|
|
@ -1788,6 +1857,19 @@ class ggml_webgpu_shader_lib {
|
|||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
defines.push_back("SRC0_INNER_TYPE=u32");
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
if (key.use_mmvq) {
|
||||
defines.push_back("LEGACY_QUANTS");
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
if (key.use_mmvq) {
|
||||
defines.push_back("K_QUANTS");
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
|
|
@ -1840,6 +1922,11 @@ class ggml_webgpu_shader_lib {
|
|||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
||||
}
|
||||
|
||||
if (key.use_mmvq) {
|
||||
defines.push_back("MMVQ");
|
||||
defines.push_back("Q8_1_T");
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
|
||||
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
|
||||
|
|
@ -2018,100 +2105,6 @@ class ggml_webgpu_shader_lib {
|
|||
return mul_mat_fast_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_legacy_mul_mat_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
|
||||
auto it = mul_mat_legacy_pipelines.find(key);
|
||||
if (it != mul_mat_legacy_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "mul_mat";
|
||||
|
||||
switch (context.src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC1_TYPE=f32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC1_TYPE=f16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported src1 type for mul_mat legacy shader");
|
||||
}
|
||||
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
const char * src0_name = src0_traits->type_name;
|
||||
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC0_TYPE=f32");
|
||||
defines.push_back("FLOAT");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC0_TYPE=f16");
|
||||
defines.push_back("FLOAT");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
{
|
||||
std::string type_upper = src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
{
|
||||
// Quantized types using u32 buffers for portability.
|
||||
defines.push_back("SRC0_TYPE=u32");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
defines.push_back(type_upper);
|
||||
defines.push_back(type_upper + "_SCALE_MIN");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
|
||||
variant += std::string("_") + src0_name;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_mul_mat, defines);
|
||||
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE;
|
||||
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
mul_mat_legacy_pipelines[key] = pipeline;
|
||||
return mul_mat_legacy_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
auto it = mul_mat_id_gather_pipelines.find(1);
|
||||
if (it != mul_mat_id_gather_pipelines.end()) {
|
||||
|
|
|
|||
|
|
@ -181,6 +181,7 @@ struct webgpu_capabilities {
|
|||
wgpu::Limits limits;
|
||||
bool supports_subgroups = false;
|
||||
bool supports_subgroup_matrix = false;
|
||||
bool supports_dot_product = false;
|
||||
|
||||
uint32_t sg_mat_m = 0;
|
||||
uint32_t sg_mat_n = 0;
|
||||
|
|
@ -210,6 +211,8 @@ struct webgpu_global_context_struct {
|
|||
wgpu::Buffer memset_params_buf;
|
||||
webgpu_pipeline memset_pipeline;
|
||||
|
||||
std::string vendor;
|
||||
|
||||
// TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050
|
||||
#ifdef GGML_WEBGPU_CPU_PROFILE
|
||||
// Profiling: labeled CPU time in ms (total)
|
||||
|
|
@ -1394,6 +1397,58 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx,
|
|||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
std::vector<webgpu_dispatch_desc> & dispatches) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
|
||||
webgpu_pipeline qq8_pipeline = ctx->shader_lib->get_quantize_q8_pipeline(shader_lib_ctx);
|
||||
|
||||
// quantize_q8 pipeline
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
const size_t q8_src1_align_offset = ROUNDUP_POW2(
|
||||
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t q8_src1_binding_size =
|
||||
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
std::vector<uint32_t> q8_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> q8_entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1),
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), q8_src1_align_offset, q8_src1_binding_size)
|
||||
};
|
||||
|
||||
auto q8_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(qq8_pipeline.context.get());
|
||||
|
||||
uint32_t q8_wg_size = q8_decisions->wg_size;
|
||||
uint32_t q8_wg_x = 1;
|
||||
uint32_t q8_wg_y = 1;
|
||||
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
|
||||
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
|
||||
|
||||
dispatches.push_back({
|
||||
qq8_pipeline, std::move(q8_params), std::move(q8_entries), { q8_wg_x, q8_wg_y }
|
||||
});
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
|
|
@ -1401,47 +1456,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
// Determine if this is a mat-vec operation
|
||||
bool is_vec = (dst->ne[1] == 1);
|
||||
|
||||
// Determine if we should use fast path
|
||||
bool use_fast = false;
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
use_fast = (src0->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
// TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_MXFP4:
|
||||
use_fast = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
// use MMVQ path for mat-vec
|
||||
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
|
||||
ctx->global_ctx->vendor);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
|
||||
|
|
@ -1456,16 +1473,20 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
shader_lib_ctx.supports_dot_product = ctx->global_ctx->capabilities.supports_dot_product;
|
||||
shader_lib_ctx.vendor = ctx->global_ctx->vendor;
|
||||
|
||||
// Get or create pipeline
|
||||
webgpu_pipeline pipeline;
|
||||
webgpu_pipeline pipeline;
|
||||
std::vector<webgpu_dispatch_desc> dispatches;
|
||||
|
||||
if (use_fast && is_vec) {
|
||||
if (is_vec) {
|
||||
if (use_mmvq) {
|
||||
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
|
||||
}
|
||||
pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);
|
||||
} else if (use_fast) {
|
||||
pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
|
||||
} else {
|
||||
pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);
|
||||
pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
|
||||
}
|
||||
|
||||
// Build params
|
||||
|
|
@ -1489,25 +1510,31 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
};
|
||||
|
||||
// Build bind group entries
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
|
||||
};
|
||||
std::vector<wgpu::BindGroupEntry> entries = {};
|
||||
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0));
|
||||
if (use_mmvq) {
|
||||
auto & mmvq_qq8_entry = dispatches[0].bind_group_entries[1];
|
||||
entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), mmvq_qq8_entry.offset,
|
||||
mmvq_qq8_entry.size));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
|
||||
}
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
|
||||
|
||||
// Calculate workgroup dimensions
|
||||
uint32_t wg_x = 1;
|
||||
uint32_t wg_y = 1;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (use_fast && is_vec) {
|
||||
if (is_vec) {
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
} else if (use_fast) {
|
||||
} else {
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
// Fast-path tiled/subgroup calculations
|
||||
|
|
@ -1528,15 +1555,13 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
}
|
||||
uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
} else { // legacy
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_size = decisions->wg_size;
|
||||
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
dispatches.push_back({
|
||||
pipeline, std::move(params), std::move(entries), { wg_x, wg_y }
|
||||
});
|
||||
|
||||
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx,
|
||||
|
|
@ -3590,6 +3615,22 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
|||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
const ggml_tensor * src0 = tensor->src[0];
|
||||
const ggml_tensor * src1 = tensor->src[1];
|
||||
bool use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
|
||||
ctx->webgpu_global_ctx->vendor);
|
||||
if (use_mmvq) {
|
||||
const size_t q8_src1_size =
|
||||
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
res = ROUNDUP_POW2(res + q8_src1_size +
|
||||
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
const ggml_tensor * src0 = tensor->src[0];
|
||||
|
|
@ -3715,12 +3756,16 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||
ctx->webgpu_global_ctx->adapter.GetInfo(&info);
|
||||
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size();
|
||||
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches();
|
||||
ctx->webgpu_global_ctx->vendor = info.vendor;
|
||||
wgpu::SupportedFeatures features;
|
||||
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
|
||||
// we require f16 support
|
||||
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
|
||||
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
|
||||
// for dot4I8packed
|
||||
ctx->webgpu_global_ctx->capabilities.supports_dot_product = ctx->webgpu_global_ctx->instance.HasWGSLLanguageFeature(
|
||||
wgpu::WGSLLanguageFeatureName::Packed4x8IntegerDotProduct);
|
||||
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
|
|
|
|||
|
|
@ -95,11 +95,10 @@ struct q5_1 {
|
|||
};
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef Q8_1_T
|
||||
struct q8_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
s: f16, // d * sum(qs[i])
|
||||
qs: array<u32, 8>
|
||||
};
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,747 +0,0 @@
|
|||
enable f16;
|
||||
|
||||
#define DECLARE_BYTE_LOADERS_SRC0
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
|
||||
#ifdef FLOAT
|
||||
const BLOCK_SIZE = 1u;
|
||||
|
||||
#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
|
||||
const BLOCK_SIZE = 32u;
|
||||
|
||||
#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
|
||||
const BLOCK_SIZE = 256u;
|
||||
#endif
|
||||
|
||||
#ifdef FLOAT
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
|
||||
let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_lo * f32(src1[src1_offset]);
|
||||
sum += q_hi * f32(src1[src1_offset + 16]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q4_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q4_1.d);
|
||||
let m = f32(block_q4_1.m);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = block_q4_1.qs[j];
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
|
||||
let q_lo = f32(q_byte & 0xF) * d + m;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_lo * f32(src1[src1_offset]);
|
||||
sum += q_hi * f32(src1[src1_offset + 16]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q5_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
let qh_packed = load_u32_at_src0(block_byte_base + 2);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_byte_offset = block_byte_base + 6 + j * 4;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_lo * f32(src1[src1_offset]);
|
||||
sum += q_hi * f32(src1[src1_offset + 16]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q5_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q5_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q5_1.d);
|
||||
let m = f32(block_q5_1.m);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = block_q5_1.qs[j];
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
|
||||
let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
|
||||
let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_lo * f32(src1[src1_offset]);
|
||||
sum += q_hi * f32(src1[src1_offset + 16]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q8_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k: u32 = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_val * f32(src1[src1_offset]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q8_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q8_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q8_1.d);
|
||||
let m = f32(block_q8_1.m);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let q_packed = block_q8_1.qs[j];
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d + m;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
sum += q_val * f32(src1[src1_offset]);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q2_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let m = f32(block.dmin);
|
||||
var sum = 0.0;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
// 2 halves of the block (128 elements each)
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
|
||||
// 4 groups (each group has 2 blocks of 16 elements)
|
||||
for (var shift: u32 = 0; shift < 8; shift += 2) {
|
||||
// 2 blocks
|
||||
for (var k: u32 = 0; k < 32; k += 16) {
|
||||
let sc = get_byte(block.scales[is / 4], is % 4);
|
||||
is++;
|
||||
let dl = d * f32(sc & 0xF);
|
||||
let ml = m * f32(sc >> 4);
|
||||
for (var l: u32 = 0u; l < 16; l++) {
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
|
||||
let qs_val = (q_byte >> shift) & 3;
|
||||
sum += (f32(qs_val) * dl - ml) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q3_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
|
||||
|
||||
// Bytes 108-109: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base + 108);
|
||||
|
||||
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
|
||||
// and 2-bits from the last 4 bytes
|
||||
// Bytes 96-107: 12 bytes of scales (3 u32s)
|
||||
let kmask1: u32 = 0x03030303;
|
||||
let kmask2: u32 = 0x0f0f0f0f;
|
||||
var scale_vals: array<u32, 4>;
|
||||
scale_vals[0] = load_u32_at_src0(block_byte_base + 96);
|
||||
scale_vals[1] = load_u32_at_src0(block_byte_base + 100);
|
||||
scale_vals[2] = load_u32_at_src0(block_byte_base + 104);
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
// Bytes 0-31: 32 bytes of hmask (8 u32s)
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 32-95: 64 bytes of qs (16 u32s)
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0u; i < 16; i++) {
|
||||
qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4);
|
||||
}
|
||||
|
||||
var sum = 0.0;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
var m: u32 = 1;
|
||||
// 2 halves of the block (128 elements each)
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
|
||||
// 4 groups (each group has 2 blocks of 16 elements)
|
||||
for (var shift: u32 = 0; shift < 8; shift += 2) {
|
||||
// 2 blocks
|
||||
for (var k: u32 = 0; k < 32; k += 16) {
|
||||
let sc = get_byte(scale_vals[is / 4], is % 4);
|
||||
is++;
|
||||
let dl = d * (f32(sc) - 32.0);
|
||||
for (var l: u32 = 0u; l < 16u; l++) {
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let hm_idx = k + l;
|
||||
let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
|
||||
let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
|
||||
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
|
||||
let qs_val = (q_byte >> shift) & 3;
|
||||
sum += ((f32(qs_val) - hm) * dl) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
m <<= 1;
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_K
|
||||
// 8 blocks of 32 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let m = f32(block.dmin);
|
||||
var sum = 0.0;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
// 2 blocks each iteration
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
|
||||
for (var shift: u32 = 0; shift < 8; shift += 4) {
|
||||
let scale_min = get_scale_min(is, block.scales);
|
||||
is++;
|
||||
let dl = d * scale_min.x;
|
||||
let ml = m * scale_min.y;
|
||||
for (var l: u32 = 0; l < 32; l++) {
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
|
||||
let qs_val = (q_byte >> shift) & 0xF;
|
||||
sum += (f32(qs_val) * dl - ml) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q5_K
|
||||
// 8 blocks of 32 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let m = f32(block.dmin);
|
||||
var sum = 0.0;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
var u: u32 = 1;
|
||||
// 2 blocks each iteration
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
|
||||
for (var shift: u32 = 0; shift < 8; shift += 4) {
|
||||
let scale_min = get_scale_min(is, block.scales);
|
||||
is++;
|
||||
let dl = d * scale_min.x;
|
||||
let ml = m * scale_min.y;
|
||||
for (var l: u32 = 0; l < 32; l++) {
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
|
||||
let qh_byte = get_byte(block.qh[l / 4], l % 4);
|
||||
let qs_val = (q_byte >> shift) & 0xF;
|
||||
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
|
||||
sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
u <<= 1;
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes
|
||||
|
||||
// Bytes 208-209: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base + 208);
|
||||
|
||||
// Bytes 0-127: 128 bytes of ql (32 u32s)
|
||||
var ql_vals: array<u32, 32>;
|
||||
for (var i: u32 = 0; i < 32; i++) {
|
||||
ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 128-191: 64 bytes of qh (16 u32s)
|
||||
var qh_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 192-207: 16 bytes of scales (4 u32s)
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4);
|
||||
}
|
||||
|
||||
var sum = 0.0;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var qh_b_idx: u32 = 0;
|
||||
var sc_b_idx: u32 = 0;
|
||||
for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
|
||||
for (var l: u32 = 0; l < 32; l++) {
|
||||
let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
|
||||
let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
|
||||
let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
|
||||
|
||||
let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
|
||||
let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
|
||||
let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
|
||||
let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
|
||||
|
||||
let is = l/16;
|
||||
let is1 = sc_b_idx + is;
|
||||
let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
|
||||
let is2 = sc_b_idx + is + 2;
|
||||
let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
|
||||
let is3 = sc_b_idx + is + 4;
|
||||
let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
|
||||
let is4 = sc_b_idx + is + 6;
|
||||
let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
|
||||
|
||||
sum += d * f32(sc1) * q1 * src1[src1_i + l];
|
||||
sum += d * f32(sc2) * q2 * src1[src1_i + l + 32];
|
||||
sum += d * f32(sc3) * q3 * src1[src1_i + l + 64];
|
||||
sum += d * f32(sc4) * q4 * src1[src1_i + l + 96];
|
||||
}
|
||||
src1_i += 128;
|
||||
qh_b_idx += 32;
|
||||
sc_b_idx += 8;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let aux0_offset = block_byte_base + 2 + ib * 2;
|
||||
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
|
||||
let aux0 = load_u32_at_src0(aux0_offset);
|
||||
let aux1 = load_u32_at_src0(aux1_offset);
|
||||
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = get_byte(aux0, l) * 8;
|
||||
let is = (aux1 >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
|
||||
sum += db * f32(g) * m * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var scale_vals = array<u32, 2>(
|
||||
load_u32_at_src0(block_byte_base + 66),
|
||||
load_u32_at_src0(block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
|
||||
let db = array<f32, 2>(
|
||||
d * (0.5 + f32(s & 0xF)) * 0.25,
|
||||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
|
||||
let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF;
|
||||
let ig = (qs_val & 511) * 8;
|
||||
let is = qs_val >> 9;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let dl = db[l/2];
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
|
||||
sum += dl * f32(g) * m * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var qs_vals : array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
|
||||
}
|
||||
|
||||
var qh_vals: array<u32, 2>;
|
||||
qh_vals[0] = load_u32_at_src0(block_byte_base + 66);
|
||||
qh_vals[1] = load_u32_at_src0(block_byte_base + 70);
|
||||
|
||||
var scale_vals: array<u32, 2>;
|
||||
scale_vals[0] = load_u32_at_src0(block_byte_base + 74);
|
||||
scale_vals[1] = load_u32_at_src0(block_byte_base + 78);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib ++) {
|
||||
let s = get_byte(scale_vals[ib / 4], ib % 4);
|
||||
let db = array<f32, 2>(
|
||||
d * (0.5 + f32(s & 0xF)) * 0.25,
|
||||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
let qs_w = qs_vals[ib];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
|
||||
let ig = (get_byte(qs_w, l) | qh_b) * 8;
|
||||
let signs = get_byte(qs_vals[ib + 8], l);
|
||||
let dl = db[l/2];
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
|
||||
sum += dl * f32(g) * m * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 16; ib += 2) {
|
||||
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
|
||||
let sc_sign = load_u32_at_src0(sc_sign_offset);
|
||||
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let is = (sc_sign >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0);
|
||||
let ig2 = get_byte(ig_val, 1);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let g1 = get_byte(iq3xxs_grid[ig1], j);
|
||||
let g2 = get_byte(iq3xxs_grid[ig2], j);
|
||||
let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
|
||||
let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
|
||||
sum += db * f32(g1) * m1 * src1[src1_i];
|
||||
sum += db * f32(g2) * m2 * src1[src1_i + 4];
|
||||
src1_i++;
|
||||
}
|
||||
src1_i += 4;
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var qh_vals = array<u32, 2>(
|
||||
load_u32_at_src0(block_byte_base + 66),
|
||||
load_u32_at_src0(block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sign_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4);
|
||||
}
|
||||
|
||||
var scale_vals = load_u32_at_src0(block_byte_base + 106);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 4; ib++) {
|
||||
let s = get_byte(scale_vals, ib);
|
||||
let db = array<f32, 2>(
|
||||
d * (1.0 + 2.0 * f32(s & 0xF)),
|
||||
d * (1.0 + 2.0 * f32(s >> 4))
|
||||
);
|
||||
for (var k: u32 = 0; k < 2; k++) {
|
||||
let dl = db[k];
|
||||
let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
|
||||
let sign_w = sign_vals[ib * 2 + k];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let signs = get_byte(sign_w, l);
|
||||
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
|
||||
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let g1 = get_byte(iq3s_grid[ig1], j);
|
||||
let g2 = get_byte(iq3s_grid[ig2], j);
|
||||
let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
|
||||
let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
|
||||
sum += dl * f32(g1) * m1 * src1[src1_i];
|
||||
sum += dl * f32(g2) * m2 * src1[src1_i + 4];
|
||||
src1_i++;
|
||||
}
|
||||
src1_i += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ1_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF;
|
||||
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
|
||||
let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let gw = iq1_grid[(ig + j) / 16];
|
||||
let g = (gw >> (((ig + j) % 16) * 2)) & 3;
|
||||
let gs = bitcast<i32>(g << 30) >> 30;
|
||||
sum += dl * (f32(gs) + delta) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef IQ1_M
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
||||
let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
|
||||
let d = f32(bitcast<vec2<f16>>(scale).x);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
|
||||
let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
|
||||
let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
|
||||
var dl = array<f32, 2>(
|
||||
d * f32(2 * s1 + 1),
|
||||
d * f32(2 * s2 + 1)
|
||||
);
|
||||
|
||||
let qh = block.qh[ib / 2] >> (16 * (ib % 2));
|
||||
var idx = array<u32, 4>(
|
||||
get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
|
||||
get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
|
||||
get_byte(block.qs[ib], 2) | ((qh) & 0x700),
|
||||
get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
|
||||
);
|
||||
var delta = array<f32, 4>(
|
||||
select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
|
||||
select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
|
||||
select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
|
||||
select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = idx[l] * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let gw = iq1_grid[(ig + j) / 16];
|
||||
let g = (gw >> (((ig + j) % 16) * 2)) & 3;
|
||||
let gs = bitcast<i32>(g << 30) >> 30;
|
||||
sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i];
|
||||
src1_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ4_NL
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 32;
|
||||
var sum = 0.0;
|
||||
var qs: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
|
||||
}
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let qsb = get_byte(qs[j / 4], j % 4);
|
||||
sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
|
||||
sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
|
||||
src1_i++;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef IQ4_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = unpack2x16float(block.d_scales_h)[0];
|
||||
let scales_h = block.d_scales_h >> 16;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
|
||||
let dl = d * (f32(ls) - 32.0);
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let iqs = ib * 16 + j;
|
||||
let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
|
||||
sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
|
||||
sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
|
||||
src1_i++;
|
||||
}
|
||||
src1_i += 16;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32, // in elements/blocks
|
||||
offset_src1: u32, // in elements/blocks
|
||||
offset_dst: u32, // in elements/blocks
|
||||
m: u32,
|
||||
n: u32,
|
||||
k: u32,
|
||||
// all strides are in elements/blocks
|
||||
stride_01: u32,
|
||||
stride_11: u32,
|
||||
stride_02: u32,
|
||||
stride_12: u32,
|
||||
stride_03: u32,
|
||||
stride_13: u32,
|
||||
|
||||
bs02: u32,
|
||||
bs03: u32,
|
||||
broadcast2: u32,
|
||||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn main(@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
let global_idx = wg_linear * 256u + local_id.x;
|
||||
|
||||
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
|
||||
if (global_idx >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
let dst2_stride = params.m * params.n;
|
||||
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||
|
||||
let dst3_idx = global_idx / dst3_stride;
|
||||
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
|
||||
let src13_idx = dst3_idx; // src1 is not broadcast
|
||||
let dst3_rem = global_idx % dst3_stride;
|
||||
|
||||
let dst2_idx = dst3_rem / dst2_stride;
|
||||
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
|
||||
let src12_idx = dst2_idx; // src1 is not broadcast
|
||||
|
||||
let dst2_rem = dst3_rem % dst2_stride;
|
||||
|
||||
let row = dst2_rem / params.m; // output row
|
||||
let col = dst2_rem % params.m; // output column
|
||||
|
||||
let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
|
||||
|
||||
var sum = 0.0;
|
||||
for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
|
||||
sum += multiply_add(src0_idx_base, src1_idx_base, i);
|
||||
}
|
||||
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
|
||||
}
|
||||
|
|
@ -3,10 +3,18 @@ enable subgroups;
|
|||
#endif
|
||||
enable f16;
|
||||
|
||||
#ifdef MMVQ
|
||||
requires packed_4x8_integer_dot_product;
|
||||
#endif
|
||||
|
||||
#define DECLARE_BYTE_LOADERS_SRC0
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
#ifdef MMVQ
|
||||
#include "mul_mat_vec_q_acc.tmpl"
|
||||
#else
|
||||
#include "mul_mat_vec_acc.tmpl"
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
|
|
@ -28,9 +36,14 @@ struct MulMatParams {
|
|||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>;
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>;
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>;
|
||||
|
||||
#ifdef MMVQ
|
||||
@group(0) @binding(1) var<storage, read_write> src1q: array<q8_1>;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>;
|
||||
#endif
|
||||
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>;
|
||||
// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
|
|
@ -75,10 +88,15 @@ fn main(
|
|||
let src12_idx = dst2_idx;
|
||||
|
||||
let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
|
||||
|
||||
#ifdef MMVQ
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
|
||||
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
|
||||
#else
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
|
||||
#endif
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
|
|
|
|||
|
|
@ -436,7 +436,6 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src
|
|||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef MUL_ACC_Q3_K
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 110
|
||||
|
|
|
|||
|
|
@ -0,0 +1,303 @@
|
|||
#ifdef U32_DEQUANT_HELPERS
|
||||
#define SRC0_TYPE u32
|
||||
|
||||
fn byte_of(v: u32, b: u32) -> u32 {
|
||||
return (v >> (b * 8u)) & 0xFFu;
|
||||
}
|
||||
|
||||
fn sbyte_of(v: u32, b: u32) -> i32 {
|
||||
let raw = i32((v >> (b * 8u)) & 0xFFu);
|
||||
return select(raw, raw - 256, raw >= 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define SRC0_TYPE SRC0_INNER_TYPE
|
||||
#define SRC1_TYPE SRC1_INNER_TYPE
|
||||
|
||||
#ifdef LEGACY_QUANTS
|
||||
#define BLOCK_SIZE 32
|
||||
#define THREADS_PER_BLOCK 4
|
||||
#elif K_QUANTS
|
||||
#define BLOCK_SIZE 256
|
||||
#define THREADS_PER_BLOCK 16
|
||||
#endif
|
||||
|
||||
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
|
||||
#define Q8_BLOCK_SIZE 32
|
||||
|
||||
#ifdef MUL_ACC_Q4_0
|
||||
#define BLOCK_SIZE_BYTES 18
|
||||
#define B_DS_TYPE vec2<f32>
|
||||
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
|
||||
let qs_packed = load_u32_at_src0(block_byte_base + 2u + 4u * inner_id);
|
||||
|
||||
return vec2<u32>(
|
||||
qs_packed & 0x0F0F0F0Fu,
|
||||
(qs_packed >> 4u) & 0x0F0F0F0Fu
|
||||
);
|
||||
}
|
||||
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
|
||||
return vec2<u32>(
|
||||
src1q[block].qs[inner_id],
|
||||
src1q[block].qs[inner_id + 4u],
|
||||
);
|
||||
}
|
||||
fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
return B_DS_TYPE(
|
||||
f32(src1q[block].d),
|
||||
f32(src1q[block].s)
|
||||
);
|
||||
}
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q4_1
|
||||
#define BLOCK_SIZE_BYTES 20
|
||||
#define B_DS_TYPE vec2<f32>
|
||||
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
|
||||
let qs_packed = load_u32_at_src0(block_byte_base + 4u + 4u * inner_id);
|
||||
|
||||
return vec2<u32>(
|
||||
qs_packed & 0x0F0F0F0Fu,
|
||||
(qs_packed >> 4u) & 0x0F0F0F0Fu
|
||||
);
|
||||
}
|
||||
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
|
||||
return vec2<u32>(
|
||||
src1q[block].qs[inner_id],
|
||||
src1q[block].qs[inner_id + 4u],
|
||||
);
|
||||
}
|
||||
fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
return B_DS_TYPE(
|
||||
f32(src1q[block].d),
|
||||
f32(src1q[block].s)
|
||||
);
|
||||
}
|
||||
fn get_dm(block_byte_base: u32) -> vec2<f32> {
|
||||
return vec2<f32>(
|
||||
f32(load_f16_at_src0(block_byte_base)),
|
||||
f32(load_f16_at_src0(block_byte_base + 2u))
|
||||
);
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q8_0
|
||||
#define BLOCK_SIZE_BYTES 34
|
||||
#define B_DS_TYPE f32
|
||||
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
|
||||
return vec2<u32>(
|
||||
load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u)),
|
||||
load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u + 1))
|
||||
);
|
||||
}
|
||||
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
|
||||
return vec2<u32>(
|
||||
src1q[block].qs[inner_id * 2u],
|
||||
src1q[block].qs[inner_id * 2u + 1],
|
||||
);
|
||||
}
|
||||
fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
return B_DS_TYPE(src1q[block].d);
|
||||
}
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef LEGACY_QUANTS
|
||||
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
var row_sum = 0;
|
||||
let a_repacked = repack_a(a_byte_base, b_inner_id);
|
||||
|
||||
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
|
||||
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
|
||||
}
|
||||
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let b_inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
let b_block_idx = src1q_idx_base + block;
|
||||
|
||||
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
|
||||
let b_ds = repack_b_dm(b_block_idx);
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q2_K
|
||||
#define BLOCK_SIZE_BYTES 84
|
||||
#define B_DS_TYPE f32
|
||||
fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> {
|
||||
let ih2 = tid / 8u;
|
||||
let phase = tid % 2u;
|
||||
let iq4_idx = 2u * ih2 + phase;
|
||||
let qs_byte_base = block_byte_base + 16u + 16u * iq4_idx;
|
||||
let qs_shift = tid & 6u;
|
||||
return vec4<u32>(
|
||||
(load_u32_at_src0_aligned(qs_byte_base) >> qs_shift) & 0x03030303u,
|
||||
(load_u32_at_src0_aligned(qs_byte_base + 4u) >> qs_shift) & 0x03030303u,
|
||||
(load_u32_at_src0_aligned(qs_byte_base + 8u) >> qs_shift) & 0x03030303u,
|
||||
(load_u32_at_src0_aligned(qs_byte_base + 12u) >> qs_shift) & 0x03030303u,
|
||||
);
|
||||
}
|
||||
fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> {
|
||||
let phase = tid % 2u;
|
||||
return vec4<u32>(
|
||||
src1q[q8_block_idx].qs[4u * phase],
|
||||
src1q[q8_block_idx].qs[4u * phase + 1u],
|
||||
src1q[q8_block_idx].qs[4u * phase + 2u],
|
||||
src1q[q8_block_idx].qs[4u * phase + 3u],
|
||||
);
|
||||
}
|
||||
fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE {
|
||||
return B_DS_TYPE(src1q[q8_block_idx].d);
|
||||
}
|
||||
fn get_dm(block_byte_base: u32) -> vec2<f32> {
|
||||
return vec2<f32>(
|
||||
f32(load_f16_at_src0(block_byte_base + 80u)),
|
||||
f32(load_f16_at_src0(block_byte_base + 82u)),
|
||||
);
|
||||
}
|
||||
fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
let scale_byte = block_byte_base + tid;
|
||||
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
|
||||
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q4_K
|
||||
#define BLOCK_SIZE_BYTES 144
|
||||
#define B_DS_TYPE vec2<f32>
|
||||
fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> {
|
||||
let iq4 = tid / 4u;
|
||||
let phase = tid % 2u;
|
||||
let nibble = (tid >> 1u) % 2u;
|
||||
let q_qs_byte_base = block_byte_base + 16u + 32u * iq4 + 16u * phase;
|
||||
let qs_shift = 4u * nibble;
|
||||
return vec4<u32>(
|
||||
(load_u32_at_src0_aligned(q_qs_byte_base) >> qs_shift) & 0x0F0F0F0Fu,
|
||||
(load_u32_at_src0_aligned(q_qs_byte_base + 4u) >> qs_shift) & 0x0F0F0F0Fu,
|
||||
(load_u32_at_src0_aligned(q_qs_byte_base + 8u) >> qs_shift) & 0x0F0F0F0Fu,
|
||||
(load_u32_at_src0_aligned(q_qs_byte_base + 12u) >> qs_shift) & 0x0F0F0F0Fu,
|
||||
);
|
||||
}
|
||||
fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> {
|
||||
let phase = tid % 2u;
|
||||
return vec4<u32>(
|
||||
src1q[q8_block_idx].qs[4u * phase],
|
||||
src1q[q8_block_idx].qs[4u * phase + 1u],
|
||||
src1q[q8_block_idx].qs[4u * phase + 2u],
|
||||
src1q[q8_block_idx].qs[4u * phase + 3u],
|
||||
);
|
||||
}
|
||||
fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE {
|
||||
return B_DS_TYPE(
|
||||
f32(src1q[q8_block_idx].d),
|
||||
f32(src1q[q8_block_idx].s),
|
||||
);
|
||||
}
|
||||
fn get_dm(block_byte_base: u32) -> vec2<f32> {
|
||||
return vec2<f32>(
|
||||
f32(load_f16_at_src0(block_byte_base + 0u)),
|
||||
f32(load_f16_at_src0(block_byte_base + 2u)),
|
||||
);
|
||||
}
|
||||
fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
let sc_m_idx = tid / 2u;
|
||||
let scales_byte_base = block_byte_base + 4u;
|
||||
let scales0_3 = load_u32_at_src0_aligned(scales_byte_base);
|
||||
let scales4_7 = load_u32_at_src0_aligned(scales_byte_base + 4u);
|
||||
let scales8_11 = load_u32_at_src0_aligned(scales_byte_base + 8u);
|
||||
|
||||
let byte_idx = sc_m_idx & 3u;
|
||||
let is_high = sc_m_idx >= 4u;
|
||||
|
||||
let sc_low = byte_of(scales0_3, byte_idx) & 0x3Fu;
|
||||
let sc_high = (byte_of(scales8_11, byte_idx) & 0x0Fu) | ((byte_of(scales0_3, byte_idx) & 0xC0u) >> 2u);
|
||||
let scale = f32(select(sc_low, sc_high, is_high));
|
||||
|
||||
let mn_low = byte_of(scales4_7, byte_idx) & 0x3Fu;
|
||||
let mn_high = (byte_of(scales8_11, byte_idx) >> 4u) | ((byte_of(scales4_7, byte_idx) & 0xC0u) >> 2u);
|
||||
let min_val = f32(select(mn_low, mn_high, is_high));
|
||||
|
||||
return vec2<f32>(scale, min_val);
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef K_QUANTS
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
enable subgroups;
|
||||
#endif
|
||||
enable f16;
|
||||
|
||||
requires packed_4x8_integer_dot_product;
|
||||
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
struct Params {
|
||||
offset_src1: u32,
|
||||
stride_12: u32,
|
||||
stride_13: u32,
|
||||
ne0: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
};
|
||||
|
||||
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src1: array<SRC1_TYPE>;
|
||||
@group(0) @binding(1) var<storage, read_write> src1q: array<q8_1>;
|
||||
|
||||
@group(0) @binding(2) var<uniform> params: Params;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
fn cluster_max_8(v: f32) -> f32 {
|
||||
var r = v;
|
||||
r = max(r, subgroupShuffleXor(r, 1u));
|
||||
r = max(r, subgroupShuffleXor(r, 2u));
|
||||
r = max(r, subgroupShuffleXor(r, 4u));
|
||||
return r;
|
||||
}
|
||||
|
||||
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
|
||||
fn cluster_add_i4x8(v: i32) -> i32 {
|
||||
var r= v;
|
||||
r += subgroupShuffleXor(r, 1u);
|
||||
r += subgroupShuffleXor(r, 2u);
|
||||
r += subgroupShuffleXor(r, 4u);
|
||||
return r;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
#define CLUSTER_SIZE 8
|
||||
|
||||
var<workgroup> partial_amaxs: array<array<f32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>;
|
||||
var<workgroup> partial_sums: array<array<i32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>;
|
||||
#endif
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
) {
|
||||
let thread_id = local_id.x;
|
||||
let num_vec4 = params.ne0 / 4u;
|
||||
|
||||
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne2 * params.ne3;
|
||||
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
if (wg_linear >= total_batches) {
|
||||
return;
|
||||
}
|
||||
|
||||
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
|
||||
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let src1_idx_vec4_base = src1_idx_base / 4u;
|
||||
|
||||
let blocks_per_row = params.ne0 / 32u;
|
||||
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
|
||||
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
|
||||
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
|
||||
let qs_idx = thread_id % 8u;
|
||||
|
||||
// reduction
|
||||
var q4 = vec4<f32>(0.0);
|
||||
var q4_quants = 0u;
|
||||
var thread_amax = 0.0;
|
||||
|
||||
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
|
||||
let is_valid = src11_vec4_idx < num_vec4;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
|
||||
var d = 0.0;
|
||||
|
||||
if (is_valid) {
|
||||
q4 = src1[src1_idx_vec4_base + src11_vec4_idx];
|
||||
let abs_q4 = abs(q4);
|
||||
thread_amax = max(max(abs_q4[0u], abs_q4[1u]), max(abs_q4[2], abs_q4[3]));
|
||||
}
|
||||
|
||||
d = cluster_max_8(thread_amax) / 127.0;
|
||||
|
||||
if (is_valid) {
|
||||
let id = select(0.0, 1.0 / d, d > 0.0);
|
||||
q4_quants = pack4xI8(vec4<i32>(round(q4 * id)));
|
||||
if (qs_idx == 0u) {
|
||||
src1q[src1q_idx].d = f16(d);
|
||||
}
|
||||
src1q[src1q_idx].qs[qs_idx] = q4_quants;
|
||||
}
|
||||
|
||||
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
|
||||
let q4_quants_sum = dot4I8Packed(q4_quants, 0x01010101u);
|
||||
let s = f16(d * f32(cluster_add_i4x8(q4_quants_sum)));
|
||||
|
||||
if (is_valid) {
|
||||
if (qs_idx == 0u) {
|
||||
src1q[src1q_idx].s = s;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
|
||||
var d = 0.0;
|
||||
let cluster_id = thread_id / 8u;
|
||||
|
||||
if (is_valid) {
|
||||
q4 = src1[src1_idx_vec4_base + src11_vec4_idx];
|
||||
let abs_q4 = abs(q4);
|
||||
thread_amax = max(max(abs_q4[0], abs_q4[1]), max(abs_q4[2], abs_q4[3]));
|
||||
partial_amaxs[cluster_id][qs_idx] = thread_amax;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (is_valid) {
|
||||
let amax = max(
|
||||
max(
|
||||
max(partial_amaxs[cluster_id][0], partial_amaxs[cluster_id][1]), max(partial_amaxs[cluster_id][2], partial_amaxs[cluster_id][3])),
|
||||
max(
|
||||
max(partial_amaxs[cluster_id][4], partial_amaxs[cluster_id][5]), max(partial_amaxs[cluster_id][6], partial_amaxs[cluster_id][7]))
|
||||
);
|
||||
|
||||
d = amax / 127.0;
|
||||
let id = select(0.0f, 1.0f / d, d > 0.0f);
|
||||
|
||||
q4_quants = pack4xI8(vec4<i32>(round(q4 * id)));
|
||||
src1q[src1q_idx].qs[qs_idx] = q4_quants;
|
||||
|
||||
if (qs_idx == 0u) {
|
||||
src1q[src1q_idx].d = f16(d);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
|
||||
|
||||
partial_sums[cluster_id][qs_idx] = dot4I8Packed(q4_quants, 0x01010101u);
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (is_valid) {
|
||||
if (qs_idx == 0u) {
|
||||
let s = d * f32(partial_sums[cluster_id][0] + partial_sums[cluster_id][1] + partial_sums[cluster_id][2] + partial_sums[cluster_id][3]
|
||||
+ partial_sums[cluster_id][4] + partial_sums[cluster_id][5] + partial_sums[cluster_id][6] + partial_sums[cluster_id][7]);
|
||||
src1q[src1q_idx].s = f16(s);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
}
|
||||
Loading…
Reference in New Issue