From 8d61a9edf0b3b429671f1a7037d67243d941b517 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Mon, 1 Jun 2026 23:40:08 -0700 Subject: [PATCH] hexagon: MUL_MAT, MUL_MAT_ID, FLASH_ATTN and GDN cleanup and optimizations for latest models (llama/23989) * hex-mm: initial support for F32 * F32 -> F32 matmuls * hex-rms-norm: fix src1 stride use in fused rms_norm_mul * hex-ops: clear spad pointers in the ops that clober it This fixes an odd case where fused rms-norm-mul was failing but only in qwen3.5-2B and only at searth op-bath sizes. * hmx-mm: add support for F32 * F32 -> F32 matmul_2d on HMX Decided to use Q4_0 * F32 -> F32 matmul for this. Q4_0 gets dequantized and tiled into F16, and here we quantize and tile F32 into F16. Super simple and pretty efficient. * hmx-mm: route f16 2D matmuls through the same kernel used for all other types * hmx-mm: re-introduce pipelined vs non-pipelined mode that we used to have but is much more generic way This update futher improves matmul performance and at the same time removes most of the redudant logic we had in different paths. * hmx-fa: slighlty improved pipeline simimar to matmul updates * hmx-mm: initial version of MAT_MUL_ID support for HMX * hmx-mm: fixed mxfp4 handling for MUL_MAT_ID * hex-gdn: optimize GATED_DELTA_NET DMA prefetch/double-buff, vectorize everything with HVX, in other words -- the usual :) * hmx-mm: missed one more case where we can use fastmod * hexagon: update DCVS settings for a slight perf bump * hmx-fa: use fastdiv in hmx-flash-attn * hmx-fa: precompute slope values to avoid disrupting the inner loop * hvx-utils/fa: new HVX helpers for powf and logf and using those to speed up FA alibi * hex-ops: fixed a bug in fusion logic that was messing up the order of the src tensors when some srcs are empty * hex-fa: correctly fallback to HVX if we have sinks or the dims are not quite right --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 22 + ggml/src/ggml-hexagon/htp-opnode.h | 50 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 71 +- ggml/src/ggml-hexagon/htp/argsort-ops.c | 1 + ggml/src/ggml-hexagon/htp/concat-ops.c | 2 + ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 15 +- .../ggml-hexagon/htp/gated-delta-net-ops.c | 676 ++++++++----- .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 118 ++- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 907 +++++++++++++----- ggml/src/ggml-hexagon/htp/hmx-ops.c | 6 + ggml/src/ggml-hexagon/htp/hmx-ops.h | 22 +- ggml/src/ggml-hexagon/htp/htp-ctx.h | 4 + ggml/src/ggml-hexagon/htp/hvx-flash-attn.h | 47 + ggml/src/ggml-hexagon/htp/hvx-log.h | 65 ++ ggml/src/ggml-hexagon/htp/hvx-pow.h | 42 + ggml/src/ggml-hexagon/htp/hvx-utils.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 26 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 390 +++++++- ggml/src/ggml-hexagon/htp/pad-ops.c | 2 + ggml/src/ggml-hexagon/htp/unary-ops.c | 17 +- 20 files changed, 1859 insertions(+), 626 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hmx-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/hvx-flash-attn.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-log.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-pow.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 920829f6a..d550841a2 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1927,6 +1927,7 @@ struct ggml_hexagon_opbatch { size_t extra_tens = 0; auto fit_tensor = [&](const ggml_tensor *t) { + if (!t) return; if (!t_map.count(t)) { extra_tens++; @@ -2602,6 +2603,27 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); return false; } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } + if (ggml_nrows(src1) > 1024) { + return false; // no huge batches (for now) + } + break; + + case GGML_TYPE_F32: + if (src1->type != GGML_TYPE_F32) { + return false; + } + if (src0->nb[1] < src0->nb[0]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F32 src0 not supported\n"); + return false; + } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } if (ggml_nrows(src1) > 1024) { return false; // no huge batches (for now) } diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h index 14b232240..8a1228ccd 100644 --- a/ggml/src/ggml-hexagon/htp-opnode.h +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -56,7 +56,7 @@ struct htp_opnode { } std::vector get_inputs() const { - std::vector inputs; + std::vector inputs(GGML_MAX_SRC, nullptr); std::vector outputs; outputs.push_back(node); for (const auto * f : fused) { @@ -70,20 +70,38 @@ struct htp_opnode { return false; }; + int count = 0; auto add_input = [&](const ggml_tensor * t) { if (t && !contains(outputs, t) && !contains(inputs, t)) { - inputs.push_back(t); + if (count < (int)inputs.size()) { + inputs[count++] = t; + } else { + inputs.push_back(t); + } } }; - for (int i = 0; i < GGML_MAX_SRC && node->src[i]; i++) { - add_input(node->src[i]); - } - for (const auto * f : fused) { - for (int i = 0; i < GGML_MAX_SRC && f->src[i]; i++) { - add_input(f->src[i]); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (fused.empty()) { + inputs[i] = node->src[i]; + } else { + if (node->src[i]) { + add_input(node->src[i]); + } } } + for (const auto * f : fused) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (f->src[i]) { + add_input(f->src[i]); + } + } + } + + if (!fused.empty()) { + inputs.resize(count); + } + return inputs; } @@ -108,6 +126,9 @@ struct htp_opformat { char names[64 * GGML_MAX_SRC]; int format_tensor_dims(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } if (t->ne[2] == 1 && t->ne[3] == 1) { return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); } else { @@ -136,6 +157,9 @@ struct htp_opformat { } int format_tensor_strides(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } const char * c = ggml_is_contiguous(t) ? "" : "!"; if (t->ne[2] == 1 && t->ne[3] == 1) { @@ -170,11 +194,11 @@ struct htp_opformat { auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", ggml_type_name(inputs[0]->type)); + p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"); for (size_t i = 1; i < inputs.size(); i++) { p += sprintf(p, " x "); - p += sprintf(p, "%s", ggml_type_name(inputs[i]->type)); + p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"); } p += sprintf(p, " -> "); @@ -184,7 +208,7 @@ struct htp_opformat { } const char * tensor_buff_name(const struct ggml_tensor * t) { - if (t->buffer) { + if (t && t->buffer) { return ggml_backend_buffer_name(t->buffer); } return "NONE"; @@ -213,11 +237,11 @@ struct htp_opformat { auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", inputs[0]->name); + p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE"); for (size_t i = 1; i < inputs.size(); i++) { p += sprintf(p, " x "); - p += sprintf(p, "%s", inputs[i]->name); + p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE"); } p += sprintf(p, " -> "); diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index ff3fc0804..f4b44fe1a 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -19,6 +19,43 @@ add_library(${HTP_LIB} SHARED htp_iface_skel.c worker-pool.c hex-dma.c +) + +target_compile_definitions(${HTP_LIB} PRIVATE + $,HTP_DEBUG=1,NDEBUG=1> + $,FARF_HIGH=1,> + FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) + +if (GGML_HEXAGON_FA_EXP2_HF) + message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") + target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) +endif() + +# HMX acceleration: available on v73+ architectures +set(HTP_HMX_VERSIONS v73 v75 v79 v81) +list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) + +if (_hmx_idx GREATER_EQUAL 0) + target_sources(${HTP_LIB} PRIVATE + hmx-matmul-ops.c + hmx-flash-attn-ops.c + hmx-queue.c + ) + + # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) + set_source_files_properties( + hmx-flash-attn-ops.c + hmx-matmul-ops.c + hmx-queue.c + PROPERTIES COMPILE_OPTIONS "-mhmx" + ) + + target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1) +endif() + +build_idl(htp_iface.idl ${HTP_LIB}) + +target_sources(${HTP_LIB} PRIVATE matmul-ops.c binary-ops.c unary-ops.c @@ -42,40 +79,6 @@ add_library(${HTP_LIB} SHARED pad-ops.c ) -target_compile_definitions(${HTP_LIB} PRIVATE - $,HTP_DEBUG=1,NDEBUG=1> - $,FARF_HIGH=1,> - FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) - -if (GGML_HEXAGON_FA_EXP2_HF) - message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") - target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) -endif() - -# HMX acceleration: available on v73+ architectures -set(HTP_HMX_VERSIONS v73 v75 v79 v81) -list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) - -if (_hmx_idx GREATER_EQUAL 0) - target_sources(${HTP_LIB} PRIVATE - hmx-flash-attn-ops.c - hmx-matmul-ops.c - hmx-queue.c - ) - - # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) - set_source_files_properties( - hmx-flash-attn-ops.c - hmx-matmul-ops.c - hmx-queue.c - PROPERTIES COMPILE_OPTIONS "-mhmx" - ) - - target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1) -endif() - -build_idl(htp_iface.idl ${HTP_LIB}) - set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) install(TARGETS ${HTP_LIB}) diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index bdd062361..73af38a35 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -276,6 +276,7 @@ int op_argsort(struct htp_ops_context * octx) { octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.size = total_spad_size; octx->src0_spad.size_per_thread = spad_per_thread; + octx->src0_spad.src = NULL; FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3], diff --git a/ggml/src/ggml-hexagon/htp/concat-ops.c b/ggml/src/ggml-hexagon/htp/concat-ops.c index 61580f2c0..f2a381313 100644 --- a/ggml/src/ggml-hexagon/htp/concat-ops.c +++ b/ggml/src/ggml-hexagon/htp/concat-ops.c @@ -262,6 +262,8 @@ int op_concat(struct htp_ops_context * octx) { octx->src0_spad.data = octx->ctx->vtcm_base; octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; if (type_size == 4) { worker_func = concat_2d_f32_transposed; diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 1bd8c1407..e99621469 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -11,6 +11,7 @@ #include "hex-dma.h" #include "hvx-utils.h" #include "hvx-dump.h" +#include "hvx-flash-attn.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -245,6 +246,7 @@ struct htp_fa_context { uint32_t n_head_log2; float m0; float m1; + float slopes[512]; uint32_t n_blocks; @@ -412,7 +414,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * } const uint32_t h = iq2; // head index - const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; + const float slope = factx->slopes[h]; HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); @@ -628,8 +630,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { } #ifdef HTP_HAS_HMX - // HMX path: head_dim multiple of 32, F16 KV - if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0) { + // HMX path: head_dim multiple of 64, F16 KV, and no sinks + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) { int ret = hmx_flash_attn_ext(octx); if (ret == HTP_STATUS_OK) { return ret; @@ -689,6 +691,13 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + if (n_head > 512) { + return HTP_STATUS_NO_SUPPORT; + } + for (uint32_t h = 0; h < n_head; ++h) { + factx.slopes[h] = (max_bias > 0.0f) ? alibi_slope(h, factx.n_head_log2, factx.m0, factx.m1) : 1.0f; + } + // total rows in q const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c index c4d08bb21..3b092d744 100644 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -3,6 +3,7 @@ #include #include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -14,106 +15,103 @@ #define HTP_GDN_MAX_SV 128 + struct htp_gdn_context { struct htp_ops_context * octx; uint32_t rows_per_thread; - size_t state_bytes; - bool use_vtcm; - uint8_t * vtcm_state_base; - size_t vtcm_state_per_thread; + size_t state_bytes; + uint8_t * vtcm_base; + size_t vtcm_per_thread; }; -static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, - const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } -static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul, - const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_mul_scalar_dot_f32(float * restrict dst, float mul, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); const HVX_Vector vmul = hvx_vec_splat_f32(mul); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vd = hvx_vmemu(dst + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } -static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, - float scale, const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, + HVX_Vector vscale, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); - const HVX_Vector vscale = hvx_vec_splat_f32(scale); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1, @@ -126,7 +124,7 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -147,11 +145,11 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); @@ -159,10 +157,10 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -185,7 +183,7 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -205,10 +203,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); @@ -216,10 +214,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -246,7 +244,7 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -267,11 +265,11 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); @@ -279,10 +277,10 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -310,7 +308,7 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -343,11 +341,11 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); @@ -359,14 +357,14 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm); HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -400,7 +398,7 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -432,10 +430,10 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); @@ -447,14 +445,14 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul); HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -496,7 +494,7 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -529,11 +527,11 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); @@ -545,14 +543,14 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6)); HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7)); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -605,26 +603,65 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[4] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); + + dma_queue * dma = octx->ctx->dma[ith]; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); + + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); const uint64_t state_seq_stride = state->nb[2] / sizeof(float); const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; const int64_t shift = (int64_t) n_tokens - (int64_t) K; - for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; + uint32_t ir_prefetch = ith; + int spad_idx = 0; - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + int curr_spad_idx = 0; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { + dma_queue_pop(dma); + dma_queue_pop(dma); + + float * s_work_curr = s_work[curr_spad_idx]; + + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); + + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; - - memcpy(s_out, s_in, gctx->state_bytes); - float * s_work = s_out; float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; @@ -640,57 +677,117 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + (uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]); - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); if (kda) { hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); uint32_t j = 0; - for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } - gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); + } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } else { const float gate = expf(g_t[0]); uint32_t j = 0; - for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } - gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); + } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } @@ -698,17 +795,40 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const int64_t target_slot = (int64_t) t - shift; if (target_slot >= 0 && target_slot < (int64_t) K) { float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - if (curr_state_o != s_work) { - memcpy(curr_state_o, s_work, gctx->state_bytes); + if (curr_state_o != s_out) { + hvx_copy_f32_uu((uint8_t *) curr_state_o, (const uint8_t *) s_work_curr, S_v * S_v); } } } attn_data += (uint64_t) S_v * H; } + + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + curr_spad_idx ^= 1; } + dma_queue_flush(dma); } + static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; struct htp_ops_context * octx = gctx->octx; @@ -743,41 +863,64 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[8] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); dma_queue * dma = octx->ctx->dma[ith]; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); - uint8_t * spad = NULL; - if (gctx->use_vtcm) { - spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith; - } + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); const uint64_t state_seq_stride = state->nb[2] / sizeof(float); const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; - for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; + uint32_t ir_prefetch = ith; + int spad_idx = 0; - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + int curr_spad_idx = 0; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { + dma_queue_pop(dma); + dma_queue_pop(dma); + + float * s_work_curr = s_work[curr_spad_idx]; + + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); + + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; - float * s_work; - - if (spad) { - dma_queue_push(dma, dma_make_ptr(spad, s_in), - S_v * sizeof(float), S_v * sizeof(float), - S_v * sizeof(float), S_v); - dma_queue_pop(dma); - s_work = (float *) spad; - } else { - s_work = s_out; - memcpy(s_work, s_in, gctx->state_bytes); - } float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; @@ -792,111 +935,145 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + (uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]); - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); if (kda) { hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); uint32_t j = 0; for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } else { const float gate = expf(g_t[0]); uint32_t j = 0; for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } - if (spad) { - dma_queue_push(dma, dma_make_ptr(s_out, spad), + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), S_v * sizeof(float), S_v * sizeof(float), S_v * sizeof(float), S_v); - dma_queue_pop(dma); + + ir_prefetch += nth; + spad_idx ^= 1; } + + curr_spad_idx ^= 1; } + dma_queue_flush(dma); } + int op_gated_delta_net(struct htp_ops_context * octx) { const struct htp_tensor * q = octx->src[0]; const struct htp_tensor * k = octx->src[1]; @@ -952,18 +1129,11 @@ int op_gated_delta_net(struct htp_ops_context * octx) { size_t state_aligned = (size_t) S_v * S_v * sizeof(float); state_aligned = (state_aligned + 127) & ~(size_t)127; - gctx.use_vtcm = false; - gctx.vtcm_state_base = NULL; - gctx.vtcm_state_per_thread = 0; + assert(octx->ctx->vtcm_base != NULL); + assert(octx->ctx->vtcm_size >= 2 * state_aligned * octx->n_threads); - if (n_tokens == 1 && octx->ctx->vtcm_base) { - size_t vtcm_total = state_aligned * octx->n_threads; - if (octx->ctx->vtcm_size >= vtcm_total) { - gctx.use_vtcm = true; - gctx.vtcm_state_base = octx->ctx->vtcm_base; - gctx.vtcm_state_per_thread = state_aligned; - } - } + gctx.vtcm_base = octx->ctx->vtcm_base; + gctx.vtcm_per_thread = 2 * state_aligned; if (n_tokens == 1) { worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads); diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index f132c0850..2796564fb 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -17,14 +17,17 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "hex-dma.h" +#include "hex-fastdiv.h" #include "hmx-profile.h" #include "hmx-queue.h" #include "hmx-utils.h" #include "htp-ctx.h" #include "htp-ops.h" #include "hvx-dump.h" +#include "hvx-copy.h" #include "hvx-reduce.h" #include "hvx-utils.h" +#include "hvx-flash-attn.h" #include "vtcm-utils.h" #include "worker-pool.h" @@ -46,7 +49,7 @@ // g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. // Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales // Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. -static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) { +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) { const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong @@ -67,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, + k_dma_size * 2 // K DMA x2 + v_dma_size * 2 // V DMA x2 + k_tile_size * 1 // K tiles - + v_tile_size * 1 // V tiles + + v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + s_tile_size * 2 // S + P + d_tile_size * 1 // D (diagonal matrix) + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum @@ -144,12 +147,13 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, // See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off. const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64 const size_t fp16 = sizeof(__fp16); + const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); // Approximate per-unit VTCM costs (without per-buffer alignment padding). const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors const size_t per_gbr2 = fp16; // D diagonal matrix const size_t per_bc = - 3 * (DK + DV) * fp16 + 2 * n_threads * fp16; // K_dma×2 + V_dma×2 + K_tile + V_tile + row bufs + 3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs const size_t per_gbr_bc = 2 * fp16; // S + P const size_t overhead = 256 * 2 + 13 * 4096; @@ -164,7 +168,6 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, // Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS. // Only relax when kv_len is too short to form enough blocks. - const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); // Cost coefficients calibrated from profiling @@ -200,7 +203,7 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, } // Exact VTCM verification (alignment padding may push over budget) - while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) { + while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) { Bc -= bc_unit; } if (Bc < bc_unit) { @@ -303,6 +306,7 @@ struct hmx_fa_context { uint32_t n_kv_heads; // number of KV heads uint32_t n_heads; // number of Q heads uint32_t G; // GQA factor = n_heads / n_kv_heads + struct fastdiv_values div_G; uint32_t n_kv_blocks; uint32_t neq1; // Q token count @@ -321,7 +325,7 @@ struct hmx_fa_context { __fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D] __fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D] __fp16 * vtcm_k_tiles; // K tiles (transposed) - __fp16 * vtcm_v_tiles; // V tiles (column-major) + __fp16 * vtcm_v_tiles[2]; // V tiles (column-major, double-buffered) __fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc] __fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc] __fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br] @@ -402,7 +406,9 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) return; } - hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, + __fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; + + hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, (int) args->src_stride, (int) args->n_col_tiles, start, end); } @@ -464,10 +470,10 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) { for (size_t r = start; r < end; r += 2) { const bool next_row_valid = (r + 1) < n_rows_g; - const size_t q_idx0 = (r + 0) / G; - const size_t h_idx0 = (r + 0) % G; - const size_t q_idx1 = (r + 1) / G; - const size_t h_idx1 = (r + 1) % G; + const size_t q_idx0 = fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = fastmodulo(r + 0, G, &factx->div_G); + const size_t q_idx1 = fastdiv(r + 1, &factx->div_G); + const size_t h_idx1 = fastmodulo(r + 1, G, &factx->div_G); const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] + (kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3]; @@ -567,8 +573,8 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) { const uint32_t ib3 = args->ib3; for (size_t r = start; r < end; ++r) { - const size_t q_idx = r / G; - const size_t h_idx = r % G; + const size_t q_idx = fastdiv(r, &factx->div_G); + const size_t h_idx = fastmodulo(r, G, &factx->div_G); // FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) -> // [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2]. @@ -780,11 +786,11 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { if (args->mask_vtcm) { // Read mask from VTCM buffer (DMA'd per KV block). // GQA dedup (scheme B): skip load when qi unchanged. - const size_t qi0 = (r + 0) / G; + const size_t qi0 = fastdiv(r + 0, &factx->div_G); v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { - const size_t qi1 = (r + 1) / G; + const size_t qi1 = fastdiv(r + 1, &factx->div_G); if (qi1 == qi0) { v_mask1 = v_mask0; // scheme B: reuse — same mask row } else { @@ -794,8 +800,8 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { } else { // Fallback: read mask directly from DDR (when mask->ne[2] > 1). const struct htp_tensor * mask = args->mask; - const size_t q_idx0 = args->q_start + ((r + 0) / G); - const size_t h_idx0 = args->kv_head * G + (r + 0) % G; + const size_t q_idx0 = args->q_start + fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = args->kv_head * G + fastmodulo(r + 0, G, &factx->div_G); const uint32_t im2_0 = h_idx0 % mask->ne[2]; const uint32_t im3_0 = args->ib3 % mask->ne[3]; @@ -805,12 +811,12 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { - const size_t q_idx1 = args->q_start + ((r + 1) / G); + const size_t q_idx1 = args->q_start + fastdiv(r + 1, &factx->div_G); if (q_idx1 == q_idx0) { // scheme B: same mask row in DDR path v_mask1 = v_mask0; } else { - const size_t h_idx1 = args->kv_head * G + (r + 1) % G; + const size_t h_idx1 = args->kv_head * G + fastmodulo(r + 1, G, &factx->div_G); const uint32_t im2_1 = h_idx1 % mask->ne[2]; const uint32_t im3_1 = args->ib3 % mask->ne[3]; const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] + @@ -1191,14 +1197,13 @@ static void hmx_fa_o_norm_worker(void * data) { // Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G. // slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1). // When max_bias == 0, all slopes are 1.0 (no ALiBi). -static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs, +static __attribute__((noinline)) void fa_compute_slopes( const struct hmx_fa_context * factx, uint32_t kv_head, size_t n_rows_g) { + __fp16 * slopes = factx->vtcm_slopes; if (factx->max_bias == 0.0f) { - for (size_t r = 0; r < n_rows_g; ++r) { - sargs->slopes[r] = 1.0f; - } + hvx_splat_f16_a(slopes, 1.0f, n_rows_g); return; } @@ -1207,10 +1212,32 @@ static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sarg const float m0 = factx->m0; const float m1 = factx->m1; - for (size_t r = 0; r < n_rows_g; ++r) { - const uint32_t h = kv_head * G + r % G; - sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); + __fp16 temp_slopes[512] __attribute__((aligned(128))); + if (G <= 32) { + // Fast path: Compute G unique slope values in vector registers + HVX_Vector v_val = hvx_alibi_slopes(kv_head, G, n_head_log2, m0, m1); + + __fp16 temp_slopes_aligned[64] __attribute__((aligned(128))); + hvx_vmem(temp_slopes_aligned) = hvx_vec_f32_to_f16(v_val, Q6_V_vzero()); + + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = temp_slopes_aligned[i]; + } + } else { + // Fallback path: G > 32 (rare configurations) + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = (__fp16)alibi_slope(kv_head * G + i, n_head_log2, m0, m1); + } } + + // Allocate stack buffer to avoid scalar writes to VTCM (which generates L2 misses) + __fp16 local_slopes[n_rows_g] __attribute__((aligned(128))); + for (size_t r = 0; r < n_rows_g; ++r) { + local_slopes[r] = temp_slopes[fastmodulo(r, G, &factx->div_G)]; + } + + // Copy to VTCM slopes using HVX block copy (both are aligned to 128 bytes) + hvx_copy_f16_aa((uint8_t *)slopes, (const uint8_t *)local_slopes, n_rows_g); } // ============================================================================ @@ -1254,19 +1281,22 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const uint32_t G = neq2 / n_kv_heads; // Thread count for multi-thread HVX phases - const uint32_t n_threads = octx->n_threads; + const uint32_t n_threads_init = octx->n_threads; // Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs) size_t Br, Bc; const size_t vtcm_budget = ctx->vtcm_size; - if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) { + if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads_init) != 0) { return HTP_STATUS_VTCM_TOO_SMALL; } const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; - const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2); + const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); + + // Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1 + const uint32_t n_threads = use_pipeline ? n_threads_init : 1; FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); @@ -1282,6 +1312,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.n_kv_heads = n_kv_heads; factx.n_heads = neq2; factx.G = G; + factx.div_G = init_fastdiv_values(G); factx.neq1 = neq1; factx.Br = (uint32_t) Br; factx.Bc = (uint32_t) Bc; @@ -1354,7 +1385,12 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); - factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + if (use_pipeline) { + factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + } else { + factx.vtcm_v_tiles[1] = NULL; + } factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes); @@ -1457,6 +1493,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ---- KV block loop with DMA double-buffering ---- size_t buf_idx = 0; + fa_compute_slopes(&factx, kv_head, n_rows_g); + // Prefetch first KV block if (factx.n_kv_blocks > 0) { const uint32_t kv_rows0 = hex_smin(Bc, nek1); @@ -1535,7 +1573,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { ou_job.o_curr = o_tile_curr; ou_job.o_prev = o_tile_prev; ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; ou_job.d_tiles = factx.vtcm_d_tiles; ou_job.hmx_scales = factx.vtcm_hmx_scales_id; ou_job.n_row_tiles = n_row_tiles; @@ -1550,11 +1588,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); TIMER_STOP(k_interleave); - if (kv_blk > 0) { - hmx_queue_pop(hmx_q); - hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); - } - // ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ---- qk_job.q_tiles = factx.vtcm_q_tiles; qk_job.k_tiles = factx.vtcm_k_tiles; @@ -1574,6 +1607,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); TIMER_STOP(v_interleave); + // Pop and swap previous block's output update (deferred HMX pop) + if (kv_blk > 0) { + hmx_queue_pop(hmx_q); + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + // Pop current block's dot product job hmx_queue_pop(hmx_q); TIMER_STOP(qk_dot); @@ -1601,7 +1641,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); TIMER_START(softmax); fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); @@ -1617,7 +1656,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { ou_job.o_curr = o_tile_curr; ou_job.o_prev = o_tile_prev; ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; ou_job.d_tiles = factx.vtcm_d_tiles; ou_job.hmx_scales = factx.vtcm_hmx_scales_id; ou_job.n_row_tiles = n_row_tiles; @@ -1712,7 +1751,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); TIMER_START(softmax); fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); @@ -1732,7 +1770,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const size_t DV_tiles = (size_t) (DV / 32); const __fp16 * restrict d_base = factx.vtcm_d_tiles; const __fp16 * restrict p_base = factx.vtcm_p_tiles; - const __fp16 * restrict v_base = factx.vtcm_v_tiles; + const __fp16 * restrict v_base = factx.vtcm_v_tiles[0]; const __fp16 * restrict op_base = o_tile_prev; __fp16 * restrict oc_base = o_tile_curr; __builtin_assume(n_row_tiles > 0); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 083d12588..dab605210 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -73,6 +73,10 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb case HTP_TYPE_MXFP4: return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb + case HTP_TYPE_F16: + return (size_t) k * sizeof(__fp16); + case HTP_TYPE_F32: + return (size_t) k * sizeof(float); default: return 0; } @@ -545,7 +549,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4( int start_tile, int end_tile) { const int n_k_tiles = state->n_k_tiles; - const int qrow_size = state->k_block; + const int qrow_size = (unsigned)state->k_block / 2; const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); @@ -720,12 +724,129 @@ static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, voi } } +static void convert_f16_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); + HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + convert_f16_weight_to_fp16_tiles_task(state, start, end); + } +} + +static void quantize_f32_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(float); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0_f32 = hvx_vmemu((const float *)(r0 + byte_off)); + HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmemu((const float *)(r1 + byte_off)) : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out_hi); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + quantize_f32_weight_to_fp16_tiles_task(state, start, end); + } +} + + static void dequantize_x4x2_weight_chunk_to_fp16_tiles( struct htp_context *ctx, __fp16 *vtcm_dst, const void *vtcm_src, int n_cols, int k_block, size_t row_stride, int weight_type, int n_k_tiles, struct fastdiv_values n_k_tiles_div, - worker_callback_t dequant_worker_fn) { + worker_callback_t dequant_worker_fn, int n_threads) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); assert(k_block % HMX_FP16_TILE_N_COLS == 0); @@ -733,7 +854,7 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; size_t n_tot_tiles = n_col_tiles * n_k_tiles; - size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); + size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); x4x2_dequantize_state_t state; state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; @@ -748,7 +869,11 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( state.n_k_tiles = n_k_tiles; state.n_k_tiles_div = n_k_tiles_div; - worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + dequant_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_threads); + } } // --- End x4x2 dequantizers --- @@ -876,11 +1001,11 @@ static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void } static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, - int n_rows, int n_cols, int n) { + int n_rows, int n_cols, int n, int n_threads) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) output_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -891,7 +1016,11 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, state.n_cols = n_cols; state.n = n; - worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + transfer_output_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, n_threads); + } } // activations : fp32 -> fp16 @@ -973,12 +1102,12 @@ static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, } } -static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { +static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) { assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); assert(VLEN == 32 * sizeof(float)); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address activation_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -989,7 +1118,11 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 * state.k_block = k_block; state.k_stride = k_stride; - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + transfer_activation_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_threads); + } } // C += AB @@ -1031,9 +1164,9 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co } } -int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, - int weight_type) { + int act_stride, int weight_stride, int weight_type) { if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { @@ -1052,6 +1185,8 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; default: return -1; } @@ -1059,21 +1194,25 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + // --- Dynamic Mode Configuration --- + const bool use_pipeline = (m > 32); + const int num_threads = (m <= 32) ? 1 : ctx->n_threads; + // --- Dynamic VTCM layout --- const size_t vec_dot_size = k * sizeof(__fp16); const size_t vtcm_budget = ctx->vtcm_size; size_t vtcm_used = 0; // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) + const size_t size_per_n = row_stride + (use_pipeline ? 2 * vec_dot_size : vec_dot_size); // Q + S0 + S1 (dequant bufs) + const size_t size_per_mn = (use_pipeline ? 2 : 1) * sizeof(__fp16); // O x 2 (output double buffer) size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n * 3, /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { - FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); + FARF(HIGH, "hmx-mm-2d: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); return -1; } @@ -1083,27 +1222,27 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * size_t scratch0_size, scratch1_size, scratch2_size; scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 + scratch1_size = use_pipeline ? scratch0_size : 0; // dequant buf 1 + scratch2_size = use_pipeline ? output_area_size : 0; // output buf 1 uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); + void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; if (vtcm_used > vtcm_budget) { - FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + FARF(ERROR, "hmx-mm-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); TIMER_DEFINE(activation_load); @@ -1114,115 +1253,137 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * TIMER_DEFINE(total); TIMER_START(total); - // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - - // A --> B: vtcm_qweight, 1 buffer - // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers - // C --> D: vtcm_output0/vtcm_output1, 2 buffers - - // Async timeline (C overlaps B+D): - // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] - // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + if (use_pipeline) { + // --- Asynchronous Pipelined Loop (Current implementation) --- + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); - } + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } - - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); - - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, weight_stride, row_stride, n_cols_A0); } - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + { + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); + } - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, weight_stride, row_stride, n_cols_A1); + } + + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } + } + + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, weight_stride, row_stride, n_cols_p2); + } + + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); + + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } + + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n, num_threads); + + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } } } + hmx_queue_suspend(ctx->hmx_queue); + } else { + // --- Synchronous Loop (Optimized for small/non-pipelined cases) --- + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + // Load Activation + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); - } + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's - // counterpart — and (i+1)%2 was last used by C_{i-1} which completed - // before C_i was submitted. - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); - } - - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); - - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { + // A: DMA Load Weight + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); + + // B: Dequantize / Convert Weight + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + // C: HMX Compute (Synchronous) + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); + + // D: Output Store + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, n, num_threads); } } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); } - hmx_queue_suspend(ctx->hmx_queue); - TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); + FARF(HIGH, "hex-mm-2d: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); if (!use_pipeline) { FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); @@ -1401,11 +1562,11 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 dma_queue_pop(ctx->dma[0]); transfer_activation_chunk_threaded(ctx, vtcm_act_g, vtcm_f32_act, (int) n_rows, - params->k, params->k); + params->k, params->k, ctx->n_threads); } else { transfer_activation_chunk_threaded(ctx, vtcm_act_g, activation_chunk, (int) n_rows, - params->k, params->act_stride); + params->k, params->act_stride, ctx->n_threads); } } TIMER_STOP(activation_load); @@ -1455,7 +1616,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 TIMER_START(output_store); { float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; - transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads); } TIMER_STOP(output_store); } @@ -1475,177 +1636,431 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); #endif - return 0; + return 0; } -// - int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const __fp16 *restrict permuted_weight, int m, int k, int n, int act_stride, int weight_stride) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - if (act_stride < k || weight_stride < k) { return -1; } + return hmx_matmul_2d_f32(ctx, dst, activation, (const uint8_t *)permuted_weight, m, k, n, + act_stride, weight_stride * (int)sizeof(__fp16), HTP_TYPE_F16); +} + +struct mmid_row_mapping { + uint32_t i1; + uint32_t i2; +}; + +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + int ne11; + struct fastdiv_values ne11_div; + size_t nb11; + size_t nb12; + int start_row; + int cne1; +} activation_transfer_gathered_task_state_t; + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int n_cols; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + size_t dst_nb1; + size_t dst_nb2; + int start_row; + int cne1; +} output_transfer_scattered_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_gathered( + __fp16 *restrict vtcm_dst, + const float *restrict src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + const struct fastdiv_values * ne11_div, + size_t nb11, + size_t nb12, + int cne1) { + const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; + + int r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + int r_idx0 = start_row + r + 0; + int r_idx1 = start_row + r + 1; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + + const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; + + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + } + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); + HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + __fp16 *dst = st->dst + (size_t)(start_row - st->start_row) * st->k_block; + transfer_activation_chunk_fp32_to_fp16_gathered( + dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1); + } +} + +static void transfer_activation_chunk_gathered_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + size_t nb11, + size_t nb12, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + activation_transfer_gathered_task_state_t state = { + .dst = dst, + .src = src, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .k_block = k_block, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .ne11 = ne11, + .ne11_div = init_fastdiv_values(ne11), + .nb11 = nb11, + .nb12 = nb12, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_activation_chunk_gathered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_gathered_worker_fn, &state, actual_threads); + } +} + +static void transfer_output_chunk_fp16_to_fp32_scattered( + float *restrict dst, + const __fp16 *restrict vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r0 = r / HMX_FP16_TILE_N_ROWS; + const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + + int r_idx0 = start_row + (int)r + 0; + int r_idx1 = start_row + (int)r + 1; + + if (r_idx0 >= cne1) break; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); + + float *output_row1 = NULL; + if (r_idx1 < cne1) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); + } + + #pragma unroll(4) + for (size_t c = 0; c < (size_t)n_cols; c += HMX_FP16_TILE_N_COLS) { + const size_t c0 = c / HMX_FP16_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row0 + c); + volatile HVX_Vector *pv_out1 = output_row1 ? (volatile HVX_Vector *) (output_row1 + c) : NULL; + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (pv_out1) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + } +} + +static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_scattered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + const __fp16 *src = st->vtcm_src + (size_t)(start_row - st->start_row) * st->n_cols; + transfer_output_chunk_fp16_to_fp32_scattered( + st->dst, src, start_row, n_rows, st->n_cols, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->dst_nb1, st->dst_nb2, st->cne1); + } +} + +static void transfer_output_chunk_scattered_threaded( + struct htp_context *ctx, + float *dst, + const __fp16 *vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + output_transfer_scattered_task_state_t state = { + .vtcm_src = vtcm_src, + .dst = dst, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .n_cols = n_cols, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .dst_nb1 = dst_nb1, + .dst_nb2 = dst_nb2, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_output_chunk_scattered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); + } +} + +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride) { + const int cne1 = m; + const int m_padded = hex_align_up(m, 32); + if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; - } - - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); - - // DMA-based activation gather for strided tensors (see batched path comment). - const bool use_dma_activation = (act_stride > k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, - /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, // W + S0 + S1 - /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch - /*per_mn=*/sizeof(__fp16), // O - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n, - /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); return -1; } - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + + const int num_threads = ctx->n_threads; + + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; + + const size_t size_per_n = row_stride + vec_dot_size; + const size_t size_per_mn = sizeof(__fp16); + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + m_padded, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m_padded * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - FARF(HIGH, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - - TIMER_DEFINE(activation_load); - TIMER_DEFINE(weight_load); - TIMER_DEFINE(hmx_core); - TIMER_DEFINE(output_store); - - TIMER_DEFINE(total); - TIMER_START(total); + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * act_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) k * sizeof(float); - const size_t stride_bytes = (size_t) act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_activation, - vtcm_f32_act, n_rows, k, k); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_activation, - activation_chunk, n_rows, k, act_stride); - } - } - TIMER_STOP(activation_load); + transfer_activation_chunk_gathered_threaded( + ctx, vtcm_activation, activation, (int) mr, (int) n_rows, k, + matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, num_threads); - const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); - - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - // issue async DMA for the first weight chunk - // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. - // The source rows can be strided (e.g. KV-cache K after ggml_permute). - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); + dma_queue_pop(ctx->dma[0]); - // issue async DMA for the next weight chunk (double buffering) - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); - // interleave row-major fp16 from scratch into tile-major in vtcm_weight - hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, k, k, 0, n_cols); - - hex_swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); - - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); - - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); + transfer_output_chunk_scattered_threaded( + ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, + matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, num_threads); } - } HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - - TIMER_STOP(total); - -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); - FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", - TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); - { - size_t weight_size = (size_t)k * n * sizeof(__fp16); - float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); - FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); - } -#endif - return 0; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.c b/ggml/src/ggml-hexagon/htp/hmx-ops.c new file mode 100644 index 000000000..114d8c148 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.c @@ -0,0 +1,6 @@ +// HMX operations compiled as a single translation unit. +// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO. + +#include "hmx-queue.c" +#include "hmx-matmul-ops.c" +#include "hmx-flash-attn-ops.c" diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index f114edb82..a67842f3f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -52,14 +52,32 @@ int hmx_matmul_f16_f32(struct htp_context *ctx, // Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); -// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4) -int hmx_matmul_q_f32(struct htp_context *ctx, +// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4) +int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const uint8_t *permuted_weight, int m, int k, int n, + int act_stride, + int weight_stride, int weight_type); +struct mmid_row_mapping; + +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride); + // HMX flash attention int hmx_flash_attn_ext(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 51f9243ce..0f1676f07 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -79,6 +79,10 @@ struct htp_context { uint64_t max_vmem; + // Persistent DDR scratchpad for MUL_MAT_ID mappings + void * ddr_spad_base; + size_t ddr_spad_size; + struct htp_ops_context octx; #ifdef HTP_HAS_HMX diff --git a/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h new file mode 100644 index 000000000..f1f2e49e4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h @@ -0,0 +1,47 @@ +#ifndef HVX_FLASH_ATTN_H +#define HVX_FLASH_ATTN_H + +#include +#include "hvx-utils.h" + +// Scalar helper to compute a single ALiBi slope. +static inline float alibi_slope(uint32_t h, uint32_t n_head_log2, float m0, float m1) { + return (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); +} + +// Vectorized helper to compute 32 ALiBi slopes starting from (kv_head * G). +static inline HVX_Vector hvx_alibi_slopes( + uint32_t kv_head, + uint32_t G, + uint32_t n_head_log2, + float m0, + float m1 +) { + static const float ramp_32[32] __attribute__((aligned(128))) = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f + }; + HVX_Vector v_ramp = hvx_vmem(ramp_32); + HVX_Vector v_h_base = hvx_vec_splat_f32((float)(kv_head * G)); + HVX_Vector v_h = hvx_vec_add_f32_f32(v_h_base, v_ramp); + + // Compute exponent_m0: h + 1 + HVX_Vector v_exp_m0 = hvx_vec_add_f32_f32(v_h, hvx_vec_splat_f32(1.0f)); + + // Compute exponent_m1: 2 * (h - n_head_log2) + 1 + HVX_Vector v_n_head_log2 = hvx_vec_splat_f32((float)n_head_log2); + HVX_Vector v_h_minus = hvx_vec_sub_f32_f32(v_h, v_n_head_log2); + HVX_Vector v_exp_m1 = hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(hvx_vec_splat_f32(2.0f), v_h_minus), hvx_vec_splat_f32(1.0f)); + + // Compute powers + HVX_Vector v_pow_m0 = hvx_vec_pow_const_base_f32(m0, v_exp_m0); + HVX_Vector v_pow_m1 = hvx_vec_pow_const_base_f32(m1, v_exp_m1); + + // Select based on h < n_head_log2 + HVX_VectorPred p_cond = Q6_Q_vcmp_gt_VsfVsf(v_n_head_log2, v_h); // v_n_head_log2 > v_h <=> h < n_head_log2 + return Q6_V_vmux_QVV(p_cond, v_pow_m0, v_pow_m1); +} + +#endif /* HVX_FLASH_ATTN_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-log.h b/ggml/src/ggml-hexagon/htp/hvx-log.h new file mode 100644 index 000000000..7013dae78 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-log.h @@ -0,0 +1,65 @@ +#ifndef HVX_LOG_H +#define HVX_LOG_H + +#include "hvx-base.h" + +// Approximates ln(x) element-wise for float vectors. +// x must contain positive float elements. +// Uses Abramowitz & Stegun polynomial approximation 4.1.44 for ln(1+y) over [0, 1]. +static inline HVX_Vector hvx_vec_log_f32(HVX_Vector x) { + // x = m * 2^e, where m in [1, 2) + HVX_Vector biased_e = Q6_Vuw_vlsr_VuwR(x, 23); + HVX_Vector e_int = Q6_Vw_vsub_VwVw(biased_e, Q6_V_vsplat_R(127)); + HVX_Vector e_float = Q6_Vsf_equals_Vw(e_int); + + // Extract mantissa and set exponent to 127 (which represents float value in [1.0, 2.0)) + HVX_Vector mant_mask = Q6_V_vsplat_R(0x007FFFFF); + HVX_Vector exp_127 = Q6_V_vsplat_R(0x3F800000); + HVX_Vector m = Q6_V_vor_VV(Q6_V_vand_VV(x, mant_mask), exp_127); + + // y = m - 1.0f, y in [0, 1) + HVX_Vector y = hvx_vec_sub_f32_f32(m, hvx_vec_splat_f32(1.0f)); + + // Abramowitz & Stegun 4.1.44 polynomial approximation of ln(1+y) + HVX_Vector c; + HVX_Vector res; + + c = hvx_vec_splat_f32(-0.0064535442f); + res = hvx_vec_mul_f32_f32(y, c); + + c = hvx_vec_splat_f32(0.0360884937f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.0953293897f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.1676540711f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.2407338084f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.3317990258f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.4998741238f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.9999964239f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + // ln(x) = e * ln(2) + ln(1+y) + HVX_Vector ln2 = hvx_vec_splat_f32(0.69314718056f); + HVX_Vector term_e = hvx_vec_mul_f32_f32(e_float, ln2); + + return hvx_vec_add_f32_f32(term_e, res); +} + +#endif /* HVX_LOG_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-pow.h b/ggml/src/ggml-hexagon/htp/hvx-pow.h new file mode 100644 index 000000000..48fe0e8ea --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-pow.h @@ -0,0 +1,42 @@ +#ifndef HVX_POW_H +#define HVX_POW_H + +#include +#include "hvx-base.h" +#include "hvx-exp.h" +#include "hvx-log.h" + +// Approximates base^exponent element-wise for float vectors. +// base must be a positive constant. exponent is an HVX f32 vector. +// Uses base^x = exp(x * ln(base)). +static inline HVX_Vector hvx_vec_pow_const_base_f32(float base, HVX_Vector exponent) { + float ln_base = logf(base); + HVX_Vector ln_base_v = hvx_vec_splat_f32(ln_base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base_v); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +// Approximates base^exponent element-wise for float vectors. +// base and exponent are HVX f32 vectors. base elements must be positive. +// Uses base^exponent = exp(exponent * ln(base)). +static inline HVX_Vector hvx_vec_pow_f32(HVX_Vector base, HVX_Vector exponent) { + HVX_Vector ln_base = hvx_vec_log_f32(base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +#endif /* HVX_POW_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 0a760cd34..23373f73a 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -17,5 +17,7 @@ #include "hvx-floor.h" #include "hvx-sin-cos.h" #include "hvx-base.h" +#include "hvx-pow.h" +#include "hvx-log.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 623008be4..3715227d2 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -63,8 +64,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.type = HAP_power_set_DCVS_v3; request.dcvs_v3.set_dcvs_enable = TRUE; - request.dcvs_v3.dcvs_enable = TRUE; - request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE; + request.dcvs_v3.dcvs_enable = FALSE; request.dcvs_v3.set_bus_params = TRUE; request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX; @@ -75,6 +75,10 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.set_sleep_disable = TRUE; request.dcvs_v3.sleep_disable = TRUE; + +#if (__HEXAGON_ARCH__ >= 79) + HAP_set_dcvs_v3_protected_bus_corners(&request, 1); +#endif if ((err = HAP_power_set((void *) ctx, &request)) != 0) { return err; } @@ -103,7 +107,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { FARF(ALWAYS, "Setting HMX clock\n"); err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error setting HMX clock."); + FARF(ERROR, "ggml-hex: error setting HMX clock."); return err; } } @@ -117,7 +121,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { FARF(ALWAYS, "Powering HMX on\n"); err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error powering on HMX."); + FARF(ERROR, "ggml-hex: error powering on HMX."); return err; } } @@ -423,10 +427,18 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->dma[i] = dma_queue_create(256); // queue depth } + ctx->ddr_spad_size = 512 * 1024; // 512 KB + ctx->ddr_spad_base = memalign(128, ctx->ddr_spad_size); + // init worker pool err = worker_pool_init(&ctx->worker_pool, n_hvx); if (err != AEE_SUCCESS) { FARF(ERROR, "Unable to create worker pool"); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } return err; } @@ -474,6 +486,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) { vtcm_free(ctx); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } + return AEE_SUCCESS; } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 7036c491b..5121c6f9b 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -53,6 +53,11 @@ struct htp_matmul_context { struct fastdiv_values mm_div_ne1; struct fastdiv_values mm_div_r2; struct fastdiv_values mm_div_r3; + + // Fields for scattered mapping & HMX support in MUL_MAT_ID + const uint32_t * matrix_row_counts; + const struct mmid_row_mapping * matrix_rows; + bool hmx_eligible; }; // vdelta control to expand first 32 e8m0 values into 32 uint32 elements @@ -2913,6 +2918,176 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_Vector * restrict x = (const HVX_Vector *) vx; + const HVX_Vector * restrict y = (const HVX_Vector *) vy; + + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements + + HVX_Vector rsum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + *s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum)); +} + +static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; + + HVX_Vector rsum0 = Q6_V_vzero(); + HVX_Vector rsum1 = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector y_sf = y[i]; + HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1); + HVX_VectorAlias va; + va.v = rsum; + s0[0] = va.fp32[0]; + s0[1] = va.fp32[1]; +} + +static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; + + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_sf = x0[i]; + HVX_Vector r1_sf = x1[i]; + HVX_Vector c0_sf = y0[i]; + HVX_Vector c1_sf = y1[i]; + + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + + HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]); + + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + HVX_VectorAlias va0, va1; + va0.v = r0_r1_c0_sum; + va1.v = r0_r1_c1_sum; + s0[0] = va0.fp32[0]; + s0[1] = va0.fp32[1]; + s1[0] = va1.fp32[0]; + s1[1] = va1.fp32[1]; +} + +static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; + + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements + + HVX_Vector rsum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); + } + + if (nloe) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + x_sf = Q6_V_vand_QV(bmask, x_sf); + y_sf = Q6_V_vand_QV(bmask, y_sf); + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); + } + + rsum = hvx_vec_reduce_sum_f32(rsum); + hvx_vec_store_u(&s[0], 4, rsum); +} + static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -3331,7 +3506,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const int is0 = (ir0 - src0_start_row); + const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3466,7 +3641,7 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3516,11 +3691,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { const uint32_t n_ids = ids->ne[0]; // n_expert_used const uint32_t n_as = ne02; // n_expert - const size_t matrix_row_counts_size = n_as * sizeof(uint32_t); - const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); - - const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0; - const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size; + const uint32_t * matrix_row_counts = mmctx->matrix_row_counts; + const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows; const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; @@ -3542,6 +3714,10 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { continue; } + if (mmctx->hmx_eligible) { + continue; + } + const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0); // Prefill spad with src0 rows @@ -3583,7 +3759,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3685,7 +3861,7 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -4086,6 +4262,47 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; @@ -4328,6 +4545,60 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + need_quant = false; + } + } else if (src0->type == HTP_TYPE_F32) { + // Try optimized f32-f32 path first (src1 in VTCM) + const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128); + const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256); + const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; + + const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size; + + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); + + if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = quantize_f32_f32; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; + + src1_row_size = f32_src1_row_size; + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + } else { + // Fallback to DDR / broadcasting + quant_job_func = NULL; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; + matmul_job_func = matmul_4d; + + src1_row_size = nb11; + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + need_quant = false; } } else { @@ -4405,20 +4676,20 @@ int op_matmul(struct htp_ops_context * octx) { return op_matmul_hvx(octx); } - // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. + // HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. // Other types fall back to HVX. uint32_t wtype = src0->type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { return op_matmul_hvx(octx); } // Quantised HMX path requires K aligned to 256 (x4x2 super-block). - // F16 HMX path requires K aligned to 32 (tile width). - if (wtype != HTP_TYPE_F16 && src0->ne[0] % 256 != 0) { + // F16 and F32 HMX paths require K aligned to 32 (tile width). + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) { return op_matmul_hvx(octx); } - if (wtype == HTP_TYPE_F16 && src0->ne[0] % 32 != 0) { + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) { return op_matmul_hvx(octx); } @@ -4463,8 +4734,8 @@ int op_matmul(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - if (src0->type == HTP_TYPE_F16) { - if (is_batched) { + if (is_batched) { + if (src0->type == HTP_TYPE_F16) { hmx_matmul_f16_f32_batched_params_t batch_params = { .dst = (float *) dst->data, .activation = (float *) src1->data, @@ -4488,13 +4759,11 @@ int op_matmul(struct htp_ops_context * octx) { }; ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); } else { - ret = hmx_matmul_f16_f32(octx->ctx, - (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, - m_total, k, n, act_stride, wgt_stride); + return op_matmul_hvx(octx); } } else { - ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, - m_total, k, n, (int) src0->type); + ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type); } if (ret != 0) { @@ -4539,8 +4808,30 @@ int op_matmul_id(struct htp_ops_context * octx) { size_t matrix_row_counts_size = n_as * sizeof(uint32_t); size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); + const size_t total_map_size = matrix_row_counts_size + matrix_row_map_size; + + void * mapping_buf = NULL; + bool must_free_mapping = false; + + if (octx->ctx->ddr_spad_base && total_map_size <= octx->ctx->ddr_spad_size) { + mapping_buf = octx->ctx->ddr_spad_base; + } else { + mapping_buf = memalign(128, total_map_size); + if (mapping_buf) { + must_free_mapping = true; + } else { + return HTP_STATUS_INTERNAL_ERR; + } + } + + uint32_t * matrix_row_counts = (uint32_t *) mapping_buf; + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) ((uint8_t *) mapping_buf + matrix_row_counts_size); + + mmctx->matrix_row_counts = matrix_row_counts; + mmctx->matrix_rows = matrix_rows; if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_NO_SUPPORT; } @@ -4552,7 +4843,7 @@ int op_matmul_id(struct htp_ops_context * octx) { src1_row_size = q8x4x2_row_size(ne10); } - const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR! htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; @@ -4568,6 +4859,7 @@ int op_matmul_id(struct htp_ops_context * octx) { // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -4587,9 +4879,6 @@ int op_matmul_id(struct htp_ops_context * octx) { if (src1_nrows > 1) { // initialize matrix_row_counts and map - uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0; - struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size; - memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); // group rows by src0 matrix @@ -4599,14 +4888,60 @@ int op_matmul_id(struct htp_ops_context * octx) { assert(i02 >= 0 && i02 < n_as); - MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 }; + matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 }; matrix_row_counts[i02] += 1; } } } - if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; + } + + bool hmx_eligible = false; +#ifdef HTP_HAS_HMX + if (octx->ctx->hmx_enabled && src1_nrows > 1) { + uint32_t wtype = src0->type; + if (ne01 % 32 == 0 && + (wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) { + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) { + hmx_eligible = true; + } else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) { + hmx_eligible = true; + } + } + } +#endif + + mmctx->hmx_eligible = hmx_eligible; + + if (hmx_eligible) { + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) continue; + + int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, + (const uint8_t *) src0->data + cur_a * nb02, + cne1, ne00, ne01, + ne11, + nb11, nb12, + nb1, nb2, + (int) src0->nb[1], (int) src0->type, + matrix_rows, cur_a, n_ids * ids->ne[1]); + if (ret != 0) { + FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_NO_SUPPORT; + } + } + + // HMX has overwritten VTCM, so force dynamic quantization cache to clear + octx->src1_spad.src = NULL; + + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; + } if (octx->src1_spad.src != src1) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); @@ -4618,5 +4953,6 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t n_matmul_jobs = octx->n_threads; worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/pad-ops.c b/ggml/src/ggml-hexagon/htp/pad-ops.c index 3abc3c2ea..aaa72b315 100644 --- a/ggml/src/ggml-hexagon/htp/pad-ops.c +++ b/ggml/src/ggml-hexagon/htp/pad-ops.c @@ -511,6 +511,8 @@ int op_pad(struct htp_ops_context * octx) { octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; octx->src0_spad.data = octx->ctx->vtcm_base; octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->dst_spad.src = NULL; } struct htp_pad_context pctx = { diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 770a66732..71fab2cdb 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -692,6 +692,11 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * const uint8_t * restrict data_src1 = uctx->data_src1; uint8_t * restrict data_dst = uctx->data_dst; + const struct htp_tensor * src1 = (htp_op == HTP_OP_RMS_NORM_MUL) ? octx->src[1] : NULL; + const uint32_t nb11 = src1 ? src1->nb[1] : 0; + const uint32_t nb12 = src1 ? src1->nb[2] : 0; + const uint32_t nb13 = src1 ? src1->nb[3] : 0; + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); @@ -738,10 +743,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * src0_row_size_aligned, nb01, src0_data_row_size, block_size); if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { - const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03); + const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb11, nb12, nb13); dma_queue_push(dma_queue, dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off), - uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, block_size); + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, block_size); } ir += block_size; @@ -823,10 +828,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { - const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); + const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb11, nb12, nb13); dma_queue_push(dma_queue, dma_make_ptr(src1_spad, data_src1 + src1_pref_off), - uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, pref_block_size); + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, pref_block_size); } } } @@ -977,6 +982,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; } + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->dst_spad.src = NULL; + FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);