hexagon: add support for Q4_1 in MUL_MAT and MUL_MAT_ID (llama/23647)

* hex-mm: add support for Q4_1 matmul/matvec, hvx-only for now

* hmx-mm: add support for Q4_1

* hex-mm: use Q8_1 dynamic quantization to avoid having to compute sums in the vec_dot

* hexagon: fix repack scratch buffer overflow

* hex-mm: fix Q4_1 repack buffer sizing

* hexagon: flip the build order for mm and fa (seems to help LTO)

* hex-mm: add vec_dot 4x1s and minor HMX cleanup after adding Q4_1

* hex-mm: fix fp16 vec_dot fallback to 2x1 and another issue that could cause incorrect output

* hexagon: resurrect early-wake and add support for polling for op-batch completions

With Q4_1 ggml-hexagon now claims pretty much the entire graphs which gives the CPU more time to chilax.
This is a good thing! But it does add extra latency for the pure benchmark runs.
Early wakeup helps recover the latency a bit in the normals runs and op-batch polling is just for benchmarking.

---------

Co-authored-by: Todor Boinovski <todorb@qti.qualcomm.com>
This commit is contained in:
Max Krasnyansky 2026-05-27 10:46:11 -07:00 committed by Georgi Gerganov
parent a52bd385d6
commit 3bbe93378c
6 changed files with 1876 additions and 72 deletions

View File

@ -68,6 +68,7 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C }
static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE;
static int opt_opbatch = 1024; // max number of ops in a batch
static int opt_opqueue = 16; // max number of pending batches
static int opt_oppoll = 0; // polling for batch completions
static std::regex* opt_opfilter = NULL; // regex of ops to not claim
@ -550,7 +551,7 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size)
size_t row_size = ggml_row_size(t->type, t->ne[0]);
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
// Ensure we don't try to read more data than is available in the source buffer 'data'
// or write more than the tensor can hold.
@ -611,7 +612,7 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
size_t row_size = ggml_row_size(t->type, t->ne[0]);
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
// Ensure we don't try to copy more data than the tensor actually contains.
const size_t total_tensor_size = (size_t)nrows * row_size;
@ -660,6 +661,239 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
ggml_aligned_free(buf_rp, row_size_rp);
}
static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) {
static const int qk = QK4_1;
for (unsigned int i = 0; i < qk / 2; ++i) {
const int x0 = (x->qs[i] & 0x0F);
const int x1 = (x->qs[i] >> 4);
qs[bi * qk + i + 0] = x0;
qs[bi * qk + i + qk / 2] = x1;
}
}
static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi) {
static const int qk = QK4_1;
for (unsigned int i = 0; i < qk / 2; ++i) {
const uint8_t x0 = qs[bi * qk + i + 0];
const uint8_t x1 = qs[bi * qk + i + qk / 2];
x->qs[i] = x0 | (x1 << 4);
}
}
static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) {
static const int qk = QK_Q4_0x4x2;
const int nb = (k + qk - 1) / qk; // number of blocks (padded)
const int nloe = k % qk; // leftovers
const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes
const int qblk_size = qk / 2; // int4 = 128 bytes
const int qrow_size = k / 2; // int4 (not padded to blocks)
uint8_t * y_q = y + 0; // quants first
uint8_t * y_d = y + qrow_size; // then scales/offsets
// Repack the quants
for (int i = 0; i < nb; i++) {
uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
unpack_q4_1_quants(qs, &x[i * 8 + 0], 0);
unpack_q4_1_quants(qs, &x[i * 8 + 1], 1);
unpack_q4_1_quants(qs, &x[i * 8 + 2], 2);
unpack_q4_1_quants(qs, &x[i * 8 + 3], 3);
unpack_q4_1_quants(qs, &x[i * 8 + 4], 4);
unpack_q4_1_quants(qs, &x[i * 8 + 5], 5);
unpack_q4_1_quants(qs, &x[i * 8 + 6], 6);
unpack_q4_1_quants(qs, &x[i * 8 + 7], 7);
bool partial = (nloe && i == nb-1);
uint8_t * q = y_q + (i * qblk_size);
for (int j = 0; j < qk / 2; j++) {
q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000];
}
}
// Repack the scales and offsets
for (int i = 0; i < nb; i++) {
ggml_half * d_m = (ggml_half *) (y_d + i * dblk_size);
for (int j = 0; j < 8; j++) {
d_m[j * 2 + 0] = x[i * 8 + j].d;
d_m[j * 2 + 1] = x[i * 8 + j].m;
}
}
}
static void unpack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) {
static const int qk = QK_Q4_0x4x2;
const int nb = (k + qk - 1) / qk; // number of blocks (padded)
const int nloe = k % qk; // leftovers
const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes
const int qblk_size = qk / 2; // int4 = 128 bytes
const int qrow_size = k / 2; // int4 (not padded to blocks)
const uint8_t * y_q = y + 0; // quants first
const uint8_t * y_d = y + qrow_size; // then scales/offsets
// Unpack the quants
for (int i = 0; i < nb; i++) {
uint8_t qs[QK_Q4_0x4x2];
bool partial = (nloe && i == nb-1);
const uint8_t * q = y_q + (i * qblk_size);
for (int j = 0; j < qk / 2; j++) {
if (partial) {
qs[j*2+0] = q[j] & 0x0F;
qs[j*2+1] = q[j] >> 4;
} else {
qs[j+000] = q[j] & 0x0F;
qs[j+128] = q[j] >> 4;
}
}
pack_q4_1_quants(&x[i * 8 + 0], qs, 0);
pack_q4_1_quants(&x[i * 8 + 1], qs, 1);
pack_q4_1_quants(&x[i * 8 + 2], qs, 2);
pack_q4_1_quants(&x[i * 8 + 3], qs, 3);
pack_q4_1_quants(&x[i * 8 + 4], qs, 4);
pack_q4_1_quants(&x[i * 8 + 5], qs, 5);
pack_q4_1_quants(&x[i * 8 + 6], qs, 6);
pack_q4_1_quants(&x[i * 8 + 7], qs, 7);
}
// Unpack the scales and offsets
for (int i = 0; i < nb; i++) {
const ggml_half * d_m = (const ggml_half *) (y_d + i * dblk_size);
for (int j = 0; j < 8; j++) {
x[i * 8 + j].d = d_m[j * 2 + 0];
x[i * 8 + j].m = d_m[j * 2 + 1];
}
}
}
static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) {
static const int qk = QK_Q4_0x4x2;
const int nb = (k + qk - 1) / qk; // number of blocks (padded)
uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
memset(qs, 0, sizeof(qs));
for (int i = 0; i < nb; i++) {
pack_q4_1_quants(&x[i * 8 + 0], qs, 0);
pack_q4_1_quants(&x[i * 8 + 1], qs, 1);
pack_q4_1_quants(&x[i * 8 + 2], qs, 2);
pack_q4_1_quants(&x[i * 8 + 3], qs, 3);
pack_q4_1_quants(&x[i * 8 + 4], qs, 4);
pack_q4_1_quants(&x[i * 8 + 5], qs, 5);
pack_q4_1_quants(&x[i * 8 + 6], qs, 6);
pack_q4_1_quants(&x[i * 8 + 7], qs, 7);
}
for (int i = 0; i < nb; i++) {
for (int j = 0; j < 8; j++) {
x[i * 8 + j].d = 0;
x[i * 8 + j].m = 0;
}
}
}
static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) {
int64_t nrows = ggml_nrows(t);
size_t row_size = ggml_row_size(t->type, t->ne[0]);
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
const size_t total_tensor_size = (size_t)nrows * row_size;
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
const int64_t n_full_rows = n_bytes_to_copy / row_size;
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
void * buf_pd = ggml_aligned_malloc(row_size_pd);
GGML_ASSERT(buf_pd != NULL);
void * buf_rp = ggml_aligned_malloc(row_size_rp);
GGML_ASSERT(buf_rp != NULL);
HEX_VERBOSE("ggml-hex: repack-q4_1-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
t->ne[0], nrows, row_size);
init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]);
for (int64_t i = 0; i < n_full_rows; i++) {
const uint8_t * src = (const uint8_t *) data + (i * row_size);
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
memcpy(buf_pd, src, row_size);
repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]);
memcpy(dst, buf_rp, row_size);
}
if (n_rem_bytes > 0) {
const int64_t i = n_full_rows;
const uint8_t * src = (const uint8_t *) data + (i * row_size);
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]);
memcpy(buf_pd, src, n_rem_bytes);
repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]);
memcpy(dst, buf_rp, n_rem_bytes);
}
ggml_aligned_free(buf_pd, row_size_pd);
ggml_aligned_free(buf_rp, row_size_rp);
}
static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) {
int64_t nrows = ggml_nrows(t);
size_t row_size = ggml_row_size(t->type, t->ne[0]);
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
const size_t total_tensor_size = (size_t)nrows * row_size;
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
const int64_t n_full_rows = n_bytes_to_copy / row_size;
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
void * buf_pd = ggml_aligned_malloc(row_size_pd);
GGML_ASSERT(buf_pd != NULL);
void * buf_rp = ggml_aligned_malloc(row_size_rp);
GGML_ASSERT(buf_rp != NULL);
HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
t->ne[0], nrows, row_size);
memset(buf_rp, 0, row_size_rp); // clear-out padded buffer to make sure the tail is all zeros
for (int64_t i = 0; i < n_full_rows; i++) {
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
uint8_t * dst = (uint8_t *) data + (i * row_size);
memcpy(buf_rp, src, row_size);
unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]);
memcpy(dst, buf_pd, row_size);
}
if (n_rem_bytes > 0) {
const int64_t i = n_full_rows;
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
uint8_t * dst = (uint8_t *) data + (i * row_size);
// We still need to read and unpack the entire source row because quantization is block-based.
memcpy(buf_rp, src, row_size);
unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]);
memcpy(dst, buf_pd, n_rem_bytes);
}
ggml_aligned_free(buf_pd, row_size_pd);
ggml_aligned_free(buf_rp, row_size_rp);
}
// ======== Q8x4x2 ====================
static void dump_block_q8_0(const block_q8_0 * b, int i) {
HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2],
@ -876,7 +1110,7 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size)
size_t row_size = ggml_row_size(t->type, t->ne[0]);
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales)
// Ensure we don't try to read more data than is available in the source buffer 'data'
// or write more than the tensor can hold.
@ -937,7 +1171,7 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size)
size_t row_size = ggml_row_size(t->type, t->ne[0]);
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales)
// Ensure we don't try to copy more data than the tensor actually contains.
const size_t total_tensor_size = (size_t)nrows * row_size;
@ -1238,7 +1472,7 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si
size_t row_size = ggml_row_size(t->type, t->ne[0]);
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
// Ensure we don't try to read more data than is available in the source buffer 'data'
// or write more than the tensor can hold.
@ -1299,7 +1533,7 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si
size_t row_size = ggml_row_size(t->type, t->ne[0]);
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
// Ensure we don't try to copy more data than the tensor actually contains.
const size_t total_tensor_size = (size_t)nrows * row_size;
@ -1365,6 +1599,12 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
repack_q4_0_q4x4x2(tensor, data, size);
break;
case GGML_TYPE_Q4_1:
GGML_ASSERT(offset == 0);
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
repack_q4_1_q4x4x2(tensor, data, size);
break;
case GGML_TYPE_Q8_0:
GGML_ASSERT(offset == 0);
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
@ -1407,6 +1647,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
repack_q4x4x2_q4_0(data, tensor, size);
break;
case GGML_TYPE_Q4_1:
GGML_ASSERT(offset == 0);
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
repack_q4x4x2_q4_1(data, tensor, size);
break;
case GGML_TYPE_Q8_0:
GGML_ASSERT(offset == 0);
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
@ -1886,7 +2132,8 @@ void ggml_hexagon_session::flush_pending(bool all) {
uint32_t n_dbufs;
// Read response packet from queue
int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, DSPQUEUE_TIMEOUT);
const uint32_t timeo = opt_oppoll ? 0 : DSPQUEUE_TIMEOUT;
int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, timeo);
if (err == AEE_EEXPIRED) {
continue;
}
@ -2327,6 +2574,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
@ -2377,6 +2625,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
@ -3622,6 +3871,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
// Basic sanity checks to make sure definitions match
static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0,
"please update hexagon_type to match ggml_type");
static_assert((unsigned int) HTP_TYPE_Q4_1 == (unsigned int) GGML_TYPE_Q4_1,
"please update hexagon_type to match ggml_type");
static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0,
"please update hexagon_type to match ggml_type");
static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
@ -3634,6 +3885,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE");
const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH");
const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE");
const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL");
const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER");
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
const char * str_etm = getenv("GGML_HEXAGON_ETM");
@ -3671,6 +3923,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage;
opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch;
opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue;
opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll;
opt_profile = str_profile ? atoi(str_profile) : 0;
opt_etm = str_etm ? atoi(str_etm) : 0;
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;

View File

@ -59,14 +59,14 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-queue.c
hmx-matmul-ops.c
hmx-flash-attn-ops.c
hmx-matmul-ops.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-matmul-ops.c
hmx-flash-attn-ops.c
hmx-matmul-ops.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)

View File

@ -34,6 +34,10 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
-8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0,
};
static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0,
};
// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value
// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6
static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
@ -62,6 +66,8 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
case HTP_TYPE_Q4_1:
return (size_t) nb * (QK_Q4_0x4x2 / 2 + 32); // 160 * nb
case HTP_TYPE_Q8_0:
return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
case HTP_TYPE_MXFP4:
@ -233,6 +239,54 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
return r;
}
static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) {
HVX_Vector vq = hvx_vmemu(packed_32);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_dm = hvx_vmemu(scale_offset);
HVX_Vector v_scales = hvx_vec_repl_f16(v_dm);
HVX_Vector v_offsets = hvx_vec_repl_f16(Q6_V_vror_VR(v_dm, 2));
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_hf = Q6_V_lo_W(vp);
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets));
}
static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx(
const uint8_t *packed_128, bool upper_nibbles,
const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) {
HVX_Vector vq = hvx_vmemu(packed_128);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_lo = Q6_V_lo_W(vp);
HVX_Vector v_hi = Q6_V_hi_W(vp);
HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4);
HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2);
HVX_Vector vd = Q6_V_lo_W(dm_deal);
HVX_Vector vm = Q6_V_hi_W(dm_deal);
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vd);
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vd, 4));
HVX_Vector v_os01 = hvx_vec_repl_2x_f16(vm);
HVX_Vector v_os23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vm, 4));
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01), v_os01));
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23), v_os23));
HVX_Vector_x2 r = { v_lo, v_hi };
return r;
}
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) {
HVX_Vector vq = hvx_vmemu(quants_32);
@ -331,11 +385,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
int start_tile, int end_tile) {
const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS;
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL);
const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1);
const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block;
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
(weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) :
hvx_vmem(q4_0_to_fp16_lut);
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
@ -356,8 +412,10 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
bool upper = (sub_blk_base >= 4);
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
+ sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
unsigned scale_off = qrow_size + blk_idx * dblk_size
+ sub_blk_base * scale_step;
__fp16 *tile_bases[4];
for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
@ -367,20 +425,38 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
if (is_q4_1) {
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
} else {
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
}
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); }
@ -446,26 +522,43 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
bool upper = (sub_blk >= 4);
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step;
HVX_Vector v_off = v_scat_base; // reset to column 0
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
if (is_q4_1) {
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(
r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
HVX_Vector v1 = (row1 < n_cols)
? dequantize_x4x2_q4_0_group_hvx(
r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
: Q6_V_vzero();
HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
HVX_Vector v1 = (row1 < n_cols)
? dequantize_x4x2_q4_1_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
: Q6_V_vzero();
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
} else {
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
HVX_Vector v1 = (row1 < n_cols)
? dequantize_x4x2_q4_0_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
: Q6_V_vzero();
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
}
(void) *(volatile HVX_Vector *)(tile_base);
} else if (weight_type == HTP_TYPE_MXFP4) {
@ -593,6 +686,8 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
// --- End x4x2 dequantizers ---
#pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics
// requires external HMX lock
static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales,
int n_row_tiles, int n_col_tiles, int n_dot_tiles) {

View File

@ -20,6 +20,7 @@ enum htp_data_type {
HTP_TYPE_F32 = 0,
HTP_TYPE_F16 = 1,
HTP_TYPE_Q4_0 = 2,
HTP_TYPE_Q4_1 = 3,
HTP_TYPE_Q8_0 = 8,
HTP_TYPE_IQ4_NL = 20,
HTP_TYPE_I32 = 26,
@ -28,6 +29,7 @@ enum htp_data_type {
// types used internally for repack, dyn.quant, etc
HTP_TYPE_Q4_0x4x2 = 200,
HTP_TYPE_Q4_1x4x2,
HTP_TYPE_Q8_0x4x2,
HTP_TYPE_MXFP4x4x2,

View File

@ -853,6 +853,11 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
for (uint32_t i=0; i < n_ops; i++) {
struct profile_data prof;
if (i == (n_ops-1)) {
// wake up the host before starting the last op
dspqueue_write_early_wakeup_noblock(queue, 0, 0);
}
profile_start(ctx->profiler, &prof);
proc_op_req(octx, tens, i, &ops[i]);
@ -869,8 +874,6 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
}
}
// dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0);
struct htp_opbatch_rsp rsp;
rsp.id = req.id;
rsp.status = HTP_STATUS_OK;

File diff suppressed because it is too large Load Diff