hexagon: HMX quantized matmul rework (llama/23368)
* hmx-mm: update debug logging in hmx-mm * hmx-mm: update dequant logic to use HVX_vector_x2/4 * hmx-mm: remove non-pipelined version of the quantize matmul It seems that we don't reall need non-pipelined version * hmx-mm: use activation depth mode and update naming Co-authored-by: Kim-Chyan Gan <kgan@qti.qualcomm.com> * hex-mm: minor hmx matmul naming updates * hmx-mm: remove unused vars * snapdragon: scripts bump default ubatch-size to 1K * hexagon: combine HMX and power and clock settings into a single set_power call * hmx-mm: remove leftover of the scale repl helper * hexagon: fix editconf error --------- Co-authored-by: Kim-Chyan Gan <kgan@qti.qualcomm.com>
This commit is contained in:
parent
3fa19558f2
commit
b93a5ba605
|
|
@ -201,11 +201,10 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32
|
|||
|
||||
// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using
|
||||
// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls.
|
||||
// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes.
|
||||
static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes.
|
||||
static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
const uint8_t *packed_128, bool upper_nibbles,
|
||||
const __fp16 *scales_4, const HVX_Vector vlut_cvt,
|
||||
HVX_Vector out[4]) {
|
||||
const __fp16 *scales_4, const HVX_Vector vlut_cvt) {
|
||||
// Load all 128 packed bytes (4 contiguous 32-byte groups)
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
|
|
@ -221,8 +220,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
|||
HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16]
|
||||
|
||||
// Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b
|
||||
volatile HVX_Vector vscale = hvx_vmemu(scales_4);
|
||||
|
||||
HVX_Vector vscale = hvx_vmemu(scales_4);
|
||||
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale);
|
||||
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4));
|
||||
|
||||
|
|
@ -230,8 +228,9 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
|||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
// Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter
|
||||
out[0] = v_lo; // group0 already in [0:63]
|
||||
out[1] = v_hi; // group2 already in [0:63]
|
||||
HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */
|
||||
v_hi /* group2 already in [0:63] */ };
|
||||
return r;
|
||||
}
|
||||
|
||||
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
|
||||
|
|
@ -292,12 +291,11 @@ static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed
|
|||
}
|
||||
|
||||
// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes).
|
||||
static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128,
|
||||
static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128,
|
||||
bool upper_nibbles,
|
||||
int sub_blk_base,
|
||||
const HVX_Vector vlut_cvt,
|
||||
mxfp4_scales_t scales,
|
||||
HVX_Vector out[4]) {
|
||||
mxfp4_scales_t scales) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
|
|
@ -318,10 +316,8 @@ static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_12
|
|||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
out[0] = v_lo;
|
||||
out[1] = Q6_V_vror_VR(v_lo, 64);
|
||||
out[2] = v_hi;
|
||||
out[3] = Q6_V_vror_VR(v_hi, 64);
|
||||
HVX_Vector_x4 r = { v_lo, Q6_V_vror_VR(v_lo, 64), v_hi, Q6_V_vror_VR(v_hi, 64) };
|
||||
return r;
|
||||
}
|
||||
|
||||
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
|
||||
|
|
@ -372,18 +368,18 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
|||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
HVX_Vector v0[2];
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
|
||||
|
||||
r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
|
||||
|
|
@ -415,21 +411,21 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
|||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector v0[4], v1[4];
|
||||
dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0);
|
||||
HVX_Vector_x4 dv0, dv1;
|
||||
dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8);
|
||||
if (row1 < n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1);
|
||||
dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8);
|
||||
} else {
|
||||
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
|
||||
dv1.v[0] = dv1.v[1] = dv1.v[2] = dv1.v[3] = Q6_V_vzero();
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) {
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[g]);
|
||||
}
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
for (int g = 0; g < 4; g++) {
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[g]);
|
||||
}
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
|
|
@ -612,11 +608,13 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict
|
|||
const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
for (int k = 0; k < n_dot_tiles; ++k) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
|
||||
row_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
for (int k = 0, k_block; k < n_dot_tiles; k += k_block) {
|
||||
k_block = hex_smin(n_dot_tiles - k, 32);
|
||||
const uint32_t range = 2048u * (uint32_t)k_block - 1;
|
||||
Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range);
|
||||
row_tiles += k_block * HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += k_block * HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
__fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS;
|
||||
|
|
@ -832,10 +830,6 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *
|
|||
worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#define FALLBACK_TO_STANDARD 1
|
||||
|
||||
// C += AB
|
||||
static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b,
|
||||
const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile,
|
||||
|
|
@ -861,314 +855,80 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co
|
|||
Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047);
|
||||
}
|
||||
|
||||
for (int k = 0; k < n_dot_tiles; ++k) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
|
||||
row_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
for (int k = 0, k_block; k < n_dot_tiles; k += k_block) {
|
||||
k_block = hex_smin(n_dot_tiles - k, 32);
|
||||
const uint32_t range = 2048u * (uint32_t)k_block - 1;
|
||||
Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range);
|
||||
row_tiles += k_block * HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += k_block * HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
Q6_mxmem_AR_after_hf(accum_tile, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx,
|
||||
float *restrict out, const float *restrict x, const uint8_t *restrict w,
|
||||
int m, int k, int n, int weight_type) {
|
||||
// assume k % 32 == 0 && n % 32 == 0
|
||||
const size_t row_stride = get_x4x2_row_stride(weight_type, k);
|
||||
if (row_stride == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
|
||||
const size_t K_BLOCK_SIZE = 1024;
|
||||
|
||||
// Fallback: if k doesn't need K-blocking, out-stationary has no advantage
|
||||
const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE;
|
||||
if (k_iters_check <= 1) {
|
||||
FARF(HIGH, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k);
|
||||
return FALLBACK_TO_STANDARD;
|
||||
}
|
||||
|
||||
// Dynamic M,N search via hmx_compute_chunks
|
||||
const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE);
|
||||
const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32)
|
||||
+ K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles)
|
||||
const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant)
|
||||
+ K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles)
|
||||
const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary)
|
||||
|
||||
// Alignment margin: hex_align_up can add up to 2047 bytes per buffer;
|
||||
// scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin
|
||||
const size_t align_margin = 4 * HMX_FP16_TILE_SIZE;
|
||||
const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment
|
||||
|
||||
size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used;
|
||||
// Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost.
|
||||
// From profiling: wt_dequant per element ≈ 1.5× activation load per element.
|
||||
// m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive).
|
||||
// n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper).
|
||||
const size_t m_block_cost = (size_t) n * 3;
|
||||
const size_t n_block_cost = (size_t) m * 2;
|
||||
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
m_block_cost, n_block_cost, &M_BLOCK_SIZE,
|
||||
&N_BLOCK_SIZE, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Compute precise buffer sizes from searched M,N and fixed K
|
||||
const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE);
|
||||
const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE);
|
||||
|
||||
const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256;
|
||||
if (total_vtcm > vtcm_budget) {
|
||||
FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm,
|
||||
vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE);
|
||||
return -1;
|
||||
}
|
||||
|
||||
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
|
||||
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size);
|
||||
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size);
|
||||
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size);
|
||||
uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz);
|
||||
uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz);
|
||||
__fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE);
|
||||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||||
assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget);
|
||||
|
||||
FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type,
|
||||
M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
|
||||
|
||||
// initialize eye tile (32x32 identity matrix)
|
||||
{
|
||||
HVX_Vector v;
|
||||
v = Q6_V_vzero();
|
||||
v = Q6_Vw_vinsert_VwR(v, 0x3c000000);
|
||||
v = Q6_V_vror_VR(v, VLEN - 4);
|
||||
v = Q6_Vw_vinsert_VwR(v, 0x00003c00);
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
((HVX_Vector *) vtcm_eye_tile)[i] = v;
|
||||
v = Q6_V_vror_VR(v, VLEN - 8);
|
||||
}
|
||||
}
|
||||
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
|
||||
|
||||
TIMER_DEFINE(fetch);
|
||||
TIMER_DEFINE(act_load);
|
||||
TIMER_DEFINE(wt_dequant);
|
||||
TIMER_DEFINE(core);
|
||||
|
||||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||||
|
||||
for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) {
|
||||
size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE);
|
||||
for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) {
|
||||
size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE);
|
||||
|
||||
const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS);
|
||||
const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS);
|
||||
|
||||
for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) {
|
||||
const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE);
|
||||
|
||||
TIMER_START(fetch);
|
||||
// fetch activation block into VTCM
|
||||
{
|
||||
const float *activation_block = x + mr * k + kk;
|
||||
|
||||
dma_queue_push(ctx->dma[0],
|
||||
dma_make_ptr(vtcm_scratch1, activation_block),
|
||||
k_blk_sz * sizeof(float),
|
||||
k * sizeof(float),
|
||||
k_blk_sz * sizeof(float),
|
||||
m_blk_sz);
|
||||
}
|
||||
|
||||
// fetch weight block into VTCM (x4x2 sub-block: quants + scales)
|
||||
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
|
||||
{
|
||||
const int blk_start = kk / QK_Q4_0x4x2;
|
||||
const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
|
||||
const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2);
|
||||
const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE;
|
||||
uint8_t *dst = vtcm_scratch0;
|
||||
const uint8_t *src = w + nc * row_stride;
|
||||
const size_t n_rows = n_blk_sz;
|
||||
const size_t src_stride = row_stride;
|
||||
const size_t dst_stride = sub_row_stride;
|
||||
const size_t quant_off = (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2));
|
||||
const size_t quant_width = (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2));
|
||||
const size_t scale_off = full_qrow + blk_start * scale_blk_size;
|
||||
const size_t scale_width = nb_sub * scale_blk_size;
|
||||
|
||||
// 2D DMA: quants sub-range
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows);
|
||||
// 2D DMA: scales sub-range
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows);
|
||||
}
|
||||
TIMER_STOP(fetch);
|
||||
|
||||
TIMER_START(act_load);
|
||||
// load activation block
|
||||
{
|
||||
dma_queue_pop(ctx->dma[0]); // wait for act DNA
|
||||
transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz);
|
||||
}
|
||||
TIMER_STOP(act_load);
|
||||
|
||||
TIMER_START(wt_dequant);
|
||||
// dequantize weight block
|
||||
{
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
// vtcm_scratch0 is used to store the qweight chunk
|
||||
// worker_pool_run_func already returned, so fetch is done
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0,
|
||||
n_blk_sz, k_blk_sz, sub_row_stride, weight_type);
|
||||
}
|
||||
TIMER_STOP(wt_dequant);
|
||||
|
||||
// core mma
|
||||
TIMER_START(core);
|
||||
{
|
||||
core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles,
|
||||
n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0);
|
||||
}
|
||||
TIMER_STOP(core);
|
||||
}
|
||||
|
||||
// store output block
|
||||
{
|
||||
float *output_block = out + (mr * n + nc);
|
||||
transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us",
|
||||
TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core));
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
|
||||
int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
|
||||
const uint8_t *restrict permuted_weight, int m, int k, int n,
|
||||
int weight_type) {
|
||||
if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; }
|
||||
if (k % 32 != 0 || n % 32 != 0) { return -1; }
|
||||
|
||||
if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// for large m, k (e.g. prefill FFN Down), use out-stationary version
|
||||
if (m >= 128 && k > n && n > 1024) {
|
||||
int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type);
|
||||
if (rc != FALLBACK_TO_STANDARD) {
|
||||
return rc; // 0 success, -1 error
|
||||
}
|
||||
FARF(HIGH, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n);
|
||||
// fall through to standard path
|
||||
}
|
||||
|
||||
size_t row_stride = get_x4x2_row_stride(weight_type, k);
|
||||
if (row_stride == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
FARF(HIGH, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type);
|
||||
|
||||
// --- Dynamic VTCM layout ---
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
size_t vtcm_used = 0;
|
||||
|
||||
// Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap.
|
||||
// Only pays off when the chunker yields >=2 n-chunks, so the main loop can
|
||||
// overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for
|
||||
// double-buffered output and the worker-dispatch overhead are pure loss.
|
||||
// Try pipeline costs first; fall back to sequential if the layout collapses
|
||||
// to one n-chunk. m >= 128 floor keeps HMX utilization reasonable.
|
||||
const size_t pipe_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs)
|
||||
const size_t pipe_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer)
|
||||
const size_t seq_per_n = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs)
|
||||
const size_t seq_per_mn = sizeof(__fp16); // O x 1
|
||||
const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs)
|
||||
const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer)
|
||||
|
||||
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
|
||||
bool use_pipeline = false;
|
||||
|
||||
if (m >= 128) {
|
||||
size_t mc = 0, nc = 0, used = 0;
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
/*m_block_cost=*/(size_t) n * 3,
|
||||
/*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 &&
|
||||
hmx_ceil_div((size_t) n, nc) >= 2) {
|
||||
m_chunk_n_rows = mc;
|
||||
n_chunk_n_cols = nc;
|
||||
vtcm_used = used;
|
||||
use_pipeline = true;
|
||||
}
|
||||
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0;
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
/*m_block_cost=*/(size_t) n * 3,
|
||||
/*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) {
|
||||
FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!use_pipeline) {
|
||||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn,
|
||||
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
|
||||
/*m_block_cost=*/(size_t) n * 3,
|
||||
/*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute precise buffer sizes per execution path
|
||||
const size_t weight_area_size = hex_align_up(
|
||||
n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE);
|
||||
const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE);
|
||||
const size_t output_area_size = hex_align_up(
|
||||
m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE);
|
||||
const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE);
|
||||
const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||||
|
||||
size_t scratch0_size, scratch1_size, scratch2_size;
|
||||
if (use_pipeline) {
|
||||
scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0
|
||||
scratch1_size = scratch0_size; // dequant buf 1
|
||||
scratch2_size = output_area_size; // output buf 1
|
||||
} else {
|
||||
scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0
|
||||
scratch1_size = scratch0_size; // x4x2 DMA buf 1
|
||||
scratch2_size = 0; // unused
|
||||
}
|
||||
scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0
|
||||
scratch1_size = scratch0_size; // dequant buf 1
|
||||
scratch2_size = output_area_size; // output buf 1
|
||||
|
||||
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
|
||||
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size);
|
||||
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size);
|
||||
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size);
|
||||
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size);
|
||||
void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size);
|
||||
void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size);
|
||||
void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL;
|
||||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||||
if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) {
|
||||
FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__,
|
||||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||||
|
||||
vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base;
|
||||
if (vtcm_used > vtcm_budget) {
|
||||
FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget);
|
||||
return -1;
|
||||
}
|
||||
|
||||
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
|
||||
|
||||
FARF(HIGH, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu",
|
||||
__func__, m, k, n, weight_type, use_pipeline,
|
||||
m_chunk_n_rows, n_chunk_n_cols,
|
||||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||||
FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu",
|
||||
m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget);
|
||||
|
||||
TIMER_DEFINE(activation_load);
|
||||
TIMER_DEFINE(weight_load);
|
||||
|
|
@ -1178,184 +938,115 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
|||
TIMER_DEFINE(total);
|
||||
TIMER_START(total);
|
||||
|
||||
FARF(HIGH, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu",
|
||||
use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols,
|
||||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||||
// 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D)
|
||||
// HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D).
|
||||
|
||||
if (!use_pipeline) {
|
||||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||||
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
|
||||
// transfer activation matrix chunk into VTCM
|
||||
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
|
||||
const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
|
||||
// A --> B: vtcm_qweight, 1 buffer
|
||||
// B --> C: vtcm_weight0/vtcm_weight1, 2 buffers
|
||||
// C --> D: vtcm_output0/vtcm_output1, 2 buffers
|
||||
|
||||
TIMER_START(activation_load);
|
||||
{
|
||||
const float *activation_chunk = activation + mr * k;
|
||||
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
|
||||
}
|
||||
TIMER_STOP(activation_load);
|
||||
// Async timeline (C overlaps B+D):
|
||||
// main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2]
|
||||
// HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████]
|
||||
|
||||
void *buf_curr = vtcm_scratch0;
|
||||
void *buf_next = vtcm_scratch1;
|
||||
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
|
||||
hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors
|
||||
|
||||
{
|
||||
const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first);
|
||||
}
|
||||
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
|
||||
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
|
||||
|
||||
for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
|
||||
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
|
||||
const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
|
||||
void *vtcm_qweight = vtcm_weight;
|
||||
void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 };
|
||||
void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 };
|
||||
|
||||
TIMER_START(weight_load);
|
||||
{
|
||||
dma_queue_pop(ctx->dma[0]); // wait until current weight chunk become ready
|
||||
|
||||
const size_t nc_next = nc + n_chunk_n_cols;
|
||||
if (nc_next < n) {
|
||||
const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols);
|
||||
|
||||
const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride;
|
||||
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next);
|
||||
}
|
||||
|
||||
// Dequant + vscatter writes directly to [K, N] transposed tiles.
|
||||
// HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight.
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type);
|
||||
|
||||
hex_swap_ptr(&buf_curr, &buf_next);
|
||||
}
|
||||
TIMER_STOP(weight_load);
|
||||
|
||||
TIMER_START(hmx_core);
|
||||
{
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
|
||||
}
|
||||
TIMER_STOP(hmx_core);
|
||||
|
||||
TIMER_START(output_store);
|
||||
{
|
||||
float *output = dst + (mr * n + nc);
|
||||
transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n);
|
||||
}
|
||||
TIMER_STOP(output_store);
|
||||
}
|
||||
// prologue: A0
|
||||
const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols);
|
||||
{
|
||||
const uint8_t *qweight_chunk_A0 = permuted_weight;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0);
|
||||
}
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
} else {
|
||||
// 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D)
|
||||
// HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D).
|
||||
|
||||
// A --> B: vtcm_qweight, 1 buffer
|
||||
// B --> C: vtcm_weight0/vtcm_weight1, 2 buffers
|
||||
// C --> D: vtcm_output0/vtcm_output1, 2 buffers
|
||||
{
|
||||
const float *activation_chunk = activation + mr * k;
|
||||
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
|
||||
}
|
||||
|
||||
// Async timeline (C overlaps B+D):
|
||||
// main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2]
|
||||
// HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████]
|
||||
// prologue: B0, A1, submit C0 (async), B1 (overlaps C0)
|
||||
{
|
||||
// B0: wait for DMA, dequant weight chunk 0
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type);
|
||||
|
||||
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
|
||||
hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors
|
||||
|
||||
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
|
||||
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
|
||||
|
||||
void *vtcm_qweight = vtcm_weight;
|
||||
void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 };
|
||||
void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 };
|
||||
|
||||
// prologue: A0
|
||||
const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols);
|
||||
{
|
||||
// Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow.
|
||||
const uint8_t *qweight_chunk_A0 = permuted_weight;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0);
|
||||
// A1: issue DMA for weight chunk 1
|
||||
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
|
||||
if (1 < n_chunk_cnt) {
|
||||
const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
|
||||
}
|
||||
|
||||
{
|
||||
const float *activation_chunk = activation + mr * k;
|
||||
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
|
||||
}
|
||||
// submit C0 (non-blocking — HMX worker executes in parallel)
|
||||
hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation,
|
||||
(__fp16 *) vtcm_weight_bufs[0], vtcm_scales,
|
||||
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
|
||||
hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0]));
|
||||
|
||||
// prologue: B0, A1, submit C0 (async), B1 (overlaps C0)
|
||||
{
|
||||
// B0: wait for DMA, dequant weight chunk 0
|
||||
// B1: DMA pop + dequant (runs in parallel with C0 on HMX worker)
|
||||
if (1 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type);
|
||||
|
||||
// A1: issue DMA for weight chunk 1
|
||||
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
|
||||
if (1 < n_chunk_cnt) {
|
||||
const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
|
||||
}
|
||||
|
||||
// submit C0 (non-blocking — HMX worker executes in parallel)
|
||||
hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation,
|
||||
(__fp16 *) vtcm_weight_bufs[0], vtcm_scales,
|
||||
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
|
||||
hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0]));
|
||||
|
||||
// B1: DMA pop + dequant (runs in parallel with C0 on HMX worker)
|
||||
if (1 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type);
|
||||
}
|
||||
}
|
||||
|
||||
// main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1})
|
||||
for (int i = 0; i < n_chunk_cnt; ++i) {
|
||||
const size_t nc = i * n_chunk_n_cols;
|
||||
const size_t nc_p1 = nc + 1 * n_chunk_n_cols;
|
||||
const size_t nc_p2 = nc + 2 * n_chunk_n_cols;
|
||||
|
||||
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
|
||||
const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols);
|
||||
const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols);
|
||||
|
||||
// issue A_{i+2}: DMA push (non-blocking)
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
|
||||
}
|
||||
|
||||
// wait C_i: block until prologue/previous C completes
|
||||
hmx_queue_pop(ctx->hmx_queue);
|
||||
|
||||
// submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below)
|
||||
// job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's
|
||||
// counterpart — and (i+1)%2 was last used by C_{i-1} which completed
|
||||
// before C_i was submitted.
|
||||
if (i + 1 < n_chunk_cnt) {
|
||||
hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2],
|
||||
(__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2],
|
||||
vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
|
||||
hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2]));
|
||||
}
|
||||
|
||||
// D_i: store output (multi-thread HVX, parallel with C_{i+1})
|
||||
float *output_chunk = dst + (mr * n + nc);
|
||||
transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n);
|
||||
|
||||
// B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1})
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type);
|
||||
}
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type);
|
||||
}
|
||||
}
|
||||
|
||||
hmx_queue_suspend(ctx->hmx_queue);
|
||||
// main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1})
|
||||
for (int i = 0; i < n_chunk_cnt; ++i) {
|
||||
const size_t nc = i * n_chunk_n_cols;
|
||||
const size_t nc_p1 = nc + 1 * n_chunk_n_cols;
|
||||
const size_t nc_p2 = nc + 2 * n_chunk_n_cols;
|
||||
|
||||
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
|
||||
const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols);
|
||||
const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols);
|
||||
|
||||
// issue A_{i+2}: DMA push (non-blocking)
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride;
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
|
||||
}
|
||||
|
||||
// wait C_i: block until prologue/previous C completes
|
||||
hmx_queue_pop(ctx->hmx_queue);
|
||||
|
||||
// submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below)
|
||||
// job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's
|
||||
// counterpart — and (i+1)%2 was last used by C_{i-1} which completed
|
||||
// before C_i was submitted.
|
||||
if (i + 1 < n_chunk_cnt) {
|
||||
hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2],
|
||||
(__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2],
|
||||
vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
|
||||
hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||||
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2]));
|
||||
}
|
||||
|
||||
// D_i: store output (multi-thread HVX, parallel with C_{i+1})
|
||||
float *output_chunk = dst + (mr * n + nc);
|
||||
transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n);
|
||||
|
||||
// B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1})
|
||||
if (i + 2 < n_chunk_cnt) {
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hmx_queue_suspend(ctx->hmx_queue);
|
||||
|
||||
TIMER_STOP(total);
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline);
|
||||
FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n);
|
||||
if (!use_pipeline) {
|
||||
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
|
||||
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
|
||||
|
|
@ -1370,15 +1061,15 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
|||
|
||||
//
|
||||
|
||||
static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) {
|
||||
static inline int hmx_matmul_batch_r2(const hmx_matmul_f16_f32_batched_params_t *params) {
|
||||
return params->ne02 > 0 ? params->ne12 / params->ne02 : 1;
|
||||
}
|
||||
|
||||
static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) {
|
||||
static inline int hmx_matmul_batch_r3(const hmx_matmul_f16_f32_batched_params_t *params) {
|
||||
return params->ne03 > 0 ? params->ne13 / params->ne03 : 1;
|
||||
}
|
||||
|
||||
static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
|
||||
static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params,
|
||||
int dst_b2, int dst_b3) {
|
||||
const int r2 = hmx_matmul_batch_r2(params);
|
||||
const int r3 = hmx_matmul_batch_r3(params);
|
||||
|
|
@ -1387,37 +1078,36 @@ static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_
|
|||
(size_t) (dst_b3 / r3) * params->src0_nb3);
|
||||
}
|
||||
|
||||
static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
|
||||
static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params,
|
||||
int dst_b2, int dst_b3) {
|
||||
return (const float *) ((const uint8_t *) params->activation +
|
||||
(size_t) dst_b2 * params->src1_nb2 +
|
||||
(size_t) dst_b3 * params->src1_nb3);
|
||||
}
|
||||
|
||||
static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
|
||||
static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params,
|
||||
int dst_b2, int dst_b3) {
|
||||
return (float *) ((uint8_t *) params->dst +
|
||||
(size_t) dst_b2 * params->dst_nb2 +
|
||||
(size_t) dst_b3 * params->dst_nb3);
|
||||
}
|
||||
|
||||
static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx,
|
||||
const hmx_matmul_w16a32_batched_params_t *params) {
|
||||
static int hmx_matmul_f16_f32_batched_legacy(struct htp_context *ctx,
|
||||
const hmx_matmul_f16_f32_batched_params_t *params) {
|
||||
int ret = 0;
|
||||
for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) {
|
||||
for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) {
|
||||
ret = hmx_mat_mul_permuted_w16a32(ctx,
|
||||
hmx_matmul_dst_batch_ptr(params, b2, b3),
|
||||
hmx_matmul_activation_batch_ptr(params, b2, b3),
|
||||
hmx_matmul_weight_batch_ptr(params, b2, b3),
|
||||
params->m, params->k, params->n,
|
||||
params->act_stride, params->weight_stride);
|
||||
ret = hmx_matmul_f16_f32(ctx, hmx_matmul_dst_batch_ptr(params, b2, b3),
|
||||
hmx_matmul_activation_batch_ptr(params, b2, b3),
|
||||
hmx_matmul_weight_batch_ptr(params, b2, b3),
|
||||
params->m, params->k, params->n,
|
||||
params->act_stride, params->weight_stride);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) {
|
||||
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params) {
|
||||
if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; }
|
||||
if (!params->m || !params->k || !params->n) { return -1; }
|
||||
if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; }
|
||||
|
|
@ -1435,7 +1125,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
|||
|
||||
if (group_size <= 1) {
|
||||
FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size);
|
||||
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
|
||||
return hmx_matmul_f16_f32_batched_legacy(ctx, params);
|
||||
}
|
||||
|
||||
// Grouped path: reuse interleaved weight across all q_heads sharing a
|
||||
|
|
@ -1464,7 +1154,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
|||
/*m_block_cost=*/(size_t) params->n,
|
||||
/*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||||
FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__);
|
||||
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
|
||||
return hmx_matmul_f16_f32_batched_legacy(ctx, params);
|
||||
}
|
||||
|
||||
const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads
|
||||
|
|
@ -1486,7 +1176,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
|||
|
||||
if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) {
|
||||
FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__);
|
||||
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
|
||||
return hmx_matmul_f16_f32_batched_legacy(ctx, params);
|
||||
}
|
||||
|
||||
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
|
||||
|
|
@ -1614,7 +1304,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
|||
|
||||
//
|
||||
|
||||
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
|
||||
int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
|
||||
const __fp16 *restrict permuted_weight, int m, int k, int n,
|
||||
int act_stride, int weight_stride) {
|
||||
if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; }
|
||||
|
|
|
|||
|
|
@ -33,14 +33,14 @@ typedef struct {
|
|||
size_t src1_nb3;
|
||||
size_t dst_nb2;
|
||||
size_t dst_nb3;
|
||||
} hmx_matmul_w16a32_batched_params_t;
|
||||
} hmx_matmul_f16_f32_batched_params_t;
|
||||
|
||||
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
|
||||
// act_stride: activation row stride in elements (= k for contiguous, or
|
||||
// nb[1]/sizeof(float) for permuted tensors like attention Q).
|
||||
// weight_stride: weight row stride in elements (= k for compact weights, or
|
||||
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
|
||||
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx,
|
||||
int hmx_matmul_f16_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const __fp16 *permuted_weight,
|
||||
|
|
@ -48,13 +48,12 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx,
|
|||
int act_stride,
|
||||
int weight_stride);
|
||||
|
||||
// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32.
|
||||
// Batched F16 wrapper over hmx_mat_mul_f16_f32.
|
||||
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
|
||||
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx,
|
||||
const hmx_matmul_w16a32_batched_params_t *params);
|
||||
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params);
|
||||
|
||||
// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL)
|
||||
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
|
||||
// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4)
|
||||
int hmx_matmul_q_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const uint8_t *permuted_weight,
|
||||
|
|
|
|||
|
|
@ -87,6 +87,27 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
|
|||
}
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ >= 75
|
||||
{
|
||||
// Power on HMX and set HMX clock
|
||||
HAP_power_request_t request;
|
||||
memset(&request, 0, sizeof(HAP_power_request_t));
|
||||
request.type = HAP_power_set_HMX_v2;
|
||||
request.hmx_v2.set_power = TRUE;
|
||||
request.hmx_v2.power_up = TRUE;
|
||||
request.hmx_v2.set_clock = TRUE;
|
||||
request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH;
|
||||
FARF(ALWAYS, "Setting HMX clock\n");
|
||||
err = HAP_power_set((void *) ctx, &request);
|
||||
if (err != AEE_SUCCESS) {
|
||||
FARF(ERROR, "Error setting HMX clock.");
|
||||
return err;
|
||||
}
|
||||
}
|
||||
#else
|
||||
{
|
||||
// Power on HMX
|
||||
HAP_power_request_t request;
|
||||
|
|
@ -94,31 +115,12 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
|
|||
request.type = HAP_power_set_HMX;
|
||||
request.hmx.power_up = TRUE;
|
||||
FARF(ALWAYS, "Powering HMX on\n");
|
||||
err = HAP_power_set((void *) &ctx, &request);
|
||||
err = HAP_power_set((void *) ctx, &request);
|
||||
if (err != AEE_SUCCESS) {
|
||||
FARF(ERROR, "Error powering on HMX.");
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ >= 75
|
||||
{
|
||||
// Set HMX clock
|
||||
HAP_power_request_t request;
|
||||
memset(&request, 0, sizeof(HAP_power_request_t));
|
||||
request.type = HAP_power_set_HMX_v2;
|
||||
request.hmx_v2.set_clock = TRUE;
|
||||
request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX;
|
||||
request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH;
|
||||
FARF(ALWAYS, "Setting HMX clock\n");
|
||||
err = HAP_power_set((void *) &ctx, &request);
|
||||
if (err != AEE_SUCCESS) {
|
||||
FARF(ERROR, "Error setting HMX clock.");
|
||||
return err;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return AEE_SUCCESS;
|
||||
|
|
|
|||
|
|
@ -2995,7 +2995,6 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
// is handled by HMX itself; when M < 32 fall back to HVX.
|
||||
const int m_total = (int) src1->ne[1];
|
||||
const int m_hmx = m_total & ~31; // 0 when M < 32
|
||||
|
||||
if (m_hmx == 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
|
@ -3020,7 +3019,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
|
||||
if (src0->type == HTP_TYPE_F16) {
|
||||
if (is_batched) {
|
||||
hmx_matmul_w16a32_batched_params_t batch_params = {
|
||||
hmx_matmul_f16_f32_batched_params_t batch_params = {
|
||||
.dst = (float *) dst->data,
|
||||
.activation = (float *) src1->data,
|
||||
.permuted_weight = (const __fp16 *) src0->data,
|
||||
|
|
@ -3041,15 +3040,14 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
.dst_nb2 = dst->nb[2],
|
||||
.dst_nb3 = dst->nb[3],
|
||||
};
|
||||
ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params);
|
||||
ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params);
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_w16a32(octx->ctx,
|
||||
ret = hmx_matmul_f16_f32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
|
||||
m_total, k, n, act_stride, wgt_stride);
|
||||
}
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
|
||||
ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
|
||||
m_total, k, n, (int) src0->type);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue