diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 920829f6a..d550841a2 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1927,6 +1927,7 @@ struct ggml_hexagon_opbatch { size_t extra_tens = 0; auto fit_tensor = [&](const ggml_tensor *t) { + if (!t) return; if (!t_map.count(t)) { extra_tens++; @@ -2602,6 +2603,27 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); return false; } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } + if (ggml_nrows(src1) > 1024) { + return false; // no huge batches (for now) + } + break; + + case GGML_TYPE_F32: + if (src1->type != GGML_TYPE_F32) { + return false; + } + if (src0->nb[1] < src0->nb[0]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F32 src0 not supported\n"); + return false; + } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } if (ggml_nrows(src1) > 1024) { return false; // no huge batches (for now) } diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h index 14b232240..8a1228ccd 100644 --- a/ggml/src/ggml-hexagon/htp-opnode.h +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -56,7 +56,7 @@ struct htp_opnode { } std::vector get_inputs() const { - std::vector inputs; + std::vector inputs(GGML_MAX_SRC, nullptr); std::vector outputs; outputs.push_back(node); for (const auto * f : fused) { @@ -70,20 +70,38 @@ struct htp_opnode { return false; }; + int count = 0; auto add_input = [&](const ggml_tensor * t) { if (t && !contains(outputs, t) && !contains(inputs, t)) { - inputs.push_back(t); + if (count < (int)inputs.size()) { + inputs[count++] = t; + } else { + inputs.push_back(t); + } } }; - for (int i = 0; i < GGML_MAX_SRC && node->src[i]; i++) { - add_input(node->src[i]); - } - for (const auto * f : fused) { - for (int i = 0; i < GGML_MAX_SRC && f->src[i]; i++) { - add_input(f->src[i]); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (fused.empty()) { + inputs[i] = node->src[i]; + } else { + if (node->src[i]) { + add_input(node->src[i]); + } } } + for (const auto * f : fused) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (f->src[i]) { + add_input(f->src[i]); + } + } + } + + if (!fused.empty()) { + inputs.resize(count); + } + return inputs; } @@ -108,6 +126,9 @@ struct htp_opformat { char names[64 * GGML_MAX_SRC]; int format_tensor_dims(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } if (t->ne[2] == 1 && t->ne[3] == 1) { return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); } else { @@ -136,6 +157,9 @@ struct htp_opformat { } int format_tensor_strides(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } const char * c = ggml_is_contiguous(t) ? "" : "!"; if (t->ne[2] == 1 && t->ne[3] == 1) { @@ -170,11 +194,11 @@ struct htp_opformat { auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", ggml_type_name(inputs[0]->type)); + p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"); for (size_t i = 1; i < inputs.size(); i++) { p += sprintf(p, " x "); - p += sprintf(p, "%s", ggml_type_name(inputs[i]->type)); + p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"); } p += sprintf(p, " -> "); @@ -184,7 +208,7 @@ struct htp_opformat { } const char * tensor_buff_name(const struct ggml_tensor * t) { - if (t->buffer) { + if (t && t->buffer) { return ggml_backend_buffer_name(t->buffer); } return "NONE"; @@ -213,11 +237,11 @@ struct htp_opformat { auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", inputs[0]->name); + p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE"); for (size_t i = 1; i < inputs.size(); i++) { p += sprintf(p, " x "); - p += sprintf(p, "%s", inputs[i]->name); + p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE"); } p += sprintf(p, " -> "); diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index ff3fc0804..f4b44fe1a 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -19,6 +19,43 @@ add_library(${HTP_LIB} SHARED htp_iface_skel.c worker-pool.c hex-dma.c +) + +target_compile_definitions(${HTP_LIB} PRIVATE + $,HTP_DEBUG=1,NDEBUG=1> + $,FARF_HIGH=1,> + FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) + +if (GGML_HEXAGON_FA_EXP2_HF) + message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") + target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) +endif() + +# HMX acceleration: available on v73+ architectures +set(HTP_HMX_VERSIONS v73 v75 v79 v81) +list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) + +if (_hmx_idx GREATER_EQUAL 0) + target_sources(${HTP_LIB} PRIVATE + hmx-matmul-ops.c + hmx-flash-attn-ops.c + hmx-queue.c + ) + + # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) + set_source_files_properties( + hmx-flash-attn-ops.c + hmx-matmul-ops.c + hmx-queue.c + PROPERTIES COMPILE_OPTIONS "-mhmx" + ) + + target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1) +endif() + +build_idl(htp_iface.idl ${HTP_LIB}) + +target_sources(${HTP_LIB} PRIVATE matmul-ops.c binary-ops.c unary-ops.c @@ -42,40 +79,6 @@ add_library(${HTP_LIB} SHARED pad-ops.c ) -target_compile_definitions(${HTP_LIB} PRIVATE - $,HTP_DEBUG=1,NDEBUG=1> - $,FARF_HIGH=1,> - FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) - -if (GGML_HEXAGON_FA_EXP2_HF) - message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") - target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) -endif() - -# HMX acceleration: available on v73+ architectures -set(HTP_HMX_VERSIONS v73 v75 v79 v81) -list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) - -if (_hmx_idx GREATER_EQUAL 0) - target_sources(${HTP_LIB} PRIVATE - hmx-flash-attn-ops.c - hmx-matmul-ops.c - hmx-queue.c - ) - - # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) - set_source_files_properties( - hmx-flash-attn-ops.c - hmx-matmul-ops.c - hmx-queue.c - PROPERTIES COMPILE_OPTIONS "-mhmx" - ) - - target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1) -endif() - -build_idl(htp_iface.idl ${HTP_LIB}) - set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) install(TARGETS ${HTP_LIB}) diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index bdd062361..73af38a35 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -276,6 +276,7 @@ int op_argsort(struct htp_ops_context * octx) { octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.size = total_spad_size; octx->src0_spad.size_per_thread = spad_per_thread; + octx->src0_spad.src = NULL; FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3], diff --git a/ggml/src/ggml-hexagon/htp/concat-ops.c b/ggml/src/ggml-hexagon/htp/concat-ops.c index 61580f2c0..f2a381313 100644 --- a/ggml/src/ggml-hexagon/htp/concat-ops.c +++ b/ggml/src/ggml-hexagon/htp/concat-ops.c @@ -262,6 +262,8 @@ int op_concat(struct htp_ops_context * octx) { octx->src0_spad.data = octx->ctx->vtcm_base; octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; if (type_size == 4) { worker_func = concat_2d_f32_transposed; diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 1bd8c1407..e99621469 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -11,6 +11,7 @@ #include "hex-dma.h" #include "hvx-utils.h" #include "hvx-dump.h" +#include "hvx-flash-attn.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -245,6 +246,7 @@ struct htp_fa_context { uint32_t n_head_log2; float m0; float m1; + float slopes[512]; uint32_t n_blocks; @@ -412,7 +414,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * } const uint32_t h = iq2; // head index - const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; + const float slope = factx->slopes[h]; HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); @@ -628,8 +630,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { } #ifdef HTP_HAS_HMX - // HMX path: head_dim multiple of 32, F16 KV - if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0) { + // HMX path: head_dim multiple of 64, F16 KV, and no sinks + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) { int ret = hmx_flash_attn_ext(octx); if (ret == HTP_STATUS_OK) { return ret; @@ -689,6 +691,13 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + if (n_head > 512) { + return HTP_STATUS_NO_SUPPORT; + } + for (uint32_t h = 0; h < n_head; ++h) { + factx.slopes[h] = (max_bias > 0.0f) ? alibi_slope(h, factx.n_head_log2, factx.m0, factx.m1) : 1.0f; + } + // total rows in q const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c index c4d08bb21..3b092d744 100644 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -3,6 +3,7 @@ #include #include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -14,106 +15,103 @@ #define HTP_GDN_MAX_SV 128 + struct htp_gdn_context { struct htp_ops_context * octx; uint32_t rows_per_thread; - size_t state_bytes; - bool use_vtcm; - uint8_t * vtcm_state_base; - size_t vtcm_state_per_thread; + size_t state_bytes; + uint8_t * vtcm_base; + size_t vtcm_per_thread; }; -static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, - const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } -static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul, - const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_mul_scalar_dot_f32(float * restrict dst, float mul, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); const HVX_Vector vmul = hvx_vec_splat_f32(mul); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vd = hvx_vmemu(dst + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } -static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, - float scale, const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, + HVX_Vector vscale, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); - const HVX_Vector vscale = hvx_vec_splat_f32(scale); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1, @@ -126,7 +124,7 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -147,11 +145,11 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); @@ -159,10 +157,10 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -185,7 +183,7 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -205,10 +203,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); @@ -216,10 +214,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -246,7 +244,7 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -267,11 +265,11 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); @@ -279,10 +277,10 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -310,7 +308,7 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -343,11 +341,11 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); @@ -359,14 +357,14 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm); HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -400,7 +398,7 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -432,10 +430,10 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); @@ -447,14 +445,14 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul); HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -496,7 +494,7 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -529,11 +527,11 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); @@ -545,14 +543,14 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6)); HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7)); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -605,26 +603,65 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[4] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); + + dma_queue * dma = octx->ctx->dma[ith]; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); + + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); const uint64_t state_seq_stride = state->nb[2] / sizeof(float); const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; const int64_t shift = (int64_t) n_tokens - (int64_t) K; - for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; + uint32_t ir_prefetch = ith; + int spad_idx = 0; - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + int curr_spad_idx = 0; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { + dma_queue_pop(dma); + dma_queue_pop(dma); + + float * s_work_curr = s_work[curr_spad_idx]; + + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); + + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; - - memcpy(s_out, s_in, gctx->state_bytes); - float * s_work = s_out; float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; @@ -640,57 +677,117 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + (uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]); - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); if (kda) { hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); uint32_t j = 0; - for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } - gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); + } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } else { const float gate = expf(g_t[0]); uint32_t j = 0; - for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } - gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); + } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } @@ -698,17 +795,40 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const int64_t target_slot = (int64_t) t - shift; if (target_slot >= 0 && target_slot < (int64_t) K) { float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - if (curr_state_o != s_work) { - memcpy(curr_state_o, s_work, gctx->state_bytes); + if (curr_state_o != s_out) { + hvx_copy_f32_uu((uint8_t *) curr_state_o, (const uint8_t *) s_work_curr, S_v * S_v); } } } attn_data += (uint64_t) S_v * H; } + + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + curr_spad_idx ^= 1; } + dma_queue_flush(dma); } + static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; struct htp_ops_context * octx = gctx->octx; @@ -743,41 +863,64 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[8] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); dma_queue * dma = octx->ctx->dma[ith]; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); - uint8_t * spad = NULL; - if (gctx->use_vtcm) { - spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith; - } + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); const uint64_t state_seq_stride = state->nb[2] / sizeof(float); const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; - for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; + uint32_t ir_prefetch = ith; + int spad_idx = 0; - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + int curr_spad_idx = 0; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { + dma_queue_pop(dma); + dma_queue_pop(dma); + + float * s_work_curr = s_work[curr_spad_idx]; + + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); + + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; - float * s_work; - - if (spad) { - dma_queue_push(dma, dma_make_ptr(spad, s_in), - S_v * sizeof(float), S_v * sizeof(float), - S_v * sizeof(float), S_v); - dma_queue_pop(dma); - s_work = (float *) spad; - } else { - s_work = s_out; - memcpy(s_work, s_in, gctx->state_bytes); - } float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; @@ -792,111 +935,145 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + (uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]); - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); if (kda) { hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); uint32_t j = 0; for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } else { const float gate = expf(g_t[0]); uint32_t j = 0; for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } - if (spad) { - dma_queue_push(dma, dma_make_ptr(s_out, spad), + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), S_v * sizeof(float), S_v * sizeof(float), S_v * sizeof(float), S_v); - dma_queue_pop(dma); + + ir_prefetch += nth; + spad_idx ^= 1; } + + curr_spad_idx ^= 1; } + dma_queue_flush(dma); } + int op_gated_delta_net(struct htp_ops_context * octx) { const struct htp_tensor * q = octx->src[0]; const struct htp_tensor * k = octx->src[1]; @@ -952,18 +1129,11 @@ int op_gated_delta_net(struct htp_ops_context * octx) { size_t state_aligned = (size_t) S_v * S_v * sizeof(float); state_aligned = (state_aligned + 127) & ~(size_t)127; - gctx.use_vtcm = false; - gctx.vtcm_state_base = NULL; - gctx.vtcm_state_per_thread = 0; + assert(octx->ctx->vtcm_base != NULL); + assert(octx->ctx->vtcm_size >= 2 * state_aligned * octx->n_threads); - if (n_tokens == 1 && octx->ctx->vtcm_base) { - size_t vtcm_total = state_aligned * octx->n_threads; - if (octx->ctx->vtcm_size >= vtcm_total) { - gctx.use_vtcm = true; - gctx.vtcm_state_base = octx->ctx->vtcm_base; - gctx.vtcm_state_per_thread = state_aligned; - } - } + gctx.vtcm_base = octx->ctx->vtcm_base; + gctx.vtcm_per_thread = 2 * state_aligned; if (n_tokens == 1) { worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads); diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index f132c0850..2796564fb 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -17,14 +17,17 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "hex-dma.h" +#include "hex-fastdiv.h" #include "hmx-profile.h" #include "hmx-queue.h" #include "hmx-utils.h" #include "htp-ctx.h" #include "htp-ops.h" #include "hvx-dump.h" +#include "hvx-copy.h" #include "hvx-reduce.h" #include "hvx-utils.h" +#include "hvx-flash-attn.h" #include "vtcm-utils.h" #include "worker-pool.h" @@ -46,7 +49,7 @@ // g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. // Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales // Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. -static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) { +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) { const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong @@ -67,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, + k_dma_size * 2 // K DMA x2 + v_dma_size * 2 // V DMA x2 + k_tile_size * 1 // K tiles - + v_tile_size * 1 // V tiles + + v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + s_tile_size * 2 // S + P + d_tile_size * 1 // D (diagonal matrix) + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum @@ -144,12 +147,13 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, // See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off. const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64 const size_t fp16 = sizeof(__fp16); + const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); // Approximate per-unit VTCM costs (without per-buffer alignment padding). const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors const size_t per_gbr2 = fp16; // D diagonal matrix const size_t per_bc = - 3 * (DK + DV) * fp16 + 2 * n_threads * fp16; // K_dma×2 + V_dma×2 + K_tile + V_tile + row bufs + 3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs const size_t per_gbr_bc = 2 * fp16; // S + P const size_t overhead = 256 * 2 + 13 * 4096; @@ -164,7 +168,6 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, // Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS. // Only relax when kv_len is too short to form enough blocks. - const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); // Cost coefficients calibrated from profiling @@ -200,7 +203,7 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, } // Exact VTCM verification (alignment padding may push over budget) - while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) { + while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) { Bc -= bc_unit; } if (Bc < bc_unit) { @@ -303,6 +306,7 @@ struct hmx_fa_context { uint32_t n_kv_heads; // number of KV heads uint32_t n_heads; // number of Q heads uint32_t G; // GQA factor = n_heads / n_kv_heads + struct fastdiv_values div_G; uint32_t n_kv_blocks; uint32_t neq1; // Q token count @@ -321,7 +325,7 @@ struct hmx_fa_context { __fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D] __fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D] __fp16 * vtcm_k_tiles; // K tiles (transposed) - __fp16 * vtcm_v_tiles; // V tiles (column-major) + __fp16 * vtcm_v_tiles[2]; // V tiles (column-major, double-buffered) __fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc] __fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc] __fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br] @@ -402,7 +406,9 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) return; } - hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, + __fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; + + hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, (int) args->src_stride, (int) args->n_col_tiles, start, end); } @@ -464,10 +470,10 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) { for (size_t r = start; r < end; r += 2) { const bool next_row_valid = (r + 1) < n_rows_g; - const size_t q_idx0 = (r + 0) / G; - const size_t h_idx0 = (r + 0) % G; - const size_t q_idx1 = (r + 1) / G; - const size_t h_idx1 = (r + 1) % G; + const size_t q_idx0 = fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = fastmodulo(r + 0, G, &factx->div_G); + const size_t q_idx1 = fastdiv(r + 1, &factx->div_G); + const size_t h_idx1 = fastmodulo(r + 1, G, &factx->div_G); const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] + (kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3]; @@ -567,8 +573,8 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) { const uint32_t ib3 = args->ib3; for (size_t r = start; r < end; ++r) { - const size_t q_idx = r / G; - const size_t h_idx = r % G; + const size_t q_idx = fastdiv(r, &factx->div_G); + const size_t h_idx = fastmodulo(r, G, &factx->div_G); // FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) -> // [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2]. @@ -780,11 +786,11 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { if (args->mask_vtcm) { // Read mask from VTCM buffer (DMA'd per KV block). // GQA dedup (scheme B): skip load when qi unchanged. - const size_t qi0 = (r + 0) / G; + const size_t qi0 = fastdiv(r + 0, &factx->div_G); v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { - const size_t qi1 = (r + 1) / G; + const size_t qi1 = fastdiv(r + 1, &factx->div_G); if (qi1 == qi0) { v_mask1 = v_mask0; // scheme B: reuse — same mask row } else { @@ -794,8 +800,8 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { } else { // Fallback: read mask directly from DDR (when mask->ne[2] > 1). const struct htp_tensor * mask = args->mask; - const size_t q_idx0 = args->q_start + ((r + 0) / G); - const size_t h_idx0 = args->kv_head * G + (r + 0) % G; + const size_t q_idx0 = args->q_start + fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = args->kv_head * G + fastmodulo(r + 0, G, &factx->div_G); const uint32_t im2_0 = h_idx0 % mask->ne[2]; const uint32_t im3_0 = args->ib3 % mask->ne[3]; @@ -805,12 +811,12 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { - const size_t q_idx1 = args->q_start + ((r + 1) / G); + const size_t q_idx1 = args->q_start + fastdiv(r + 1, &factx->div_G); if (q_idx1 == q_idx0) { // scheme B: same mask row in DDR path v_mask1 = v_mask0; } else { - const size_t h_idx1 = args->kv_head * G + (r + 1) % G; + const size_t h_idx1 = args->kv_head * G + fastmodulo(r + 1, G, &factx->div_G); const uint32_t im2_1 = h_idx1 % mask->ne[2]; const uint32_t im3_1 = args->ib3 % mask->ne[3]; const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] + @@ -1191,14 +1197,13 @@ static void hmx_fa_o_norm_worker(void * data) { // Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G. // slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1). // When max_bias == 0, all slopes are 1.0 (no ALiBi). -static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs, +static __attribute__((noinline)) void fa_compute_slopes( const struct hmx_fa_context * factx, uint32_t kv_head, size_t n_rows_g) { + __fp16 * slopes = factx->vtcm_slopes; if (factx->max_bias == 0.0f) { - for (size_t r = 0; r < n_rows_g; ++r) { - sargs->slopes[r] = 1.0f; - } + hvx_splat_f16_a(slopes, 1.0f, n_rows_g); return; } @@ -1207,10 +1212,32 @@ static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sarg const float m0 = factx->m0; const float m1 = factx->m1; - for (size_t r = 0; r < n_rows_g; ++r) { - const uint32_t h = kv_head * G + r % G; - sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); + __fp16 temp_slopes[512] __attribute__((aligned(128))); + if (G <= 32) { + // Fast path: Compute G unique slope values in vector registers + HVX_Vector v_val = hvx_alibi_slopes(kv_head, G, n_head_log2, m0, m1); + + __fp16 temp_slopes_aligned[64] __attribute__((aligned(128))); + hvx_vmem(temp_slopes_aligned) = hvx_vec_f32_to_f16(v_val, Q6_V_vzero()); + + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = temp_slopes_aligned[i]; + } + } else { + // Fallback path: G > 32 (rare configurations) + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = (__fp16)alibi_slope(kv_head * G + i, n_head_log2, m0, m1); + } } + + // Allocate stack buffer to avoid scalar writes to VTCM (which generates L2 misses) + __fp16 local_slopes[n_rows_g] __attribute__((aligned(128))); + for (size_t r = 0; r < n_rows_g; ++r) { + local_slopes[r] = temp_slopes[fastmodulo(r, G, &factx->div_G)]; + } + + // Copy to VTCM slopes using HVX block copy (both are aligned to 128 bytes) + hvx_copy_f16_aa((uint8_t *)slopes, (const uint8_t *)local_slopes, n_rows_g); } // ============================================================================ @@ -1254,19 +1281,22 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const uint32_t G = neq2 / n_kv_heads; // Thread count for multi-thread HVX phases - const uint32_t n_threads = octx->n_threads; + const uint32_t n_threads_init = octx->n_threads; // Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs) size_t Br, Bc; const size_t vtcm_budget = ctx->vtcm_size; - if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) { + if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads_init) != 0) { return HTP_STATUS_VTCM_TOO_SMALL; } const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; - const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2); + const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); + + // Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1 + const uint32_t n_threads = use_pipeline ? n_threads_init : 1; FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); @@ -1282,6 +1312,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.n_kv_heads = n_kv_heads; factx.n_heads = neq2; factx.G = G; + factx.div_G = init_fastdiv_values(G); factx.neq1 = neq1; factx.Br = (uint32_t) Br; factx.Bc = (uint32_t) Bc; @@ -1354,7 +1385,12 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); - factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + if (use_pipeline) { + factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + } else { + factx.vtcm_v_tiles[1] = NULL; + } factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes); @@ -1457,6 +1493,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ---- KV block loop with DMA double-buffering ---- size_t buf_idx = 0; + fa_compute_slopes(&factx, kv_head, n_rows_g); + // Prefetch first KV block if (factx.n_kv_blocks > 0) { const uint32_t kv_rows0 = hex_smin(Bc, nek1); @@ -1535,7 +1573,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { ou_job.o_curr = o_tile_curr; ou_job.o_prev = o_tile_prev; ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; ou_job.d_tiles = factx.vtcm_d_tiles; ou_job.hmx_scales = factx.vtcm_hmx_scales_id; ou_job.n_row_tiles = n_row_tiles; @@ -1550,11 +1588,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); TIMER_STOP(k_interleave); - if (kv_blk > 0) { - hmx_queue_pop(hmx_q); - hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); - } - // ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ---- qk_job.q_tiles = factx.vtcm_q_tiles; qk_job.k_tiles = factx.vtcm_k_tiles; @@ -1574,6 +1607,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); TIMER_STOP(v_interleave); + // Pop and swap previous block's output update (deferred HMX pop) + if (kv_blk > 0) { + hmx_queue_pop(hmx_q); + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + // Pop current block's dot product job hmx_queue_pop(hmx_q); TIMER_STOP(qk_dot); @@ -1601,7 +1641,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); TIMER_START(softmax); fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); @@ -1617,7 +1656,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { ou_job.o_curr = o_tile_curr; ou_job.o_prev = o_tile_prev; ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; ou_job.d_tiles = factx.vtcm_d_tiles; ou_job.hmx_scales = factx.vtcm_hmx_scales_id; ou_job.n_row_tiles = n_row_tiles; @@ -1712,7 +1751,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); TIMER_START(softmax); fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); @@ -1732,7 +1770,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const size_t DV_tiles = (size_t) (DV / 32); const __fp16 * restrict d_base = factx.vtcm_d_tiles; const __fp16 * restrict p_base = factx.vtcm_p_tiles; - const __fp16 * restrict v_base = factx.vtcm_v_tiles; + const __fp16 * restrict v_base = factx.vtcm_v_tiles[0]; const __fp16 * restrict op_base = o_tile_prev; __fp16 * restrict oc_base = o_tile_curr; __builtin_assume(n_row_tiles > 0); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 083d12588..dab605210 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -73,6 +73,10 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb case HTP_TYPE_MXFP4: return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb + case HTP_TYPE_F16: + return (size_t) k * sizeof(__fp16); + case HTP_TYPE_F32: + return (size_t) k * sizeof(float); default: return 0; } @@ -545,7 +549,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4( int start_tile, int end_tile) { const int n_k_tiles = state->n_k_tiles; - const int qrow_size = state->k_block; + const int qrow_size = (unsigned)state->k_block / 2; const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); @@ -720,12 +724,129 @@ static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, voi } } +static void convert_f16_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); + HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + convert_f16_weight_to_fp16_tiles_task(state, start, end); + } +} + +static void quantize_f32_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(float); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0_f32 = hvx_vmemu((const float *)(r0 + byte_off)); + HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmemu((const float *)(r1 + byte_off)) : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out_hi); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + quantize_f32_weight_to_fp16_tiles_task(state, start, end); + } +} + + static void dequantize_x4x2_weight_chunk_to_fp16_tiles( struct htp_context *ctx, __fp16 *vtcm_dst, const void *vtcm_src, int n_cols, int k_block, size_t row_stride, int weight_type, int n_k_tiles, struct fastdiv_values n_k_tiles_div, - worker_callback_t dequant_worker_fn) { + worker_callback_t dequant_worker_fn, int n_threads) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); assert(k_block % HMX_FP16_TILE_N_COLS == 0); @@ -733,7 +854,7 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; size_t n_tot_tiles = n_col_tiles * n_k_tiles; - size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); + size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); x4x2_dequantize_state_t state; state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; @@ -748,7 +869,11 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( state.n_k_tiles = n_k_tiles; state.n_k_tiles_div = n_k_tiles_div; - worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + dequant_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_threads); + } } // --- End x4x2 dequantizers --- @@ -876,11 +1001,11 @@ static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void } static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, - int n_rows, int n_cols, int n) { + int n_rows, int n_cols, int n, int n_threads) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) output_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -891,7 +1016,11 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, state.n_cols = n_cols; state.n = n; - worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + transfer_output_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, n_threads); + } } // activations : fp32 -> fp16 @@ -973,12 +1102,12 @@ static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, } } -static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { +static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) { assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); assert(VLEN == 32 * sizeof(float)); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address activation_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -989,7 +1118,11 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 * state.k_block = k_block; state.k_stride = k_stride; - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + transfer_activation_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_threads); + } } // C += AB @@ -1031,9 +1164,9 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co } } -int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, - int weight_type) { + int act_stride, int weight_stride, int weight_type) { if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { @@ -1052,6 +1185,8 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; default: return -1; } @@ -1059,21 +1194,25 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + // --- Dynamic Mode Configuration --- + const bool use_pipeline = (m > 32); + const int num_threads = (m <= 32) ? 1 : ctx->n_threads; + // --- Dynamic VTCM layout --- const size_t vec_dot_size = k * sizeof(__fp16); const size_t vtcm_budget = ctx->vtcm_size; size_t vtcm_used = 0; // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) + const size_t size_per_n = row_stride + (use_pipeline ? 2 * vec_dot_size : vec_dot_size); // Q + S0 + S1 (dequant bufs) + const size_t size_per_mn = (use_pipeline ? 2 : 1) * sizeof(__fp16); // O x 2 (output double buffer) size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n * 3, /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { - FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); + FARF(HIGH, "hmx-mm-2d: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); return -1; } @@ -1083,27 +1222,27 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * size_t scratch0_size, scratch1_size, scratch2_size; scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 + scratch1_size = use_pipeline ? scratch0_size : 0; // dequant buf 1 + scratch2_size = use_pipeline ? output_area_size : 0; // output buf 1 uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); + void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; if (vtcm_used > vtcm_budget) { - FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + FARF(ERROR, "hmx-mm-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); TIMER_DEFINE(activation_load); @@ -1114,115 +1253,137 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * TIMER_DEFINE(total); TIMER_START(total); - // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - - // A --> B: vtcm_qweight, 1 buffer - // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers - // C --> D: vtcm_output0/vtcm_output1, 2 buffers - - // Async timeline (C overlaps B+D): - // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] - // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + if (use_pipeline) { + // --- Asynchronous Pipelined Loop (Current implementation) --- + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); - } + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } - - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); - - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, weight_stride, row_stride, n_cols_A0); } - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + { + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); + } - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, weight_stride, row_stride, n_cols_A1); + } + + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } + } + + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, weight_stride, row_stride, n_cols_p2); + } + + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); + + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } + + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n, num_threads); + + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } } } + hmx_queue_suspend(ctx->hmx_queue); + } else { + // --- Synchronous Loop (Optimized for small/non-pipelined cases) --- + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + // Load Activation + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); - } + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's - // counterpart — and (i+1)%2 was last used by C_{i-1} which completed - // before C_i was submitted. - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); - } - - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); - - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { + // A: DMA Load Weight + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); + + // B: Dequantize / Convert Weight + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + // C: HMX Compute (Synchronous) + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); + + // D: Output Store + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, n, num_threads); } } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); } - hmx_queue_suspend(ctx->hmx_queue); - TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); + FARF(HIGH, "hex-mm-2d: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); if (!use_pipeline) { FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); @@ -1401,11 +1562,11 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 dma_queue_pop(ctx->dma[0]); transfer_activation_chunk_threaded(ctx, vtcm_act_g, vtcm_f32_act, (int) n_rows, - params->k, params->k); + params->k, params->k, ctx->n_threads); } else { transfer_activation_chunk_threaded(ctx, vtcm_act_g, activation_chunk, (int) n_rows, - params->k, params->act_stride); + params->k, params->act_stride, ctx->n_threads); } } TIMER_STOP(activation_load); @@ -1455,7 +1616,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 TIMER_START(output_store); { float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; - transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads); } TIMER_STOP(output_store); } @@ -1475,177 +1636,431 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); #endif - return 0; + return 0; } -// - int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const __fp16 *restrict permuted_weight, int m, int k, int n, int act_stride, int weight_stride) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - if (act_stride < k || weight_stride < k) { return -1; } + return hmx_matmul_2d_f32(ctx, dst, activation, (const uint8_t *)permuted_weight, m, k, n, + act_stride, weight_stride * (int)sizeof(__fp16), HTP_TYPE_F16); +} + +struct mmid_row_mapping { + uint32_t i1; + uint32_t i2; +}; + +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + int ne11; + struct fastdiv_values ne11_div; + size_t nb11; + size_t nb12; + int start_row; + int cne1; +} activation_transfer_gathered_task_state_t; + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int n_cols; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + size_t dst_nb1; + size_t dst_nb2; + int start_row; + int cne1; +} output_transfer_scattered_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_gathered( + __fp16 *restrict vtcm_dst, + const float *restrict src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + const struct fastdiv_values * ne11_div, + size_t nb11, + size_t nb12, + int cne1) { + const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; + + int r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + int r_idx0 = start_row + r + 0; + int r_idx1 = start_row + r + 1; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + + const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; + + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + } + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); + HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + __fp16 *dst = st->dst + (size_t)(start_row - st->start_row) * st->k_block; + transfer_activation_chunk_fp32_to_fp16_gathered( + dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1); + } +} + +static void transfer_activation_chunk_gathered_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + size_t nb11, + size_t nb12, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + activation_transfer_gathered_task_state_t state = { + .dst = dst, + .src = src, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .k_block = k_block, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .ne11 = ne11, + .ne11_div = init_fastdiv_values(ne11), + .nb11 = nb11, + .nb12 = nb12, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_activation_chunk_gathered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_gathered_worker_fn, &state, actual_threads); + } +} + +static void transfer_output_chunk_fp16_to_fp32_scattered( + float *restrict dst, + const __fp16 *restrict vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r0 = r / HMX_FP16_TILE_N_ROWS; + const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + + int r_idx0 = start_row + (int)r + 0; + int r_idx1 = start_row + (int)r + 1; + + if (r_idx0 >= cne1) break; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); + + float *output_row1 = NULL; + if (r_idx1 < cne1) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); + } + + #pragma unroll(4) + for (size_t c = 0; c < (size_t)n_cols; c += HMX_FP16_TILE_N_COLS) { + const size_t c0 = c / HMX_FP16_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row0 + c); + volatile HVX_Vector *pv_out1 = output_row1 ? (volatile HVX_Vector *) (output_row1 + c) : NULL; + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (pv_out1) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + } +} + +static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_scattered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + const __fp16 *src = st->vtcm_src + (size_t)(start_row - st->start_row) * st->n_cols; + transfer_output_chunk_fp16_to_fp32_scattered( + st->dst, src, start_row, n_rows, st->n_cols, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->dst_nb1, st->dst_nb2, st->cne1); + } +} + +static void transfer_output_chunk_scattered_threaded( + struct htp_context *ctx, + float *dst, + const __fp16 *vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + output_transfer_scattered_task_state_t state = { + .vtcm_src = vtcm_src, + .dst = dst, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .n_cols = n_cols, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .dst_nb1 = dst_nb1, + .dst_nb2 = dst_nb2, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_output_chunk_scattered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); + } +} + +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride) { + const int cne1 = m; + const int m_padded = hex_align_up(m, 32); + if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; - } - - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); - - // DMA-based activation gather for strided tensors (see batched path comment). - const bool use_dma_activation = (act_stride > k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, - /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, // W + S0 + S1 - /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch - /*per_mn=*/sizeof(__fp16), // O - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n, - /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); return -1; } - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + + const int num_threads = ctx->n_threads; + + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; + + const size_t size_per_n = row_stride + vec_dot_size; + const size_t size_per_mn = sizeof(__fp16); + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + m_padded, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m_padded * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - FARF(HIGH, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - - TIMER_DEFINE(activation_load); - TIMER_DEFINE(weight_load); - TIMER_DEFINE(hmx_core); - TIMER_DEFINE(output_store); - - TIMER_DEFINE(total); - TIMER_START(total); + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * act_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) k * sizeof(float); - const size_t stride_bytes = (size_t) act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_activation, - vtcm_f32_act, n_rows, k, k); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_activation, - activation_chunk, n_rows, k, act_stride); - } - } - TIMER_STOP(activation_load); + transfer_activation_chunk_gathered_threaded( + ctx, vtcm_activation, activation, (int) mr, (int) n_rows, k, + matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, num_threads); - const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); - - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - // issue async DMA for the first weight chunk - // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. - // The source rows can be strided (e.g. KV-cache K after ggml_permute). - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); + dma_queue_pop(ctx->dma[0]); - // issue async DMA for the next weight chunk (double buffering) - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); - // interleave row-major fp16 from scratch into tile-major in vtcm_weight - hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, k, k, 0, n_cols); - - hex_swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); - - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); - - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); + transfer_output_chunk_scattered_threaded( + ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, + matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, num_threads); } - } HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - - TIMER_STOP(total); - -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); - FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", - TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); - { - size_t weight_size = (size_t)k * n * sizeof(__fp16); - float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); - FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); - } -#endif - return 0; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.c b/ggml/src/ggml-hexagon/htp/hmx-ops.c new file mode 100644 index 000000000..114d8c148 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.c @@ -0,0 +1,6 @@ +// HMX operations compiled as a single translation unit. +// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO. + +#include "hmx-queue.c" +#include "hmx-matmul-ops.c" +#include "hmx-flash-attn-ops.c" diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index f114edb82..a67842f3f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -52,14 +52,32 @@ int hmx_matmul_f16_f32(struct htp_context *ctx, // Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); -// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4) -int hmx_matmul_q_f32(struct htp_context *ctx, +// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4) +int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const uint8_t *permuted_weight, int m, int k, int n, + int act_stride, + int weight_stride, int weight_type); +struct mmid_row_mapping; + +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride); + // HMX flash attention int hmx_flash_attn_ext(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 51f9243ce..0f1676f07 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -79,6 +79,10 @@ struct htp_context { uint64_t max_vmem; + // Persistent DDR scratchpad for MUL_MAT_ID mappings + void * ddr_spad_base; + size_t ddr_spad_size; + struct htp_ops_context octx; #ifdef HTP_HAS_HMX diff --git a/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h new file mode 100644 index 000000000..f1f2e49e4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h @@ -0,0 +1,47 @@ +#ifndef HVX_FLASH_ATTN_H +#define HVX_FLASH_ATTN_H + +#include +#include "hvx-utils.h" + +// Scalar helper to compute a single ALiBi slope. +static inline float alibi_slope(uint32_t h, uint32_t n_head_log2, float m0, float m1) { + return (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); +} + +// Vectorized helper to compute 32 ALiBi slopes starting from (kv_head * G). +static inline HVX_Vector hvx_alibi_slopes( + uint32_t kv_head, + uint32_t G, + uint32_t n_head_log2, + float m0, + float m1 +) { + static const float ramp_32[32] __attribute__((aligned(128))) = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f + }; + HVX_Vector v_ramp = hvx_vmem(ramp_32); + HVX_Vector v_h_base = hvx_vec_splat_f32((float)(kv_head * G)); + HVX_Vector v_h = hvx_vec_add_f32_f32(v_h_base, v_ramp); + + // Compute exponent_m0: h + 1 + HVX_Vector v_exp_m0 = hvx_vec_add_f32_f32(v_h, hvx_vec_splat_f32(1.0f)); + + // Compute exponent_m1: 2 * (h - n_head_log2) + 1 + HVX_Vector v_n_head_log2 = hvx_vec_splat_f32((float)n_head_log2); + HVX_Vector v_h_minus = hvx_vec_sub_f32_f32(v_h, v_n_head_log2); + HVX_Vector v_exp_m1 = hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(hvx_vec_splat_f32(2.0f), v_h_minus), hvx_vec_splat_f32(1.0f)); + + // Compute powers + HVX_Vector v_pow_m0 = hvx_vec_pow_const_base_f32(m0, v_exp_m0); + HVX_Vector v_pow_m1 = hvx_vec_pow_const_base_f32(m1, v_exp_m1); + + // Select based on h < n_head_log2 + HVX_VectorPred p_cond = Q6_Q_vcmp_gt_VsfVsf(v_n_head_log2, v_h); // v_n_head_log2 > v_h <=> h < n_head_log2 + return Q6_V_vmux_QVV(p_cond, v_pow_m0, v_pow_m1); +} + +#endif /* HVX_FLASH_ATTN_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-log.h b/ggml/src/ggml-hexagon/htp/hvx-log.h new file mode 100644 index 000000000..7013dae78 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-log.h @@ -0,0 +1,65 @@ +#ifndef HVX_LOG_H +#define HVX_LOG_H + +#include "hvx-base.h" + +// Approximates ln(x) element-wise for float vectors. +// x must contain positive float elements. +// Uses Abramowitz & Stegun polynomial approximation 4.1.44 for ln(1+y) over [0, 1]. +static inline HVX_Vector hvx_vec_log_f32(HVX_Vector x) { + // x = m * 2^e, where m in [1, 2) + HVX_Vector biased_e = Q6_Vuw_vlsr_VuwR(x, 23); + HVX_Vector e_int = Q6_Vw_vsub_VwVw(biased_e, Q6_V_vsplat_R(127)); + HVX_Vector e_float = Q6_Vsf_equals_Vw(e_int); + + // Extract mantissa and set exponent to 127 (which represents float value in [1.0, 2.0)) + HVX_Vector mant_mask = Q6_V_vsplat_R(0x007FFFFF); + HVX_Vector exp_127 = Q6_V_vsplat_R(0x3F800000); + HVX_Vector m = Q6_V_vor_VV(Q6_V_vand_VV(x, mant_mask), exp_127); + + // y = m - 1.0f, y in [0, 1) + HVX_Vector y = hvx_vec_sub_f32_f32(m, hvx_vec_splat_f32(1.0f)); + + // Abramowitz & Stegun 4.1.44 polynomial approximation of ln(1+y) + HVX_Vector c; + HVX_Vector res; + + c = hvx_vec_splat_f32(-0.0064535442f); + res = hvx_vec_mul_f32_f32(y, c); + + c = hvx_vec_splat_f32(0.0360884937f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.0953293897f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.1676540711f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.2407338084f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.3317990258f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.4998741238f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.9999964239f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + // ln(x) = e * ln(2) + ln(1+y) + HVX_Vector ln2 = hvx_vec_splat_f32(0.69314718056f); + HVX_Vector term_e = hvx_vec_mul_f32_f32(e_float, ln2); + + return hvx_vec_add_f32_f32(term_e, res); +} + +#endif /* HVX_LOG_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-pow.h b/ggml/src/ggml-hexagon/htp/hvx-pow.h new file mode 100644 index 000000000..48fe0e8ea --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-pow.h @@ -0,0 +1,42 @@ +#ifndef HVX_POW_H +#define HVX_POW_H + +#include +#include "hvx-base.h" +#include "hvx-exp.h" +#include "hvx-log.h" + +// Approximates base^exponent element-wise for float vectors. +// base must be a positive constant. exponent is an HVX f32 vector. +// Uses base^x = exp(x * ln(base)). +static inline HVX_Vector hvx_vec_pow_const_base_f32(float base, HVX_Vector exponent) { + float ln_base = logf(base); + HVX_Vector ln_base_v = hvx_vec_splat_f32(ln_base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base_v); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +// Approximates base^exponent element-wise for float vectors. +// base and exponent are HVX f32 vectors. base elements must be positive. +// Uses base^exponent = exp(exponent * ln(base)). +static inline HVX_Vector hvx_vec_pow_f32(HVX_Vector base, HVX_Vector exponent) { + HVX_Vector ln_base = hvx_vec_log_f32(base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +#endif /* HVX_POW_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 0a760cd34..23373f73a 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -17,5 +17,7 @@ #include "hvx-floor.h" #include "hvx-sin-cos.h" #include "hvx-base.h" +#include "hvx-pow.h" +#include "hvx-log.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 623008be4..3715227d2 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -63,8 +64,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.type = HAP_power_set_DCVS_v3; request.dcvs_v3.set_dcvs_enable = TRUE; - request.dcvs_v3.dcvs_enable = TRUE; - request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE; + request.dcvs_v3.dcvs_enable = FALSE; request.dcvs_v3.set_bus_params = TRUE; request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX; @@ -75,6 +75,10 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.set_sleep_disable = TRUE; request.dcvs_v3.sleep_disable = TRUE; + +#if (__HEXAGON_ARCH__ >= 79) + HAP_set_dcvs_v3_protected_bus_corners(&request, 1); +#endif if ((err = HAP_power_set((void *) ctx, &request)) != 0) { return err; } @@ -103,7 +107,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { FARF(ALWAYS, "Setting HMX clock\n"); err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error setting HMX clock."); + FARF(ERROR, "ggml-hex: error setting HMX clock."); return err; } } @@ -117,7 +121,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { FARF(ALWAYS, "Powering HMX on\n"); err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error powering on HMX."); + FARF(ERROR, "ggml-hex: error powering on HMX."); return err; } } @@ -423,10 +427,18 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->dma[i] = dma_queue_create(256); // queue depth } + ctx->ddr_spad_size = 512 * 1024; // 512 KB + ctx->ddr_spad_base = memalign(128, ctx->ddr_spad_size); + // init worker pool err = worker_pool_init(&ctx->worker_pool, n_hvx); if (err != AEE_SUCCESS) { FARF(ERROR, "Unable to create worker pool"); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } return err; } @@ -474,6 +486,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) { vtcm_free(ctx); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } + return AEE_SUCCESS; } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 7036c491b..5121c6f9b 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -53,6 +53,11 @@ struct htp_matmul_context { struct fastdiv_values mm_div_ne1; struct fastdiv_values mm_div_r2; struct fastdiv_values mm_div_r3; + + // Fields for scattered mapping & HMX support in MUL_MAT_ID + const uint32_t * matrix_row_counts; + const struct mmid_row_mapping * matrix_rows; + bool hmx_eligible; }; // vdelta control to expand first 32 e8m0 values into 32 uint32 elements @@ -2913,6 +2918,176 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_Vector * restrict x = (const HVX_Vector *) vx; + const HVX_Vector * restrict y = (const HVX_Vector *) vy; + + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements + + HVX_Vector rsum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + *s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum)); +} + +static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; + + HVX_Vector rsum0 = Q6_V_vzero(); + HVX_Vector rsum1 = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector y_sf = y[i]; + HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1); + HVX_VectorAlias va; + va.v = rsum; + s0[0] = va.fp32[0]; + s0[1] = va.fp32[1]; +} + +static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; + + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_sf = x0[i]; + HVX_Vector r1_sf = x1[i]; + HVX_Vector c0_sf = y0[i]; + HVX_Vector c1_sf = y1[i]; + + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + + HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]); + + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + HVX_VectorAlias va0, va1; + va0.v = r0_r1_c0_sum; + va1.v = r0_r1_c1_sum; + s0[0] = va0.fp32[0]; + s0[1] = va0.fp32[1]; + s1[0] = va1.fp32[0]; + s1[1] = va1.fp32[1]; +} + +static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; + + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements + + HVX_Vector rsum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); + } + + if (nloe) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + x_sf = Q6_V_vand_QV(bmask, x_sf); + y_sf = Q6_V_vand_QV(bmask, y_sf); + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); + } + + rsum = hvx_vec_reduce_sum_f32(rsum); + hvx_vec_store_u(&s[0], 4, rsum); +} + static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -3331,7 +3506,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const int is0 = (ir0 - src0_start_row); + const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3466,7 +3641,7 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3516,11 +3691,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { const uint32_t n_ids = ids->ne[0]; // n_expert_used const uint32_t n_as = ne02; // n_expert - const size_t matrix_row_counts_size = n_as * sizeof(uint32_t); - const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); - - const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0; - const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size; + const uint32_t * matrix_row_counts = mmctx->matrix_row_counts; + const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows; const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; @@ -3542,6 +3714,10 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { continue; } + if (mmctx->hmx_eligible) { + continue; + } + const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0); // Prefill spad with src0 rows @@ -3583,7 +3759,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3685,7 +3861,7 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -4086,6 +4262,47 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; @@ -4328,6 +4545,60 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + need_quant = false; + } + } else if (src0->type == HTP_TYPE_F32) { + // Try optimized f32-f32 path first (src1 in VTCM) + const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128); + const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256); + const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; + + const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size; + + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); + + if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = quantize_f32_f32; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; + + src1_row_size = f32_src1_row_size; + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + } else { + // Fallback to DDR / broadcasting + quant_job_func = NULL; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; + matmul_job_func = matmul_4d; + + src1_row_size = nb11; + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + need_quant = false; } } else { @@ -4405,20 +4676,20 @@ int op_matmul(struct htp_ops_context * octx) { return op_matmul_hvx(octx); } - // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. + // HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. // Other types fall back to HVX. uint32_t wtype = src0->type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { return op_matmul_hvx(octx); } // Quantised HMX path requires K aligned to 256 (x4x2 super-block). - // F16 HMX path requires K aligned to 32 (tile width). - if (wtype != HTP_TYPE_F16 && src0->ne[0] % 256 != 0) { + // F16 and F32 HMX paths require K aligned to 32 (tile width). + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) { return op_matmul_hvx(octx); } - if (wtype == HTP_TYPE_F16 && src0->ne[0] % 32 != 0) { + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) { return op_matmul_hvx(octx); } @@ -4463,8 +4734,8 @@ int op_matmul(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - if (src0->type == HTP_TYPE_F16) { - if (is_batched) { + if (is_batched) { + if (src0->type == HTP_TYPE_F16) { hmx_matmul_f16_f32_batched_params_t batch_params = { .dst = (float *) dst->data, .activation = (float *) src1->data, @@ -4488,13 +4759,11 @@ int op_matmul(struct htp_ops_context * octx) { }; ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); } else { - ret = hmx_matmul_f16_f32(octx->ctx, - (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, - m_total, k, n, act_stride, wgt_stride); + return op_matmul_hvx(octx); } } else { - ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, - m_total, k, n, (int) src0->type); + ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type); } if (ret != 0) { @@ -4539,8 +4808,30 @@ int op_matmul_id(struct htp_ops_context * octx) { size_t matrix_row_counts_size = n_as * sizeof(uint32_t); size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); + const size_t total_map_size = matrix_row_counts_size + matrix_row_map_size; + + void * mapping_buf = NULL; + bool must_free_mapping = false; + + if (octx->ctx->ddr_spad_base && total_map_size <= octx->ctx->ddr_spad_size) { + mapping_buf = octx->ctx->ddr_spad_base; + } else { + mapping_buf = memalign(128, total_map_size); + if (mapping_buf) { + must_free_mapping = true; + } else { + return HTP_STATUS_INTERNAL_ERR; + } + } + + uint32_t * matrix_row_counts = (uint32_t *) mapping_buf; + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) ((uint8_t *) mapping_buf + matrix_row_counts_size); + + mmctx->matrix_row_counts = matrix_row_counts; + mmctx->matrix_rows = matrix_rows; if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_NO_SUPPORT; } @@ -4552,7 +4843,7 @@ int op_matmul_id(struct htp_ops_context * octx) { src1_row_size = q8x4x2_row_size(ne10); } - const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR! htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; @@ -4568,6 +4859,7 @@ int op_matmul_id(struct htp_ops_context * octx) { // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -4587,9 +4879,6 @@ int op_matmul_id(struct htp_ops_context * octx) { if (src1_nrows > 1) { // initialize matrix_row_counts and map - uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0; - struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size; - memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); // group rows by src0 matrix @@ -4599,14 +4888,60 @@ int op_matmul_id(struct htp_ops_context * octx) { assert(i02 >= 0 && i02 < n_as); - MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 }; + matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 }; matrix_row_counts[i02] += 1; } } } - if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; + } + + bool hmx_eligible = false; +#ifdef HTP_HAS_HMX + if (octx->ctx->hmx_enabled && src1_nrows > 1) { + uint32_t wtype = src0->type; + if (ne01 % 32 == 0 && + (wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) { + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) { + hmx_eligible = true; + } else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) { + hmx_eligible = true; + } + } + } +#endif + + mmctx->hmx_eligible = hmx_eligible; + + if (hmx_eligible) { + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) continue; + + int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, + (const uint8_t *) src0->data + cur_a * nb02, + cne1, ne00, ne01, + ne11, + nb11, nb12, + nb1, nb2, + (int) src0->nb[1], (int) src0->type, + matrix_rows, cur_a, n_ids * ids->ne[1]); + if (ret != 0) { + FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_NO_SUPPORT; + } + } + + // HMX has overwritten VTCM, so force dynamic quantization cache to clear + octx->src1_spad.src = NULL; + + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; + } if (octx->src1_spad.src != src1) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); @@ -4618,5 +4953,6 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t n_matmul_jobs = octx->n_threads; worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/pad-ops.c b/ggml/src/ggml-hexagon/htp/pad-ops.c index 3abc3c2ea..aaa72b315 100644 --- a/ggml/src/ggml-hexagon/htp/pad-ops.c +++ b/ggml/src/ggml-hexagon/htp/pad-ops.c @@ -511,6 +511,8 @@ int op_pad(struct htp_ops_context * octx) { octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; octx->src0_spad.data = octx->ctx->vtcm_base; octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->dst_spad.src = NULL; } struct htp_pad_context pctx = { diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 770a66732..71fab2cdb 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -692,6 +692,11 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * const uint8_t * restrict data_src1 = uctx->data_src1; uint8_t * restrict data_dst = uctx->data_dst; + const struct htp_tensor * src1 = (htp_op == HTP_OP_RMS_NORM_MUL) ? octx->src[1] : NULL; + const uint32_t nb11 = src1 ? src1->nb[1] : 0; + const uint32_t nb12 = src1 ? src1->nb[2] : 0; + const uint32_t nb13 = src1 ? src1->nb[3] : 0; + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); @@ -738,10 +743,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * src0_row_size_aligned, nb01, src0_data_row_size, block_size); if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { - const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03); + const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb11, nb12, nb13); dma_queue_push(dma_queue, dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off), - uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, block_size); + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, block_size); } ir += block_size; @@ -823,10 +828,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { - const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); + const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb11, nb12, nb13); dma_queue_push(dma_queue, dma_make_ptr(src1_spad, data_src1 + src1_pref_off), - uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, pref_block_size); + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, pref_block_size); } } } @@ -977,6 +982,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; } + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->dst_spad.src = NULL; + FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);