opencl: add MoE support for q4_k, q5_k, q6_k on Adreno (llama/23303)

* opencl: add q4_k moe support

* opencl: add q5_k moe support

* opencl: add q6_k moe support

* opencl: adjust format

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
This commit is contained in:
shaofeiqi 2026-05-19 14:29:00 -07:00 committed by Georgi Gerganov
parent aca63e7638
commit 37f17208c2
9 changed files with 2600 additions and 7 deletions

View File

@ -110,6 +110,12 @@ set(GGML_OPENCL_KERNELS
gemv_moe_q5_0_f32_ns
gemm_moe_q5_1_f32_ns
gemv_moe_q5_1_f32_ns
gemm_moe_q4_k_f32_ns
gemv_moe_q4_k_f32_ns
gemm_moe_q5_k_f32_ns
gemv_moe_q5_k_f32_ns
gemm_moe_q6_k_f32_ns
gemv_moe_q6_k_f32_ns
gemm_moe_mxfp4_f32
gemv_moe_mxfp4_f32
gemm_moe_mxfp4_f32_ns

File diff suppressed because it is too large Load Diff

View File

@ -664,6 +664,391 @@ kernel void kernel_restore_block_q5_1_trans4_ns(
((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
}
kernel void kernel_convert_block_q4_k_trans4_ns(
__global struct block_q4_K * src0,
__global uint * dst_q,
__global half * dst_d,
__global half * dst_dm,
__global uchar * dst_s,
uint ne00,
uint ne01,
uchar mask_0F,
uchar mask_F0
) {
uint i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
__global struct block_q4_K * b = src0 + src_blk_offset;
dst_d [dst_blk_offset] = b->d;
dst_dm[dst_blk_offset] = b->dm;
uint4 qv[8];
uchar * qv_bytes = (uchar *)qv;
for (int i = 0; i < QK_K / 64; ++i) {
for (int j = 0; j < 16; ++j) {
uchar x0 = b->q[i*32 + 2*j];
uchar x1 = b->q[i*32 + 2*j + 1];
qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
}
}
uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
#pragma unroll
for (int p = 0; p < 8; ++p) {
uint4 v = qv[p];
dst_q[base + (p * 4 + 0) * ne01] = v.x;
dst_q[base + (p * 4 + 1) * ne01] = v.y;
dst_q[base + (p * 4 + 2) * ne01] = v.z;
dst_q[base + (p * 4 + 3) * ne01] = v.w;
}
__global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE;
#pragma unroll
for (int i = 0; i < K_SCALE_SIZE; ++i) {
s_dst[i] = b->s[i];
}
}
kernel void kernel_restore_block_q4_k_trans4_ns(
__global uint * src_q,
__global half * src_d,
__global half * src_dm,
__global uchar * src_s,
__global struct block_q4_K * dst0,
uint ne00,
uint ne01,
uchar mask_0F,
uchar mask_F0
) {
uint i00 = get_global_id(1); // block index along K
uint i01 = get_global_id(0); // row index
uint i02 = get_global_id(2); // batch index
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
__global struct block_q4_K * b = dst0 + dst_blk_offset;
b->d = src_d[src_blk_offset];
b->dm = src_dm[src_blk_offset];
__global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE;
for (int i = 0; i < K_SCALE_SIZE; ++i) {
b->s[i] = s_src[i];
}
uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
uint4 qv[8];
for (int p = 0; p < 8; ++p) {
qv[p].x = src_q[base + (p * 4 + 0) * ne01];
qv[p].y = src_q[base + (p * 4 + 1) * ne01];
qv[p].z = src_q[base + (p * 4 + 2) * ne01];
qv[p].w = src_q[base + (p * 4 + 3) * ne01];
}
uchar * qv_bytes = (uchar *)qv;
for (int i = 0; i < QK_K / 64; ++i) {
for (int j = 0; j < 16; ++j) {
uchar lo = qv_bytes[i*32 + j];
uchar hi = qv_bytes[i*32 + j + 16];
b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4));
b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0));
}
}
}
kernel void kernel_convert_block_q5_k_trans4_ns(
__global struct block_q5_K * src0,
__global uint * dst_qs,
__global uint * dst_qh,
__global half * dst_d,
__global half * dst_dm,
__global uchar * dst_s,
uint ne00,
uint ne01,
uchar mask_0F,
uchar mask_F0
) {
uint i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
__global struct block_q5_K * b = src0 + src_blk_offset;
dst_d [dst_blk_offset] = b->d;
dst_dm[dst_blk_offset] = b->dm;
for (int k = 0; k < 8; k++) {
uchar b0 = 0, b1 = 0, b2 = 0, b3 = 0;
for (int bit = 0; bit < 8; bit++) {
b0 |= (uchar)(((b->qh[bit] >> k) & 1) << bit);
b1 |= (uchar)(((b->qh[8 + bit] >> k) & 1) << bit);
b2 |= (uchar)(((b->qh[16 + bit] >> k) & 1) << bit);
b3 |= (uchar)(((b->qh[24 + bit] >> k) & 1) << bit);
}
uint packed = (uint)b0 | ((uint)b1 << 8) | ((uint)b2 << 16) | ((uint)b3 << 24);
dst_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01] = packed;
}
uint4 qv[8];
uchar * qv_bytes = (uchar *)qv;
for (int i = 0; i < QK_K / 64; ++i) {
for (int j = 0; j < 16; ++j) {
uchar x0 = b->qs[i*32 + 2*j];
uchar x1 = b->qs[i*32 + 2*j + 1];
qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
}
}
uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
#pragma unroll
for (int p = 0; p < 8; ++p) {
uint4 v = qv[p];
dst_qs[base + (p * 4 + 0) * ne01] = v.x;
dst_qs[base + (p * 4 + 1) * ne01] = v.y;
dst_qs[base + (p * 4 + 2) * ne01] = v.z;
dst_qs[base + (p * 4 + 3) * ne01] = v.w;
}
__global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE;
#pragma unroll
for (int i = 0; i < K_SCALE_SIZE; ++i) {
s_dst[i] = b->s[i];
}
}
kernel void kernel_restore_block_q5_k_trans4_ns(
__global uint * src_qs,
__global uint * src_qh,
__global half * src_d,
__global half * src_dm,
__global uchar * src_s,
__global struct block_q5_K * dst0,
uint ne00,
uint ne01,
uchar mask_0F,
uchar mask_F0
) {
uint i00 = get_global_id(1); // block index along K
uint i01 = get_global_id(0); // row index
uint i02 = get_global_id(2); // batch index
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
__global struct block_q5_K * b = dst0 + dst_blk_offset;
b->d = src_d[src_blk_offset];
b->dm = src_dm[src_blk_offset];
for (int j = 0; j < 32; j++) b->qh[j] = 0;
for (int k = 0; k < 8; k++) {
uint packed = src_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01];
uchar b0 = (uchar)(packed & 0xFF);
uchar b1 = (uchar)((packed >> 8) & 0xFF);
uchar b2 = (uchar)((packed >> 16) & 0xFF);
uchar b3 = (uchar)((packed >> 24) & 0xFF);
for (int bit = 0; bit < 8; bit++) {
b->qh[bit] |= (uchar)(((b0 >> bit) & 1) << k);
b->qh[8 + bit] |= (uchar)(((b1 >> bit) & 1) << k);
b->qh[16 + bit] |= (uchar)(((b2 >> bit) & 1) << k);
b->qh[24 + bit] |= (uchar)(((b3 >> bit) & 1) << k);
}
}
__global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE;
for (int i = 0; i < K_SCALE_SIZE; ++i) {
b->s[i] = s_src[i];
}
uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
uint4 qv[8];
for (int p = 0; p < 8; ++p) {
qv[p].x = src_qs[base + (p * 4 + 0) * ne01];
qv[p].y = src_qs[base + (p * 4 + 1) * ne01];
qv[p].z = src_qs[base + (p * 4 + 2) * ne01];
qv[p].w = src_qs[base + (p * 4 + 3) * ne01];
}
uchar * qv_bytes = (uchar *)qv;
for (int i = 0; i < QK_K / 64; ++i) {
for (int j = 0; j < 16; ++j) {
uchar lo = qv_bytes[i*32 + j];
uchar hi = qv_bytes[i*32 + j + 16];
b->qs[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4));
b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0));
}
}
}
kernel void kernel_convert_block_q6_k_trans4_ns(
__global struct block_q6_K * src0,
__global uint * dst_ql,
__global uint * dst_qh,
__global half * dst_d,
__global char * dst_s,
uint ne00,
uint ne01,
uchar mask_0F,
uchar mask_F0
) {
uint i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
__global struct block_q6_K * b = src0 + src_blk_offset;
dst_d[dst_blk_offset] = b->d;
uint4 qlv[8];
uchar * qlv_bytes = (uchar *)qlv;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 16; ++j) {
uchar x0 = b->ql[i*64 + 2*j];
uchar x1 = b->ql[i*64 + 2*j + 1];
uchar x2 = b->ql[i*64 + 32 + 2*j];
uchar x3 = b->ql[i*64 + 32 + 2*j + 1];
qlv_bytes[i*64 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
qlv_bytes[i*64 + j + 16] = convert_uchar(x2 & mask_0F) | convert_uchar((x3 & mask_0F) << 4);
qlv_bytes[i*64 + j + 32] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
qlv_bytes[i*64 + j + 48] = convert_uchar((x2 & mask_F0) >> 4) | convert_uchar(x3 & mask_F0);
}
}
uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
#pragma unroll
for (int p = 0; p < 8; ++p) {
uint4 v = qlv[p];
dst_ql[ql_base + (p * 4 + 0) * ne01] = v.x;
dst_ql[ql_base + (p * 4 + 1) * ne01] = v.y;
dst_ql[ql_base + (p * 4 + 2) * ne01] = v.z;
dst_ql[ql_base + (p * 4 + 3) * ne01] = v.w;
}
uint qhv[16] = {0};
for (int n = 0; n < 2; ++n) {
for (int l = 0; l < 32; ++l) {
uchar h = b->qh[n*32 + l];
int u = l / 16;
int bit_pos = (l % 16) * 2;
qhv[(n*4 + 0)*2 + u] |= ((uint)((h >> 0) & 0x03)) << bit_pos;
qhv[(n*4 + 1)*2 + u] |= ((uint)((h >> 2) & 0x03)) << bit_pos;
qhv[(n*4 + 2)*2 + u] |= ((uint)((h >> 4) & 0x03)) << bit_pos;
qhv[(n*4 + 3)*2 + u] |= ((uint)((h >> 6) & 0x03)) << bit_pos;
}
}
uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01;
for (int p = 0; p < 16; ++p) {
dst_qh[qh_base + p * ne01] = qhv[p];
}
__global char * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16;
#pragma unroll
for (int i = 0; i < 16; ++i) {
s_dst[i] = b->scales[i];
}
}
kernel void kernel_restore_block_q6_k_trans4_ns(
__global uint * src_ql,
__global uint * src_qh,
__global half * src_d,
__global char * src_s,
__global struct block_q6_K * dst0,
uint ne00,
uint ne01,
uchar mask_0F,
uchar mask_F0
) {
uint i00 = get_global_id(1); // block index along K
uint i01 = get_global_id(0); // row index
uint i02 = get_global_id(2); // batch index
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
__global struct block_q6_K * b = dst0 + dst_blk_offset;
b->d = src_d[src_blk_offset];
uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
uint4 qlv[8];
for (int p = 0; p < 8; ++p) {
qlv[p].x = src_ql[ql_base + (p * 4 + 0) * ne01];
qlv[p].y = src_ql[ql_base + (p * 4 + 1) * ne01];
qlv[p].z = src_ql[ql_base + (p * 4 + 2) * ne01];
qlv[p].w = src_ql[ql_base + (p * 4 + 3) * ne01];
}
uchar * qlv_bytes = (uchar *)qlv;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 16; ++j) {
uchar lo_02 = qlv_bytes[i*64 + j];
uchar lo_13 = qlv_bytes[i*64 + j + 16];
uchar hi_02 = qlv_bytes[i*64 + j + 32];
uchar hi_13 = qlv_bytes[i*64 + j + 48];
b->ql[i*64 + 2*j] = convert_uchar((lo_02 & mask_0F) | ((hi_02 & mask_0F) << 4));
b->ql[i*64 + 2*j + 1] = convert_uchar(((lo_02 & mask_F0) >> 4) | (hi_02 & mask_F0));
b->ql[i*64 + 32 + 2*j] = convert_uchar((lo_13 & mask_0F) | ((hi_13 & mask_0F) << 4));
b->ql[i*64 + 32 + 2*j + 1] = convert_uchar(((lo_13 & mask_F0) >> 4) | (hi_13 & mask_F0));
}
}
uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01;
uint qhv[16];
for (int p = 0; p < 16; ++p) {
qhv[p] = src_qh[qh_base + p * ne01];
}
for (int n = 0; n < 2; ++n) {
for (int l = 0; l < 32; ++l) {
int u = l / 16;
int bit_pos = (l % 16) * 2;
uchar v0 = (uchar)((qhv[(n*4 + 0)*2 + u] >> bit_pos) & 0x03);
uchar v1 = (uchar)((qhv[(n*4 + 1)*2 + u] >> bit_pos) & 0x03);
uchar v2 = (uchar)((qhv[(n*4 + 2)*2 + u] >> bit_pos) & 0x03);
uchar v3 = (uchar)((qhv[(n*4 + 3)*2 + u] >> bit_pos) & 0x03);
b->qh[n*32 + l] = v0 | (v1 << 2) | (v2 << 4) | (v3 << 6);
}
}
__global char * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16;
for (int i = 0; i < 16; ++i) {
b->scales[i] = s_src[i];
}
}
//------------------------------------------------------------------------------
// block_mxfp4
//------------------------------------------------------------------------------

View File

@ -0,0 +1,279 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
#define TILESIZE_K 16
#define TILESIZE_M 64
#define TILESIZE_N 32
#define QK_K 256
#define K_SCALE_SIZE 12
inline void get_scale_min_k4(
int j,
global const uchar * q,
uchar * d,
uchar * m
) {
if (j < 4) {
*d = q[j] & 63;
*m = q[j+4] & 63;
} else {
*d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2);
*m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2);
}
}
#define dequantize_q4_k(q4, a_f16, scale, minv) \
a_f16.s0 = (half)((float)(q4.s0 & 0x000F) * scale - minv); \
a_f16.s1 = (half)((float)((q4.s0 & 0x00F0) >> 4) * scale - minv); \
a_f16.s2 = (half)((float)((q4.s0 & 0x0F00) >> 8) * scale - minv); \
a_f16.s3 = (half)((float)((q4.s0 & 0xF000) >> 12) * scale - minv); \
a_f16.s4 = (half)((float)(q4.s1 & 0x000F) * scale - minv); \
a_f16.s5 = (half)((float)((q4.s1 & 0x00F0) >> 4) * scale - minv); \
a_f16.s6 = (half)((float)((q4.s1 & 0x0F00) >> 8) * scale - minv); \
a_f16.s7 = (half)((float)((q4.s1 & 0xF000) >> 12) * scale - minv); \
a_f16.s8 = (half)((float)(q4.s2 & 0x000F) * scale - minv); \
a_f16.s9 = (half)((float)((q4.s2 & 0x00F0) >> 4) * scale - minv); \
a_f16.sa = (half)((float)((q4.s2 & 0x0F00) >> 8) * scale - minv); \
a_f16.sb = (half)((float)((q4.s2 & 0xF000) >> 12) * scale - minv); \
a_f16.sc = (half)((float)(q4.s3 & 0x000F) * scale - minv); \
a_f16.sd = (half)((float)((q4.s3 & 0x00F0) >> 4) * scale - minv); \
a_f16.se = (half)((float)((q4.s3 & 0x0F00) >> 8) * scale - minv); \
a_f16.sf = (half)((float)((q4.s3 & 0xF000) >> 12) * scale - minv); \
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
__attribute__((qcom_wave_pair_mode(1)))
kernel void kernel_gemm_moe_q4_k_f32_ns(
__read_only image1d_buffer_t src0_q,
__global half * src0_d,
__global half * src0_dm,
__global uchar * src0_s,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global ushort * src2_emap,
__write_only image1d_buffer_t dst,
__global int * total_tiles,
uint ne00,
uint ne01
) {
uint block_id_m = get_global_id(1); // m_tile
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
return;
}
__private half16 reg_a;
__private float32 reg_c = (float32)(0);
__local half4 shared_b[128];
const ushort expert_id = src2_emap[block_id_n];
const uint row = block_id_m * TILESIZE_M;
const uint col = block_id_n * TILESIZE_N;
uint sub_block_id_m = get_local_id(0);
uint2 b_global_offset;
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
b_global_offset.y = b_global_offset.x + (16 * ne00);
uint2 b_local_offset;
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
b_local_offset.y = b_local_offset.x + 16;
uint num_superblocks = ne00 / QK_K;
uint scales_per_row = num_superblocks * K_SCALE_SIZE;
uint row_idx = row + get_global_id(0);
// Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
uint sub = step / 32;
uint sb = sub / 8;
uint j = sub % 8;
// Load d and dm for super-block
uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0);
half d_val = src0_d[d_offset];
half dm_val = src0_dm[d_offset];
// Load sub-block scale and min
global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE;
uchar sv, mn;
get_scale_min_k4(j, sc, &sv, &mn);
float scale = (float)d_val * (float)sv;
float minv = (float)dm_val * (float)mn;
// First sub-block (16 elements)
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
uint b_sub_offset = col * ne00 + step;
// Load 16 q (64-bits) in transposed layout
uint2 q4x16;
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
// Load 16x32 floats from matrix B
float8 bx8_f32;
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
half8 bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
// Dequantization
dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv);
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
half16 acc;
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
// Second half (next 16 elements, same sub-block scale)
uint half_step = step + TILESIZE_K;
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
b_sub_offset = col * ne00 + half_step;
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv);
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
// Load post router and share in LM
__local uint out_idx[TILESIZE_N];
if (get_local_id(0) < TILESIZE_N) {
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
if (idx == 0xFFFFFFFF) {
idx = src2[block_id_n * TILESIZE_N + 0];
}
out_idx[get_local_id(0)] = idx * ne01;
}
barrier(CLK_LOCAL_MEM_FENCE);
// Scatter results back to original position in output grid
uint m_offset = row + get_local_id(0);
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
// Store zero padding parts to the index of first output in tile
barrier(CLK_GLOBAL_MEM_FENCE);
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
}

View File

@ -0,0 +1,284 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
#define TILESIZE_K 16
#define TILESIZE_M 64
#define TILESIZE_N 32
#define QK_K 256
#define K_SCALE_SIZE 12
inline void get_scale_min_k4(
int j,
global const uchar * q,
uchar * d,
uchar * m
) {
if (j < 4) {
*d = q[j] & 63;
*m = q[j+4] & 63;
} else {
*d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2);
*m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2);
}
}
#define dequantize_q5_k(qs5x16, qh5x16, a_f16, scale, m) \
a_f16.s0 = (half)((float)(( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) * scale + m); \
a_f16.s1 = (half)((float)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) * scale + m)); \
a_f16.s2 = (half)((float)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) * scale + m)); \
a_f16.s3 = (half)((float)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) * scale + m)); \
a_f16.s4 = (half)((float)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) * scale + m)); \
a_f16.s5 = (half)((float)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) * scale + m)); \
a_f16.s6 = (half)((float)(((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) * scale + m); \
a_f16.s7 = (half)((float)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) * scale + m)); \
a_f16.s8 = (half)((float)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) * scale + m)); \
a_f16.s9 = (half)((float)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) * scale + m)); \
a_f16.sa = (half)((float)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) * scale + m)); \
a_f16.sb = (half)((float)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) * scale + m)); \
a_f16.sc = (half)((float)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) * scale + m)); \
a_f16.sd = (half)((float)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) * scale + m)); \
a_f16.se = (half)((float)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) * scale + m)); \
a_f16.sf = (half)((float)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) * scale + m)); \
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
__attribute__((qcom_wave_pair_mode(1)))
kernel void kernel_gemm_moe_q5_k_f32_ns(
__read_only image1d_buffer_t src0_q,
__global uint * src0_qh,
__global uchar * src0_s,
__global half * src0_d,
__global half * src0_dm,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global ushort * src2_emap,
__write_only image1d_buffer_t dst,
__global int * total_tiles,
uint ne00,
uint ne01
) {
uint block_id_m = get_global_id(1); // m_tile
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
return;
}
__private half16 reg_a;
__private float32 reg_c = (float32)(0);
__local half4 shared_b[128];
const ushort expert_id = src2_emap[block_id_n];
const uint row = block_id_m * TILESIZE_M;
const uint col = block_id_n * TILESIZE_N;
uint sub_block_id_m = get_local_id(0);
uint2 b_global_offset;
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
b_global_offset.y = b_global_offset.x + (16 * ne00);
uint2 b_local_offset;
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
b_local_offset.y = b_local_offset.x + 16;
uint num_superblocks = ne00 / QK_K;
uint scales_per_row = num_superblocks * K_SCALE_SIZE;
uint row_idx = row + get_global_id(0);
// Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
uint sub = step / 32;
uint sb = sub / 8;
uint j = sub % 8;
// Load d and dm for super-block
uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0);
half d_val = src0_d[d_offset];
half dm_val = src0_dm[d_offset];
// Load sub-block scale and min
global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE;
uchar sv, mn;
get_scale_min_k4(j, sc, &sv, &mn);
float scale = (float)d_val * (float)sv;
float minv = -(float)dm_val * (float)mn;
// qh is stored at sub-block granularity
uint qh_offset = row + sub * ne01 + expert_id * num_superblocks * 8 * ne01 + get_global_id(0);
uchar4 qhx32 = as_uchar4(src0_qh[qh_offset]);
// First sub-block (16 elements)
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
uint b_sub_offset = col * ne00 + step;
// Load 16 q (64-bits) in transposed layout
uint2 q4x16;
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
// Load 16x32 floats from matrix B
float8 bx8_f32;
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
half8 bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
// Dequantization
dequantize_q5_k(as_ushort4(q4x16), qhx32.lo, reg_a, scale, minv);
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
half16 acc;
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
// Second half
uint half_step = step + TILESIZE_K;
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
b_sub_offset = col * ne00 + half_step;
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
dequantize_q5_k(as_ushort4(q4x16), qhx32.hi, reg_a, scale, minv);
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
// Load post router and share in LM
__local uint out_idx[TILESIZE_N];
if (get_local_id(0) < TILESIZE_N) {
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
if (idx == 0xFFFFFFFF) {
idx = src2[block_id_n * TILESIZE_N + 0];
}
out_idx[get_local_id(0)] = idx * ne01;
}
barrier(CLK_LOCAL_MEM_FENCE);
// Scatter results back to original position in output grid
uint m_offset = row + get_local_id(0);
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
// Store zero padding parts to the index of first output in tile
barrier(CLK_GLOBAL_MEM_FENCE);
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
}

View File

@ -0,0 +1,263 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
#define TILESIZE_K 16
#define TILESIZE_M 64
#define TILESIZE_N 32
#define QK_K 256
#define dequantize_q6_k(qs16, qh16, a_f16, scale) \
a_f16.s0 = (half)(((float)(( qs16.s0 & 0x000F) | ((uint)(( qh16 ) & 0x3) << 4)) - 32.f) * scale); \
a_f16.s1 = (half)(((float)((( qs16.s0 >> 4) & 0x000F) | ((uint)(( qh16 >> 2) & 0x3) << 4)) - 32.f) * scale); \
a_f16.s2 = (half)(((float)((( qs16.s0 >> 8) & 0x000F) | ((uint)(( qh16 >> 4) & 0x3) << 4)) - 32.f) * scale); \
a_f16.s3 = (half)(((float)((( qs16.s0 >>12) & 0x000F) | ((uint)(( qh16 >> 6) & 0x3) << 4)) - 32.f) * scale); \
a_f16.s4 = (half)(((float)(( qs16.s1 & 0x000F) | ((uint)(( qh16 >> 8) & 0x3) << 4)) - 32.f) * scale); \
a_f16.s5 = (half)(((float)((( qs16.s1 >> 4) & 0x000F) | ((uint)(( qh16 >> 10) & 0x3) << 4)) - 32.f) * scale); \
a_f16.s6 = (half)(((float)((( qs16.s1 >> 8) & 0x000F) | ((uint)(( qh16 >> 12) & 0x3) << 4)) - 32.f) * scale); \
a_f16.s7 = (half)(((float)((( qs16.s1 >>12) & 0x000F) | ((uint)(( qh16 >> 14) & 0x3) << 4)) - 32.f) * scale); \
a_f16.s8 = (half)(((float)(( qs16.s2 & 0x000F) | ((uint)(( qh16 >> 16) & 0x3) << 4)) - 32.f) * scale); \
a_f16.s9 = (half)(((float)((( qs16.s2 >> 4) & 0x000F) | ((uint)(( qh16 >> 18) & 0x3) << 4)) - 32.f) * scale); \
a_f16.sa = (half)(((float)((( qs16.s2 >> 8) & 0x000F) | ((uint)(( qh16 >> 20) & 0x3) << 4)) - 32.f) * scale); \
a_f16.sb = (half)(((float)((( qs16.s2 >>12) & 0x000F) | ((uint)(( qh16 >> 22) & 0x3) << 4)) - 32.f) * scale); \
a_f16.sc = (half)(((float)(( qs16.s3 & 0x000F) | ((uint)(( qh16 >> 24) & 0x3) << 4)) - 32.f) * scale); \
a_f16.sd = (half)(((float)((( qs16.s3 >> 4) & 0x000F) | ((uint)(( qh16 >> 26) & 0x3) << 4)) - 32.f) * scale); \
a_f16.se = (half)(((float)((( qs16.s3 >> 8) & 0x000F) | ((uint)(( qh16 >> 28) & 0x3) << 4)) - 32.f) * scale); \
a_f16.sf = (half)(((float)((( qs16.s3 >>12) & 0x000F) | ((uint)(( qh16 >> 30) & 0x3) << 4)) - 32.f) * scale); \
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
c_reg.lo += convert_float8(acc.lo); \
c_reg.hi += convert_float8(acc.hi); \
__attribute__((qcom_wave_pair_mode(1)))
kernel void kernel_gemm_moe_q6_k_f32_ns(
__read_only image1d_buffer_t src0_ql,
__global uint * src0_qh,
__global char * src0_s,
__global half * src0_d,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global ushort * src2_emap,
__write_only image1d_buffer_t dst,
__global int * total_tiles,
uint ne00,
uint ne01
) {
uint block_id_m = get_global_id(1); // m_tile
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
return;
}
__private half16 reg_a;
__private float32 reg_c = (float32)(0);
__local half4 shared_b[128];
const ushort expert_id = src2_emap[block_id_n];
const uint row = block_id_m * TILESIZE_M;
const uint col = block_id_n * TILESIZE_N;
uint sub_block_id_m = get_local_id(0);
uint2 b_global_offset;
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
b_global_offset.y = b_global_offset.x + (16 * ne00);
uint2 b_local_offset;
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
b_local_offset.y = b_local_offset.x + 16;
uint num_superblocks = ne00 / QK_K;
uint scales_per_row = num_superblocks * 16;
uint row_idx = row + get_global_id(0);
// Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
uint sub = step / 32; // 32-element group index
uint sb = sub / 8; // super-block index
uint j = sub % 8; // group within super-block
// Load d for super-block
uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0);
half d_val = src0_d[d_offset];
// Load sub-block scales
global const char * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * 16;
float scale0 = (float)d_val * (float)sc[j * 2];
float scale1 = (float)d_val * (float)sc[j * 2 + 1];
uint qh_base = row + (sub * 2) * ne01 + expert_id * (num_superblocks * 16) * ne01 + get_global_id(0);
uint qh_first16 = src0_qh[qh_base];
uint qh_second16 = src0_qh[qh_base + ne01];
// First half (16 elements)
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
uint b_sub_offset = col * ne00 + step;
// Load 16 ql nibbles (2 uints) from image
uint2 q4x16;
q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x;
q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x;
// Load 16x32 floats from matrix B
float8 bx8_f32;
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
half8 bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
// Dequantize first 16 elements (scale0)
dequantize_q6_k(as_ushort4(q4x16), qh_first16, reg_a, scale0);
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
half16 acc;
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
// Second half
uint half_step = step + TILESIZE_K;
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
b_sub_offset = col * ne00 + half_step;
q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x;
q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x;
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
bx8_f16 = convert_half8(bx8_f32);
shared_b[b_local_offset.x] = bx8_f16.lo;
shared_b[b_local_offset.y] = bx8_f16.hi;
dequantize_q6_k(as_ushort4(q4x16), qh_second16, reg_a, scale1);
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
// Load post router and share in LM
__local uint out_idx[TILESIZE_N];
if (get_local_id(0) < TILESIZE_N) {
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
if (idx == 0xFFFFFFFF) {
idx = src2[block_id_n * TILESIZE_N + 0];
}
out_idx[get_local_id(0)] = idx * ne01;
}
barrier(CLK_LOCAL_MEM_FENCE);
// Scatter results back to original position in output grid
uint m_offset = row + get_local_id(0);
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
// Store zero padding parts to the index of first output in tile
barrier(CLK_GLOBAL_MEM_FENCE);
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
}

View File

@ -0,0 +1,151 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define QK_K 256
#define K_SCALE_SIZE 12
#define N_SIMDGROUP 4
#define SIMDGROUP_WIDTH 64
inline void get_scale_min_k4(
int j,
global const uchar * q,
uchar * d,
uchar * m
) {
if (j < 4) {
*d = q[j] & 63;
*m = q[j+4] & 63;
} else {
*d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2);
*m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2);
}
}
static inline float8 q4_k_to_fp32_packed8(ushort2 q4x8, float scale, float minv) {
float8 fp32x8;
fp32x8.s0 = (q4x8.s0 & 0x000F) * scale - minv;
fp32x8.s1 = ((q4x8.s0 & 0x00F0) >> 4) * scale - minv;
fp32x8.s2 = ((q4x8.s0 & 0x0F00) >> 8) * scale - minv;
fp32x8.s3 = ((q4x8.s0 & 0xF000) >> 12) * scale - minv;
fp32x8.s4 = (q4x8.s1 & 0x000F) * scale - minv;
fp32x8.s5 = ((q4x8.s1 & 0x00F0) >> 4) * scale - minv;
fp32x8.s6 = ((q4x8.s1 & 0x0F00) >> 8) * scale - minv;
fp32x8.s7 = ((q4x8.s1 & 0xF000) >> 12) * scale - minv;
return fp32x8;
}
__attribute__((qcom_reqd_sub_group_size("half")))
__kernel void kernel_gemv_moe_q4_k_f32_ns(
__global uint * src0_q,
__global half * src0_d,
__global half * src0_dm,
__global uchar * src0_s,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne11
) {
uint i01 = get_global_id(0);
uint i20 = get_global_id(2);
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
uint i11 = i20 % ne11;
uint expert_id = src2[i20];
int num_superblocks = ne00 / QK_K;
int num_subblocks = ne00 / 32;
int scales_per_row = num_superblocks * K_SCALE_SIZE;
// Expert offsets in the transposed noshuffle layout
uint expert_q_offset = expert_id * (ne00 / 8) * ne01;
uint expert_d_offset = expert_id * num_superblocks * ne01;
__private float sum = 0.0f;
// Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter
for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) {
uint sb = ib / 8;
uint j = ib % 8;
// Load d and dmin for this super-block
half d_val = src0_d[expert_d_offset + sb * ne01 + i01];
half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01];
// Load sub-block scale and min
global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE;
uchar sv, mn;
get_scale_min_k4(j, sc, &sv, &mn);
float scale = (float)d_val * (float)sv;
float minv = (float)dm_val * (float)mn;
// Load 4 uints of quants (32 nibbles = 32 elements)
uint q_base = expert_q_offset + ib * ne01 * 4 + i01;
uint4 regQ;
regQ.s0 = src0_q[q_base];
regQ.s1 = src0_q[q_base + ne01];
regQ.s2 = src0_q[q_base + ne01 * 2];
regQ.s3 = src0_q[q_base + ne01 * 3];
// Load activations: 32 floats = 8 float4s
uint y_offset = i11 * ne00 / 4 + ib * 8;
float8 fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s0), scale, minv);
float4 shared_y4;
shared_y4 = read_imagef(src1, (y_offset + 0));
float4 acc = shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 1));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s1), scale, minv);
shared_y4 = read_imagef(src1, (y_offset + 2));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 3));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s2), scale, minv);
shared_y4 = read_imagef(src1, (y_offset + 4));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 5));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s3), scale, minv);
shared_y4 = read_imagef(src1, (y_offset + 6));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 7));
acc += shared_y4 * fp32x8.hi;
sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
}
// reduction in local memory, assumes #subgroups=4
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 1 output per thread in subgroup 0
if (sgid == 0) {
dst = dst + (offsetd >> 2);
dst[i01 + i20 * ne01] = sum;
}
}

View File

@ -0,0 +1,156 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define QK_K 256
#define K_SCALE_SIZE 12
#define N_SIMDGROUP 4
#define SIMDGROUP_WIDTH 64
inline void get_scale_min_k4(
int j,
global const uchar * q,
uchar * d,
uchar * m
) {
if (j < 4) {
*d = q[j] & 63;
*m = q[j+4] & 63;
} else {
*d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2);
*m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2);
}
}
static inline float8 q5_k_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8, half s, half m) {
float8 fp32x8;
fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) * s + m);
fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) * s + m);
fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) * s + m);
fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) * s + m);
fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) * s + m);
fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) * s + m);
fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) * s + m);
fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) * s + m);
return fp32x8;
}
__attribute__((qcom_reqd_sub_group_size("half")))
__kernel void kernel_gemv_moe_q5_k_f32_ns(
__global uint * src0_q,
__global uint * src0_qh,
__global half * src0_d,
__global half * src0_dm,
__global uchar * src0_s,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne11
) {
uint i01 = get_global_id(0);
uint i20 = get_global_id(2);
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
uint i11 = i20 % ne11;
uint expert_id = src2[i20];
int num_superblocks = ne00 / QK_K;
int num_subblocks = ne00 / 32;
int scales_per_row = num_superblocks * K_SCALE_SIZE;
// Expert offsets in the transposed noshuffle layout
uint expert_q_offset = expert_id * (ne00 / 8) * ne01;
uint expert_d_offset = expert_id * num_superblocks * ne01;
__private float sum = 0.0f;
// Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter
for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) {
uint sb = ib / 8;
uint j = ib % 8;
// Load d and dmin for this super-block
half d_val = src0_d[expert_d_offset + sb * ne01 + i01];
half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01];
// sub_block index = sb * 8 + j
uint expert_qh_offset = expert_id * num_superblocks * 8 * ne01;
uchar4 regQh = as_uchar4(src0_qh[expert_qh_offset + (sb * 8 + j) * ne01 + i01]);
// Load sub-block scale and min
global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE;
uchar sv, mn;
get_scale_min_k4(j, sc, &sv, &mn);
float scale = (float)d_val * (float)sv;
float minv = -(float)dm_val * (float)mn;
// Load 4 uints of quants (32 nibbles = 32 elements)
uint q_base = expert_q_offset + ib * ne01 * 4 + i01;
uint4 regQ;
regQ.s0 = src0_q[q_base];
regQ.s1 = src0_q[q_base + ne01];
regQ.s2 = src0_q[q_base + ne01 * 2];
regQ.s3 = src0_q[q_base + ne01 * 3];
// Load activations: 32 floats = 8 float4s
uint y_offset = i11 * ne00 / 4 + ib * 8;
float8 fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0, scale, minv);
float4 shared_y4;
shared_y4 = read_imagef(src1, (y_offset + 0));
float4 acc = shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 1));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1, scale, minv);
shared_y4 = read_imagef(src1, (y_offset + 2));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 3));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2, scale, minv);
shared_y4 = read_imagef(src1, (y_offset + 4));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 5));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3, scale, minv);
shared_y4 = read_imagef(src1, (y_offset + 6));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 7));
acc += shared_y4 * fp32x8.hi;
sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
}
// reduction in local memory, assumes #subgroups=4
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 1 output per thread in subgroup 0
if (sgid == 0) {
dst = dst + (offsetd >> 2);
dst[i01 + i20 * ne01] = sum;
}
}

View File

@ -0,0 +1,137 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define QK_K 256
#define N_SIMDGROUP 4
#define SIMDGROUP_WIDTH 64
static inline float8 q6_k_to_fp32_packed8(ushort2 ql8, ushort qh8, float d_scale) {
float8 fp32x8;
fp32x8.s0 = ((float)(( ql8.s0 & 0x000F) | ((uint)((qh8 ) & 0x3) << 4)) - 32.f) * d_scale;
fp32x8.s1 = ((float)((( ql8.s0 >> 4) & 0x000F) | ((uint)((qh8 >> 2) & 0x3) << 4)) - 32.f) * d_scale;
fp32x8.s2 = ((float)((( ql8.s0 >> 8) & 0x000F) | ((uint)((qh8 >> 4) & 0x3) << 4)) - 32.f) * d_scale;
fp32x8.s3 = ((float)((( ql8.s0 >> 12)& 0x000F) | ((uint)((qh8 >> 6) & 0x3) << 4)) - 32.f) * d_scale;
fp32x8.s4 = ((float)(( ql8.s1 & 0x000F) | ((uint)((qh8 >> 8) & 0x3) << 4)) - 32.f) * d_scale;
fp32x8.s5 = ((float)((( ql8.s1 >> 4) & 0x000F) | ((uint)((qh8 >>10) & 0x3) << 4)) - 32.f) * d_scale;
fp32x8.s6 = ((float)((( ql8.s1 >> 8) & 0x000F) | ((uint)((qh8 >>12) & 0x3) << 4)) - 32.f) * d_scale;
fp32x8.s7 = ((float)((( ql8.s1 >> 12)& 0x000F) | ((uint)((qh8 >>14) & 0x3) << 4)) - 32.f) * d_scale;
return fp32x8;
}
__attribute__((qcom_reqd_sub_group_size("half")))
__kernel void kernel_gemv_moe_q6_k_f32_ns(
__global uint * src0_ql,
__global uint * src0_qh,
__global char * src0_s,
__global half * src0_d,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne11
) {
uint i01 = get_global_id(0);
uint i20 = get_global_id(2);
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
uint i11 = i20 % ne11;
uint expert_id = src2[i20];
int num_superblocks = ne00 / QK_K;
int num_subblocks = ne00 / 32; // 8 sub-blocks of 32 per super-block
int scales_per_row = num_superblocks * 16;
// Expert offsets in the transposed noshuffle layout
uint expert_ql_offset = expert_id * (ne00 / 8) * ne01; // 32 uints per super-block
uint expert_qh_offset = expert_id * (ne00 / 16) * ne01; // 16 uints per super-block
uint expert_d_offset = expert_id * num_superblocks * ne01;
__private float sum = 0.0f;
// Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter
for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) {
uint sb = ib / 8; // super-block index
uint j = ib % 8; // 32-element group within super-block
// Load d for this super-block
half d_val = src0_d[expert_d_offset + sb * ne01 + i01];
// Load 2 sub-block scales
global const char * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * 16;
float scale0 = (float)d_val * (float)sc[j * 2];
float scale1 = (float)d_val * (float)sc[j * 2 + 1];
// Load 4 uints of ql
uint ql_base = expert_ql_offset + (ib * 4) * ne01 + i01;
uint4 regQL;
regQL.s0 = src0_ql[ql_base];
regQL.s1 = src0_ql[ql_base + ne01];
regQL.s2 = src0_ql[ql_base + ne01 * 2];
regQL.s3 = src0_ql[ql_base + ne01 * 3];
// Load 2 uints of qh
uint qh_base = expert_qh_offset + (ib * 2) * ne01 + i01;
uint2 regQH;
regQH.s0 = src0_qh[qh_base];
regQH.s1 = src0_qh[qh_base + ne01];
// Load activations: 32 floats = 8 float4s
uint y_offset = i11 * ne00 / 4 + ib * 8;
float8 fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s0), (ushort)(regQH.s0 & 0xFFFF), scale0);
float4 shared_y4;
shared_y4 = read_imagef(src1, (y_offset + 0));
float4 acc = shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 1));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s1), (ushort)(regQH.s0 >> 16), scale0);
shared_y4 = read_imagef(src1, (y_offset + 2));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 3));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s2), (ushort)(regQH.s1 & 0xFFFF), scale1);
shared_y4 = read_imagef(src1, (y_offset + 4));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 5));
acc += shared_y4 * fp32x8.hi;
fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s3), (ushort)(regQH.s1 >> 16), scale1);
shared_y4 = read_imagef(src1, (y_offset + 6));
acc += shared_y4 * fp32x8.lo;
shared_y4 = read_imagef(src1, (y_offset + 7));
acc += shared_y4 * fp32x8.hi;
sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
}
// reduction in local memory, assumes #subgroups=4
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 1 output per thread in subgroup 0
if (sgid == 0) {
dst = dst + (offsetd >> 2);
dst[i01 + i20 * ne01] = sum;
}
}