hexagon: minor refresh for HMX FA and MM (llama/23796)
* hex-fa: clean up qf32/fp32 handling and stride handling * hex-fa: fix corner case fp NAN issues that were cause bad output from gemma4 on v79 * hex-fa: vectorize leftover handling * hex-fa: avoid HVX fallback during token gen HMX has more FP16 compute capacity * hmx-mm: remove dead code * hmx-mm: use fastdiv in x4x2 dequant * hmx-mm: sandwich dequant and scatter to improve perf * hmx-mm: fixed rebase conflicts * hmx-mm: further improve weight dequant by doing early type dispatch and precomputing fastdiv * hmx-mm: an even earlier dispatch for per-type dequant * hmx-mm: dequant linear types like q4_0 and q4_1 without the LUTs This is a bit faster than LUT. * hex-cmake: one more tweak for lto --------- Co-authored-by: Trivikram Reddy <tamarnat@qti.qualcomm.com>
This commit is contained in:
parent
b896e91f18
commit
1b241b879c
|
|
@ -58,15 +58,16 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
|||
|
||||
if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-queue.c
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,16 @@
|
|||
// Must be multiple of 32
|
||||
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
|
||||
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
|
||||
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
// This is a bit of a hack because the compiler is strugling to properly inline
|
||||
// the default hvx_vec_f32_to_f16 with output into the local array.
|
||||
static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
|
|
@ -54,8 +64,8 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
|||
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)));
|
||||
HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p));
|
||||
rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
|
|
@ -105,10 +115,10 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
|
|||
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
|
||||
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
|
||||
HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)));
|
||||
HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)));
|
||||
HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p));
|
||||
HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p));
|
||||
HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p));
|
||||
HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p));
|
||||
|
||||
HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
|
||||
return hvx_vec_reduce_sum_f32x4(rsum0123);
|
||||
|
|
@ -123,7 +133,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
|
|||
const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
const size_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector sums; // initialize at j = 0
|
||||
HVX_Vector sums = Q6_V_vzero();
|
||||
const size_t stride_x_4 = stride_x * 4;
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
|
||||
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
|
||||
|
|
@ -132,8 +142,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
|
|||
x += stride_x_4;
|
||||
}
|
||||
|
||||
sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums);
|
||||
return Q6_Vsf_equals_Vqf32(sums);
|
||||
return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums);
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * s (F16)
|
||||
|
|
@ -268,11 +277,10 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t *
|
|||
uint32_t i = 0;
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; ++i) {
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
|
||||
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs);
|
||||
}
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
|
||||
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -438,25 +446,44 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
|||
// Process in sub-blocks of 32 (VLEN_FP32)
|
||||
HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32];
|
||||
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
|
||||
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
for (uint32_t iv = 0; ic < current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
// 1. Compute scores
|
||||
HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale);
|
||||
|
||||
// 2. Softcap
|
||||
if (factx->logit_softcap != 0.0f) {
|
||||
scores = hvx_vec_tanh_f32(scores);
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
scores = HVX_OP_MUL_F32(scores, logit_cap);
|
||||
}
|
||||
|
||||
// 3. Mask
|
||||
if (mask) {
|
||||
const __fp16 * mp = m_base + ic;
|
||||
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
||||
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
||||
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
|
||||
// Multiplying -INFINITY (0xFC00) by a slope in VhfVhf instructions can incorrectly produce NaN on v79.
|
||||
// Clamp -INFINITY to the max negative fp16 finite value (-65504.0f).
|
||||
HVX_Vector vinf = Q6_Vh_vsplat_R(0xFC00);
|
||||
HVX_Vector vmin = Q6_Vh_vsplat_R(0xFBFF);
|
||||
HVX_VectorPred is_inf = Q6_Q_vcmp_eq_VhVh(m_vals_f16, vinf);
|
||||
m_vals_f16 = Q6_V_vmux_QVV(is_inf, vmin, m_vals_f16);
|
||||
|
||||
#if __HVX_ARCH__ >= 79
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
||||
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
||||
scores = Q6_Vsf_vadd_VsfVsf(add_val, scores);
|
||||
#else
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
||||
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
||||
scores = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Mask out invalid lanes for leftover handling
|
||||
uint32_t valid_lanes = current_block_size - ic;
|
||||
if (valid_lanes < VLEN_FP32) {
|
||||
HVX_VectorPred valid_pred = Q6_Q_vsetq_R(valid_lanes * 4); // 4 bytes per fp32 lane
|
||||
scores = Q6_V_vmux_QVV(valid_pred, scores, hvx_vec_splat_f32(-INFINITY));
|
||||
}
|
||||
|
||||
sb_scores[iv] = scores;
|
||||
|
|
@ -466,78 +493,55 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
|||
{
|
||||
// 4. Online Softmax Update
|
||||
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
|
||||
HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec));
|
||||
HVX_Vector diff_vec = HVX_OP_SUB_F32(M_vec, M_new_vec);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
||||
M_vec = M_new_vec;
|
||||
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
|
||||
for (uint32_t ic2 = 0, iv = 0; ic2 < current_block_size; ic2 += VLEN_FP32, ++iv) {
|
||||
HVX_Vector scores = sb_scores[iv];
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
HVX_Vector scores_shifted = HVX_OP_SUB_F32(scores, M_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(scores_shifted);
|
||||
|
||||
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
|
||||
p_sum_vec = HVX_OP_ADD_F32(p_sum_vec, P);
|
||||
|
||||
// 5. Accumulate V
|
||||
__fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16];
|
||||
hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0));
|
||||
|
||||
float __attribute__((aligned(128))) P_arr[VLEN_FP32];
|
||||
hvx_vec_store_a(P_arr, 128, P);
|
||||
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
if (cur_ic >= current_block_size) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (cur_ic + 1 == current_block_size) {
|
||||
// Odd leftover, process single row
|
||||
if (P_arr[j] != 0.0f) {
|
||||
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, (p_arr + j), DV);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Avoid NaN * 0.0 = NaN for uninitialized V cache rows.
|
||||
// Check the f32 values to safely avoid strict aliasing violations.
|
||||
if (P_arr[j] == 0.0f && P_arr[j + 1] == 0.0f) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV);
|
||||
}
|
||||
}
|
||||
|
||||
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
|
||||
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
|
||||
}
|
||||
|
||||
if (ic < current_block_size) {
|
||||
// Sync scalars for leftover/next block if needed
|
||||
float M = hvx_vec_get_f32(M_vec);
|
||||
float S = hvx_vec_get_f32(S_vec);
|
||||
|
||||
// Leftover
|
||||
for (; ic < current_block_size; ++ic) {
|
||||
float s_val;
|
||||
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
|
||||
if (factx->logit_softcap != 0.0f) {
|
||||
s_val = factx->logit_softcap * tanhf(s_val);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
const float m_val = m_base[ic];
|
||||
s_val += slope * m_val;
|
||||
}
|
||||
|
||||
const float Mold = M;
|
||||
__fp16 vs = 1.0f;
|
||||
|
||||
if (s_val > M) {
|
||||
M = s_val;
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
float ms = hvx_vec_get_f32(ms_vec);
|
||||
S = S * ms + vs;
|
||||
} else {
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
|
||||
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
||||
S += vs;
|
||||
}
|
||||
|
||||
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
|
||||
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV);
|
||||
}
|
||||
|
||||
M_vec = hvx_vec_splat_f32(M);
|
||||
S_vec = hvx_vec_splat_f32(S);
|
||||
S_vec = HVX_OP_ADD_F32(HVX_OP_MUL_F32(S_vec, ms_vec), p_sum_vec);
|
||||
}
|
||||
|
||||
// Issue DMA for next+1 block (if exists)
|
||||
|
|
@ -599,8 +603,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
|||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// dst is permuted
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
|
||||
// dst is permuted: [DV, n_heads, n_tokens, n_seq]
|
||||
// head stride is nb[1], token stride is nb[2], batch stride is nb[3]
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + i2 * dst->nb[1] + i1 * dst->nb[2] + i3 * dst->nb[3];
|
||||
|
||||
if (dst->type == HTP_TYPE_F32) {
|
||||
hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
|
|
@ -623,8 +628,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
|||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
// HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV
|
||||
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) {
|
||||
// HMX path: head_dim multiple of 32, F16 KV
|
||||
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0) {
|
||||
int ret = hmx_flash_attn_ext(octx);
|
||||
if (ret == HTP_STATUS_OK) {
|
||||
return ret;
|
||||
|
|
|
|||
|
|
@ -1248,9 +1248,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
|||
if (DK % 32 != 0 || DV % 32 != 0) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (neq1 < 32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
// GQA factor
|
||||
const uint32_t n_kv_heads = k->ne[2];
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
#include "ggml-common.h"
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hex-fastdiv.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include "hvx-utils.h"
|
||||
|
|
@ -187,45 +188,44 @@ next_nc:
|
|||
// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles
|
||||
// of the same 32 packed bytes.
|
||||
static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) {
|
||||
(void)vlut_cvt;
|
||||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
|
||||
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
|
||||
// q4x4x2 stores two int4 values per byte. Keep only the selected nibble.
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
// Shuffle before LUT
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
// Use standard vlut16 (not _nomatch) to avoid stale-register NaN.
|
||||
// _nomatch retains the previous destination-register value for colliding
|
||||
// indices, but the C intrinsic doesn't model the implicit read so the
|
||||
// compiler may allocate a register containing garbage/NaN.
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||||
|
||||
HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8);
|
||||
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_int8));
|
||||
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
|
||||
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
|
||||
}
|
||||
|
||||
// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using
|
||||
// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls.
|
||||
// full HVX vector width.
|
||||
// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes.
|
||||
static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
const uint8_t *packed_128, bool upper_nibbles,
|
||||
const __fp16 *scales_4, const HVX_Vector vlut_cvt) {
|
||||
// Load all 128 packed bytes (4 contiguous 32-byte groups)
|
||||
(void)vlut_cvt;
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
// Shuffle before LUT
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8);
|
||||
|
||||
// Full-width vlut16: 128 byte lookups -> 128 fp16 results in a VectorPair
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp); // [group0: 32 fp16 | group1: 32 fp16]
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16]
|
||||
HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_int8);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp_int16);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp_int16);
|
||||
|
||||
v_lo = Q6_Vhf_equals_Vh(v_lo);
|
||||
v_hi = Q6_Vhf_equals_Vh(v_hi);
|
||||
|
||||
// Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b
|
||||
HVX_Vector vscale = hvx_vmemu(scales_4);
|
||||
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale);
|
||||
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4));
|
||||
|
|
@ -233,13 +233,12 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
|
|||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
// Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter
|
||||
HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */
|
||||
v_hi /* group2 already in [0:63] */ };
|
||||
HVX_Vector_x2 r = { v_lo, v_hi };
|
||||
return r;
|
||||
}
|
||||
|
||||
static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) {
|
||||
(void)vlut_cvt;
|
||||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_dm = hvx_vmemu(scale_offset);
|
||||
|
|
@ -248,9 +247,9 @@ static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32
|
|||
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||||
|
||||
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_quants));
|
||||
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
|
||||
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets));
|
||||
}
|
||||
|
|
@ -258,16 +257,18 @@ static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32
|
|||
static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx(
|
||||
const uint8_t *packed_128, bool upper_nibbles,
|
||||
const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) {
|
||||
(void)vlut_cvt;
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_quants);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp_int16);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp_int16);
|
||||
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp);
|
||||
v_lo = Q6_Vhf_equals_Vh(v_lo);
|
||||
v_hi = Q6_Vhf_equals_Vh(v_hi);
|
||||
|
||||
HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4);
|
||||
HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2);
|
||||
|
|
@ -287,6 +288,45 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx(
|
|||
return r;
|
||||
}
|
||||
|
||||
// LUT-based dequantizers for non-linear IQ4_NL format.
|
||||
static inline HVX_Vector dequantize_x4x2_iq4_nl_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||||
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
|
||||
}
|
||||
|
||||
static inline HVX_Vector_x2 dequantize_x4x2_iq4_nl_x4groups_hvx(
|
||||
const uint8_t *packed_128, bool upper_nibbles,
|
||||
const __fp16 *scales_4, const HVX_Vector vlut_cvt) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp);
|
||||
|
||||
HVX_Vector vscale = hvx_vmemu(scales_4);
|
||||
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale);
|
||||
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4));
|
||||
|
||||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
HVX_Vector_x2 r = { v_lo, v_hi };
|
||||
return r;
|
||||
}
|
||||
|
||||
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
|
||||
static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) {
|
||||
HVX_Vector vq = hvx_vmemu(quants_32);
|
||||
|
|
@ -374,122 +414,176 @@ static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t *
|
|||
return r;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
__fp16 *dst;
|
||||
const uint8_t *src;
|
||||
int n_cols;
|
||||
int k_block;
|
||||
size_t row_stride;
|
||||
int weight_type;
|
||||
int n_tot_tiles;
|
||||
int n_tiles_per_task;
|
||||
int n_tasks;
|
||||
int n_k_tiles;
|
||||
struct fastdiv_values n_k_tiles_div;
|
||||
} x4x2_dequantize_state_t;
|
||||
|
||||
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
|
||||
// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes.
|
||||
// Output: vtcm_dst in tile-major FP16 layout.
|
||||
static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
__fp16 *restrict vtcm_dst,
|
||||
const uint8_t *restrict vtcm_src,
|
||||
int n_cols, int k_block,
|
||||
size_t row_stride, int weight_type,
|
||||
|
||||
#define DEFINE_DEQUANTIZE_Q4_TASK(suffix, lut_name, helper_prefix, dblk_size, scale_step) \
|
||||
static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix( \
|
||||
const x4x2_dequantize_state_t *state, \
|
||||
int start_tile, int end_tile) { \
|
||||
\
|
||||
const int n_k_tiles = state->n_k_tiles; \
|
||||
const int qrow_size = (unsigned)state->k_block / 2; \
|
||||
const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; \
|
||||
const HVX_Vector vlut_cvt = hvx_vmem(lut_name); \
|
||||
\
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); \
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); \
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); \
|
||||
\
|
||||
unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); \
|
||||
unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); \
|
||||
\
|
||||
for (unsigned t = start_tile; t < (unsigned)end_tile; ) { \
|
||||
if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } \
|
||||
\
|
||||
if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { \
|
||||
unsigned blk_idx = ((kt * 32) / QK_Q4_0x4x2); \
|
||||
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; \
|
||||
bool upper = (sub_blk_base >= 4); \
|
||||
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); \
|
||||
unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk_base * (scale_step); \
|
||||
\
|
||||
__fp16 *tile_bases[4]; \
|
||||
for (unsigned g = 0; g < 4; g++) { \
|
||||
tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; \
|
||||
} \
|
||||
\
|
||||
HVX_Vector v_off = v_scat_base; \
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \
|
||||
\
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { \
|
||||
const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \
|
||||
const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \
|
||||
\
|
||||
HVX_Vector_x2 dv0 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \
|
||||
r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); \
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); \
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
|
||||
\
|
||||
HVX_Vector_x2 dv1 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \
|
||||
r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); \
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); \
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); \
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
|
||||
} \
|
||||
\
|
||||
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } \
|
||||
t += 4; kt += 4; \
|
||||
continue; \
|
||||
} \
|
||||
\
|
||||
__fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; \
|
||||
{ \
|
||||
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; \
|
||||
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; \
|
||||
bool upper = (sub_blk >= 4); \
|
||||
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; \
|
||||
unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk * (scale_step); \
|
||||
\
|
||||
HVX_Vector v_off = v_scat_base; \
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; \
|
||||
\
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { \
|
||||
const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \
|
||||
const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \
|
||||
\
|
||||
HVX_Vector v0 = dequantize_x4x2_##helper_prefix##_group_hvx( \
|
||||
r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \
|
||||
HVX_Vector v1 = (row1 < (unsigned)state->n_cols) \
|
||||
? dequantize_x4x2_##helper_prefix##_group_hvx( \
|
||||
r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) \
|
||||
: Q6_V_vzero(); \
|
||||
\
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); \
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); \
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \
|
||||
} \
|
||||
(void) *(volatile HVX_Vector *)(tile_base); \
|
||||
} \
|
||||
++t; ++kt; \
|
||||
} \
|
||||
\
|
||||
if (start_tile < end_tile) { \
|
||||
(void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \
|
||||
int start = task_id * state->n_tiles_per_task; \
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \
|
||||
} \
|
||||
}
|
||||
|
||||
DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16))
|
||||
DEFINE_DEQUANTIZE_Q4_TASK(q4_1, q4_1_to_fp16_lut, q4_1, 32, 4)
|
||||
DEFINE_DEQUANTIZE_Q4_TASK(iq4_nl, iq4_nl_to_fp16_lut, iq4_nl, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16))
|
||||
|
||||
static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(
|
||||
const x4x2_dequantize_state_t *state,
|
||||
int start_tile, int end_tile) {
|
||||
|
||||
const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS;
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1);
|
||||
const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block;
|
||||
const int n_k_tiles = state->n_k_tiles;
|
||||
const int qrow_size = state->k_block;
|
||||
const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div;
|
||||
const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut);
|
||||
|
||||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
|
||||
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
|
||||
(weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) :
|
||||
hvx_vmem(q4_0_to_fp16_lut);
|
||||
|
||||
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
|
||||
// Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128
|
||||
// maps to K-rows 2i and 2i+1. Column offset (n*4) added per row.
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes)
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
|
||||
|
||||
unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index
|
||||
unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index
|
||||
for (unsigned t = start_tile; t < end_tile; ) {
|
||||
if (kt >= n_k_tiles) { kt = 0; ct++; }
|
||||
unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div);
|
||||
unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div);
|
||||
|
||||
// --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
|
||||
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
|
||||
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
|
||||
unsigned scale_off = qrow_size + blk_idx * dblk_size
|
||||
+ sub_blk_base * scale_step;
|
||||
for (unsigned t = start_tile; t < (unsigned)end_tile; ) {
|
||||
if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; }
|
||||
|
||||
__fp16 *tile_bases[4];
|
||||
for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
|
||||
if (is_q4_1) {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
} else {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); }
|
||||
t += 4; kt += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales ---
|
||||
if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
// Batch-4 fast path for MXFP4
|
||||
if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) {
|
||||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||||
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4
|
||||
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32;
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes
|
||||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales
|
||||
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2);
|
||||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE;
|
||||
|
||||
__fp16 * tile_bases[4];
|
||||
for (int g = 0; g < 4; g++) {
|
||||
tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS;
|
||||
tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||||
const uint8_t * r0 = state->src + row0 * state->row_stride;
|
||||
const uint8_t * r1 = state->src + row1 * state->row_stride;
|
||||
|
||||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector_x4 dv0, dv1;
|
||||
dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8);
|
||||
if (row1 < n_cols) {
|
||||
if (row1 < state->n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8);
|
||||
} else {
|
||||
|
|
@ -510,58 +604,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
|||
(void) *(volatile HVX_Vector *) (tile_bases[g]);
|
||||
}
|
||||
|
||||
t += 4;
|
||||
t += 4; kt += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// --- Single-tile fallback ---
|
||||
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
if (is_q4) {
|
||||
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||||
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
|
||||
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
|
||||
unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step;
|
||||
|
||||
HVX_Vector v_off = v_scat_base; // reset to column 0
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
if (is_q4_1) {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector v1 = (row1 < n_cols)
|
||||
? dequantize_x4x2_q4_1_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
|
||||
: Q6_V_vzero();
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
} else {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector v1 = (row1 < n_cols)
|
||||
? dequantize_x4x2_q4_0_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
|
||||
: Q6_V_vzero();
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
}
|
||||
(void) *(volatile HVX_Vector *)(tile_base);
|
||||
} else if (weight_type == HTP_TYPE_MXFP4) {
|
||||
// Single-tile fallback
|
||||
__fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS;
|
||||
{
|
||||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
|
|
@ -573,15 +622,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
|||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
|
||||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||||
const uint8_t * r0 = state->src + row0 * state->row_stride;
|
||||
const uint8_t * r1 = state->src + row1 * state->row_stride;
|
||||
|
||||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8);
|
||||
HVX_Vector v1;
|
||||
if (row1 < n_cols) {
|
||||
if (row1 < state->n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8);
|
||||
} else {
|
||||
|
|
@ -594,23 +642,59 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
|||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
(void) *(volatile HVX_Vector *) (tile_base);
|
||||
} else {
|
||||
// Q8_0
|
||||
}
|
||||
++t; ++kt;
|
||||
}
|
||||
|
||||
if (start_tile < end_tile) {
|
||||
(void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS);
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end);
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
|
||||
const x4x2_dequantize_state_t *state,
|
||||
int start_tile, int end_tile) {
|
||||
|
||||
const int n_k_tiles = state->n_k_tiles;
|
||||
const int qrow_size = state->k_block;
|
||||
const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div;
|
||||
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
|
||||
|
||||
unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div);
|
||||
unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div);
|
||||
|
||||
for (unsigned t = start_tile; t < (unsigned)end_tile; ) {
|
||||
if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; }
|
||||
|
||||
__fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS;
|
||||
{
|
||||
int blk_idx = (kt * 32) / QK_Q8_0x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32;
|
||||
int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32;
|
||||
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
|
||||
|
||||
HVX_Vector v_off = v_scat_base; // reset to column 0
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
|
||||
const uint8_t *r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row1 * row_stride;
|
||||
const uint8_t *r0 = state->src + row0 * state->row_stride;
|
||||
const uint8_t *r1 = state->src + row1 * state->row_stride;
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off));
|
||||
HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero();
|
||||
HVX_Vector v1 = (row1 < state->n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero();
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
|
|
@ -622,50 +706,31 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
|||
++t; ++kt;
|
||||
}
|
||||
|
||||
// Drain HVX scatter write buffer: a vmem load on the same HW thread retires
|
||||
// all pending scatter entries to VTCM. Without this, the main thread's HMX
|
||||
// reads may see stale data because atomic_fetch_sub (release) only orders
|
||||
// regular stores, not the HVX scatter buffer.
|
||||
if (start_tile < end_tile) {
|
||||
(void) *(volatile HVX_Vector *)(vtcm_dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS);
|
||||
(void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS);
|
||||
}
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
__fp16 *dst;
|
||||
const uint8_t *src;
|
||||
int n_cols;
|
||||
int k_block;
|
||||
size_t row_stride;
|
||||
int weight_type;
|
||||
int n_tot_tiles;
|
||||
int n_tiles_per_task;
|
||||
int n_tasks;
|
||||
} x4x2_dequantize_state_t;
|
||||
|
||||
static void dequantize_x4x2_worker_loop(unsigned int n, unsigned int i, void *data) {
|
||||
static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
state->dst, state->src, state->n_cols, state->k_block,
|
||||
state->row_stride, state->weight_type, start, end);
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end);
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
|
||||
struct htp_context *ctx, __fp16 *vtcm_dst,
|
||||
const void *vtcm_src, int n_cols, int k_block,
|
||||
size_t row_stride, int weight_type) {
|
||||
size_t row_stride, int weight_type,
|
||||
int n_k_tiles, struct fastdiv_values n_k_tiles_div,
|
||||
worker_callback_t dequant_worker_fn) {
|
||||
|
||||
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
|
||||
assert(k_block % HMX_FP16_TILE_N_COLS == 0);
|
||||
|
||||
size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
|
||||
size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
|
||||
size_t n_tot_tiles = n_col_tiles * n_k_tiles;
|
||||
|
||||
size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads);
|
||||
|
|
@ -680,8 +745,10 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
|
|||
state.k_block = k_block;
|
||||
state.row_stride = row_stride;
|
||||
state.weight_type = weight_type;
|
||||
state.n_k_tiles = n_k_tiles;
|
||||
state.n_k_tiles_div = n_k_tiles_div;
|
||||
|
||||
worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads);
|
||||
worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, ctx->n_threads);
|
||||
}
|
||||
|
||||
// --- End x4x2 dequantizers ---
|
||||
|
|
@ -978,6 +1045,20 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
|
|||
return -1;
|
||||
}
|
||||
|
||||
worker_callback_t dequant_worker_fn = NULL;
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break;
|
||||
case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break;
|
||||
case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break;
|
||||
case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break;
|
||||
case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
|
||||
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
|
||||
const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles);
|
||||
|
||||
// --- Dynamic VTCM layout ---
|
||||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
|
|
@ -1070,7 +1151,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
|
|||
{
|
||||
// B0: wait for DMA, dequant weight chunk 0
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn);
|
||||
|
||||
// A1: issue DMA for weight chunk 1
|
||||
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
|
||||
|
|
@ -1089,7 +1170,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
|
|||
// B1: DMA pop + dequant (runs in parallel with C0 on HMX worker)
|
||||
if (1 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1131,7 +1212,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *
|
|||
// B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1})
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue