hexagon: minor refresh for HMX FA and MM (llama/23796)

* hex-fa: clean up qf32/fp32 handling and stride handling

* hex-fa: fix corner case fp NAN issues that were cause bad output from gemma4 on v79

* hex-fa: vectorize leftover handling

* hex-fa: avoid HVX fallback during token gen HMX has more FP16 compute capacity

* hmx-mm: remove dead code

* hmx-mm: use fastdiv in x4x2 dequant

* hmx-mm: sandwich dequant and scatter to improve perf

* hmx-mm: fixed rebase conflicts

* hmx-mm: further improve weight dequant by doing early type dispatch and precomputing fastdiv

* hmx-mm: an even earlier dispatch for per-type dequant

* hmx-mm: dequant linear types like q4_0 and q4_1 without the LUTs

This is a bit faster than LUT.

* hex-cmake: one more tweak for lto

---------

Co-authored-by: Trivikram Reddy <tamarnat@qti.qualcomm.com>
This commit is contained in:
Max Krasnyansky 2026-05-28 04:49:11 -07:00 committed by Georgi Gerganov
parent b896e91f18
commit 1b241b879c
4 changed files with 371 additions and 287 deletions

View File

@ -58,15 +58,16 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-queue.c
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)

View File

@ -22,6 +22,16 @@
// Must be multiple of 32
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
// This is a bit of a hack because the compiler is strugling to properly inline
// the default hvx_vec_f32_to_f16 with output into the local array.
static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
@ -54,8 +64,8 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)));
HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p));
rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
hvx_vec_store_u(r, 4, rsum);
}
@ -105,10 +115,10 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
}
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)));
HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)));
HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p));
HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p));
HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p));
HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p));
HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
return hvx_vec_reduce_sum_f32x4(rsum0123);
@ -123,7 +133,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
const size_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector sums; // initialize at j = 0
HVX_Vector sums = Q6_V_vzero();
const size_t stride_x_4 = stride_x * 4;
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
@ -132,8 +142,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
x += stride_x_4;
}
sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums);
return Q6_Vsf_equals_Vqf32(sums);
return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums);
}
// MAD: y (F32) += x (F16) * s (F16)
@ -268,11 +277,10 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t *
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; ++i) {
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs);
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs));
}
}
@ -438,25 +446,44 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
// Process in sub-blocks of 32 (VLEN_FP32)
HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32];
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
for (uint32_t iv = 0; ic < current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores
HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale);
// 2. Softcap
if (factx->logit_softcap != 0.0f) {
scores = hvx_vec_tanh_f32(scores);
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
scores = Q6_Vsf_equals_Vqf32(scores);
scores = HVX_OP_MUL_F32(scores, logit_cap);
}
// 3. Mask
if (mask) {
const __fp16 * mp = m_base + ic;
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
scores = Q6_Vsf_equals_Vqf32(scores);
// Multiplying -INFINITY (0xFC00) by a slope in VhfVhf instructions can incorrectly produce NaN on v79.
// Clamp -INFINITY to the max negative fp16 finite value (-65504.0f).
HVX_Vector vinf = Q6_Vh_vsplat_R(0xFC00);
HVX_Vector vmin = Q6_Vh_vsplat_R(0xFBFF);
HVX_VectorPred is_inf = Q6_Q_vcmp_eq_VhVh(m_vals_f16, vinf);
m_vals_f16 = Q6_V_vmux_QVV(is_inf, vmin, m_vals_f16);
#if __HVX_ARCH__ >= 79
HVX_VectorPair m_vals_f32_pair = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
scores = Q6_Vsf_vadd_VsfVsf(add_val, scores);
#else
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
scores = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores));
#endif
}
// Mask out invalid lanes for leftover handling
uint32_t valid_lanes = current_block_size - ic;
if (valid_lanes < VLEN_FP32) {
HVX_VectorPred valid_pred = Q6_Q_vsetq_R(valid_lanes * 4); // 4 bytes per fp32 lane
scores = Q6_V_vmux_QVV(valid_pred, scores, hvx_vec_splat_f32(-INFINITY));
}
sb_scores[iv] = scores;
@ -466,78 +493,55 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
{
// 4. Online Softmax Update
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec));
HVX_Vector diff_vec = HVX_OP_SUB_F32(M_vec, M_new_vec);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
M_vec = M_new_vec;
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
for (uint32_t ic2 = 0, iv = 0; ic2 < current_block_size; ic2 += VLEN_FP32, ++iv) {
HVX_Vector scores = sb_scores[iv];
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
HVX_Vector scores_shifted = HVX_OP_SUB_F32(scores, M_vec);
HVX_Vector P = hvx_vec_exp_f32(scores_shifted);
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
p_sum_vec = HVX_OP_ADD_F32(p_sum_vec, P);
// 5. Accumulate V
__fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16];
hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0));
float __attribute__((aligned(128))) P_arr[VLEN_FP32];
hvx_vec_store_a(P_arr, 128, P);
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
const uint32_t cur_ic = ic2 + j;
if (cur_ic >= current_block_size) {
break;
}
if (cur_ic + 1 == current_block_size) {
// Odd leftover, process single row
if (P_arr[j] != 0.0f) {
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, (p_arr + j), DV);
}
break;
}
// Avoid NaN * 0.0 = NaN for uninitialized V cache rows.
// Check the f32 values to safely avoid strict aliasing violations.
if (P_arr[j] == 0.0f && P_arr[j + 1] == 0.0f) {
continue;
}
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV);
}
}
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
}
if (ic < current_block_size) {
// Sync scalars for leftover/next block if needed
float M = hvx_vec_get_f32(M_vec);
float S = hvx_vec_get_f32(S_vec);
// Leftover
for (; ic < current_block_size; ++ic) {
float s_val;
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
if (factx->logit_softcap != 0.0f) {
s_val = factx->logit_softcap * tanhf(s_val);
}
if (mask) {
const float m_val = m_base[ic];
s_val += slope * m_val;
}
const float Mold = M;
__fp16 vs = 1.0f;
if (s_val > M) {
M = s_val;
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
float ms = hvx_vec_get_f32(ms_vec);
S = S * ms + vs;
} else {
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
S += vs;
}
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV);
}
M_vec = hvx_vec_splat_f32(M);
S_vec = hvx_vec_splat_f32(S);
S_vec = HVX_OP_ADD_F32(HVX_OP_MUL_F32(S_vec, ms_vec), p_sum_vec);
}
// Issue DMA for next+1 block (if exists)
@ -599,8 +603,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
const int i2 = iq2;
const int i3 = iq3;
// dst is permuted
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
// dst is permuted: [DV, n_heads, n_tokens, n_seq]
// head stride is nb[1], token stride is nb[2], batch stride is nb[3]
uint8_t * dst_ptr = (uint8_t *) dst->data + i2 * dst->nb[1] + i1 * dst->nb[2] + i3 * dst->nb[3];
if (dst->type == HTP_TYPE_F32) {
hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
@ -623,8 +628,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
}
#ifdef HTP_HAS_HMX
// HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) {
// HMX path: head_dim multiple of 32, F16 KV
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0) {
int ret = hmx_flash_attn_ext(octx);
if (ret == HTP_STATUS_OK) {
return ret;

View File

@ -1248,9 +1248,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
if (DK % 32 != 0 || DV % 32 != 0) {
return HTP_STATUS_NO_SUPPORT;
}
if (neq1 < 32) {
return HTP_STATUS_NO_SUPPORT;
}
// GQA factor
const uint32_t n_kv_heads = k->ne[2];

View File

@ -16,6 +16,7 @@
#include "ggml-common.h"
#include "hex-dma.h"
#include "hex-fastdiv.h"
#include "worker-pool.h"
#include "hvx-utils.h"
@ -187,45 +188,44 @@ next_nc:
// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles
// of the same 32 packed bytes.
static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) {
(void)vlut_cvt;
HVX_Vector vq = hvx_vmemu(packed_32);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
// q4x4x2 stores two int4 values per byte. Keep only the selected nibble.
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
// Shuffle before LUT
v_quants = Q6_Vb_vshuff_Vb(v_quants);
// Use standard vlut16 (not _nomatch) to avoid stale-register NaN.
// _nomatch retains the previous destination-register value for colliding
// indices, but the C intrinsic doesn't model the implicit read so the
// compiler may allocate a register containing garbage/NaN.
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_hf = Q6_V_lo_W(vp);
HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8);
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_int8));
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
}
// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using
// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls.
// full HVX vector width.
// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes.
static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
const uint8_t *packed_128, bool upper_nibbles,
const __fp16 *scales_4, const HVX_Vector vlut_cvt) {
// Load all 128 packed bytes (4 contiguous 32-byte groups)
(void)vlut_cvt;
HVX_Vector vq = hvx_vmemu(packed_128);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
// Shuffle before LUT
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8);
// Full-width vlut16: 128 byte lookups -> 128 fp16 results in a VectorPair
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_lo = Q6_V_lo_W(vp); // [group0: 32 fp16 | group1: 32 fp16]
HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16]
HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_int8);
HVX_Vector v_lo = Q6_V_lo_W(vp_int16);
HVX_Vector v_hi = Q6_V_hi_W(vp_int16);
v_lo = Q6_Vhf_equals_Vh(v_lo);
v_hi = Q6_Vhf_equals_Vh(v_hi);
// Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b
HVX_Vector vscale = hvx_vmemu(scales_4);
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale);
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4));
@ -233,13 +233,12 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
// Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter
HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */
v_hi /* group2 already in [0:63] */ };
HVX_Vector_x2 r = { v_lo, v_hi };
return r;
}
static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) {
(void)vlut_cvt;
HVX_Vector vq = hvx_vmemu(packed_32);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_dm = hvx_vmemu(scale_offset);
@ -248,9 +247,9 @@ static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_hf = Q6_V_lo_W(vp);
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_quants));
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets));
}
@ -258,16 +257,18 @@ static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32
static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx(
const uint8_t *packed_128, bool upper_nibbles,
const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) {
(void)vlut_cvt;
HVX_Vector vq = hvx_vmemu(packed_128);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_quants);
HVX_Vector v_lo = Q6_V_lo_W(vp_int16);
HVX_Vector v_hi = Q6_V_hi_W(vp_int16);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_lo = Q6_V_lo_W(vp);
HVX_Vector v_hi = Q6_V_hi_W(vp);
v_lo = Q6_Vhf_equals_Vh(v_lo);
v_hi = Q6_Vhf_equals_Vh(v_hi);
HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4);
HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2);
@ -287,6 +288,45 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx(
return r;
}
// LUT-based dequantizers for non-linear IQ4_NL format.
static inline HVX_Vector dequantize_x4x2_iq4_nl_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) {
HVX_Vector vq = hvx_vmemu(packed_32);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_hf = Q6_V_lo_W(vp);
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
}
static inline HVX_Vector_x2 dequantize_x4x2_iq4_nl_x4groups_hvx(
const uint8_t *packed_128, bool upper_nibbles,
const __fp16 *scales_4, const HVX_Vector vlut_cvt) {
HVX_Vector vq = hvx_vmemu(packed_128);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_lo = Q6_V_lo_W(vp);
HVX_Vector v_hi = Q6_V_hi_W(vp);
HVX_Vector vscale = hvx_vmemu(scales_4);
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale);
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4));
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
HVX_Vector_x2 r = { v_lo, v_hi };
return r;
}
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) {
HVX_Vector vq = hvx_vmemu(quants_32);
@ -374,122 +414,176 @@ static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t *
return r;
}
typedef struct {
__fp16 *dst;
const uint8_t *src;
int n_cols;
int k_block;
size_t row_stride;
int weight_type;
int n_tot_tiles;
int n_tiles_per_task;
int n_tasks;
int n_k_tiles;
struct fastdiv_values n_k_tiles_div;
} x4x2_dequantize_state_t;
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes.
// Output: vtcm_dst in tile-major FP16 layout.
static void dequantize_x4x2_weight_to_fp16_tiles_task(
__fp16 *restrict vtcm_dst,
const uint8_t *restrict vtcm_src,
int n_cols, int k_block,
size_t row_stride, int weight_type,
#define DEFINE_DEQUANTIZE_Q4_TASK(suffix, lut_name, helper_prefix, dblk_size, scale_step) \
static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix( \
const x4x2_dequantize_state_t *state, \
int start_tile, int end_tile) { \
\
const int n_k_tiles = state->n_k_tiles; \
const int qrow_size = (unsigned)state->k_block / 2; \
const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; \
const HVX_Vector vlut_cvt = hvx_vmem(lut_name); \
\
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); \
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); \
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); \
\
unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); \
unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); \
\
for (unsigned t = start_tile; t < (unsigned)end_tile; ) { \
if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } \
\
if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { \
unsigned blk_idx = ((kt * 32) / QK_Q4_0x4x2); \
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; \
bool upper = (sub_blk_base >= 4); \
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); \
unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk_base * (scale_step); \
\
__fp16 *tile_bases[4]; \
for (unsigned g = 0; g < 4; g++) { \
tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; \
} \
\
HVX_Vector v_off = v_scat_base; \
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \
\
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { \
const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \
const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \
\
HVX_Vector_x2 dv0 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \
r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); \
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); \
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
\
HVX_Vector_x2 dv1 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \
r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); \
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); \
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); \
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
} \
\
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } \
t += 4; kt += 4; \
continue; \
} \
\
__fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; \
{ \
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; \
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; \
bool upper = (sub_blk >= 4); \
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; \
unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk * (scale_step); \
\
HVX_Vector v_off = v_scat_base; \
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; \
\
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { \
const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \
const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \
\
HVX_Vector v0 = dequantize_x4x2_##helper_prefix##_group_hvx( \
r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \
HVX_Vector v1 = (row1 < (unsigned)state->n_cols) \
? dequantize_x4x2_##helper_prefix##_group_hvx( \
r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) \
: Q6_V_vzero(); \
\
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); \
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); \
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
} \
(void) *(volatile HVX_Vector *)(tile_base); \
} \
++t; ++kt; \
} \
\
if (start_tile < end_tile) { \
(void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); \
} \
} \
\
static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \
int start = task_id * state->n_tiles_per_task; \
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \
dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \
} \
}
DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16))
DEFINE_DEQUANTIZE_Q4_TASK(q4_1, q4_1_to_fp16_lut, q4_1, 32, 4)
DEFINE_DEQUANTIZE_Q4_TASK(iq4_nl, iq4_nl_to_fp16_lut, iq4_nl, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16))
static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(
const x4x2_dequantize_state_t *state,
int start_tile, int end_tile) {
const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS;
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL);
const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1);
const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block;
const int n_k_tiles = state->n_k_tiles;
const int qrow_size = state->k_block;
const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div;
const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut);
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
(weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) :
hvx_vmem(q4_0_to_fp16_lut);
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
// Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128
// maps to K-rows 2i and 2i+1. Column offset (n*4) added per row.
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes)
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index
unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index
for (unsigned t = start_tile; t < end_tile; ) {
if (kt >= n_k_tiles) { kt = 0; ct++; }
unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div);
unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div);
// --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
bool upper = (sub_blk_base >= 4);
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
unsigned scale_off = qrow_size + blk_idx * dblk_size
+ sub_blk_base * scale_step;
for (unsigned t = start_tile; t < (unsigned)end_tile; ) {
if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; }
__fp16 *tile_bases[4];
for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
HVX_Vector v_off = v_scat_base;
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
if (is_q4_1) {
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
} else {
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
}
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); }
t += 4; kt += 4;
continue;
}
// --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales ---
if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
// Batch-4 fast path for MXFP4
if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) {
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32;
bool upper = (sub_blk_base >= 4);
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2);
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE;
__fp16 * tile_bases[4];
for (int g = 0; g < 4; g++) {
tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS;
tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS;
}
HVX_Vector v_off = v_scat_base;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
int row1 = row0 + 1;
const uint8_t * r0 = vtcm_src + row0 * row_stride;
const uint8_t * r1 = vtcm_src + row1 * row_stride;
const uint8_t * r0 = state->src + row0 * state->row_stride;
const uint8_t * r1 = state->src + row1 * state->row_stride;
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
HVX_Vector_x4 dv0, dv1;
dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8);
if (row1 < n_cols) {
if (row1 < state->n_cols) {
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8);
} else {
@ -510,58 +604,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
(void) *(volatile HVX_Vector *) (tile_bases[g]);
}
t += 4;
t += 4; kt += 4;
continue;
}
// --- Single-tile fallback ---
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
if (is_q4) {
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
bool upper = (sub_blk >= 4);
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step;
HVX_Vector v_off = v_scat_base; // reset to column 0
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
if (is_q4_1) {
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
HVX_Vector v1 = (row1 < n_cols)
? dequantize_x4x2_q4_1_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
: Q6_V_vzero();
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
} else {
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
HVX_Vector v1 = (row1 < n_cols)
? dequantize_x4x2_q4_0_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
: Q6_V_vzero();
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
}
(void) *(volatile HVX_Vector *)(tile_base);
} else if (weight_type == HTP_TYPE_MXFP4) {
// Single-tile fallback
__fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS;
{
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32;
bool upper = (sub_blk >= 4);
@ -573,15 +622,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
int row1 = row0 + 1;
const uint8_t * r0 = vtcm_src + row0 * row_stride;
const uint8_t * r1 = vtcm_src + row1 * row_stride;
const uint8_t * r0 = state->src + row0 * state->row_stride;
const uint8_t * r1 = state->src + row1 * state->row_stride;
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8);
HVX_Vector v1;
if (row1 < n_cols) {
if (row1 < state->n_cols) {
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8);
} else {
@ -594,23 +642,59 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
(void) *(volatile HVX_Vector *) (tile_base);
} else {
// Q8_0
}
++t; ++kt;
}
if (start_tile < end_tile) {
(void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS);
}
}
static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end);
}
}
static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
const x4x2_dequantize_state_t *state,
int start_tile, int end_tile) {
const int n_k_tiles = state->n_k_tiles;
const int qrow_size = state->k_block;
const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div;
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div);
unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div);
for (unsigned t = start_tile; t < (unsigned)end_tile; ) {
if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; }
__fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS;
{
int blk_idx = (kt * 32) / QK_Q8_0x4x2;
int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32;
int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32;
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
HVX_Vector v_off = v_scat_base; // reset to column 0
HVX_Vector v_off = v_scat_base;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
int row1 = row0 + 1;
const uint8_t *r0 = vtcm_src + row0 * row_stride;
const uint8_t *r1 = vtcm_src + row1 * row_stride;
const uint8_t *r0 = state->src + row0 * state->row_stride;
const uint8_t *r1 = state->src + row1 * state->row_stride;
HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off));
HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero();
HVX_Vector v1 = (row1 < state->n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero();
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
@ -622,50 +706,31 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
++t; ++kt;
}
// Drain HVX scatter write buffer: a vmem load on the same HW thread retires
// all pending scatter entries to VTCM. Without this, the main thread's HMX
// reads may see stale data because atomic_fetch_sub (release) only orders
// regular stores, not the HVX scatter buffer.
if (start_tile < end_tile) {
(void) *(volatile HVX_Vector *)(vtcm_dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS);
(void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS);
}
}
typedef struct {
__fp16 *dst;
const uint8_t *src;
int n_cols;
int k_block;
size_t row_stride;
int weight_type;
int n_tot_tiles;
int n_tiles_per_task;
int n_tasks;
} x4x2_dequantize_state_t;
static void dequantize_x4x2_worker_loop(unsigned int n, unsigned int i, void *data) {
static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
dequantize_x4x2_weight_to_fp16_tiles_task(
state->dst, state->src, state->n_cols, state->k_block,
state->row_stride, state->weight_type, start, end);
dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end);
}
}
static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
struct htp_context *ctx, __fp16 *vtcm_dst,
const void *vtcm_src, int n_cols, int k_block,
size_t row_stride, int weight_type) {
size_t row_stride, int weight_type,
int n_k_tiles, struct fastdiv_values n_k_tiles_div,
worker_callback_t dequant_worker_fn) {
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
assert(k_block % HMX_FP16_TILE_N_COLS == 0);
size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
size_t n_tot_tiles = n_col_tiles * n_k_tiles;
size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads);
@ -680,8 +745,10 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
state.k_block = k_block;
state.row_stride = row_stride;
state.weight_type = weight_type;
state.n_k_tiles = n_k_tiles;
state.n_k_tiles_div = n_k_tiles_div;
worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads);
worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, ctx->n_threads);
}
// --- End x4x2 dequantizers ---
@ -978,6 +1045,20 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
return -1;
}
worker_callback_t dequant_worker_fn = NULL;
switch (weight_type) {
case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break;
case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break;
case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break;
case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break;
case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break;
default:
return -1;
}
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles);
// --- Dynamic VTCM layout ---
const size_t vec_dot_size = k * sizeof(__fp16);
const size_t vtcm_budget = ctx->vtcm_size;
@ -1070,7 +1151,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
{
// B0: wait for DMA, dequant weight chunk 0
dma_queue_pop(ctx->dma[0]);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn);
// A1: issue DMA for weight chunk 1
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
@ -1089,7 +1170,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
// B1: DMA pop + dequant (runs in parallel with C0 on HMX worker)
if (1 < n_chunk_cnt) {
dma_queue_pop(ctx->dma[0]);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn);
}
}
@ -1131,7 +1212,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
// B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1})
if (i + 2 < n_chunk_cnt) {
dma_queue_pop(ctx->dma[0]);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn);
}
}
}