opencl: add MoE support for q4_k, q5_k, q6_k on Adreno (llama/23303)
* opencl: add q4_k moe support * opencl: add q5_k moe support * opencl: add q6_k moe support * opencl: adjust format --------- Co-authored-by: Li He <lih@qti.qualcomm.com>
This commit is contained in:
parent
aca63e7638
commit
37f17208c2
|
|
@ -110,6 +110,12 @@ set(GGML_OPENCL_KERNELS
|
|||
gemv_moe_q5_0_f32_ns
|
||||
gemm_moe_q5_1_f32_ns
|
||||
gemv_moe_q5_1_f32_ns
|
||||
gemm_moe_q4_k_f32_ns
|
||||
gemv_moe_q4_k_f32_ns
|
||||
gemm_moe_q5_k_f32_ns
|
||||
gemv_moe_q5_k_f32_ns
|
||||
gemm_moe_q6_k_f32_ns
|
||||
gemv_moe_q6_k_f32_ns
|
||||
gemm_moe_mxfp4_f32
|
||||
gemv_moe_mxfp4_f32
|
||||
gemm_moe_mxfp4_f32_ns
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -664,6 +664,391 @@ kernel void kernel_restore_block_q5_1_trans4_ns(
|
|||
((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
|
||||
}
|
||||
|
||||
kernel void kernel_convert_block_q4_k_trans4_ns(
|
||||
__global struct block_q4_K * src0,
|
||||
__global uint * dst_q,
|
||||
__global half * dst_d,
|
||||
__global half * dst_dm,
|
||||
__global uchar * dst_s,
|
||||
uint ne00,
|
||||
uint ne01,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0
|
||||
) {
|
||||
uint i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_K;
|
||||
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
|
||||
__global struct block_q4_K * b = src0 + src_blk_offset;
|
||||
|
||||
dst_d [dst_blk_offset] = b->d;
|
||||
dst_dm[dst_blk_offset] = b->dm;
|
||||
|
||||
uint4 qv[8];
|
||||
uchar * qv_bytes = (uchar *)qv;
|
||||
for (int i = 0; i < QK_K / 64; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uchar x0 = b->q[i*32 + 2*j];
|
||||
uchar x1 = b->q[i*32 + 2*j + 1];
|
||||
|
||||
qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
|
||||
qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
|
||||
}
|
||||
}
|
||||
|
||||
uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
|
||||
#pragma unroll
|
||||
for (int p = 0; p < 8; ++p) {
|
||||
uint4 v = qv[p];
|
||||
dst_q[base + (p * 4 + 0) * ne01] = v.x;
|
||||
dst_q[base + (p * 4 + 1) * ne01] = v.y;
|
||||
dst_q[base + (p * 4 + 2) * ne01] = v.z;
|
||||
dst_q[base + (p * 4 + 3) * ne01] = v.w;
|
||||
}
|
||||
|
||||
__global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < K_SCALE_SIZE; ++i) {
|
||||
s_dst[i] = b->s[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q4_k_trans4_ns(
|
||||
__global uint * src_q,
|
||||
__global half * src_d,
|
||||
__global half * src_dm,
|
||||
__global uchar * src_s,
|
||||
__global struct block_q4_K * dst0,
|
||||
uint ne00,
|
||||
uint ne01,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0
|
||||
) {
|
||||
uint i00 = get_global_id(1); // block index along K
|
||||
uint i01 = get_global_id(0); // row index
|
||||
uint i02 = get_global_id(2); // batch index
|
||||
|
||||
uint ne00_blk = ne00 / QK_K;
|
||||
|
||||
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
|
||||
__global struct block_q4_K * b = dst0 + dst_blk_offset;
|
||||
|
||||
b->d = src_d[src_blk_offset];
|
||||
b->dm = src_dm[src_blk_offset];
|
||||
|
||||
__global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE;
|
||||
for (int i = 0; i < K_SCALE_SIZE; ++i) {
|
||||
b->s[i] = s_src[i];
|
||||
}
|
||||
|
||||
uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
|
||||
|
||||
uint4 qv[8];
|
||||
for (int p = 0; p < 8; ++p) {
|
||||
qv[p].x = src_q[base + (p * 4 + 0) * ne01];
|
||||
qv[p].y = src_q[base + (p * 4 + 1) * ne01];
|
||||
qv[p].z = src_q[base + (p * 4 + 2) * ne01];
|
||||
qv[p].w = src_q[base + (p * 4 + 3) * ne01];
|
||||
}
|
||||
|
||||
uchar * qv_bytes = (uchar *)qv;
|
||||
for (int i = 0; i < QK_K / 64; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uchar lo = qv_bytes[i*32 + j];
|
||||
uchar hi = qv_bytes[i*32 + j + 16];
|
||||
b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4));
|
||||
b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_convert_block_q5_k_trans4_ns(
|
||||
__global struct block_q5_K * src0,
|
||||
__global uint * dst_qs,
|
||||
__global uint * dst_qh,
|
||||
__global half * dst_d,
|
||||
__global half * dst_dm,
|
||||
__global uchar * dst_s,
|
||||
uint ne00,
|
||||
uint ne01,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0
|
||||
) {
|
||||
uint i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_K;
|
||||
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
|
||||
__global struct block_q5_K * b = src0 + src_blk_offset;
|
||||
|
||||
dst_d [dst_blk_offset] = b->d;
|
||||
dst_dm[dst_blk_offset] = b->dm;
|
||||
|
||||
for (int k = 0; k < 8; k++) {
|
||||
uchar b0 = 0, b1 = 0, b2 = 0, b3 = 0;
|
||||
for (int bit = 0; bit < 8; bit++) {
|
||||
b0 |= (uchar)(((b->qh[bit] >> k) & 1) << bit);
|
||||
b1 |= (uchar)(((b->qh[8 + bit] >> k) & 1) << bit);
|
||||
b2 |= (uchar)(((b->qh[16 + bit] >> k) & 1) << bit);
|
||||
b3 |= (uchar)(((b->qh[24 + bit] >> k) & 1) << bit);
|
||||
}
|
||||
uint packed = (uint)b0 | ((uint)b1 << 8) | ((uint)b2 << 16) | ((uint)b3 << 24);
|
||||
dst_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01] = packed;
|
||||
}
|
||||
|
||||
uint4 qv[8];
|
||||
uchar * qv_bytes = (uchar *)qv;
|
||||
for (int i = 0; i < QK_K / 64; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uchar x0 = b->qs[i*32 + 2*j];
|
||||
uchar x1 = b->qs[i*32 + 2*j + 1];
|
||||
|
||||
qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
|
||||
qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
|
||||
}
|
||||
}
|
||||
|
||||
uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
|
||||
#pragma unroll
|
||||
for (int p = 0; p < 8; ++p) {
|
||||
uint4 v = qv[p];
|
||||
dst_qs[base + (p * 4 + 0) * ne01] = v.x;
|
||||
dst_qs[base + (p * 4 + 1) * ne01] = v.y;
|
||||
dst_qs[base + (p * 4 + 2) * ne01] = v.z;
|
||||
dst_qs[base + (p * 4 + 3) * ne01] = v.w;
|
||||
}
|
||||
|
||||
__global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < K_SCALE_SIZE; ++i) {
|
||||
s_dst[i] = b->s[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q5_k_trans4_ns(
|
||||
__global uint * src_qs,
|
||||
__global uint * src_qh,
|
||||
__global half * src_d,
|
||||
__global half * src_dm,
|
||||
__global uchar * src_s,
|
||||
__global struct block_q5_K * dst0,
|
||||
uint ne00,
|
||||
uint ne01,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0
|
||||
) {
|
||||
uint i00 = get_global_id(1); // block index along K
|
||||
uint i01 = get_global_id(0); // row index
|
||||
uint i02 = get_global_id(2); // batch index
|
||||
|
||||
uint ne00_blk = ne00 / QK_K;
|
||||
|
||||
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
|
||||
__global struct block_q5_K * b = dst0 + dst_blk_offset;
|
||||
|
||||
b->d = src_d[src_blk_offset];
|
||||
b->dm = src_dm[src_blk_offset];
|
||||
|
||||
for (int j = 0; j < 32; j++) b->qh[j] = 0;
|
||||
for (int k = 0; k < 8; k++) {
|
||||
uint packed = src_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01];
|
||||
uchar b0 = (uchar)(packed & 0xFF);
|
||||
uchar b1 = (uchar)((packed >> 8) & 0xFF);
|
||||
uchar b2 = (uchar)((packed >> 16) & 0xFF);
|
||||
uchar b3 = (uchar)((packed >> 24) & 0xFF);
|
||||
for (int bit = 0; bit < 8; bit++) {
|
||||
b->qh[bit] |= (uchar)(((b0 >> bit) & 1) << k);
|
||||
b->qh[8 + bit] |= (uchar)(((b1 >> bit) & 1) << k);
|
||||
b->qh[16 + bit] |= (uchar)(((b2 >> bit) & 1) << k);
|
||||
b->qh[24 + bit] |= (uchar)(((b3 >> bit) & 1) << k);
|
||||
}
|
||||
}
|
||||
|
||||
__global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE;
|
||||
for (int i = 0; i < K_SCALE_SIZE; ++i) {
|
||||
b->s[i] = s_src[i];
|
||||
}
|
||||
|
||||
uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
|
||||
|
||||
uint4 qv[8];
|
||||
for (int p = 0; p < 8; ++p) {
|
||||
qv[p].x = src_qs[base + (p * 4 + 0) * ne01];
|
||||
qv[p].y = src_qs[base + (p * 4 + 1) * ne01];
|
||||
qv[p].z = src_qs[base + (p * 4 + 2) * ne01];
|
||||
qv[p].w = src_qs[base + (p * 4 + 3) * ne01];
|
||||
}
|
||||
|
||||
uchar * qv_bytes = (uchar *)qv;
|
||||
for (int i = 0; i < QK_K / 64; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uchar lo = qv_bytes[i*32 + j];
|
||||
uchar hi = qv_bytes[i*32 + j + 16];
|
||||
b->qs[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4));
|
||||
b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_convert_block_q6_k_trans4_ns(
|
||||
__global struct block_q6_K * src0,
|
||||
__global uint * dst_ql,
|
||||
__global uint * dst_qh,
|
||||
__global half * dst_d,
|
||||
__global char * dst_s,
|
||||
uint ne00,
|
||||
uint ne01,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0
|
||||
) {
|
||||
uint i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_K;
|
||||
|
||||
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
|
||||
__global struct block_q6_K * b = src0 + src_blk_offset;
|
||||
|
||||
dst_d[dst_blk_offset] = b->d;
|
||||
|
||||
uint4 qlv[8];
|
||||
uchar * qlv_bytes = (uchar *)qlv;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uchar x0 = b->ql[i*64 + 2*j];
|
||||
uchar x1 = b->ql[i*64 + 2*j + 1];
|
||||
uchar x2 = b->ql[i*64 + 32 + 2*j];
|
||||
uchar x3 = b->ql[i*64 + 32 + 2*j + 1];
|
||||
qlv_bytes[i*64 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
|
||||
qlv_bytes[i*64 + j + 16] = convert_uchar(x2 & mask_0F) | convert_uchar((x3 & mask_0F) << 4);
|
||||
qlv_bytes[i*64 + j + 32] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
|
||||
qlv_bytes[i*64 + j + 48] = convert_uchar((x2 & mask_F0) >> 4) | convert_uchar(x3 & mask_F0);
|
||||
}
|
||||
}
|
||||
|
||||
uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
|
||||
|
||||
#pragma unroll
|
||||
for (int p = 0; p < 8; ++p) {
|
||||
uint4 v = qlv[p];
|
||||
dst_ql[ql_base + (p * 4 + 0) * ne01] = v.x;
|
||||
dst_ql[ql_base + (p * 4 + 1) * ne01] = v.y;
|
||||
dst_ql[ql_base + (p * 4 + 2) * ne01] = v.z;
|
||||
dst_ql[ql_base + (p * 4 + 3) * ne01] = v.w;
|
||||
}
|
||||
|
||||
uint qhv[16] = {0};
|
||||
|
||||
for (int n = 0; n < 2; ++n) {
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
uchar h = b->qh[n*32 + l];
|
||||
int u = l / 16;
|
||||
int bit_pos = (l % 16) * 2;
|
||||
qhv[(n*4 + 0)*2 + u] |= ((uint)((h >> 0) & 0x03)) << bit_pos;
|
||||
qhv[(n*4 + 1)*2 + u] |= ((uint)((h >> 2) & 0x03)) << bit_pos;
|
||||
qhv[(n*4 + 2)*2 + u] |= ((uint)((h >> 4) & 0x03)) << bit_pos;
|
||||
qhv[(n*4 + 3)*2 + u] |= ((uint)((h >> 6) & 0x03)) << bit_pos;
|
||||
}
|
||||
}
|
||||
|
||||
uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01;
|
||||
|
||||
for (int p = 0; p < 16; ++p) {
|
||||
dst_qh[qh_base + p * ne01] = qhv[p];
|
||||
}
|
||||
|
||||
__global char * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
s_dst[i] = b->scales[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q6_k_trans4_ns(
|
||||
__global uint * src_ql,
|
||||
__global uint * src_qh,
|
||||
__global half * src_d,
|
||||
__global char * src_s,
|
||||
__global struct block_q6_K * dst0,
|
||||
uint ne00,
|
||||
uint ne01,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0
|
||||
) {
|
||||
uint i00 = get_global_id(1); // block index along K
|
||||
uint i01 = get_global_id(0); // row index
|
||||
uint i02 = get_global_id(2); // batch index
|
||||
|
||||
uint ne00_blk = ne00 / QK_K;
|
||||
|
||||
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
|
||||
__global struct block_q6_K * b = dst0 + dst_blk_offset;
|
||||
|
||||
b->d = src_d[src_blk_offset];
|
||||
|
||||
uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01;
|
||||
uint4 qlv[8];
|
||||
for (int p = 0; p < 8; ++p) {
|
||||
qlv[p].x = src_ql[ql_base + (p * 4 + 0) * ne01];
|
||||
qlv[p].y = src_ql[ql_base + (p * 4 + 1) * ne01];
|
||||
qlv[p].z = src_ql[ql_base + (p * 4 + 2) * ne01];
|
||||
qlv[p].w = src_ql[ql_base + (p * 4 + 3) * ne01];
|
||||
}
|
||||
|
||||
uchar * qlv_bytes = (uchar *)qlv;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uchar lo_02 = qlv_bytes[i*64 + j];
|
||||
uchar lo_13 = qlv_bytes[i*64 + j + 16];
|
||||
uchar hi_02 = qlv_bytes[i*64 + j + 32];
|
||||
uchar hi_13 = qlv_bytes[i*64 + j + 48];
|
||||
b->ql[i*64 + 2*j] = convert_uchar((lo_02 & mask_0F) | ((hi_02 & mask_0F) << 4));
|
||||
b->ql[i*64 + 2*j + 1] = convert_uchar(((lo_02 & mask_F0) >> 4) | (hi_02 & mask_F0));
|
||||
b->ql[i*64 + 32 + 2*j] = convert_uchar((lo_13 & mask_0F) | ((hi_13 & mask_0F) << 4));
|
||||
b->ql[i*64 + 32 + 2*j + 1] = convert_uchar(((lo_13 & mask_F0) >> 4) | (hi_13 & mask_F0));
|
||||
}
|
||||
}
|
||||
|
||||
uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01;
|
||||
uint qhv[16];
|
||||
for (int p = 0; p < 16; ++p) {
|
||||
qhv[p] = src_qh[qh_base + p * ne01];
|
||||
}
|
||||
|
||||
for (int n = 0; n < 2; ++n) {
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
int u = l / 16;
|
||||
int bit_pos = (l % 16) * 2;
|
||||
uchar v0 = (uchar)((qhv[(n*4 + 0)*2 + u] >> bit_pos) & 0x03);
|
||||
uchar v1 = (uchar)((qhv[(n*4 + 1)*2 + u] >> bit_pos) & 0x03);
|
||||
uchar v2 = (uchar)((qhv[(n*4 + 2)*2 + u] >> bit_pos) & 0x03);
|
||||
uchar v3 = (uchar)((qhv[(n*4 + 3)*2 + u] >> bit_pos) & 0x03);
|
||||
b->qh[n*32 + l] = v0 | (v1 << 2) | (v2 << 4) | (v3 << 6);
|
||||
}
|
||||
}
|
||||
|
||||
__global char * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
b->scales[i] = s_src[i];
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_mxfp4
|
||||
//------------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -0,0 +1,279 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
|
||||
|
||||
#define TILESIZE_K 16
|
||||
#define TILESIZE_M 64
|
||||
#define TILESIZE_N 32
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
|
||||
inline void get_scale_min_k4(
|
||||
int j,
|
||||
global const uchar * q,
|
||||
uchar * d,
|
||||
uchar * m
|
||||
) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63;
|
||||
*m = q[j+4] & 63;
|
||||
} else {
|
||||
*d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2);
|
||||
*m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2);
|
||||
}
|
||||
}
|
||||
|
||||
#define dequantize_q4_k(q4, a_f16, scale, minv) \
|
||||
a_f16.s0 = (half)((float)(q4.s0 & 0x000F) * scale - minv); \
|
||||
a_f16.s1 = (half)((float)((q4.s0 & 0x00F0) >> 4) * scale - minv); \
|
||||
a_f16.s2 = (half)((float)((q4.s0 & 0x0F00) >> 8) * scale - minv); \
|
||||
a_f16.s3 = (half)((float)((q4.s0 & 0xF000) >> 12) * scale - minv); \
|
||||
a_f16.s4 = (half)((float)(q4.s1 & 0x000F) * scale - minv); \
|
||||
a_f16.s5 = (half)((float)((q4.s1 & 0x00F0) >> 4) * scale - minv); \
|
||||
a_f16.s6 = (half)((float)((q4.s1 & 0x0F00) >> 8) * scale - minv); \
|
||||
a_f16.s7 = (half)((float)((q4.s1 & 0xF000) >> 12) * scale - minv); \
|
||||
a_f16.s8 = (half)((float)(q4.s2 & 0x000F) * scale - minv); \
|
||||
a_f16.s9 = (half)((float)((q4.s2 & 0x00F0) >> 4) * scale - minv); \
|
||||
a_f16.sa = (half)((float)((q4.s2 & 0x0F00) >> 8) * scale - minv); \
|
||||
a_f16.sb = (half)((float)((q4.s2 & 0xF000) >> 12) * scale - minv); \
|
||||
a_f16.sc = (half)((float)(q4.s3 & 0x000F) * scale - minv); \
|
||||
a_f16.sd = (half)((float)((q4.s3 & 0x00F0) >> 4) * scale - minv); \
|
||||
a_f16.se = (half)((float)((q4.s3 & 0x0F00) >> 8) * scale - minv); \
|
||||
a_f16.sf = (half)((float)((q4.s3 & 0xF000) >> 12) * scale - minv); \
|
||||
|
||||
|
||||
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
|
||||
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
|
||||
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
|
||||
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
|
||||
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
|
||||
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
|
||||
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
|
||||
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
|
||||
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
|
||||
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
|
||||
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
|
||||
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
|
||||
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
|
||||
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
|
||||
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
|
||||
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
|
||||
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
|
||||
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
|
||||
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
|
||||
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
|
||||
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
|
||||
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
|
||||
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
|
||||
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
|
||||
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
|
||||
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
|
||||
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
|
||||
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
|
||||
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
|
||||
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
|
||||
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
|
||||
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
|
||||
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
|
||||
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
|
||||
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
|
||||
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
|
||||
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
|
||||
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
|
||||
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
|
||||
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
|
||||
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
|
||||
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
|
||||
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
|
||||
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
|
||||
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
|
||||
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
|
||||
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
|
||||
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
|
||||
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
|
||||
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
|
||||
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
|
||||
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
|
||||
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
|
||||
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
|
||||
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
|
||||
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
|
||||
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
|
||||
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
|
||||
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
|
||||
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
|
||||
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
|
||||
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
|
||||
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
|
||||
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
|
||||
|
||||
__attribute__((qcom_wave_pair_mode(1)))
|
||||
kernel void kernel_gemm_moe_q4_k_f32_ns(
|
||||
__read_only image1d_buffer_t src0_q,
|
||||
__global half * src0_d,
|
||||
__global half * src0_dm,
|
||||
__global uchar * src0_s,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global ushort * src2_emap,
|
||||
__write_only image1d_buffer_t dst,
|
||||
__global int * total_tiles,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint block_id_m = get_global_id(1); // m_tile
|
||||
uint block_id_n = get_global_id(2); // n_tile
|
||||
|
||||
// Boundary check
|
||||
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
|
||||
return;
|
||||
}
|
||||
|
||||
__private half16 reg_a;
|
||||
__private float32 reg_c = (float32)(0);
|
||||
__local half4 shared_b[128];
|
||||
|
||||
const ushort expert_id = src2_emap[block_id_n];
|
||||
|
||||
const uint row = block_id_m * TILESIZE_M;
|
||||
const uint col = block_id_n * TILESIZE_N;
|
||||
|
||||
uint sub_block_id_m = get_local_id(0);
|
||||
uint2 b_global_offset;
|
||||
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
|
||||
b_global_offset.y = b_global_offset.x + (16 * ne00);
|
||||
uint2 b_local_offset;
|
||||
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
|
||||
b_local_offset.y = b_local_offset.x + 16;
|
||||
|
||||
uint num_superblocks = ne00 / QK_K;
|
||||
uint scales_per_row = num_superblocks * K_SCALE_SIZE;
|
||||
uint row_idx = row + get_global_id(0);
|
||||
|
||||
// Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16
|
||||
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
|
||||
uint sub = step / 32;
|
||||
uint sb = sub / 8;
|
||||
uint j = sub % 8;
|
||||
|
||||
// Load d and dm for super-block
|
||||
uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0);
|
||||
half d_val = src0_d[d_offset];
|
||||
half dm_val = src0_dm[d_offset];
|
||||
|
||||
// Load sub-block scale and min
|
||||
global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE;
|
||||
uchar sv, mn;
|
||||
get_scale_min_k4(j, sc, &sv, &mn);
|
||||
|
||||
float scale = (float)d_val * (float)sv;
|
||||
float minv = (float)dm_val * (float)mn;
|
||||
|
||||
// First sub-block (16 elements)
|
||||
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
uint b_sub_offset = col * ne00 + step;
|
||||
|
||||
// Load 16 q (64-bits) in transposed layout
|
||||
uint2 q4x16;
|
||||
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B
|
||||
float8 bx8_f32;
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
half8 bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv);
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
half16 acc;
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
|
||||
// Second half (next 16 elements, same sub-block scale)
|
||||
uint half_step = step + TILESIZE_K;
|
||||
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
b_sub_offset = col * ne00 + half_step;
|
||||
|
||||
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv);
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
}
|
||||
|
||||
// Load post router and share in LM
|
||||
__local uint out_idx[TILESIZE_N];
|
||||
|
||||
if (get_local_id(0) < TILESIZE_N) {
|
||||
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
|
||||
if (idx == 0xFFFFFFFF) {
|
||||
idx = src2[block_id_n * TILESIZE_N + 0];
|
||||
}
|
||||
out_idx[get_local_id(0)] = idx * ne01;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Scatter results back to original position in output grid
|
||||
uint m_offset = row + get_local_id(0);
|
||||
|
||||
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
|
||||
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
|
||||
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
|
||||
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
|
||||
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
|
||||
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
|
||||
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
|
||||
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
|
||||
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
|
||||
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
|
||||
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
|
||||
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
|
||||
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
|
||||
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
|
||||
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
|
||||
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
|
||||
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
|
||||
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
|
||||
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
|
||||
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
|
||||
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
|
||||
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
|
||||
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
|
||||
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
|
||||
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
|
||||
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
|
||||
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
|
||||
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
|
||||
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
|
||||
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
|
||||
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
|
||||
|
||||
// Store zero padding parts to the index of first output in tile
|
||||
barrier(CLK_GLOBAL_MEM_FENCE);
|
||||
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
|
||||
}
|
||||
|
|
@ -0,0 +1,284 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
|
||||
|
||||
#define TILESIZE_K 16
|
||||
#define TILESIZE_M 64
|
||||
#define TILESIZE_N 32
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
|
||||
inline void get_scale_min_k4(
|
||||
int j,
|
||||
global const uchar * q,
|
||||
uchar * d,
|
||||
uchar * m
|
||||
) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63;
|
||||
*m = q[j+4] & 63;
|
||||
} else {
|
||||
*d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2);
|
||||
*m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2);
|
||||
}
|
||||
}
|
||||
|
||||
#define dequantize_q5_k(qs5x16, qh5x16, a_f16, scale, m) \
|
||||
a_f16.s0 = (half)((float)(( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) * scale + m); \
|
||||
a_f16.s1 = (half)((float)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.s2 = (half)((float)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.s3 = (half)((float)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.s4 = (half)((float)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.s5 = (half)((float)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.s6 = (half)((float)(((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) * scale + m); \
|
||||
a_f16.s7 = (half)((float)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.s8 = (half)((float)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.s9 = (half)((float)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.sa = (half)((float)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.sb = (half)((float)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.sc = (half)((float)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.sd = (half)((float)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.se = (half)((float)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) * scale + m)); \
|
||||
a_f16.sf = (half)((float)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) * scale + m)); \
|
||||
|
||||
|
||||
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
|
||||
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
|
||||
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
|
||||
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
|
||||
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
|
||||
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
|
||||
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
|
||||
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
|
||||
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
|
||||
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
|
||||
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
|
||||
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
|
||||
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
|
||||
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
|
||||
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
|
||||
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
|
||||
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
|
||||
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
|
||||
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
|
||||
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
|
||||
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
|
||||
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
|
||||
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
|
||||
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
|
||||
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
|
||||
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
|
||||
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
|
||||
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
|
||||
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
|
||||
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
|
||||
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
|
||||
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
|
||||
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
|
||||
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
|
||||
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
|
||||
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
|
||||
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
|
||||
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
|
||||
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
|
||||
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
|
||||
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
|
||||
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
|
||||
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
|
||||
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
|
||||
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
|
||||
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
|
||||
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
|
||||
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
|
||||
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
|
||||
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
|
||||
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
|
||||
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
|
||||
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
|
||||
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
|
||||
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
|
||||
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
|
||||
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
|
||||
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
|
||||
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
|
||||
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
|
||||
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
|
||||
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
|
||||
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
|
||||
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
|
||||
|
||||
__attribute__((qcom_wave_pair_mode(1)))
|
||||
kernel void kernel_gemm_moe_q5_k_f32_ns(
|
||||
__read_only image1d_buffer_t src0_q,
|
||||
__global uint * src0_qh,
|
||||
__global uchar * src0_s,
|
||||
__global half * src0_d,
|
||||
__global half * src0_dm,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global ushort * src2_emap,
|
||||
__write_only image1d_buffer_t dst,
|
||||
__global int * total_tiles,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint block_id_m = get_global_id(1); // m_tile
|
||||
uint block_id_n = get_global_id(2); // n_tile
|
||||
|
||||
// Boundary check
|
||||
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
|
||||
return;
|
||||
}
|
||||
|
||||
__private half16 reg_a;
|
||||
__private float32 reg_c = (float32)(0);
|
||||
__local half4 shared_b[128];
|
||||
|
||||
const ushort expert_id = src2_emap[block_id_n];
|
||||
|
||||
const uint row = block_id_m * TILESIZE_M;
|
||||
const uint col = block_id_n * TILESIZE_N;
|
||||
|
||||
uint sub_block_id_m = get_local_id(0);
|
||||
uint2 b_global_offset;
|
||||
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
|
||||
b_global_offset.y = b_global_offset.x + (16 * ne00);
|
||||
uint2 b_local_offset;
|
||||
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
|
||||
b_local_offset.y = b_local_offset.x + 16;
|
||||
|
||||
uint num_superblocks = ne00 / QK_K;
|
||||
uint scales_per_row = num_superblocks * K_SCALE_SIZE;
|
||||
uint row_idx = row + get_global_id(0);
|
||||
|
||||
// Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16
|
||||
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
|
||||
uint sub = step / 32;
|
||||
uint sb = sub / 8;
|
||||
uint j = sub % 8;
|
||||
|
||||
// Load d and dm for super-block
|
||||
uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0);
|
||||
half d_val = src0_d[d_offset];
|
||||
half dm_val = src0_dm[d_offset];
|
||||
|
||||
// Load sub-block scale and min
|
||||
global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE;
|
||||
uchar sv, mn;
|
||||
get_scale_min_k4(j, sc, &sv, &mn);
|
||||
|
||||
float scale = (float)d_val * (float)sv;
|
||||
float minv = -(float)dm_val * (float)mn;
|
||||
|
||||
// qh is stored at sub-block granularity
|
||||
uint qh_offset = row + sub * ne01 + expert_id * num_superblocks * 8 * ne01 + get_global_id(0);
|
||||
uchar4 qhx32 = as_uchar4(src0_qh[qh_offset]);
|
||||
|
||||
// First sub-block (16 elements)
|
||||
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
uint b_sub_offset = col * ne00 + step;
|
||||
|
||||
// Load 16 q (64-bits) in transposed layout
|
||||
uint2 q4x16;
|
||||
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B
|
||||
float8 bx8_f32;
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
half8 bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
dequantize_q5_k(as_ushort4(q4x16), qhx32.lo, reg_a, scale, minv);
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
half16 acc;
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
|
||||
// Second half
|
||||
uint half_step = step + TILESIZE_K;
|
||||
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
b_sub_offset = col * ne00 + half_step;
|
||||
|
||||
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
dequantize_q5_k(as_ushort4(q4x16), qhx32.hi, reg_a, scale, minv);
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
}
|
||||
|
||||
// Load post router and share in LM
|
||||
__local uint out_idx[TILESIZE_N];
|
||||
|
||||
if (get_local_id(0) < TILESIZE_N) {
|
||||
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
|
||||
if (idx == 0xFFFFFFFF) {
|
||||
idx = src2[block_id_n * TILESIZE_N + 0];
|
||||
}
|
||||
out_idx[get_local_id(0)] = idx * ne01;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Scatter results back to original position in output grid
|
||||
uint m_offset = row + get_local_id(0);
|
||||
|
||||
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
|
||||
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
|
||||
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
|
||||
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
|
||||
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
|
||||
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
|
||||
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
|
||||
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
|
||||
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
|
||||
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
|
||||
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
|
||||
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
|
||||
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
|
||||
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
|
||||
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
|
||||
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
|
||||
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
|
||||
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
|
||||
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
|
||||
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
|
||||
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
|
||||
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
|
||||
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
|
||||
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
|
||||
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
|
||||
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
|
||||
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
|
||||
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
|
||||
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
|
||||
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
|
||||
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
|
||||
|
||||
// Store zero padding parts to the index of first output in tile
|
||||
barrier(CLK_GLOBAL_MEM_FENCE);
|
||||
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
|
||||
}
|
||||
|
|
@ -0,0 +1,263 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
|
||||
|
||||
#define TILESIZE_K 16
|
||||
#define TILESIZE_M 64
|
||||
#define TILESIZE_N 32
|
||||
#define QK_K 256
|
||||
|
||||
#define dequantize_q6_k(qs16, qh16, a_f16, scale) \
|
||||
a_f16.s0 = (half)(((float)(( qs16.s0 & 0x000F) | ((uint)(( qh16 ) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.s1 = (half)(((float)((( qs16.s0 >> 4) & 0x000F) | ((uint)(( qh16 >> 2) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.s2 = (half)(((float)((( qs16.s0 >> 8) & 0x000F) | ((uint)(( qh16 >> 4) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.s3 = (half)(((float)((( qs16.s0 >>12) & 0x000F) | ((uint)(( qh16 >> 6) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.s4 = (half)(((float)(( qs16.s1 & 0x000F) | ((uint)(( qh16 >> 8) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.s5 = (half)(((float)((( qs16.s1 >> 4) & 0x000F) | ((uint)(( qh16 >> 10) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.s6 = (half)(((float)((( qs16.s1 >> 8) & 0x000F) | ((uint)(( qh16 >> 12) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.s7 = (half)(((float)((( qs16.s1 >>12) & 0x000F) | ((uint)(( qh16 >> 14) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.s8 = (half)(((float)(( qs16.s2 & 0x000F) | ((uint)(( qh16 >> 16) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.s9 = (half)(((float)((( qs16.s2 >> 4) & 0x000F) | ((uint)(( qh16 >> 18) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.sa = (half)(((float)((( qs16.s2 >> 8) & 0x000F) | ((uint)(( qh16 >> 20) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.sb = (half)(((float)((( qs16.s2 >>12) & 0x000F) | ((uint)(( qh16 >> 22) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.sc = (half)(((float)(( qs16.s3 & 0x000F) | ((uint)(( qh16 >> 24) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.sd = (half)(((float)((( qs16.s3 >> 4) & 0x000F) | ((uint)(( qh16 >> 26) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.se = (half)(((float)((( qs16.s3 >> 8) & 0x000F) | ((uint)(( qh16 >> 28) & 0x3) << 4)) - 32.f) * scale); \
|
||||
a_f16.sf = (half)(((float)((( qs16.s3 >>12) & 0x000F) | ((uint)(( qh16 >> 30) & 0x3) << 4)) - 32.f) * scale); \
|
||||
|
||||
|
||||
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
|
||||
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
|
||||
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
|
||||
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
|
||||
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
|
||||
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
|
||||
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
|
||||
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
|
||||
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
|
||||
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
|
||||
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
|
||||
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
|
||||
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
|
||||
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
|
||||
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
|
||||
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
|
||||
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
|
||||
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
|
||||
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
|
||||
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
|
||||
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
|
||||
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
|
||||
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
|
||||
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
|
||||
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
|
||||
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
|
||||
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
|
||||
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
|
||||
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
|
||||
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
|
||||
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
|
||||
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
|
||||
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
|
||||
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
|
||||
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
|
||||
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
|
||||
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
|
||||
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
|
||||
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
|
||||
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
|
||||
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
|
||||
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
|
||||
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
|
||||
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
|
||||
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
|
||||
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
|
||||
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
|
||||
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
|
||||
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
|
||||
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
|
||||
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
|
||||
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
|
||||
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
|
||||
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
|
||||
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
|
||||
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
|
||||
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
|
||||
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
|
||||
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
|
||||
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
|
||||
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
|
||||
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
|
||||
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
|
||||
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
|
||||
|
||||
__attribute__((qcom_wave_pair_mode(1)))
|
||||
kernel void kernel_gemm_moe_q6_k_f32_ns(
|
||||
__read_only image1d_buffer_t src0_ql,
|
||||
__global uint * src0_qh,
|
||||
__global char * src0_s,
|
||||
__global half * src0_d,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global ushort * src2_emap,
|
||||
__write_only image1d_buffer_t dst,
|
||||
__global int * total_tiles,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint block_id_m = get_global_id(1); // m_tile
|
||||
uint block_id_n = get_global_id(2); // n_tile
|
||||
|
||||
// Boundary check
|
||||
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
|
||||
return;
|
||||
}
|
||||
|
||||
__private half16 reg_a;
|
||||
__private float32 reg_c = (float32)(0);
|
||||
__local half4 shared_b[128];
|
||||
|
||||
const ushort expert_id = src2_emap[block_id_n];
|
||||
|
||||
const uint row = block_id_m * TILESIZE_M;
|
||||
const uint col = block_id_n * TILESIZE_N;
|
||||
|
||||
uint sub_block_id_m = get_local_id(0);
|
||||
uint2 b_global_offset;
|
||||
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
|
||||
b_global_offset.y = b_global_offset.x + (16 * ne00);
|
||||
uint2 b_local_offset;
|
||||
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
|
||||
b_local_offset.y = b_local_offset.x + 16;
|
||||
|
||||
uint num_superblocks = ne00 / QK_K;
|
||||
uint scales_per_row = num_superblocks * 16;
|
||||
uint row_idx = row + get_global_id(0);
|
||||
|
||||
// Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16
|
||||
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
|
||||
uint sub = step / 32; // 32-element group index
|
||||
uint sb = sub / 8; // super-block index
|
||||
uint j = sub % 8; // group within super-block
|
||||
|
||||
// Load d for super-block
|
||||
uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0);
|
||||
half d_val = src0_d[d_offset];
|
||||
|
||||
// Load sub-block scales
|
||||
global const char * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * 16;
|
||||
float scale0 = (float)d_val * (float)sc[j * 2];
|
||||
float scale1 = (float)d_val * (float)sc[j * 2 + 1];
|
||||
|
||||
uint qh_base = row + (sub * 2) * ne01 + expert_id * (num_superblocks * 16) * ne01 + get_global_id(0);
|
||||
uint qh_first16 = src0_qh[qh_base];
|
||||
uint qh_second16 = src0_qh[qh_base + ne01];
|
||||
|
||||
// First half (16 elements)
|
||||
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
uint b_sub_offset = col * ne00 + step;
|
||||
|
||||
// Load 16 ql nibbles (2 uints) from image
|
||||
uint2 q4x16;
|
||||
q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x;
|
||||
q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B
|
||||
float8 bx8_f32;
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
half8 bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantize first 16 elements (scale0)
|
||||
dequantize_q6_k(as_ushort4(q4x16), qh_first16, reg_a, scale0);
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
half16 acc;
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
|
||||
// Second half
|
||||
uint half_step = step + TILESIZE_K;
|
||||
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
b_sub_offset = col * ne00 + half_step;
|
||||
|
||||
q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x;
|
||||
q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
dequantize_q6_k(as_ushort4(q4x16), qh_second16, reg_a, scale1);
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
}
|
||||
|
||||
// Load post router and share in LM
|
||||
__local uint out_idx[TILESIZE_N];
|
||||
|
||||
if (get_local_id(0) < TILESIZE_N) {
|
||||
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
|
||||
if (idx == 0xFFFFFFFF) {
|
||||
idx = src2[block_id_n * TILESIZE_N + 0];
|
||||
}
|
||||
out_idx[get_local_id(0)] = idx * ne01;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Scatter results back to original position in output grid
|
||||
uint m_offset = row + get_local_id(0);
|
||||
|
||||
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
|
||||
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
|
||||
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
|
||||
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
|
||||
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
|
||||
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
|
||||
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
|
||||
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
|
||||
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
|
||||
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
|
||||
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
|
||||
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
|
||||
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
|
||||
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
|
||||
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
|
||||
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
|
||||
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
|
||||
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
|
||||
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
|
||||
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
|
||||
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
|
||||
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
|
||||
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
|
||||
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
|
||||
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
|
||||
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
|
||||
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
|
||||
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
|
||||
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
|
||||
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
|
||||
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
|
||||
|
||||
// Store zero padding parts to the index of first output in tile
|
||||
barrier(CLK_GLOBAL_MEM_FENCE);
|
||||
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
|
||||
}
|
||||
|
|
@ -0,0 +1,151 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
#define N_SIMDGROUP 4
|
||||
#define SIMDGROUP_WIDTH 64
|
||||
|
||||
inline void get_scale_min_k4(
|
||||
int j,
|
||||
global const uchar * q,
|
||||
uchar * d,
|
||||
uchar * m
|
||||
) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63;
|
||||
*m = q[j+4] & 63;
|
||||
} else {
|
||||
*d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2);
|
||||
*m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2);
|
||||
}
|
||||
}
|
||||
|
||||
static inline float8 q4_k_to_fp32_packed8(ushort2 q4x8, float scale, float minv) {
|
||||
float8 fp32x8;
|
||||
fp32x8.s0 = (q4x8.s0 & 0x000F) * scale - minv;
|
||||
fp32x8.s1 = ((q4x8.s0 & 0x00F0) >> 4) * scale - minv;
|
||||
fp32x8.s2 = ((q4x8.s0 & 0x0F00) >> 8) * scale - minv;
|
||||
fp32x8.s3 = ((q4x8.s0 & 0xF000) >> 12) * scale - minv;
|
||||
fp32x8.s4 = (q4x8.s1 & 0x000F) * scale - minv;
|
||||
fp32x8.s5 = ((q4x8.s1 & 0x00F0) >> 4) * scale - minv;
|
||||
fp32x8.s6 = ((q4x8.s1 & 0x0F00) >> 8) * scale - minv;
|
||||
fp32x8.s7 = ((q4x8.s1 & 0xF000) >> 12) * scale - minv;
|
||||
return fp32x8;
|
||||
}
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("half")))
|
||||
__kernel void kernel_gemv_moe_q4_k_f32_ns(
|
||||
__global uint * src0_q,
|
||||
__global half * src0_d,
|
||||
__global half * src0_dm,
|
||||
__global uchar * src0_s,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne11
|
||||
) {
|
||||
uint i01 = get_global_id(0);
|
||||
uint i20 = get_global_id(2);
|
||||
uint sgid = get_local_id(1);
|
||||
uint slid = get_sub_group_local_id();
|
||||
|
||||
uint i11 = i20 % ne11;
|
||||
|
||||
uint expert_id = src2[i20];
|
||||
|
||||
int num_superblocks = ne00 / QK_K;
|
||||
int num_subblocks = ne00 / 32;
|
||||
int scales_per_row = num_superblocks * K_SCALE_SIZE;
|
||||
|
||||
// Expert offsets in the transposed noshuffle layout
|
||||
uint expert_q_offset = expert_id * (ne00 / 8) * ne01;
|
||||
uint expert_d_offset = expert_id * num_superblocks * ne01;
|
||||
|
||||
__private float sum = 0.0f;
|
||||
|
||||
// Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter
|
||||
for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) {
|
||||
uint sb = ib / 8;
|
||||
uint j = ib % 8;
|
||||
|
||||
// Load d and dmin for this super-block
|
||||
half d_val = src0_d[expert_d_offset + sb * ne01 + i01];
|
||||
half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01];
|
||||
|
||||
// Load sub-block scale and min
|
||||
global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE;
|
||||
uchar sv, mn;
|
||||
get_scale_min_k4(j, sc, &sv, &mn);
|
||||
|
||||
float scale = (float)d_val * (float)sv;
|
||||
float minv = (float)dm_val * (float)mn;
|
||||
|
||||
// Load 4 uints of quants (32 nibbles = 32 elements)
|
||||
uint q_base = expert_q_offset + ib * ne01 * 4 + i01;
|
||||
|
||||
uint4 regQ;
|
||||
regQ.s0 = src0_q[q_base];
|
||||
regQ.s1 = src0_q[q_base + ne01];
|
||||
regQ.s2 = src0_q[q_base + ne01 * 2];
|
||||
regQ.s3 = src0_q[q_base + ne01 * 3];
|
||||
|
||||
// Load activations: 32 floats = 8 float4s
|
||||
uint y_offset = i11 * ne00 / 4 + ib * 8;
|
||||
|
||||
float8 fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s0), scale, minv);
|
||||
|
||||
float4 shared_y4;
|
||||
shared_y4 = read_imagef(src1, (y_offset + 0));
|
||||
float4 acc = shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 1));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s1), scale, minv);
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 2));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 3));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s2), scale, minv);
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 4));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 5));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s3), scale, minv);
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 6));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 7));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #subgroups=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
||||
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
||||
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
||||
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 output per thread in subgroup 0
|
||||
if (sgid == 0) {
|
||||
dst = dst + (offsetd >> 2);
|
||||
dst[i01 + i20 * ne01] = sum;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
#define N_SIMDGROUP 4
|
||||
#define SIMDGROUP_WIDTH 64
|
||||
|
||||
inline void get_scale_min_k4(
|
||||
int j,
|
||||
global const uchar * q,
|
||||
uchar * d,
|
||||
uchar * m
|
||||
) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63;
|
||||
*m = q[j+4] & 63;
|
||||
} else {
|
||||
*d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2);
|
||||
*m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2);
|
||||
}
|
||||
}
|
||||
|
||||
static inline float8 q5_k_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8, half s, half m) {
|
||||
float8 fp32x8;
|
||||
fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) * s + m);
|
||||
fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) * s + m);
|
||||
fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) * s + m);
|
||||
fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) * s + m);
|
||||
fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) * s + m);
|
||||
fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) * s + m);
|
||||
fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) * s + m);
|
||||
fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) * s + m);
|
||||
return fp32x8;
|
||||
}
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("half")))
|
||||
__kernel void kernel_gemv_moe_q5_k_f32_ns(
|
||||
__global uint * src0_q,
|
||||
__global uint * src0_qh,
|
||||
__global half * src0_d,
|
||||
__global half * src0_dm,
|
||||
__global uchar * src0_s,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne11
|
||||
) {
|
||||
uint i01 = get_global_id(0);
|
||||
uint i20 = get_global_id(2);
|
||||
uint sgid = get_local_id(1);
|
||||
uint slid = get_sub_group_local_id();
|
||||
|
||||
uint i11 = i20 % ne11;
|
||||
|
||||
uint expert_id = src2[i20];
|
||||
|
||||
int num_superblocks = ne00 / QK_K;
|
||||
int num_subblocks = ne00 / 32;
|
||||
int scales_per_row = num_superblocks * K_SCALE_SIZE;
|
||||
|
||||
// Expert offsets in the transposed noshuffle layout
|
||||
uint expert_q_offset = expert_id * (ne00 / 8) * ne01;
|
||||
uint expert_d_offset = expert_id * num_superblocks * ne01;
|
||||
|
||||
__private float sum = 0.0f;
|
||||
|
||||
// Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter
|
||||
for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) {
|
||||
uint sb = ib / 8;
|
||||
uint j = ib % 8;
|
||||
|
||||
// Load d and dmin for this super-block
|
||||
half d_val = src0_d[expert_d_offset + sb * ne01 + i01];
|
||||
half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01];
|
||||
|
||||
// sub_block index = sb * 8 + j
|
||||
uint expert_qh_offset = expert_id * num_superblocks * 8 * ne01;
|
||||
uchar4 regQh = as_uchar4(src0_qh[expert_qh_offset + (sb * 8 + j) * ne01 + i01]);
|
||||
|
||||
// Load sub-block scale and min
|
||||
global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE;
|
||||
uchar sv, mn;
|
||||
get_scale_min_k4(j, sc, &sv, &mn);
|
||||
|
||||
float scale = (float)d_val * (float)sv;
|
||||
float minv = -(float)dm_val * (float)mn;
|
||||
|
||||
// Load 4 uints of quants (32 nibbles = 32 elements)
|
||||
uint q_base = expert_q_offset + ib * ne01 * 4 + i01;
|
||||
|
||||
uint4 regQ;
|
||||
regQ.s0 = src0_q[q_base];
|
||||
regQ.s1 = src0_q[q_base + ne01];
|
||||
regQ.s2 = src0_q[q_base + ne01 * 2];
|
||||
regQ.s3 = src0_q[q_base + ne01 * 3];
|
||||
|
||||
// Load activations: 32 floats = 8 float4s
|
||||
uint y_offset = i11 * ne00 / 4 + ib * 8;
|
||||
|
||||
float8 fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0, scale, minv);
|
||||
|
||||
float4 shared_y4;
|
||||
shared_y4 = read_imagef(src1, (y_offset + 0));
|
||||
float4 acc = shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 1));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1, scale, minv);
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 2));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 3));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2, scale, minv);
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 4));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 5));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3, scale, minv);
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 6));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 7));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #subgroups=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
||||
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
||||
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
||||
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 output per thread in subgroup 0
|
||||
if (sgid == 0) {
|
||||
dst = dst + (offsetd >> 2);
|
||||
dst[i01 + i20 * ne01] = sum;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#define QK_K 256
|
||||
#define N_SIMDGROUP 4
|
||||
#define SIMDGROUP_WIDTH 64
|
||||
|
||||
static inline float8 q6_k_to_fp32_packed8(ushort2 ql8, ushort qh8, float d_scale) {
|
||||
float8 fp32x8;
|
||||
fp32x8.s0 = ((float)(( ql8.s0 & 0x000F) | ((uint)((qh8 ) & 0x3) << 4)) - 32.f) * d_scale;
|
||||
fp32x8.s1 = ((float)((( ql8.s0 >> 4) & 0x000F) | ((uint)((qh8 >> 2) & 0x3) << 4)) - 32.f) * d_scale;
|
||||
fp32x8.s2 = ((float)((( ql8.s0 >> 8) & 0x000F) | ((uint)((qh8 >> 4) & 0x3) << 4)) - 32.f) * d_scale;
|
||||
fp32x8.s3 = ((float)((( ql8.s0 >> 12)& 0x000F) | ((uint)((qh8 >> 6) & 0x3) << 4)) - 32.f) * d_scale;
|
||||
fp32x8.s4 = ((float)(( ql8.s1 & 0x000F) | ((uint)((qh8 >> 8) & 0x3) << 4)) - 32.f) * d_scale;
|
||||
fp32x8.s5 = ((float)((( ql8.s1 >> 4) & 0x000F) | ((uint)((qh8 >>10) & 0x3) << 4)) - 32.f) * d_scale;
|
||||
fp32x8.s6 = ((float)((( ql8.s1 >> 8) & 0x000F) | ((uint)((qh8 >>12) & 0x3) << 4)) - 32.f) * d_scale;
|
||||
fp32x8.s7 = ((float)((( ql8.s1 >> 12)& 0x000F) | ((uint)((qh8 >>14) & 0x3) << 4)) - 32.f) * d_scale;
|
||||
return fp32x8;
|
||||
}
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("half")))
|
||||
__kernel void kernel_gemv_moe_q6_k_f32_ns(
|
||||
__global uint * src0_ql,
|
||||
__global uint * src0_qh,
|
||||
__global char * src0_s,
|
||||
__global half * src0_d,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne11
|
||||
) {
|
||||
uint i01 = get_global_id(0);
|
||||
uint i20 = get_global_id(2);
|
||||
uint sgid = get_local_id(1);
|
||||
uint slid = get_sub_group_local_id();
|
||||
|
||||
uint i11 = i20 % ne11;
|
||||
|
||||
uint expert_id = src2[i20];
|
||||
|
||||
int num_superblocks = ne00 / QK_K;
|
||||
int num_subblocks = ne00 / 32; // 8 sub-blocks of 32 per super-block
|
||||
int scales_per_row = num_superblocks * 16;
|
||||
|
||||
// Expert offsets in the transposed noshuffle layout
|
||||
uint expert_ql_offset = expert_id * (ne00 / 8) * ne01; // 32 uints per super-block
|
||||
uint expert_qh_offset = expert_id * (ne00 / 16) * ne01; // 16 uints per super-block
|
||||
uint expert_d_offset = expert_id * num_superblocks * ne01;
|
||||
|
||||
__private float sum = 0.0f;
|
||||
|
||||
// Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter
|
||||
for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) {
|
||||
uint sb = ib / 8; // super-block index
|
||||
uint j = ib % 8; // 32-element group within super-block
|
||||
|
||||
// Load d for this super-block
|
||||
half d_val = src0_d[expert_d_offset + sb * ne01 + i01];
|
||||
|
||||
// Load 2 sub-block scales
|
||||
global const char * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * 16;
|
||||
float scale0 = (float)d_val * (float)sc[j * 2];
|
||||
float scale1 = (float)d_val * (float)sc[j * 2 + 1];
|
||||
|
||||
// Load 4 uints of ql
|
||||
uint ql_base = expert_ql_offset + (ib * 4) * ne01 + i01;
|
||||
uint4 regQL;
|
||||
regQL.s0 = src0_ql[ql_base];
|
||||
regQL.s1 = src0_ql[ql_base + ne01];
|
||||
regQL.s2 = src0_ql[ql_base + ne01 * 2];
|
||||
regQL.s3 = src0_ql[ql_base + ne01 * 3];
|
||||
|
||||
// Load 2 uints of qh
|
||||
uint qh_base = expert_qh_offset + (ib * 2) * ne01 + i01;
|
||||
uint2 regQH;
|
||||
regQH.s0 = src0_qh[qh_base];
|
||||
regQH.s1 = src0_qh[qh_base + ne01];
|
||||
|
||||
// Load activations: 32 floats = 8 float4s
|
||||
uint y_offset = i11 * ne00 / 4 + ib * 8;
|
||||
|
||||
float8 fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s0), (ushort)(regQH.s0 & 0xFFFF), scale0);
|
||||
|
||||
float4 shared_y4;
|
||||
shared_y4 = read_imagef(src1, (y_offset + 0));
|
||||
float4 acc = shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 1));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s1), (ushort)(regQH.s0 >> 16), scale0);
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 2));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 3));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s2), (ushort)(regQH.s1 & 0xFFFF), scale1);
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 4));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 5));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s3), (ushort)(regQH.s1 >> 16), scale1);
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 6));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (y_offset + 7));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #subgroups=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
||||
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
||||
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
||||
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 output per thread in subgroup 0
|
||||
if (sgid == 0) {
|
||||
dst = dst + (offsetd >> 2);
|
||||
dst[i01 + i20 * ne01] = sum;
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue