From 4861a3eeb5cb86df2de29c38c488e44d8dc9f6ca Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Fri, 1 May 2026 20:29:13 -0700 Subject: [PATCH] hexagon: hmx flash attention (llama/22347) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * hmx: extract shared interleave headers and unify matmul batched * hmx: add HMX-accelerated flash attention for prefill * hmx: replace asm wrappers with Q6_ intrinsics in hmx-utils.h Switches three single-instruction helpers from inline asm to the matching Q6_ intrinsics, matching the style established by aizip f8737609a and used by the upstream PR #21554 hmx-matmul-ops.c rewrite: hmx_set_output_scales asm "bias=mxmem2" -> Q6_bias_mxmem2_A hmx_load_tile_pair_fp16 asm packet -> Q6_activation_hf_mxmem_RR + Q6_weight_hf_mxmem_RR hmx_consume_accumulator_fp16 asm "mxmem=acc" -> Q6_mxmem_AR_after_hf hmx_load_tiles_fp16 stays on inline asm: it uses ":deep" activation streaming, and the mixed Q6_activation_hf_mxmem_RR_deep + non-deep Q6_weight_hf_mxmem_RR pair fails the HMX backend constraint check ("activate weight pair (1) exceeds limit (1)"). The asm bundle keeps both halves in one VLIW packet and avoids the diagnostic. Functionally equivalent — same instructions emitted; the Q6_ intrinsics just give the compiler more visibility for scheduling. * hmx: drop the duplicate interleave_fp16_weight_chunk_to_tiles * hmx: apply upstream optimization to hmx-flash-attn-ops.c apply restrict, __builtin_assume, and pointer accumulation to the three HMX workers (qk_dot, o_update, o_norm) and the matching inline HMX loops in op_hmx_flash_attn_ext. * hmx: unify interleave helper * hmx: multi-thread Q load / O store and enable prefill FA dispatch Extract inline Q-load and O-store loops into worker_pool-parallel helpers (fa_phase_q_load, fa_phase_o_store) so HVX threads split the F32↔F16 conversion work across row ranges. Also relax the softmax threading gate from n_row_vec_cnt >= n_threads to >= 2, which was unnecessarily forcing single-thread fallback when n_rows_g < 512. On the dispatch side, remove the ne[2] != 1 guard that blocked multi-head (prefill) FA from reaching the HTP backend — GQA is already handled internally by both the HMX and HVX flash-attention paths. * hmx: relax matmul pipeline gate to cover k > n shapes (e.g. FFN_down) * hmx: optimize FA softmax mask phase (no-ALiBi fast path + GQA dedup) * hmx: Add an asm memory clobber at the phase boundary to prevent reorder bug * [experimental]: fp16 softmax (EXP2_HF) to accelerate fa Bake log2(e) into qk_scale and use hvx_exp2_hf directly for P and m_diff (base-2 consistent, matches htp-ops-lib). ~22 ALU ops for 64 lanes vs ~44 for the F32 round-trip path. * hmx flash-attn: refine cost model coefficients based on profiling data * hmx flash-attn: replace asm clobber with targeted volatile reads on vtcm_d_tiles * hmx flash-attn: fix prefill correctness (dst indexing, softmax reduce, V stride) * hmx flash-attn: fix p_tiles dual-tile OOB race; enable MT + pipeline * hmx flash-attn: preserve additive mask bias in no-ALiBi fast path The no-ALiBi fast path (max_bias==0) was skipping mask add entirely on the assumption that mask values are only {0, -inf}. This is wrong when the mask carries additive positional bias — those terms were silently dropped. Keep the slope-mul skip (slope≡1.0) but add mask back so the bias survives; vmux still clamps below -16 to -inf. Also add HMX FA coverage to test-backend-ops: prefill shapes (nb=64, nb=32) × {mask on/off} × {ALiBi on/off} × {softcap on/off}, F16 KV, hs ∈ {64, 128}. * hmx: fix softcap+EXP2_HF interaction, tighten matmul pipeline gate, add FA tests - flash-attn: when EXP2_HF is on AND logit_softcap is active, fold log2(e) into the post-tanh multiplier (v_cap) instead of pre-baking it into qk_scale. Pre-baking shifted the tanh knee from x≈c to x≈c/log2(e) and produced numerically wrong softcapped outputs whenever both knobs were enabled. - flash-attn softmax (fa_softmax_thread): replace the union+memcpy scalar extract pattern with HVX vmux-based per-row accumulators on rowmax/rowsum. Add hvx_vec_get_f16 helper in hvx-base.h. Functional parity, less scalar code, clearer hf/qf16 lane-format contract. - matmul (hmx_mat_mul_permuted_qk_0_d16a32): pick pipeline vs sequential layout based on whether the chunker actually yields >=2 n-chunks, instead of the static (m>=128 && n>=256) gate. Avoids paying for output double-buffer + worker dispatch when there is no HMX/HVX overlap to gain (e.g. shapes that collapse to one n-chunk). - tests: add HMX flash-attention coverage over the {mask, ALiBi (max_bias), logit_softcap} cross-product for the prefill path — head_dim 64/128, GQA 4×4, kv=512/nb=64 plus a kv=113/nb=32 non-aligned case. * [Help Wanted]: refactor D matrix computation into separate function for clarity and maintainability * format code * hexagon: looks like -O3 is causing issues with the large code base, switch to -O2 and -flto instead * hexagon: use hex_ prefix for swap_ptr * hexagon: move vtcm_seq_alloc into vtcm-utils.h More vtcm allocator updates are coming so it makes sense to start the separate hdr for it. * hmx-utils: add hmx_prefix for layout converters * hmx-mm: move main hmx_mm functions to the end, remove unused fwd decls, etc * hmx-mm: remove unused qweight_fetch_task_state_t and minor alignment fixes * hmx-fa: minor alignment fixes * hmx-fa: move hmx_flash_atten into hmx-ops.h * hmx-fa: remove redundant workpool pointer in the hmx_fa_ctx, plus minor alignment updates * hmx-fa: minor alignment and simplifications * hexagon: move FA_EXP_F16 option to hostside CMake file * hmx-fa: use hvx_vec_splat_f16 instead of fp16_to_bits * hmx-fa: add hvx_splat_u16/u8 and use that in the fa instead custom hvx_fill * hmx-fa: some more alignment updates in the core fa function * hmx-fa: keep slopes in vtcm in fp16 Saves malloc/free and removes the need for float -> fp16 downcast on every use. * hexagon: consistent noinline usage (after static) * hex-hmx: consistent use FARF_HIGH to enable debug output * hmx-utils: no need for always_inline attr * hex-hmx: consistent noinline usage (static noinline ...) * hex-hmx: simplify init_col_scales * hexagon: fix editorconfig errors * hmx-mm: minor alignment fixes --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/CMakeLists.txt | 3 +- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 3 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 7 + .../ggml-hexagon/htp/cmake-toolchain.cmake | 10 +- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 14 +- ggml/src/ggml-hexagon/htp/hex-utils.h | 6 + .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 1840 +++++++++++++++++ ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 1375 ++++++------ ggml/src/ggml-hexagon/htp/hmx-ops.h | 3 + ggml/src/ggml-hexagon/htp/hmx-utils.h | 192 +- ggml/src/ggml-hexagon/htp/hvx-base.h | 6 + ggml/src/ggml-hexagon/htp/hvx-copy.h | 37 +- ggml/src/ggml-hexagon/htp/vtcm-utils.h | 16 + 13 files changed, 2768 insertions(+), 744 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/vtcm-utils.h diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index f3a58354..b82bae0c 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -22,7 +22,8 @@ message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include(ExternalProject) -option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF) set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate") set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)") diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 6bb07310..df4ed101 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2254,8 +2254,7 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess return false; } - if (dst->ne[2] != 1 || dst->ne[3] != 1) { - // FA during prompt still needs work + if (dst->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 8bd52847..7c9e4cda 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -44,6 +44,11 @@ target_compile_definitions(${HTP_LIB} PRIVATE $,FARF_HIGH=1,> FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) +if (GGML_HEXAGON_FA_EXP2_HF) + message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") + target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) +endif() + # HMX acceleration: available on v73+ architectures set(HTP_HMX_VERSIONS v73 v75 v79 v81) list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) @@ -52,11 +57,13 @@ if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE hmx-queue.c hmx-matmul-ops.c + hmx-flash-attn-ops.c ) # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) set_source_files_properties( hmx-matmul-ops.c + hmx-flash-attn-ops.c PROPERTIES COMPILE_OPTIONS "-mhmx" ) diff --git a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake index 7fa236e3..ed5c1984 100644 --- a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +++ b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake @@ -138,15 +138,15 @@ set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,") set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,") #Compiler Options -set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") +set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") -set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") -set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3") +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") +set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O2") set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") -set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") -set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3") +set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") +set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O2") set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}") set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}") diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index d296a322..d95df6ac 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -17,13 +17,14 @@ #include "htp-ctx.h" #include "htp-ops.h" #include "htp-ops.h" +#include "hmx-ops.h" // Must be multiple of 32 #define FLASH_ATTN_BLOCK_SIZE (32 * 2) // This is a bit of a hack because the compiler is strugling to properly inline // the default hvx_vec_f32_to_f16 with output into the local array. -static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) +static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) { *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1); } @@ -621,6 +622,17 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } +#ifdef HTP_HAS_HMX + // HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) { + int ret = hmx_flash_attn_ext(octx); + if (ret == HTP_STATUS_OK) { + return ret; + } + // VTCM too small or other failure -> fall through to HVX path + } +#endif + struct htp_fa_context factx; factx.octx = octx; diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index 329249e1..6239ceff 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -74,6 +74,12 @@ static inline size_t hex_smax(size_t a, size_t b) { return a > b ? a : b; } +static inline void hex_swap_ptr(void ** p1, void ** p2) { + void * t = *p1; + *p1 = *p2; + *p2 = t; +} + static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); Q6_l2fetch_AP((void *) p, control); diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c new file mode 100644 index 00000000..8a6d7c14 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -0,0 +1,1840 @@ +// HMX-accelerated Flash Attention for prefill (neq1 >= 32). +// Ported from htp-ops-lib/src/dsp/ops/flash_attn.c, adapted to the htp/ codebase. + +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "hex-dma.h" +#include "hmx-profile.h" +#include "hmx-queue.h" +#include "hmx-utils.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-dump.h" +#include "hvx-reduce.h" +#include "hvx-utils.h" +#include "vtcm-utils.h" +#include "worker-pool.h" + +// ============================================================================ +// Constants +// ============================================================================ + +// Tile constants from hmx-utils.h +// HMX_FP16_TILE_N_ROWS = 32 +// HMX_FP16_TILE_N_COLS = 32 +// HMX_FP16_TILE_N_ELMS = 1024 +// HMX_FP16_TILE_SIZE = 2048 + +// ============================================================================ +// Dynamic block size computation (GQA-aware) +// ============================================================================ + +// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration. +// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. +// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales +// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) { + const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); + const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] + const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong + const size_t k_dma_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K DMA: [Bc, DK] x2 double-buf + const size_t v_dma_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V DMA: [Bc, DV] x2 double-buf + const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved + const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved + const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc] + const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br] + const size_t col_vec_size = hex_align_up(g_br * sizeof(__fp16), 256); // m, l, etc. + const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256); + const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128); + const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096); + const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128); + + return q_tile_size * 1 // Q tiles + + o_tile_size * 2 // O ping-pong + + k_dma_size * 2 // K DMA x2 + + v_dma_size * 2 // V DMA x2 + + k_tile_size * 1 // K tiles + + v_tile_size * 1 // V tiles + + s_tile_size * 2 // S + P + + d_tile_size * 1 // D (diagonal matrix) + + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum + + row_vec_size * 2 * n_threads // per-thread softmax row scratch + + m_buf_size * 1 // mask VTCM buffer [Br rows] + + slopes_size // Slopes + + 256 * 2; // HMX scales (id + qk) +} + +// ============================================================================ +// FP16 exp2 polynomial (ported from htp-ops-lib/include/dsp/hvx_math.h) +// ============================================================================ +// 5th-order Horner polynomial for exp2(x) in qf16/hf16 domain. Input must be +// ≤ 0 (safe softmax invariant — overflow handling omitted). ~18 ALU ops per +// 64 fp16 lanes, fully parallel across HVX threads (no scatter/gather engine). +// Replaces the F32 round-trip (qf16→f32→exp→f32→f16, ~44 ops for 2×32 lanes). +static inline HVX_Vector hvx_exp2_hf(HVX_Vector x_v) { + const HVX_Vector zero_v = Q6_V_vzero(); + const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5 + + // k = round_toward_neg_inf(x); f = (float)k; frac = x - f + HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v)); + HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16 + HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16 + + HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16 + + // Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0 + HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0 + + // Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k + y = Q6_Vhf_equals_Vqf16(y); + HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11); + y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp); + HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp); + y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10); + return Q6_V_vmux_QVV(q_underflow, zero_v, y); +} + +#define FA_MIN_KV_BLOCKS 3 + +// Cost-based (Br, Bc) search for flash attention with pipeline constraint. +// +// VTCM model (same as before): +// overhead + g_br * per_gbr + g_br² * per_gbr2 + Bc * per_bc + g_br * Bc * per_gbr_bc +// +// Cost model (minimization objective): +// Q * (c_q_fixed + K * c_iter_fixed), where Q = ceil(qo/Br), K = ceil(kv/Bc) +static int hmx_fa_find_chunk_size(size_t * Br_out, + size_t * Bc_out, + size_t gqa_factor, + size_t DK, + size_t DV, + size_t qo_len, + size_t kv_len, + size_t vtcm_budget, + size_t n_threads) { + const size_t T = HMX_FP16_TILE_N_ROWS; // 32 + const size_t br_unit = hmx_ceil_div(T, gqa_factor); + // Bc must be a multiple of 64 so that n_tiles_per_bc is even. The softmax + // P-tile write uses a dual-tile pattern (vshuff + two stores 16 slots apart) + // that would race across r0 blocks if the last dual-tile is half-occupied. + // See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off. + const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64 + const size_t fp16 = sizeof(__fp16); + + // Approximate per-unit VTCM costs (without per-buffer alignment padding). + const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors + const size_t per_gbr2 = fp16; // D diagonal matrix + const size_t per_bc = + 3 * (DK + DV) * fp16 + 2 * n_threads * fp16; // K_dma×2 + V_dma×2 + K_tile + V_tile + row bufs + const size_t per_gbr_bc = 2 * fp16; // S + P + + const size_t overhead = 256 * 2 + 13 * 4096; + + if (vtcm_budget <= overhead) { + return -1; + } + const size_t usable = vtcm_budget - overhead; + + // Br_max: largest Br aligned to br_unit that does not exceed qo_len. + const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit; + + // Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS. + // Only relax when kv_len is too short to form enough blocks. + const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); + const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : + (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); + // Cost coefficients calibrated from profiling + const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store + const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers + + size_t best_cost = SIZE_MAX, best_mn = 0; + size_t best_Br = 0, best_Bc = 0; + + for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) { + const size_t g_br = hex_align_up(gqa_factor * Br, T); + + // g_br-dependent VTCM cost: g_br * per_gbr + g_br² * per_gbr2 + const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2; + if (gbr_cost >= usable) { + if (Br == br_unit) { + break; + } + continue; + } + + // Analytically solve for max Bc: + // remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16_mask) + // The Br * fp16 term accounts for the VTCM mask buffer [Br × Bc]. + const size_t remain = usable - gbr_cost; + const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16; + size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit); + if (Bc < bc_unit) { + if (Br == br_unit) { + break; + } + continue; + } + + // Exact VTCM verification (alignment padding may push over budget) + while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) { + Bc -= bc_unit; + } + if (Bc < bc_unit) { + if (Br == br_unit) { + break; + } + continue; + } + + const size_t q_blocks = (qo_len + Br - 1) / Br; + const size_t kv_blocks = (kv_len + Bc - 1) / Bc; + const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed); + const size_t mn = Br * Bc; + + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_Br = Br; + best_Bc = Bc; + } + + if (Br == br_unit) { + break; + } + } + + if (best_Br == 0) { + return -1; + } + + *Br_out = best_Br; + *Bc_out = best_Bc; + return 0; +} + +// ============================================================================ +// Tile interleave / extract helpers +// ============================================================================ + +// transpose scatter offsets moved to hmx-utils.h as hmx_transpose_scatter_offsets + +// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6 +// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile +static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = { + 0 * 136, 0 * 136 + 6, + 1 * 136, 1 * 136 + 6, + 2 * 136, 2 * 136 + 6, + 3 * 136, 3 * 136 + 6, + 4 * 136, 4 * 136 + 6, + 5 * 136, 5 * 136 + 6, + 6 * 136, 6 * 136 + 6, + 7 * 136, 7 * 136 + 6, + 8 * 136, 8 * 136 + 6, + 9 * 136, 9 * 136 + 6, + 10 * 136, 10 * 136 + 6, + 11 * 136, 11 * 136 + 6, + 12 * 136, 12 * 136 + 6, + 13 * 136, 13 * 136 + 6, + 14 * 136, 14 * 136 + 6, + 15 * 136, 15 * 136 + 6, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, +}; + +// hmx_interleave_rows_to_tiles and hmx_interleave_cols_to_tiles are in hmx-utils.h + +// ============================================================================ +// HMX Flash Attention context (GQA-merged) +// ============================================================================ + +struct hmx_fa_context { + const struct htp_ops_context * octx; + bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 + uint32_t n_threads; + + // Op parameters + float scale; + float max_bias; + float logit_softcap; + uint32_t n_head_log2; + float m0, m1; + + // Dimensions + uint32_t DK, DV; + uint32_t n_kv; // kv_len + uint32_t n_kv_heads; // number of KV heads + uint32_t n_heads; // number of Q heads + uint32_t G; // GQA factor = n_heads / n_kv_heads + uint32_t n_kv_blocks; + uint32_t neq1; // Q token count + + // Types + bool is_q_fp32; + bool is_dst_fp32; + + // Dynamic block sizes + uint32_t Br; // Q tokens per block (before GQA expansion) + uint32_t Bc; + uint32_t g_br; // hex_align_up(G * Br, 32) - actual tile row dim + + // VTCM buffers (allocated by vtcm_seq_alloc) + __fp16 * vtcm_q_tiles; // Q tile format [g_br, D] + __fp16 * vtcm_o_tiles[2]; // O ping-pong [g_br, D] + __fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D] + __fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D] + __fp16 * vtcm_k_tiles; // K tiles (transposed) + __fp16 * vtcm_v_tiles; // V tiles (column-major) + __fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc] + __fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc] + __fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br] + HVX_Vector * vtcm_m_vec; // Row max [g_br] + HVX_Vector * vtcm_l_vec; // Row sum [g_br] + HVX_Vector * vtcm_s_rowmax; // Softmax intermediate [g_br] + HVX_Vector * vtcm_p_rowsum; // Softmax intermediate [g_br] + HVX_Vector * vtcm_row_bufs; // Per-thread softmax row scratch [n_threads][2][Bc/64] + uint8_t * vtcm_hmx_scales_id; // HMX output scales (identity) + uint8_t * vtcm_hmx_scales_qk; // HMX output scales (qk_scale) + __fp16 * vtcm_mask_buf; // VTCM mask buffer [Br × m_line], DMA'd per KV block + __fp16 * vtcm_slopes; // ALiBi slopes [g_br] + size_t row_buf_stride; // HVX vectors per row buffer (Bc/64) + size_t mask_buf_row_stride; // elements (__fp16) per row in mask buffer + bool mask_broadcast; // true when mask->ne[2] == 1 (head-independent, single 2D DMA) +}; + +// ============================================================================ +// Multi-thread K interleave phase +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + int kv_rows; + size_t src_stride; + size_t buf_idx; +} fa_k_int_args_t; + +static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data) { + fa_k_int_args_t * args = (fa_k_int_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const int total_rows = args->kv_rows; + const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); // ensure even (row pairs) + const int start = i * rows_per_t; + const int end = hex_smin(start + rows_per_t, total_rows); + + if (start >= total_rows) { + return; + } + + hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK, + (int) args->src_stride, start, end); +} + +static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_k_int_args_t args = { factx, kv_rows, src_stride, buf_idx }; + if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_k_interleave_thread, &args, factx->n_threads); + } else { + fa_k_interleave_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread V interleave phase +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + int kv_rows; + size_t src_stride; + size_t buf_idx; + size_t n_col_tiles; +} fa_v_int_args_t; + +static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) { + fa_v_int_args_t * args = (fa_v_int_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const int total_rows = args->kv_rows; + const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); + const int start = i * rows_per_t; + const int end = hex_smin(start + rows_per_t, total_rows); + + if (start >= total_rows) { + return; + } + + hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, + (int) args->src_stride, (int) args->n_col_tiles, start, end); +} + +static void fa_phase_v_interleave(struct hmx_fa_context * factx, + int kv_rows, + size_t src_stride, + size_t buf_idx, + size_t n_col_tiles) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_v_int_args_t args = { factx, kv_rows, src_stride, buf_idx, n_col_tiles }; + if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_v_interleave_thread, &args, factx->n_threads); + } else { + fa_v_interleave_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread Q load phase: read Q[G × neq1, DK] from DDR, convert F32→F16 +// (or deal F16 pairs), and write interleaved into vtcm_q_tiles. +// Each thread owns a disjoint range of row pairs; writes target distinct tile +// slots (r0 selects tile row, r1 selects intra-tile slot), so there is no +// write conflict. Padding fill (when n_rows_g < g_br) is done single-threaded +// by the caller before dispatching. +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + const struct htp_tensor * q; + uint32_t q_start; + uint32_t kv_head; + uint32_t ib3; + size_t n_rows_g; +} fa_q_load_args_t; + +static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) { + fa_q_load_args_t * args = (fa_q_load_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t G = factx->G; + const size_t DK = factx->DK; + + // Partition row pairs across threads. Keep each thread's start even so r/r+1 + // are always in the same thread's range. + const size_t rows_per_t = hex_align_up(hmx_ceil_div(n_rows_g, n), 2); + const size_t start = (size_t) i * rows_per_t; + const size_t end = hex_smin(start + rows_per_t, n_rows_g); + + if (start >= n_rows_g) { + return; + } + + const struct htp_tensor * q = args->q; + const uint32_t q_start = args->q_start; + const uint32_t kv_head = args->kv_head; + const uint32_t ib3 = args->ib3; + + for (size_t r = start; r < end; r += 2) { + const bool next_row_valid = (r + 1) < n_rows_g; + + const size_t q_idx0 = (r + 0) / G; + const size_t h_idx0 = (r + 0) % G; + const size_t q_idx1 = (r + 1) / G; + const size_t h_idx1 = (r + 1) % G; + + const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] + + (kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3]; + const uint8_t * q_ptr1 = next_row_valid ? ((const uint8_t *) q->data + (q_start + q_idx1) * q->nb[1] + + (kv_head * G + h_idx1) * q->nb[2] + ib3 * q->nb[3]) : + NULL; + + size_t r0 = r / HMX_FP16_TILE_N_ROWS; + size_t r1 = r % HMX_FP16_TILE_N_ROWS; + __fp16 * out_base = factx->vtcm_q_tiles + r0 * HMX_FP16_TILE_N_ROWS * DK; + + if (factx->is_q_fp32) { + const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; + const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; + + for (uint32_t d = 0; d < DK / 32; ++d) { + HVX_Vector v0 = pv_in0[d]; + HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); + HVX_Vector v_hf = hvx_vec_f32_to_f16_shuff(v0, v1); + + HVX_Vector * out_tile = (HVX_Vector *) (out_base + d * HMX_FP16_TILE_N_ELMS); + out_tile[r1 / 2] = v_hf; + } + } else { + const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; + const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; + + for (uint32_t d = 0; d < DK / 64; ++d) { + HVX_Vector v0 = pv_in0[d]; + HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); + HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2); + + __fp16 * out_dual_tile = out_base + d * HMX_FP16_TILE_N_ELMS * 2; + HVX_Vector * pv_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; + HVX_Vector * pv_out1 = pv_out0 + 16; + + *pv_out0 = Q6_V_lo_W(vp); + *pv_out1 = Q6_V_hi_W(vp); + } + } + } +} + +static void fa_phase_q_load(struct hmx_fa_context * factx, + const struct htp_tensor * q, + uint32_t q_start, + uint32_t kv_head, + uint32_t ib3, + size_t n_rows_g) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_q_load_args_t args = { factx, q, q_start, kv_head, ib3, n_rows_g }; + // Require >= 2 row pairs per thread so partitioning is worthwhile. + if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_q_load_thread, &args, factx->n_threads); + } else { + fa_q_load_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread O store phase: read O tiles from VTCM, convert F16->F32 (or +// deal F16 pairs), and write to strided DDR dst tensor. Each thread owns a +// disjoint row range; writes target distinct dst rows (different q_idx/h_idx +// pairs produced by r/G and r%G), so there is no write conflict. +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + const struct htp_tensor * dst; + const __fp16 * o_tile_src; + uint32_t q_start; + uint32_t kv_head; + uint32_t ib3; + size_t n_rows_g; +} fa_o_store_args_t; + +static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) { + fa_o_store_args_t * args = (fa_o_store_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t G = factx->G; + const size_t DV = factx->DV; + + const size_t rows_per_t = hmx_ceil_div(n_rows_g, n); + const size_t start = (size_t) i * rows_per_t; + const size_t end = hex_smin(start + rows_per_t, n_rows_g); + + if (start >= n_rows_g) { + return; + } + + const struct htp_tensor * dst = args->dst; + const __fp16 * o_tile_src = args->o_tile_src; + const uint32_t q_start = args->q_start; + const uint32_t kv_head = args->kv_head; + const uint32_t ib3 = args->ib3; + + for (size_t r = start; r < end; ++r) { + const size_t q_idx = r / G; + const size_t h_idx = r % G; + + // FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) -> + // [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2]. + uint8_t * dst_row = (uint8_t *) dst->data + (kv_head * G + h_idx) * dst->nb[1] + + (q_start + q_idx) * dst->nb[2] + ib3 * dst->nb[3]; + + size_t r0 = r / HMX_FP16_TILE_N_ROWS; + size_t r1 = r % HMX_FP16_TILE_N_ROWS; + const __fp16 * tile_row_base = o_tile_src + r0 * HMX_FP16_TILE_N_ROWS * DV; + + if (factx->is_dst_fp32) { + float * out = (float *) dst_row; + for (uint32_t d = 0; d < DV / 32; ++d) { + const HVX_Vector * in_tile = (const HVX_Vector *) (tile_row_base + d * HMX_FP16_TILE_N_ELMS); + HVX_VectorPair vp = hvx_vec_f16_to_f32_shuff(in_tile[r1 / 2]); + if (r1 % 2 == 0) { + *(HVX_UVector *) (out + d * 32) = Q6_V_lo_W(vp); + } else { + *(HVX_UVector *) (out + d * 32) = Q6_V_hi_W(vp); + } + } + } else { + __fp16 * out = (__fp16 *) dst_row; + for (uint32_t d = 0; d < DV / 64; ++d) { + const __fp16 * in_dual_tile = tile_row_base + d * HMX_FP16_TILE_N_ELMS * 2; + const HVX_Vector * pv_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; + const HVX_Vector * pv_in1 = pv_in0 + 16; + HVX_VectorPair vp = Q6_W_vdeal_VVR(*pv_in1, *pv_in0, -2); + if (r1 % 2 == 0) { + *(HVX_UVector *) (out + d * 64) = Q6_V_lo_W(vp); + } else { + *(HVX_UVector *) (out + d * 64) = Q6_V_hi_W(vp); + } + } + } + } +} + +static void fa_phase_o_store(struct hmx_fa_context * factx, + const struct htp_tensor * dst, + const __fp16 * o_tile_src, + uint32_t q_start, + uint32_t kv_head, + uint32_t ib3, + size_t n_rows_g) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_o_store_args_t args = { factx, dst, o_tile_src, q_start, kv_head, ib3, n_rows_g }; + if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_o_store_thread, &args, factx->n_threads); + } else { + fa_o_store_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread softmax phase + serial m/l update + build_D +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + size_t kv_rows; + size_t n_rows_g; + size_t n_col_tiles; + size_t n_tiles_per_bc; + size_t n_row_tiles; + size_t n_row_tiles_g_br; + uint32_t Bc; + uint32_t G; + uint32_t kv_head; + uint32_t kv_start; + uint32_t q_start; + uint32_t ib3; + bool has_alibi; // true when max_bias != 0 (need slope * mask + add) + + // ALiBi per-head slopes (indexed by GQA-merged row: slope[r] for r in [0, n_rows_g)) + // slope[r] = 1.0 when max_bias == 0 (no ALiBi) + // Pointer into hmx_fa_context.vtcm_slopes (sized to g_br) + __fp16 * slopes; + + // Mask info (preloaded before softmax) + const struct htp_tensor * mask; + const __fp16 * mask_vtcm; // VTCM mask buffer base (NULL = DDR fallback) + size_t mask_vtcm_row_stride; // elements (__fp16) per row in VTCM mask buffer +} fa_softmax_args_t; + +static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { + fa_softmax_args_t * args = (fa_softmax_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t kv_rows = args->kv_rows; + const size_t Bc = args->Bc; + const size_t G = args->G; + const size_t n_tiles_per_bc = args->n_tiles_per_bc; + const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); + + // Partition r_vec_idx across threads + const size_t vecs_per_t = hmx_ceil_div(n_row_vec_cnt, n); + const size_t vec_start = i * vecs_per_t; + const size_t vec_end = hex_smin(vec_start + vecs_per_t, n_row_vec_cnt); + + if (vec_start >= n_row_vec_cnt) { + return; + } + + // Per-thread row scratch: thread i uses bufs at offset i * 2 * stride + const size_t row_buf_stride = factx->row_buf_stride; + HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride; + HVX_Vector * my_row_buf1 = my_row_buf0 + row_buf_stride; + + const HVX_Vector v_neg_inf = Q6_Vh_vsplat_R(0xfbff); + + // Per-row accumulators: each fp16 lane in a 64-lane vector holds one row's scalar. + // CONTRACT: lane bits must be IEEE fp16 (hf), never qf16 — qf16 uses a different + // bit layout, so a later hf-domain read would silently produce wrong values. + // Convert first via Q6_Vhf_equals_Vqf16(). For reference: vtcm_m_vec/vtcm_s_rowmax + // are hf; vtcm_l_vec is qf16 — don't mix them up. + + for (size_t r_vec_idx = vec_start; r_vec_idx < vec_end; ++r_vec_idx) { + HVX_Vector rowmax_acc_v = v_neg_inf; + HVX_Vector rowsum_acc_v = Q6_V_vzero(); + HVX_Vector m_prev_v = factx->vtcm_m_vec[r_vec_idx]; + + for (int r_vec_off = 0; r_vec_off < 64; r_vec_off += 2) { + int r = r_vec_idx * 64 + r_vec_off; + if (r >= (int) hex_align_up(n_rows_g, 2)) { + break; + } + + int r0 = r / HMX_FP16_TILE_N_ROWS; + int r1 = r % HMX_FP16_TILE_N_ROWS; + + const __fp16 * s_ld_base = factx->vtcm_s_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; + __fp16 * p_st_base = factx->vtcm_p_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; + + // Decode 2 rows from S tiles into per-thread row buffers + HVX_Vector * pv_row_buf0 = my_row_buf0; + HVX_Vector * pv_row_buf1 = my_row_buf1; + for (size_t c = 0; c < kv_rows; c += 64) { + const __fp16 * in_dual_tile = s_ld_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; + const HVX_Vector * pv_s_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; + const HVX_Vector * pv_s_in1 = pv_s_in0 + 16; + + HVX_VectorPair vp_s_dual_row = Q6_W_vdeal_VVR(*pv_s_in1, *pv_s_in0, -2); + *pv_row_buf0++ = Q6_V_lo_W(vp_s_dual_row); + *pv_row_buf1++ = Q6_V_hi_W(vp_s_dual_row); + } + + // Apply softcap if enabled (in F32 precision) + if (factx->logit_softcap != 0.0f) { + // When EXP2_HF is on, fold log2(e) into v_cap so the output lands in + // log2(e)-scaled space for the downstream exp2. log2(e) is kept OUT + // of qk_scale in this configuration (see scale setup) so tanh sees + // the physical QK/(√d·c) argument. + float cap = factx->logit_softcap; +#ifdef HMX_FA_USE_EXP2_HF + cap *= 1.44269504f; // log2(e) +#endif + const HVX_Vector v_cap = hvx_vec_splat_f32(cap); + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + + HVX_VectorPair r0_f32 = hvx_vec_f16_to_f32(my_row_buf0[ci]); + HVX_Vector t0_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r0_f32)); + HVX_Vector t0_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r0_f32)); + t0_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_lo, v_cap)); + t0_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_hi, v_cap)); + my_row_buf0[ci] = hvx_vec_f32_to_f16(t0_lo, t0_hi); + + HVX_VectorPair r1_f32 = hvx_vec_f16_to_f32(my_row_buf1[ci]); + HVX_Vector t1_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r1_f32)); + HVX_Vector t1_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r1_f32)); + t1_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_lo, v_cap)); + t1_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_hi, v_cap)); + my_row_buf1[ci] = hvx_vec_f32_to_f16(t1_lo, t1_hi); + } + } + + // Apply mask & compute rowmax(S) + // + // Optimizations over baseline: + // A. No-ALiBi fast path: when max_bias==0 (slope≡1.0), skip the + // slope multiplication — still add mask (additive bias) but + // avoid the mul_f16_f16. Saves 2 ops/dual-row vs ALiBi path. + // B. GQA mask row dedup: G consecutive Q rows share one mask row + // (qi = r / G). Reuse mask vector when qi is unchanged between + // row0 and row1 (saves ~75% of VTCM loads for G=4). + + // ALiBi slopes — only needed when has_alibi (scheme A) + HVX_Vector v_slope0, v_slope1; + if (args->has_alibi) { + v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]); + v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero(); + } + + const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c) + + HVX_Vector v_s_rowmax0 = v_neg_inf; + HVX_Vector v_s_rowmax1 = v_neg_inf; + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + const size_t ne = hex_smin(kv_rows - c, 64); + HVX_VectorPred q_tail_keep = Q6_Q_vsetq2_R(ne * sizeof(__fp16)); + + if (args->mask) { + HVX_Vector v_mask0, v_mask1; + + if (args->mask_vtcm) { + // Read mask from VTCM buffer (DMA'd per KV block). + // GQA dedup (scheme B): skip load when qi unchanged. + const size_t qi0 = (r + 0) / G; + v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); + v_mask1 = v_neg_inf; + if (r + 1 < (int) n_rows_g) { + const size_t qi1 = (r + 1) / G; + if (qi1 == qi0) { + v_mask1 = v_mask0; // scheme B: reuse — same mask row + } else { + v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c); + } + } + } else { + // Fallback: read mask directly from DDR (when mask->ne[2] > 1). + const struct htp_tensor * mask = args->mask; + const size_t q_idx0 = args->q_start + ((r + 0) / G); + const size_t h_idx0 = args->kv_head * G + (r + 0) % G; + const uint32_t im2_0 = h_idx0 % mask->ne[2]; + const uint32_t im3_0 = args->ib3 % mask->ne[3]; + + const __fp16 * m0_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx0 * mask->nb[1] + + im2_0 * mask->nb[2] + im3_0 * mask->nb[3]) + args->kv_start + c; + v_mask0 = *(const HVX_UVector *) m0_ptr; + v_mask1 = v_neg_inf; + + if (r + 1 < (int) n_rows_g) { + const size_t q_idx1 = args->q_start + ((r + 1) / G); + if (q_idx1 == q_idx0) { + // scheme B: same mask row in DDR path + v_mask1 = v_mask0; + } else { + const size_t h_idx1 = args->kv_head * G + (r + 1) % G; + const uint32_t im2_1 = h_idx1 % mask->ne[2]; + const uint32_t im3_1 = args->ib3 % mask->ne[3]; + const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] + + im2_1 * mask->nb[2] + im3_1 * mask->nb[3]) + args->kv_start + c; + v_mask1 = *(const HVX_UVector *) m1_ptr; + } + } + } + + // Threshold: mask values below -16.0 are treated as -inf (causal mask). + HVX_VectorPred q_keep0 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask0, v_threshold), q_tail_keep); + HVX_VectorPred q_keep1 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask1, v_threshold), q_tail_keep); + + if (args->has_alibi) { + // ALiBi path: S += slope * mask (full mul + add) + HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0); + HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1); + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf); + } else { + // No-ALiBi fast path (scheme A): slope≡1.0, skip the mul + // but still add mask (additive positional bias). vmux + // clamps mask < -16 to -inf as a numerical safeguard. + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_mask0), v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_mask1), v_neg_inf); + } + } else { + if (ne < 64) { + my_row_buf0[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf0[ci], v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf1[ci], v_neg_inf); + } + } + + v_s_rowmax0 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax0, my_row_buf0[ci]); + v_s_rowmax1 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax1, my_row_buf1[ci]); + } + + v_s_rowmax0 = hvx_vec_reduce_max_f16(v_s_rowmax0); + v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1); + + // Splat m_prev[r], m_prev[r+1] from the per-row accumulator. + // vror brings the target lane to lane 0, then extract + re-splat. + HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2))); + HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2))); + + // HVX max — both operands are splats, so result is splat of m_new. + HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0); + HVX_Vector v_dup_m1 = Q6_Vhf_vmax_VhfVhf(v_m_prev1, v_s_rowmax1); + + // Insert row r, r+1 rowmax into rowmax_acc_v via 2-byte-wide vmux. + // Byte ranges: lane0 = [r_vec_off*2 .. r_vec_off*2+1], lane1 shifted by 2. + // vsetq2 handles the n=128 corner case when r_vec_off reaches 62. + { + HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); + HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); + HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); + HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); + HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); + rowmax_acc_v = Q6_V_vmux_QVV(p_lane0, v_dup_m0, rowmax_acc_v); + rowmax_acc_v = Q6_V_vmux_QVV(p_lane1, v_dup_m1, rowmax_acc_v); + } + + // Compute P = exp(S - m_new), using HVX exp + const HVX_Vector v_zero = Q6_V_vzero(); + HVX_Vector v_p_rowsum0 = v_zero; + HVX_Vector v_p_rowsum1 = v_zero; + +#ifdef HMX_FA_USE_EXP2_HF + // FP16 exp2 polynomial path (matches htp-ops-lib flash_attn.c): + // P = exp2(S - m_new) + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); + HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); + + HVX_Vector v_p_row0_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); + HVX_Vector v_p_row1_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); +#else + // F32 exp path: qf16 → f32 → exp → f32 → f16. Higher precision, + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); + HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); + + HVX_VectorPair vp0 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); + HVX_Vector p0_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp0)); + HVX_Vector p0_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp0)); + HVX_Vector v_p_row0_hf = hvx_vec_f32_to_f16_shuff(p0_lo, p0_hi); + + HVX_VectorPair vp1 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); + HVX_Vector p1_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp1)); + HVX_Vector p1_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp1)); + HVX_Vector v_p_row1_hf = hvx_vec_f32_to_f16_shuff(p1_lo, p1_hi); +#endif + // Write P to tile format. Dual-tile pattern assumes Bc is a + // multiple of 64 (enforced by bc_unit=64 in hmx_fa_find_chunk_size), + // so both tile halves are always in the current r0 block. + __fp16 * out_dual_tile = p_st_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; + HVX_Vector * pv_p_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; + HVX_Vector * pv_p_out1 = pv_p_out0 + 16; + + HVX_VectorPair vp_p_dual = Q6_W_vshuff_VVR(v_p_row1_hf, v_p_row0_hf, -2); + *pv_p_out0 = Q6_V_lo_W(vp_p_dual); + *pv_p_out1 = Q6_V_hi_W(vp_p_dual); + + HVX_VectorPair vp_p0 = hvx_vec_f16_to_f32_shuff(v_p_row0_hf); + HVX_VectorPair vp_p1 = hvx_vec_f16_to_f32_shuff(v_p_row1_hf); + + v_p_rowsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum0, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p0), Q6_V_hi_W(vp_p0))); + v_p_rowsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum1, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p1), Q6_V_hi_W(vp_p1))); + } + + HVX_Vector rowsum0_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum0)); + HVX_Vector rowsum1_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum1)); + { + // Both inputs are f32 splats, so the f32->f16 output is an fp16 splat. + HVX_Vector rv0_v = hvx_vec_f32_to_f16(rowsum0_sf, rowsum0_sf); + HVX_Vector rv1_v = hvx_vec_f32_to_f16(rowsum1_sf, rowsum1_sf); + + HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); + HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); + HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); + HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); + HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); + rowsum_acc_v = Q6_V_vmux_QVV(p_lane0, rv0_v, rowsum_acc_v); + rowsum_acc_v = Q6_V_vmux_QVV(p_lane1, rv1_v, rowsum_acc_v); + } + } + + factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v; + factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v; + } +} + +// Serial m/l update + build_D. Must run after softmax barrier (s_rowmax written by all threads). +// +// noinline: function boundary acts as a hard compiler barrier so the (size_t)addr scatter +// intrinsics inside cannot be hoisted past the call site. Mirrors the structural protection +// matmul gets for free via worker_pool function-pointer dispatch. Without this, the compiler +// can reorder the scatter past the subsequent hmx_queue_push and the HMX-queue worker thread +// reads stale VTCM (PPL → ~vocab-size). +static __attribute__((noinline)) void fa_ml_update_and_build_d(struct hmx_fa_context * factx, + size_t n_rows_g, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + // Reuse s_rowmax buffer for exp(m_diff) — safe because softmax is fully complete + HVX_Vector * const mvec_exp_m_diff = factx->vtcm_s_rowmax; + + const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); + for (size_t i = 0; i < n_row_vec_cnt; ++i) { + HVX_Vector v_m_prev = factx->vtcm_m_vec[i]; + HVX_Vector v_m_curr = Q6_Vhf_vmax_VhfVhf(v_m_prev, factx->vtcm_s_rowmax[i]); + HVX_Vector v_m_diff = Q6_Vqf16_vsub_VhfVhf(v_m_prev, v_m_curr); + +#ifdef HMX_FA_USE_EXP2_HF + // Base-2 path: must match P = exp2(S - m_new) in fa_softmax_thread. + HVX_Vector v_exp_m_diff = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_m_diff)); +#else + HVX_VectorPair vp_diff = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_m_diff)); + HVX_Vector exp_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp_diff)); + HVX_Vector exp_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp_diff)); + HVX_Vector v_exp_m_diff = hvx_vec_f32_to_f16_shuff(exp_lo, exp_hi); +#endif + + HVX_Vector v_l_curr = Q6_Vqf16_vmpy_Vqf16Vhf(factx->vtcm_l_vec[i], v_exp_m_diff); + v_l_curr = Q6_Vqf16_vadd_Vqf16Vhf(v_l_curr, factx->vtcm_p_rowsum[i]); + + factx->vtcm_m_vec[i] = v_m_curr; + factx->vtcm_l_vec[i] = v_l_curr; + mvec_exp_m_diff[i] = v_exp_m_diff; + } + + // Build diagonal tile D = diag(exp(m_diff)) + const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; + const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); + for (size_t i = 0; i < n_row_tiles; ++i) { + const HVX_Vector v_content = Q6_V_vror_VR(mvec_exp_m_diff[i / 2], (i % 2) * 64); + __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); + // Compiler barrier — Q6_vscatter takes (size_t)addr; without this the + // compiler may not recognize the volatile read below as aliasing and + // could reorder it before the scatter, defeating the HW drain. + __asm__ __volatile__("" ::: "memory"); + // Per-tile drain: scatter regions are disjoint (stride > tile size), + // so a single drain at tile 0 does NOT retire later tiles' entries. + (void) *(volatile HVX_Vector *) out_base; + } +} + +// Build D = diag(1/l) tile for the final O = D @ O normalization. +// +// noinline: same rationale as fa_ml_update_and_build_d — keeps Q6_vscatter from +// being hoisted past the subsequent hmx_queue_push at the o_norm call site. +static __attribute__((noinline)) void fa_build_d_diag_inv_l(struct hmx_fa_context * factx, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; + const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); + const HVX_Vector one = hvx_vec_splat_f32(1.0f); + + HVX_Vector v_content = Q6_V_vzero(); + for (size_t i = 0; i < n_row_tiles; ++i) { + if ((i % 2) == 0) { + HVX_Vector v_l_hf = Q6_Vhf_equals_Vqf16(factx->vtcm_l_vec[i / 2]); + HVX_VectorPair vp_l = hvx_vec_f16_to_f32_shuff(v_l_hf); + HVX_Vector inv_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_lo_W(vp_l)))); + HVX_Vector inv_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_hi_W(vp_l)))); + v_content = hvx_vec_f32_to_f16_shuff(inv_lo, inv_hi); + } else { + v_content = Q6_V_vror_VR(v_content, 64); + } + + __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); + // Compiler barrier — see fa_ml_update_and_build_d for rationale. + __asm__ __volatile__("" ::: "memory"); + (void) *(volatile HVX_Vector *) out_base; + } +} + +// Combined: multi-thread softmax -> barrier -> serial m/l update + build_D +static void fa_phase_softmax_and_build_d(struct hmx_fa_context * factx, + fa_softmax_args_t * sargs, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + const size_t n_row_vec_cnt = hmx_ceil_div(sargs->n_rows_g, 64); + + if (factx->n_threads > 1 && n_row_vec_cnt >= 2) { + uint32_t n_use = (uint32_t) hex_smin((size_t) factx->n_threads, n_row_vec_cnt); + worker_pool_run_func(wp, fa_softmax_thread, sargs, n_use); + } else { + fa_softmax_thread(1, 0, sargs); + } + // barrier implicit in worker_pool_run_func return + + fa_ml_update_and_build_d(factx, sargs->n_rows_g, n_row_tiles, n_row_tiles_g_br); +} + +// ============================================================================ +// HMX job structs and worker functions +// ============================================================================ + +typedef struct { + const __fp16 * q_tiles; + const __fp16 * k_tiles; + __fp16 * s_tiles; + size_t n_row_tiles; + size_t n_col_tiles; + size_t n_dot_tiles; // DK / 32 + size_t n_tiles_per_bc; + uint8_t * hmx_scales; +} hmx_fa_qk_job_t; + +static void hmx_fa_qk_dot_worker(void * data) { + hmx_fa_qk_job_t * job = (hmx_fa_qk_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_col_tiles = job->n_col_tiles; + const size_t n_dot_tiles = job->n_dot_tiles; + const size_t n_tiles_per_bc = job->n_tiles_per_bc; + const __fp16 * restrict q_tiles = job->q_tiles; + const __fp16 * restrict k_tiles = job->k_tiles; + __fp16 * restrict s_tiles = job->s_tiles; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + const __fp16 * row_tiles = q_tiles + r * HMX_FP16_TILE_N_ROWS * n_dot_tiles * HMX_FP16_TILE_N_COLS; + const __fp16 * col_tiles = k_tiles + c * HMX_FP16_TILE_N_COLS * n_dot_tiles * HMX_FP16_TILE_N_COLS; + __fp16 * out_tile = s_tiles + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; + + for (size_t k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } +} + +typedef struct { + __fp16 * o_curr; + const __fp16 * o_prev; + const __fp16 * p_tiles; + const __fp16 * v_tiles; + const __fp16 * d_tiles; + uint8_t * hmx_scales; + size_t n_row_tiles; + size_t n_col_tiles; + size_t n_row_tiles_g_br; + size_t n_tiles_per_bc; + size_t DV; +} hmx_fa_o_update_job_t; + +static void hmx_fa_o_update_worker(void * data) { + hmx_fa_o_update_job_t * job = (hmx_fa_o_update_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_col_tiles = job->n_col_tiles; + const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; + const size_t n_tiles_per_bc = job->n_tiles_per_bc; + const size_t DV_tiles = job->DV / 32; + const __fp16 * restrict d_tiles = job->d_tiles; + const __fp16 * restrict p_tiles = job->p_tiles; + const __fp16 * restrict v_tiles = job->v_tiles; + const __fp16 * restrict o_prev = job->o_prev; + __fp16 * restrict o_curr = job->o_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + // D[r,r] @ O_prev[r,c] — only the diagonal tile + const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + + // P @ V (accumulate on same accumulator) + const __fp16 * p_tile_in = p_tiles + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + const __fp16 * v_tile_in = v_tiles + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_col_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); + p_tile_in += HMX_FP16_TILE_N_ELMS; + v_tile_in += HMX_FP16_TILE_N_ELMS; + } + + __fp16 * o_tile_out = o_curr + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(o_tile_out, 0); + } + } +} + +typedef struct { + __fp16 * o_curr; // output (row-major tile layout) + const __fp16 * o_prev; // input (column-major tile layout) + const __fp16 * d_tiles; // diag(1/l) tiles + uint8_t * hmx_scales; + size_t n_row_tiles; + size_t n_row_tiles_g_br; + size_t DV; +} hmx_fa_o_norm_job_t; + +static void hmx_fa_o_norm_worker(void * data) { + hmx_fa_o_norm_job_t * job = (hmx_fa_o_norm_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; + const size_t DV_tiles = job->DV / 32; + const __fp16 * restrict d_tiles = job->d_tiles; + const __fp16 * restrict o_prev = job->o_prev; + __fp16 * restrict o_curr = job->o_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + __fp16 * o_out = o_curr + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; + + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + Q6_mxmem_AR_after_hf(o_out, 0); + } + } +} + +// Populate per-GQA-row ALiBi slopes for a given KV head. +// Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G. +// slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1). +// When max_bias == 0, all slopes are 1.0 (no ALiBi). +static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs, + const struct hmx_fa_context * factx, + uint32_t kv_head, + size_t n_rows_g) { + if (factx->max_bias == 0.0f) { + for (size_t r = 0; r < n_rows_g; ++r) { + sargs->slopes[r] = 1.0f; + } + return; + } + + const uint32_t G = factx->G; + const uint32_t n_head_log2 = factx->n_head_log2; + const float m0 = factx->m0; + const float m1 = factx->m1; + + for (size_t r = 0; r < n_rows_g; ++r) { + const uint32_t h = kv_head * G + r % G; + sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); + } +} + +// ============================================================================ +// Core HMX flash attention algorithm (GQA-merged) +// ============================================================================ + +int hmx_flash_attn_ext(struct htp_ops_context * octx) { + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = (octx->src[3] && octx->src[3]->data) ? octx->src[3] : NULL; + const struct htp_tensor * dst = octx->dst; + + struct htp_context * const ctx = octx->ctx; + + if (!ctx->hmx_enabled) { + return HTP_STATUS_NO_SUPPORT; + } + + // Dimensions + const uint32_t neq0 = q->ne[0]; // head_dim (DK) + const uint32_t neq1 = q->ne[1]; // n_tokens + const uint32_t neq2 = q->ne[2]; // n_heads + const uint32_t neq3 = q->ne[3]; // n_seqs + + const uint32_t nek0 = k->ne[0]; // head_dim + const uint32_t nek1 = k->ne[1]; // kv_len + + const uint32_t nev0 = v->ne[0]; // head_dim (DV) + + const uint32_t DK = neq0; + const uint32_t DV = nev0; + + // HMX requires head_dim to be multiple of 32 + if (DK % 32 != 0 || DV % 32 != 0) { + return HTP_STATUS_NO_SUPPORT; + } + if (neq1 < 32) { + return HTP_STATUS_NO_SUPPORT; + } + + // GQA factor + const uint32_t n_kv_heads = k->ne[2]; + const uint32_t G = neq2 / n_kv_heads; + + // Thread count for multi-thread HVX phases + const uint32_t n_threads = octx->n_threads; + + // Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs) + size_t Br, Bc; + const size_t vtcm_budget = ctx->vtcm_size; + if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); + + const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; + const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2); + + FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", + neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); + + // ======== Build context ======== + struct hmx_fa_context factx; + memset(&factx, 0, sizeof(factx)); + factx.octx = octx; + factx.n_threads = octx->ctx->n_threads; + factx.DK = DK; + factx.DV = DV; + factx.n_kv = nek1; + factx.n_kv_heads = n_kv_heads; + factx.n_heads = neq2; + factx.G = G; + factx.neq1 = neq1; + factx.Br = (uint32_t) Br; + factx.Bc = (uint32_t) Bc; + factx.g_br = (uint32_t) g_br; + factx.n_kv_blocks = n_kv_blocks; + factx.is_q_fp32 = (q->type == HTP_TYPE_F32); + factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32); + factx.use_pipeline = use_pipeline; + factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1); + + // Extract op parameters (mutable during softcap adjustment, then stored as const in factx) + float scale = 1.0f, max_bias = 0.0f, logit_softcap = 0.0f; + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + +#ifdef HMX_FA_USE_EXP2_HF + // Pre-bake log2(e) into qk_scale so HMX-produced S tiles are in log2(e)-scaled + // space. Then exp2(S - m) in the softmax equals base-e exp((S - m) / log2(e)), + // preserving ggml's base-e softmax semantics. Matches htp-ops-lib flash_attn.c. + // + // When softcap is active we cannot pre-bake log2(e) here — it would land inside + // the tanh argument and shift the softcap knee from x≈c to x≈c/log2(e), giving + // numerically wrong softcapped values. Instead fold log2(e) into the post-tanh + // multiplier (see softcap block: v_cap absorbs log2(e)). + if (logit_softcap == 0.0f) { + scale *= 1.44269504f; // log2(e) + } +#endif + + factx.scale = scale; + factx.max_bias = max_bias; + factx.logit_softcap = logit_softcap; + + factx.n_head_log2 = 1u << (uint32_t) floor(log2(neq2)); + factx.m0 = powf(2.0f, -(max_bias) / factx.n_head_log2); + factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + + // ======== VTCM allocation (GQA-aware) ======== + const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096); + const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096); + const size_t k_dma_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); + const size_t v_dma_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); + const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); + const size_t d_tile_bytes = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); + const size_t col_vec_bytes = hex_align_up(g_br * sizeof(__fp16), 256); + const size_t row_vec_bytes = hex_align_up(Bc * sizeof(__fp16), 256); + const size_t m_line_bytes = hex_align_up(Bc * sizeof(__fp16), 128); + const size_t m_buf_bytes = hex_align_up(Br * m_line_bytes, 4096); + const size_t slopes_bytes = hex_align_up(g_br * sizeof(__fp16), 128); + + uint8_t * vtcm_cur = ctx->vtcm_base; + + factx.vtcm_q_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, q_tile_bytes); + factx.vtcm_o_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); + factx.vtcm_o_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); + factx.vtcm_k_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); + factx.vtcm_k_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); + factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); + factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); + factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); + factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); + factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); + factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes); + factx.vtcm_m_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_l_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_s_rowmax = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_p_rowsum = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_row_bufs = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, row_vec_bytes * 2 * n_threads); + factx.row_buf_stride = row_vec_bytes / sizeof(HVX_Vector); + factx.vtcm_hmx_scales_id = vtcm_seq_alloc(&vtcm_cur, 256); + factx.vtcm_hmx_scales_qk = vtcm_seq_alloc(&vtcm_cur, 256); + factx.vtcm_mask_buf = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, m_buf_bytes); + factx.mask_buf_row_stride = m_line_bytes / sizeof(__fp16); + factx.vtcm_slopes = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, slopes_bytes); + + if ((size_t) (vtcm_cur - ctx->vtcm_base) > ctx->vtcm_size) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + // ======== Initialize HMX output scales ======== + // Identity scale (1.0) for O updates and normalization + hmx_init_column_scales(factx.vtcm_hmx_scales_id, Q6_V_vsplat_R(0x3c00)); // 1.0 + + // QK scale embedded in HMX output + hmx_init_column_scales(factx.vtcm_hmx_scales_qk, hvx_vec_splat_f16(factx.scale)); + + // ======== Skip compute if profiling ======== + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + // Profiling timers + TIMER_DEFINE(total); + TIMER_DEFINE(q_load); + TIMER_DEFINE(kv_dma); + TIMER_DEFINE(k_interleave); + TIMER_DEFINE(v_interleave); + TIMER_DEFINE(qk_dot); + TIMER_DEFINE(softmax); + TIMER_DEFINE(o_update); + TIMER_DEFINE(o_norm); + TIMER_DEFINE(o_store); + + TIMER_START(total); + + // ======== DMA setup ======== + dma_queue * const dma = ctx->dma[0]; + + // Padded row sizes for DMA + const size_t size_k_row = nek0 * sizeof(__fp16); + const size_t size_v_row = nev0 * sizeof(__fp16); + const size_t size_k_row_padded = hex_round_up(nek0 * sizeof(__fp16), 128); + const size_t size_v_row_padded = hex_round_up(nev0 * sizeof(__fp16), 128); + + const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS; + const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS; + + // Q/O element size for Q load and O store + const size_t qo_element_size = factx.is_q_fp32 ? sizeof(float) : sizeof(__fp16); + + // ======== HMX lock strategy ======== + // Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend. + // Fallback: main thread holds the lock (original behavior). + if (!factx.use_pipeline) { + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + } + + // ======== Reusable job descriptors for pipeline ======== + hmx_fa_qk_job_t qk_job; + hmx_fa_o_update_job_t ou_job; + hmx_fa_o_norm_job_t on_job; + + // ======== Main loop: per batch, per KV head, per Q block ======== + for (uint32_t ib3 = 0; ib3 < neq3; ++ib3) { + for (uint32_t kv_head = 0; kv_head < n_kv_heads; ++kv_head) { + const uint32_t ik2 = kv_head; + const uint32_t ik3 = ib3 / (neq3 / k->ne[3]); + const uint32_t iv2 = kv_head; + const uint32_t iv3 = ib3 / (neq3 / v->ne[3]); + + for (uint32_t q_start = 0; q_start < neq1; q_start += Br) { + const uint32_t n_q_rows = hex_smin(Br, neq1 - q_start); + const size_t n_rows_g = n_q_rows * G; + const size_t g_br_actual = hex_align_up(n_rows_g, HMX_FP16_TILE_N_ROWS); + const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS; + + // ---- Load Q block [g_br, D] -> tiles, interleaving G heads ---- + TIMER_START(q_load); + if (n_rows_g < g_br) { + hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes); + } + fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g); + TIMER_STOP(q_load); + + // ---- Initialize per-block state ---- + hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes); + hvx_splat_u8_a(factx.vtcm_d_tiles, 0, d_tile_bytes); + hvx_splat_u16_a(factx.vtcm_m_vec, 0xfbff, col_vec_bytes/2); + + __fp16 * o_tile_prev = factx.vtcm_o_tiles[0]; + __fp16 * o_tile_curr = factx.vtcm_o_tiles[1]; + hvx_splat_u8_a(o_tile_prev, 0, o_tile_bytes); + + // ---- KV block loop with DMA double-buffering ---- + size_t buf_idx = 0; + + // Prefetch first KV block + if (factx.n_kv_blocks > 0) { + const uint32_t kv_rows0 = hex_smin(Bc, nek1); + + const uint8_t * k_src = (const uint8_t *) k->data + ik2 * k->nb[2] + ik3 * k->nb[3]; + dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[0], k_src), size_k_row_padded, k->nb[1], + size_k_row, kv_rows0); + + const uint8_t * v_src = (const uint8_t *) v->data + iv2 * v->nb[2] + iv3 * v->nb[3]; + dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[0], v_src), size_v_row_padded, v->nb[1], + size_v_row, kv_rows0); + } + + // Mask DMA: single 2D transfer of n_q_rows unique mask rows into VTCM buffer. + // Only when mask is head-broadcast (ne[2]==1); otherwise softmax reads DDR directly. + #define MASK_DMA_PUSH(kv_start_val, kv_rows_val, has_mask_dma_var) \ + do { \ + has_mask_dma_var = false; \ + if (mask && factx.mask_broadcast) { \ + const uint32_t _im3 = ib3 % mask->ne[3]; \ + const uint8_t * _ms = (const uint8_t *) mask->data + q_start * mask->nb[1] + _im3 * mask->nb[3] + \ + (kv_start_val) * sizeof(__fp16); \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_mask_buf, _ms), m_line_bytes, mask->nb[1], \ + (kv_rows_val) * sizeof(__fp16), n_q_rows); \ + has_mask_dma_var = true; \ + } \ + } while (0) + + #define MASK_DMA_POP(has_mask_dma_var) \ + do { \ + if (has_mask_dma_var) { \ + dma_queue_pop(dma); \ + } \ + } while (0) + + #define DMA_PREFETCH_KV(blk_val) \ + do { \ + if ((blk_val) < factx.n_kv_blocks) { \ + const uint32_t _ns = (blk_val) * Bc; \ + const uint32_t _nr = hex_smin(Bc, nek1 - _ns); \ + size_t _nb = 1 - buf_idx; \ + const uint8_t * _ks = (const uint8_t *) k->data + _ns * k->nb[1] + ik2 * k->nb[2] + ik3 * k->nb[3]; \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[_nb], _ks), size_k_row_padded, k->nb[1], size_k_row, _nr); \ + const uint8_t * _vs = (const uint8_t *) v->data + _ns * v->nb[1] + iv2 * v->nb[2] + iv3 * v->nb[3]; \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[_nb], _vs), size_v_row_padded, v->nb[1], size_v_row, _nr); \ + } \ + } while (0) + + const size_t k_src_stride = size_k_row_padded / sizeof(__fp16); + const size_t v_src_stride = size_v_row_padded / sizeof(__fp16); + + if (factx.use_pipeline) { + // ================================================================== + // Pipeline path: HVX phases ‖ HMX queue worker + // ================================================================== + struct hmx_queue * hmx_q = ctx->hmx_queue; + + for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { + const uint32_t kv_start = kv_blk * Bc; + const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); + const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); + + // Wait for current KV DMA + TIMER_START(kv_dma); + dma_queue_pop(dma); // K + dma_queue_pop(dma); // V + TIMER_STOP(kv_dma); + + // Push mask DMA for this block (single 2D DMA when broadcast) + bool has_mask_dma = false; + MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); + + // ---- Phase 1: K_int(blk) ‖ O_update(blk-1) ---- + if (kv_blk > 0) { + // Submit O_update for previous block (HMX worker) + ou_job.o_curr = o_tile_curr; + ou_job.o_prev = o_tile_prev; + ou_job.p_tiles = factx.vtcm_p_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.d_tiles = factx.vtcm_d_tiles; + ou_job.hmx_scales = factx.vtcm_hmx_scales_id; + ou_job.n_row_tiles = n_row_tiles; + ou_job.n_col_tiles = hmx_ceil_div(hex_smin(Bc, nek1 - (kv_blk - 1) * Bc), HMX_FP16_TILE_N_COLS); + ou_job.n_row_tiles_g_br = n_row_tiles_g_br; + ou_job.n_tiles_per_bc = n_tiles_per_bc; + ou_job.DV = DV; + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); + } + + TIMER_START(k_interleave); + fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); + TIMER_STOP(k_interleave); + + if (kv_blk > 0) { + hmx_queue_pop(hmx_q); + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + // ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ---- + qk_job.q_tiles = factx.vtcm_q_tiles; + qk_job.k_tiles = factx.vtcm_k_tiles; + qk_job.s_tiles = factx.vtcm_s_tiles; + qk_job.n_row_tiles = n_row_tiles; + qk_job.n_col_tiles = n_col_tiles; + qk_job.n_dot_tiles = DK / 32; + qk_job.n_tiles_per_bc = n_tiles_per_bc; + qk_job.hmx_scales = factx.vtcm_hmx_scales_qk; + TIMER_START(qk_dot); + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job)); + + // DMA push next block (non-blocking, before worker_pool) + DMA_PREFETCH_KV(kv_blk + 1); + + TIMER_START(v_interleave); + fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); + TIMER_STOP(v_interleave); + + hmx_queue_pop(hmx_q); + TIMER_STOP(qk_dot); + + // ---- Phase 3: softmax(blk) + build_D(blk) | HMX idle ---- + // Pop mask DMA before softmax (ensures VTCM buffer is ready) + MASK_DMA_POP(has_mask_dma); + + fa_softmax_args_t sargs; + memset(&sargs, 0, sizeof(sargs)); + sargs.factx = &factx; + sargs.kv_rows = kv_rows; + sargs.n_rows_g = n_rows_g; + sargs.n_col_tiles = n_col_tiles; + sargs.n_tiles_per_bc = n_tiles_per_bc; + sargs.n_row_tiles = n_row_tiles; + sargs.n_row_tiles_g_br = n_row_tiles_g_br; + sargs.Bc = Bc; + sargs.G = G; + sargs.kv_head = kv_head; + sargs.kv_start = kv_start; + sargs.q_start = q_start; + sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); + sargs.mask = mask; + sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; + sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; + sargs.slopes = factx.vtcm_slopes; + fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); + + TIMER_START(softmax); + fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); + TIMER_STOP(softmax); + + buf_idx = 1 - buf_idx; + } // end KV block loop (pipeline) + + // Epilogue: O_update for last block + if (factx.n_kv_blocks > 0) { + const uint32_t last_blk = factx.n_kv_blocks - 1; + const size_t last_cols = hmx_ceil_div(hex_smin(Bc, nek1 - last_blk * Bc), HMX_FP16_TILE_N_COLS); + ou_job.o_curr = o_tile_curr; + ou_job.o_prev = o_tile_prev; + ou_job.p_tiles = factx.vtcm_p_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.d_tiles = factx.vtcm_d_tiles; + ou_job.hmx_scales = factx.vtcm_hmx_scales_id; + ou_job.n_row_tiles = n_row_tiles; + ou_job.n_col_tiles = last_cols; + ou_job.n_row_tiles_g_br = n_row_tiles_g_br; + ou_job.n_tiles_per_bc = n_tiles_per_bc; + ou_job.DV = DV; + + TIMER_START(o_update); + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); + hmx_queue_pop(hmx_q); + TIMER_STOP(o_update); + + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + } else { + // ================================================================== + // Fallback path: sequential with multi-thread HVX phases + // Main thread holds HMX lock, runs HMX inline. + // ================================================================== + + for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { + const uint32_t kv_start = kv_blk * Bc; + const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); + const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); + + TIMER_START(kv_dma); + dma_queue_pop(dma); // K + dma_queue_pop(dma); // V + TIMER_STOP(kv_dma); + + bool has_mask_dma = false; + MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); + DMA_PREFETCH_KV(kv_blk + 1); + + // K interleave (multi-thread HVX) + TIMER_START(k_interleave); + fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); + TIMER_STOP(k_interleave); + + // QK dot (inline HMX on main thread) + TIMER_START(qk_dot); + { + const size_t n_dot_tiles = (size_t) (DK / 32); + const __fp16 * restrict q_base = factx.vtcm_q_tiles; + const __fp16 * restrict k_base = factx.vtcm_k_tiles; + __fp16 * restrict s_base = factx.vtcm_s_tiles; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + const __fp16 * row_tiles = q_base + r * HMX_FP16_TILE_N_ROWS * DK; + const __fp16 * col_tiles = k_base + c * HMX_FP16_TILE_N_COLS * DK; + __fp16 * out_tile = s_base + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } + } + TIMER_STOP(qk_dot); + + // Pop mask DMA + MASK_DMA_POP(has_mask_dma); + + // Softmax + build_D (multi-thread HVX + serial m/l update) + fa_softmax_args_t sargs; + memset(&sargs, 0, sizeof(sargs)); + sargs.factx = &factx; + sargs.kv_rows = kv_rows; + sargs.n_rows_g = n_rows_g; + sargs.n_col_tiles = n_col_tiles; + sargs.n_tiles_per_bc = n_tiles_per_bc; + sargs.n_row_tiles = n_row_tiles; + sargs.n_row_tiles_g_br = n_row_tiles_g_br; + sargs.Bc = Bc; + sargs.G = G; + sargs.kv_head = kv_head; + sargs.kv_start = kv_start; + sargs.q_start = q_start; + sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); + sargs.mask = mask; + sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; + sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; + sargs.slopes = factx.vtcm_slopes; + fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); + + TIMER_START(softmax); + fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); + TIMER_STOP(softmax); + + // V interleave (multi-thread HVX) + TIMER_START(v_interleave); + // FIX(v-stride): use n_tiles_per_bc (block-invariant) as V tile layout + // stride to match o_update's v_tile access. Using per-block n_col_tiles + // misplaces DV_tile 1..3 in the last partial KV block. + fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); + TIMER_STOP(v_interleave); + + // O update (inline HMX on main thread) + TIMER_START(o_update); + { + const size_t DV_tiles = (size_t) (DV / 32); + const __fp16 * restrict d_base = factx.vtcm_d_tiles; + const __fp16 * restrict p_base = factx.vtcm_p_tiles; + const __fp16 * restrict v_base = factx.vtcm_v_tiles; + const __fp16 * restrict op_base = o_tile_prev; + __fp16 * restrict oc_base = o_tile_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + + const __fp16 * p_tile_in = p_base + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + const __fp16 * v_tile_in = v_base + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_col_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); + p_tile_in += HMX_FP16_TILE_N_ELMS; + v_tile_in += HMX_FP16_TILE_N_ELMS; + } + + __fp16 * o_tile_out = oc_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(o_tile_out, 0); + } + } + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + TIMER_STOP(o_update); + + buf_idx = 1 - buf_idx; + } // end KV block loop (fallback) + } + + // ---- Final normalization: O = diag(1/l) @ O ---- + TIMER_START(o_norm); + { + fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br); + + // HMX: O_final = diag(1/l) @ O_prev + if (factx.use_pipeline) { + on_job.o_curr = o_tile_curr; + on_job.o_prev = o_tile_prev; + on_job.d_tiles = factx.vtcm_d_tiles; + on_job.hmx_scales = factx.vtcm_hmx_scales_id; + on_job.n_row_tiles = n_row_tiles; + on_job.n_row_tiles_g_br = n_row_tiles_g_br; + on_job.DV = DV; + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_fa_o_norm_worker, &on_job)); + hmx_queue_pop(ctx->hmx_queue); + } else { + const size_t DV_tiles = (size_t) (DV / 32); + const __fp16 * restrict d_base = factx.vtcm_d_tiles; + const __fp16 * restrict op_base = o_tile_prev; + __fp16 * restrict oc_base = o_tile_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + __fp16 * o_out = oc_base + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; + + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + Q6_mxmem_AR_after_hf(o_out, 0); + } + } + } + } + TIMER_STOP(o_norm); + + // ---- Store O block ---- + TIMER_START(o_store); + fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g); + TIMER_STOP(o_store); + +#undef MASK_DMA_PUSH +#undef MASK_DMA_POP +#undef DMA_PREFETCH_KV + + } // end Q block loop + } // end KV head loop + } // end batch loop + + if (factx.use_pipeline) { + hmx_queue_suspend(ctx->hmx_queue); + } else { + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + } + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total), + TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave)); + FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax), + TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store)); +#endif + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 05e3c6c2..2666a78a 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -28,6 +28,8 @@ #include "hmx-queue.h" #include "hmx-profile.h" +#include "vtcm-utils.h" + static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, }; @@ -43,40 +45,11 @@ static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, }; -// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile. -// word[i] = i*128 maps K-row-pair i to byte offset i*128 in the tile. -// Column offset (n*4) is added at runtime. Only entries 0..15 are used (masked by predicate). -static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { - 0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128, - 8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128, - 16*128, 17*128, 18*128, 19*128, 20*128, 21*128, 22*128, 23*128, - 24*128, 25*128, 26*128, 27*128, 28*128, 29*128, 30*128, 31*128 -}; - // Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes #define HMX_X4X2_SCALES_PER_BLK 8 #define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL) #define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4) -static inline void swap_ptr(void **p1, void **p2) { - void *t = *p1; - *p1 = *p2; - *p2 = t; -} - -typedef struct { - uint8_t *dst; - const uint8_t *src; - dma_queue *dma; - size_t n_rows; - size_t src_stride; // DDR row stride (full row_stride) - size_t dst_stride; // VTCM sub-block row stride - size_t quant_off; // quant byte offset in each DDR row - size_t quant_width; // quant bytes to copy per row - size_t scale_off; // scale byte offset in each DDR row - size_t scale_width; // scale bytes to copy per row -} qweight_fetch_task_state_t; - // Compute the byte stride of one row in x4x2 format. // Numerically equals ggml_row_size(type, k) when k is 256-aligned, because // x4x2 packing has the same density as block_q4_0 / block_q8_0. @@ -202,46 +175,6 @@ next_nc: return 0; } -// forward declaration – defined after transfer_activation_chunk_fp32_to_fp16 -void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride); - -// Scatter row-major FP16 weight (already in VTCM scratch) directly into transposed [K][N] tiles. -// vtcm_src: [n_cols][k] row-major fp16 in VTCM scratch buffer -// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16 -static void interleave_fp16_weight_chunk_to_tiles(__fp16 *restrict vtcm_dst, - const __fp16 *restrict vtcm_src, - int n_cols, int k) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - assert(k % HMX_FP16_TILE_N_COLS == 0); - - const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; - const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - for (int r = 0; r < n_cols; r += 2) { - int ct = r / HMX_FP16_TILE_N_ROWS; // N-dimension tile index - int local_r = r % HMX_FP16_TILE_N_ROWS; // intra-tile row index - const bool next_row_valid = (r + 1) < n_cols; - - // Offset vectors for N-columns local_r and local_r+1, reused across K-tiles. - HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); - HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); - - for (int c = 0; c < k; c += HMX_FP16_TILE_N_COLS) { - int kt = c / HMX_FP16_TILE_N_COLS; - int tile_idx = ct * n_k_tiles + kt; - __fp16 *tile_base = vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS; - - HVX_Vector v0 = hvx_vmemu(vtcm_src + r * k + c); - HVX_Vector v1 = next_row_valid ? hvx_vmemu(vtcm_src + (r + 1) * k + c) : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off0, v0); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off1, v1); - } - } -} - // --- x4x2 format dequantizers --- // Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. @@ -303,8 +236,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( } // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. -static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx( - const int8_t *quants_32, const __fp16 *scale) { +static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { HVX_Vector vq = hvx_vmemu(quants_32); HVX_Vector v_scales = hvx_vec_splat_f16(*scale); HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); @@ -414,8 +346,8 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128 // maps to K-rows 2i and 2i+1. Column offset (n*4) added per row. - const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes) unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index @@ -658,12 +590,12 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; state.n_tot_tiles = n_tot_tiles; state.n_tiles_per_task = n_tiles_per_task; - state.dst = vtcm_dst; - state.src = (const uint8_t *)vtcm_src; - state.n_cols = n_cols; - state.k_block = k_block; - state.row_stride = row_stride; - state.weight_type = weight_type; + state.dst = vtcm_dst; + state.src = (const uint8_t *)vtcm_src; + state.n_cols = n_cols; + state.k_block = k_block; + state.row_stride = row_stride; + state.weight_type = weight_type; worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads); } @@ -733,7 +665,7 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, job->n_dot_tiles = n_dot_tiles; } -// --- End async HMX matmul job --- +// output : fp16 -> f32p static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); @@ -807,421 +739,304 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads); } -static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { - return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; -} +// activations : fp32 -> fp16 -static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { - return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; -} +static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) { + for (int r = 0; r < n_rows; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx -static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, - int dst_b2, int dst_b3) { - const int r2 = hmx_matmul_batch_r2(params); - const int r3 = hmx_matmul_batch_r3(params); - return (const __fp16 *) ((const uint8_t *) params->permuted_weight + - (size_t) (dst_b2 / r2) * params->src0_nb2 + - (size_t) (dst_b3 / r3) * params->src0_nb3); -} + const bool next_row_valid = (r + 1) < n_rows; -static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (const float *) ((const uint8_t *) params->activation + - (size_t) dst_b2 * params->src1_nb2 + - (size_t) dst_b3 * params->src1_nb3); -} + const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); + const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); -static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (float *) ((uint8_t *) params->dst + - (size_t) dst_b2 * params->dst_nb2 + - (size_t) dst_b3 * params->dst_nb3); -} + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); -static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params) { - int ret = 0; - for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { - for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { - ret = hmx_mat_mul_permuted_w16a32(ctx, - hmx_matmul_dst_batch_ptr(params, b2, b3), - hmx_matmul_activation_batch_ptr(params, b2, b3), - hmx_matmul_weight_batch_ptr(params, b2, b3), - params->m, params->k, params->n, - params->act_stride, params->weight_stride); + // compute output position + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; } } - return ret; } -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { - if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } - if (!params->m || !params->k || !params->n) { return -1; } - if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } - if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } - if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } - if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + int k_stride; +} activation_transfer_task_state_t; - if (!hex_is_aligned(params->dst, VLEN) || - !hex_is_aligned(params->activation, VLEN) || - !hex_is_aligned(params->permuted_weight, VLEN)) { +static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + // one chunk: one row + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + __fp16 *dst = st->dst + chunk_idx * st->k_block; + const float *src = st->src + chunk_idx * st->k_stride; + transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); + } +} + +static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { + assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); + assert(VLEN == 32 * sizeof(float)); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address + + activation_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.src = src; + state.k_block = k_block; + state.k_stride = k_stride; + + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); +} + +// + +#define FALLBACK_TO_STANDARD 1 + +// C += AB +static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, + const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, + int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)col_scales); + + const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t i = 0; i < n_row_tiles; ++i) { + const __fp16 *row_base = a + i * dot_tile_stride; + __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t j = 0; j < n_col_tiles; ++j) { + Q6_mxclracc_hf(); + + const __fp16 *col_tiles = b + j * dot_tile_stride; + const __fp16 *row_tiles = row_base; + __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; + if (!zero_init) { + Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); + } + + for (int k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(accum_tile, 0); + } + } +} + +static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, + float *restrict out, const float *restrict x, const uint8_t *restrict w, + int m, int k, int n, int weight_type) { + // assume k % 32 == 0 && n % 32 == 0 + const size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { return -1; } - const int group_size = hmx_matmul_batch_r2(params); + const size_t vtcm_budget = ctx->vtcm_size; - if (group_size <= 1) { - FARF(MEDIUM, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + const size_t K_BLOCK_SIZE = 1024; + + // Fallback: if k doesn't need K-blocking, out-stationary has no advantage + const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; + if (k_iters_check <= 1) { + FARF(HIGH, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); + return FALLBACK_TO_STANDARD; } - // Grouped path: reuse interleaved weight across all q_heads sharing a - // kv_head. Each q_head gets its own activation buffer in VTCM (so - // activation is loaded once per m_chunk and reused across all n_chunks), - // and each q_head is computed individually to avoid tile-major packing - // issues. m_chunk_n_rows is always a multiple of 32 (from - // hmx_compute_chunks), so per-head tile arrays don't overlap. - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = params->k * sizeof(__fp16); + // Dynamic M,N search via hmx_compute_chunks + const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); + const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) + + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) + const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) + + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) + const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) - // When the activation has a large stride (e.g. permuted Q tensor with - // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. - // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather - // strided rows into a contiguous block before the F32->F16 conversion. - const bool use_dma_activation = (params->act_stride > params->k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; + // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; + // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin + const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; + const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, - /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, - /*per_mn=*/sizeof(__fp16), params->m, params->n, - /*m_block_cost=*/(size_t) params->n, - /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); - } - - const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; - - if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { - FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - FARF(MEDIUM, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, params->m, params->k, params->n, group_size, params->ne13, - m_chunk_n_rows, n_chunk_n_cols, - (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - - TIMER_DEFINE(activation_load); - TIMER_DEFINE(weight_load); - TIMER_DEFINE(hmx_core); - TIMER_DEFINE(output_store); - TIMER_DEFINE(total); - - TIMER_START(total); - - const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (int b3 = 0; b3 < params->ne13; ++b3) { - for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { - const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); - - for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); - - // Pre-load activations for all heads in the group (once per m_chunk). - // When the source is strided (permuted Q), use 2D DMA to gather - // contiguous rows into a VTCM scratch buffer first, then HVX - // converts from the contiguous VTCM buffer. This avoids L2 cache - // thrashing from HVX loads at large strides. - TIMER_START(activation_load); - for (int g = 0; g < group_size; ++g) { - const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; - __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) params->k * sizeof(float); - const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - vtcm_f32_act, (int) n_rows, - params->k, params->k); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - activation_chunk, (int) n_rows, - params->k, params->act_stride); - } - } - TIMER_STOP(activation_load); - - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - { - const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } - - for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); - - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); - - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < (size_t) params->n) { - const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } - - interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k); - swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); - - // Reuse the interleaved weight for every q_head in this GQA group - for (int g = 0; g < group_size; ++g) { - TIMER_START(hmx_core); - { - const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, - params->k / 32); - } - TIMER_STOP(hmx_core); - - TIMER_START(output_store); - { - float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; - transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); - } - TIMER_STOP(output_store); - } - } - } - } - } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - - TIMER_STOP(total); - -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total), - params->m, params->k, params->n, group_size); - FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", - TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); -#endif - - return 0; -} - -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, - const __fp16 *restrict permuted_weight, int m, int k, int n, - int act_stride, int weight_stride) { - if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - if (act_stride < k || weight_stride < k) { return -1; } - if (k % 32 != 0 || n % 32 != 0) { return -1; } - - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; - } - - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); - - // DMA-based activation gather for strided tensors (see batched path comment). - const bool use_dma_activation = (act_stride > k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, - /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, // W + S0 + S1 - /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch - /*per_mn=*/sizeof(__fp16), // O - m, n, - /*m_block_cost=*/(size_t) n, - /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; + // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. + // From profiling: wt_dequant per element ≈ 1.5× activation load per element. + // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). + // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). + const size_t m_block_cost = (size_t) n * 3; + const size_t n_block_cost = (size_t) m * 2; + if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, + &N_BLOCK_SIZE, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); return -1; } - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + // Compute precise buffer sizes from searched M,N and fixed K + const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); + const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); - // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; + if (total_vtcm > vtcm_budget) { + FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, + vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); return -1; } + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); + uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); + uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); + __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); + + FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, + M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + + // initialize eye tile (32x32 identity matrix) + { + HVX_Vector v; + v = Q6_V_vzero(); + v = Q6_Vw_vinsert_VwR(v, 0x3c000000); + v = Q6_V_vror_VR(v, VLEN - 4); + v = Q6_Vw_vinsert_VwR(v, 0x00003c00); + for (int i = 0; i < 16; ++i) { + ((HVX_Vector *) vtcm_eye_tile)[i] = v; + v = Q6_V_vror_VR(v, VLEN - 8); + } + } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(MEDIUM, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - - TIMER_DEFINE(activation_load); - TIMER_DEFINE(weight_load); - TIMER_DEFINE(hmx_core); - TIMER_DEFINE(output_store); - - TIMER_DEFINE(total); - TIMER_START(total); + TIMER_DEFINE(fetch); + TIMER_DEFINE(act_load); + TIMER_DEFINE(wt_dequant); + TIMER_DEFINE(core); HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { + size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); + for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { + size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * act_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) k * sizeof(float); - const size_t stride_bytes = (size_t) act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_activation, - vtcm_f32_act, n_rows, k, k); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_activation, - activation_chunk, n_rows, k, act_stride); - } - } - TIMER_STOP(activation_load); + const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); - const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); + for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { + const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; + TIMER_START(fetch); + // fetch activation block into VTCM + { + const float *activation_block = x + mr * k + kk; - // issue async DMA for the first weight chunk - // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. - // The source rows can be strided (e.g. KV-cache K after ggml_permute). - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready - - // issue async DMA for the next weight chunk (double buffering) - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_scratch1, activation_block), + k_blk_sz * sizeof(float), + k * sizeof(float), + k_blk_sz * sizeof(float), + m_blk_sz); } - // interleave row-major fp16 from scratch into tile-major in vtcm_weight - interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *)buf_curr, n_cols, k); + // fetch weight block into VTCM (x4x2 sub-block: quants + scales) + const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); + { + const int blk_start = kk / QK_Q4_0x4x2; + const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; + const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); + const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; + uint8_t *dst = vtcm_scratch0; + const uint8_t *src = w + nc * row_stride; + const size_t n_rows = n_blk_sz; + const size_t src_stride = row_stride; + const size_t dst_stride = sub_row_stride; + const size_t quant_off = (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); + const size_t quant_width = (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); + const size_t scale_off = full_qrow + blk_start * scale_blk_size; + const size_t scale_width = nb_sub * scale_blk_size; - swap_ptr(&buf_curr, &buf_next); + // 2D DMA: quants sub-range + dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows); + // 2D DMA: scales sub-range + dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows); + } + TIMER_STOP(fetch); + + TIMER_START(act_load); + // load activation block + { + dma_queue_pop(ctx->dma[0]); // wait for act DNA + transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); + } + TIMER_STOP(act_load); + + TIMER_START(wt_dequant); + // dequantize weight block + { + dma_queue_pop(ctx->dma[0]); + dma_queue_pop(ctx->dma[0]); + // vtcm_scratch0 is used to store the qweight chunk + // worker_pool_run_func already returned, so fetch is done + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, + n_blk_sz, k_blk_sz, sub_row_stride, weight_type); + } + TIMER_STOP(wt_dequant); + + // core mma + TIMER_START(core); + { + core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, + n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); + } + TIMER_STOP(core); } - TIMER_STOP(weight_load); - TIMER_START(hmx_core); + // store output block { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); + float *output_block = out + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); } - TIMER_STOP(hmx_core); - - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); } - } HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - TIMER_STOP(total); - #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); - FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", - TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); - { - size_t weight_size = (size_t)k * n * sizeof(__fp16); - float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); - FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); - } + FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", + TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); #endif - return 0; } -int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m, - int k, int n, int w_type); - -#define FALLBACK_TO_STANDARD 1 - int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, int weight_type) { @@ -1238,7 +1053,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds if (rc != FALLBACK_TO_STANDARD) { return rc; // 0 success, -1 error } - FARF(MEDIUM, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); + FARF(HIGH, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); // fall through to standard path } @@ -1247,31 +1062,46 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds return -1; } - FARF(MEDIUM, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); + FARF(HIGH, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); // --- Dynamic VTCM layout --- const size_t vtcm_budget = ctx->vtcm_size; const size_t vec_dot_size = k * sizeof(__fp16); - const bool use_pipeline = (m >= 128) && (k <= n); - // Select cost parameters based on execution path - size_t per_n_cost, per_mn_cost; - if (use_pipeline) { - per_n_cost = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - per_mn_cost = 2 * sizeof(__fp16); // O x 2 (output double buffer) - } else { - per_n_cost = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) - per_mn_cost = sizeof(__fp16); // O x 1 - } + // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. + // Only pays off when the chunker yields >=2 n-chunks, so the main loop can + // overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for + // double-buffered output and the worker-dispatch overhead are pure loss. + // Try pipeline costs first; fall back to sequential if the layout collapses + // to one n-chunk. m >= 128 floor keeps HMX utilization reasonable. + const size_t pipe_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) + const size_t pipe_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) + const size_t seq_per_n = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) + const size_t seq_per_mn = sizeof(__fp16); // O x 1 size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // Quantized weight: dequant ~1.5x more expensive per element than activation load. - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost, m, n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d pipe=%d budget=%zu)", - __func__, m, k, n, use_pipeline, vtcm_budget); - return -1; + bool use_pipeline = false; + + if (m >= 128) { + size_t mc = 0, nc = 0, used = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, m, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && + hmx_ceil_div((size_t) n, nc) >= 2) { + m_chunk_n_rows = mc; + n_chunk_n_cols = nc; + vtcm_used = used; + use_pipeline = true; + } + } + + if (!use_pipeline) { + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, m, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } } // Compute precise buffer sizes per execution path @@ -1308,7 +1138,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", + FARF(HIGH, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", __func__, m, k, n, weight_type, use_pipeline, m_chunk_n_rows, n_chunk_n_cols, (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); @@ -1321,7 +1151,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds TIMER_DEFINE(total); TIMER_START(total); - FARF(MEDIUM, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", + FARF(HIGH, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); @@ -1368,7 +1198,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight. dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type); - swap_ptr(&buf_curr, &buf_next); + hex_swap_ptr(&buf_curr, &buf_next); } TIMER_STOP(weight_load); @@ -1511,300 +1341,417 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds return 0; } -// C += AB -void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, - int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(n_dot_tiles > 0); +// - Q6_bias_mxmem2_A((void *)col_scales); +static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; +} - const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t i = 0; i < n_row_tiles; ++i) { - const __fp16 *row_base = a + i * dot_tile_stride; - __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t j = 0; j < n_col_tiles; ++j) { - Q6_mxclracc_hf(); +static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; +} - const __fp16 *col_tiles = b + j * dot_tile_stride; - const __fp16 *row_tiles = row_base; - __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; - if (!zero_init) { - Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); - } +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + const int r2 = hmx_matmul_batch_r2(params); + const int r3 = hmx_matmul_batch_r3(params); + return (const __fp16 *) ((const uint8_t *) params->permuted_weight + + (size_t) (dst_b2 / r2) * params->src0_nb2 + + (size_t) (dst_b3 / r3) * params->src0_nb3); +} - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; - } - Q6_mxmem_AR_after_hf(accum_tile, 0); +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (const float *) ((const uint8_t *) params->activation + + (size_t) dst_b2 * params->src1_nb2 + + (size_t) dst_b3 * params->src1_nb3); +} + +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (float *) ((uint8_t *) params->dst + + (size_t) dst_b2 * params->dst_nb2 + + (size_t) dst_b3 * params->dst_nb3); +} + +static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_w16a32_batched_params_t *params) { + int ret = 0; + for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { + for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { + ret = hmx_mat_mul_permuted_w16a32(ctx, + hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); } } + return ret; } -static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, - int k_block, int k_stride) { - for (int r = 0; r < n_rows; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx +int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { + if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } + if (!params->m || !params->k || !params->n) { return -1; } + if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } + if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } + if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } + if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } - const bool next_row_valid = (r + 1) < n_rows; - - const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); - const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - // compute output position - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } -} - -typedef struct { - __fp16 *dst; - const float *src; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int k_block; - int k_stride; -} activation_transfer_task_state_t; - -static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { - activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; - - for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { - // one chunk: one row - int chunk_idx = task_id * st->n_chunks_per_task; - size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); - - __fp16 *dst = st->dst + chunk_idx * st->k_block; - const float *src = st->src + chunk_idx * st->k_stride; - transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); - } -} - -void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { - assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); - assert(VLEN == 32 * sizeof(float)); - - size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address - - activation_transfer_task_state_t state; - state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; - state.n_tot_chunks = n_tot_chunks; - state.n_chunks_per_task = n_chunks_per_task; - state.dst = dst; - state.src = src; - state.k_block = k_block; - state.k_stride = k_stride; - - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); -} - -int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, - int m, int k, int n, int weight_type) { - // assume k % 32 == 0 && n % 32 == 0 - const size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { + if (!hex_is_aligned(params->dst, VLEN) || + !hex_is_aligned(params->activation, VLEN) || + !hex_is_aligned(params->permuted_weight, VLEN)) { return -1; } - const size_t vtcm_budget = ctx->vtcm_size; + const int group_size = hmx_matmul_batch_r2(params); - const size_t K_BLOCK_SIZE = 1024; - - // Fallback: if k doesn't need K-blocking, out-stationary has no advantage - const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; - if (k_iters_check <= 1) { - FARF(MEDIUM, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); - return FALLBACK_TO_STANDARD; + if (group_size <= 1) { + FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); } - // Dynamic M,N search via hmx_compute_chunks - const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); - const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) - + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) - const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) - + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) - const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) - // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; - // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin - const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; - const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment + // Grouped path: reuse interleaved weight across all q_heads sharing a + // kv_head. Each q_head gets its own activation buffer in VTCM (so + // activation is loaded once per m_chunk and reused across all n_chunks), + // and each q_head is computed individually to avoid tile-major packing + // issues. m_chunk_n_rows is always a multiple of 32 (from + // hmx_compute_chunks), so per-head tile arrays don't overlap. + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = params->k * sizeof(__fp16); - size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; - // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. - // From profiling: wt_dequant per element ≈ 1.5× activation load per element. - // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). - // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). - const size_t m_block_cost = (size_t) n * 3; - const size_t n_block_cost = (size_t) m * 2; - if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, - &N_BLOCK_SIZE, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; + // When the activation has a large stride (e.g. permuted Q tensor with + // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. + // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather + // strided rows into a contiguous block before the F32->F16 conversion. + const bool use_dma_activation = (params->act_stride > params->k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, + /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, + /*per_mn=*/sizeof(__fp16), params->m, params->n, + /*m_block_cost=*/(size_t) params->n, + /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); } - // Compute precise buffer sizes from searched M,N and fixed K - const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); - const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); - - const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; - if (total_vtcm > vtcm_budget) { - FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, - vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); - return -1; - } + const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); - uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); - uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); - __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; - FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, - M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - - // initialize eye tile (32x32 identity matrix) - { - HVX_Vector v; - v = Q6_V_vzero(); - v = Q6_Vw_vinsert_VwR(v, 0x3c000000); - v = Q6_V_vror_VR(v, VLEN - 4); - v = Q6_Vw_vinsert_VwR(v, 0x00003c00); - for (int i = 0; i < 16; ++i) { - ((HVX_Vector *) vtcm_eye_tile)[i] = v; - v = Q6_V_vror_VR(v, VLEN - 8); - } + if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { + FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); } + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - TIMER_DEFINE(fetch); - TIMER_DEFINE(act_load); - TIMER_DEFINE(wt_dequant); - TIMER_DEFINE(core); + FARF(HIGH, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, params->m, params->k, params->n, group_size, params->ne13, + m_chunk_n_rows, n_chunk_n_cols, + (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + TIMER_DEFINE(total); + + TIMER_START(total); + + const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { - size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); - for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { - size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); + for (int b3 = 0; b3 < params->ne13; ++b3) { + for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { + const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); - const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); + for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); - for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { - const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); + // Pre-load activations for all heads in the group (once per m_chunk). + // When the source is strided (permuted Q), use 2D DMA to gather + // contiguous rows into a VTCM scratch buffer first, then HVX + // converts from the contiguous VTCM buffer. This avoids L2 cache + // thrashing from HVX loads at large strides. + TIMER_START(activation_load); + for (int g = 0; g < group_size; ++g) { + const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; + __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) params->k * sizeof(float); + const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + vtcm_f32_act, (int) n_rows, + params->k, params->k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride); + } + } + TIMER_STOP(activation_load); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; - TIMER_START(fetch); - // fetch activation block into VTCM { - const float *activation_block = x + mr * k + kk; - - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_scratch1, activation_block), - k_blk_sz * sizeof(float), - k * sizeof(float), - k_blk_sz * sizeof(float), - m_blk_sz); + const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); } - // fetch weight block into VTCM (x4x2 sub-block: quants + scales) - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); - { - qweight_fetch_task_state_t s; + for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); - const int blk_start = kk / QK_Q4_0x4x2; - const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); - const int scale_blk_size = - (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); - s.dst = vtcm_scratch0; - s.src = w + nc * row_stride; - s.n_rows = n_blk_sz; - s.src_stride = row_stride; - s.dst_stride = sub_row_stride; - s.quant_off = - (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); - s.quant_width = - (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); - s.scale_off = full_qrow + blk_start * scale_blk_size; - s.scale_width = nb_sub * scale_blk_size; + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < (size_t) params->n) { + const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; - // 2D DMA: quants sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off), - s.dst_stride, s.src_stride, s.quant_width, s.n_rows); - // 2D DMA: scales sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off), - s.dst_stride, s.src_stride, s.scale_width, s.n_rows); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k, params->k, + 0, n_cols); + hex_swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + // Reuse the interleaved weight for every q_head in this GQA group + for (int g = 0; g < group_size; ++g) { + TIMER_START(hmx_core); + { + const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, + params->k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); + } + TIMER_STOP(output_store); + } } - TIMER_STOP(fetch); - - TIMER_START(act_load); - // load activation block - { - dma_queue_pop(ctx->dma[0]); // wait for act DNA - transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); - } - TIMER_STOP(act_load); - - TIMER_START(wt_dequant); - // dequantize weight block - { - dma_queue_pop(ctx->dma[0]); - dma_queue_pop(ctx->dma[0]); - // vtcm_scratch0 is used to store the qweight chunk - // worker_pool_run_func already returned, so fetch is done - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, - n_blk_sz, k_blk_sz, sub_row_stride, weight_type); - } - TIMER_STOP(wt_dequant); - - // core mma - TIMER_START(core); - { - core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, - n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); - } - TIMER_STOP(core); - } - - // store output block - { - float *output_block = out + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); } } } HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + TIMER_STOP(total); + #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", - TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total), + params->m, params->k, params->n, group_size); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); #endif + + return 0; +} + +// + +int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const __fp16 *restrict permuted_weight, int m, int k, int n, + int act_stride, int weight_stride) { + if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } + if (act_stride < k || weight_stride < k) { return -1; } + if (k % 32 != 0 || n % 32 != 0) { return -1; } + + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } + + // --- Dynamic VTCM layout --- + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = k * sizeof(__fp16); + + // DMA-based activation gather for strided tensors (see batched path comment). + const bool use_dma_activation = (act_stride > k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. + if (hmx_compute_chunks(vtcm_budget, + /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, // W + S0 + S1 + /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch + /*per_mn=*/sizeof(__fp16), // O + m, n, + /*m_block_cost=*/(size_t) n, + /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + + // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { + FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + FARF(HIGH, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + + TIMER_DEFINE(total); + TIMER_START(total); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + // transfer activation matrix chunk into VTCM + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + + TIMER_START(activation_load); + { + const float *activation_chunk = activation + mr * act_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) k * sizeof(float); + const size_t stride_bytes = (size_t) act_stride * sizeof(float); + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_activation, + vtcm_f32_act, n_rows, k, k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_activation, + activation_chunk, n_rows, k, act_stride); + } + } + TIMER_STOP(activation_load); + + const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + // issue async DMA for the first weight chunk + // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. + // The source rows can be strided (e.g. KV-cache K after ggml_permute). + { + const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready + + // issue async DMA for the next weight chunk (double buffering) + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < n) { + const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + // interleave row-major fp16 from scratch into tile-major in vtcm_weight + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, k, k, 0, n_cols); + + hex_swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + TIMER_START(hmx_core); + { + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); + } + TIMER_STOP(output_store); + } + + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + { + size_t weight_size = (size_t)k * n * sizeof(__fp16); + float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); + FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + } +#endif + return 0; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index fb95d36f..1c78ffad 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -61,6 +61,9 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, int m, int k, int n, int weight_type); +// HMX flash attention +int hmx_flash_attn_ext(struct htp_ops_context * octx); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h index af04619c..68f174d6 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -4,6 +4,9 @@ #ifndef HMX_UTILS_H #define HMX_UTILS_H +#include "hvx-base.h" + +#include #include #include @@ -12,21 +15,188 @@ #define HMX_FP16_TILE_N_ELMS 1024 #define HMX_FP16_TILE_SIZE 2048 -#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline)) - // Initialise aligned 256-byte area with scale vector + zero padding. -static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { - HVX_Vector *pv = (HVX_Vector *)out_scales; - *pv++ = v_scale; - *pv = Q6_V_vzero(); +static inline void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { + volatile HVX_Vector *pv = (HVX_Vector *) out_scales; + pv[0] = v_scale; + pv[1] = Q6_V_vzero(); } -// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) --- +// --- Shared scatter offsets and interleave helper --- -static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { - uint8_t *p = *vtcm_ptr; - *vtcm_ptr += size; - return p; +// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile. +// word[i] = i*128 maps K-row-pair i to byte offset i*128. +// Column offset (n*4) is added at runtime. Entries 0..15 cover one tile (region 2047); +// entries 16..31 cover the next adjacent tile (region 4095) — pick region size at the +// call site to scatter into one tile (masked) or two contiguous tiles (unmasked). +static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { + 0 * 128, 1 * 128, 2 * 128, 3 * 128, 4 * 128, 5 * 128, 6 * 128, 7 * 128, 8 * 128, 9 * 128, 10 * 128, + 11 * 128, 12 * 128, 13 * 128, 14 * 128, 15 * 128, 16 * 128, 17 * 128, 18 * 128, 19 * 128, 20 * 128, 21 * 128, + 22 * 128, 23 * 128, 24 * 128, 25 * 128, 26 * 128, 27 * 128, 28 * 128, 29 * 128, 30 * 128, 31 * 128, +}; + +// Scatter row-major FP16 data (in VTCM scratch) into transposed [K][N] tiles. +// vtcm_src: [n_cols][src_stride] row-major fp16 (only first k elements per row are used) +// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16 +// Processes rows [start_row, end_row) for multi-thread slicing. +// Full range: start_row=0, end_row=n_cols. +static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, + const __fp16 * restrict vtcm_src, + int n_cols, + int k, + int src_stride, + int start_row, + int end_row) { + assert(k % HMX_FP16_TILE_N_COLS == 0); + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + // Each hvx_vmemu load brings 64 fp16 = 128 bytes covering 2 adjacent K-tiles. + // When n_k_tiles is even, scatter into 2 K-tiles per call (region 4095, no mask) + // using the upper half of hmx_transpose_scatter_offsets. Tail one K-tile (when + // n_k_tiles is odd) falls back to single-tile masked scatter. + const bool pair_scatter = (n_k_tiles & 1) == 0; + const size_t pair_region = (size_t) (2 * HMX_FP16_TILE_SIZE - 1); + const size_t single_region = (size_t) (HMX_FP16_TILE_SIZE - 1); + __builtin_assume(k > 0); + __builtin_assume(end_row > start_row); + + if (pair_scatter) { + // Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter. + const int c_step = 2 * HMX_FP16_TILE_N_COLS; + const size_t c_byte_step = (size_t) c_step * sizeof(__fp16); + const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS; + const int n_c_iters = k / c_step; + + for (int r = start_row; r < end_row; r += 2) { + const int ct = r / HMX_FP16_TILE_N_ROWS; + const int local_r = r % HMX_FP16_TILE_N_ROWS; + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols; + const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + if (p1) { + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); + p1 += c_byte_step; + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1); + tile_base += dst_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero); + tile_base += dst_step; + } + } + } + } else { + // Fallback: scatter one K-tile per call (region 2047, masked). + const int c_step = HMX_FP16_TILE_N_COLS; + const size_t c_byte_step = (size_t) c_step * sizeof(__fp16); + const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS; + const int n_c_iters = k / c_step; + + for (int r = start_row; r < end_row; r += 2) { + const int ct = r / HMX_FP16_TILE_N_ROWS; + const int local_r = r % HMX_FP16_TILE_N_ROWS; + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols; + const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + if (p1) { + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); + p1 += c_byte_step; + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1); + tile_base += dst_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero); + tile_base += dst_step; + } + } + } + } +} + +// Interleave row-major FP16 data into column-major tile format. +// Input: [n_rows, head_dim] row-major. Output: tile[dim_tile][row_tile]. +// Processes rows [start_row, end_row) for multi-thread slicing. +// Full range: start_row=0, end_row=n_rows. +static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out, + const __fp16 * restrict src, + int n_rows, + int head_dim, + int src_stride, + int n_row_tiles, + int start_row, + int end_row) { + __builtin_assume(head_dim > 0); + const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS; + + for (int r = start_row; r < end_row; r += 2) { + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows; + + const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride); + const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL; + + // Row-pair invariants hoisted out of the c loop. + const int r0 = r / HMX_FP16_TILE_N_ROWS; + const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2; + + // tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0). + // Each c step (+= 64) advances both by 2 dim-tiles worth of fp16. + __fp16 * tb0 = tiles_out + (size_t) r0 * HMX_FP16_TILE_N_ELMS; + __fp16 * tb1 = tb0 + tile_stride_elms; + const size_t tb_step = 2 * tile_stride_elms; + + if (pv_in1) { + for (int c = 0; c < head_dim; c += 64) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2); + ((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp); + ((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp); + tb0 += tb_step; + tb1 += tb_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int c = 0; c < head_dim; c += 64) { + HVX_Vector v0 = *pv_in0++; + HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2); + ((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp); + ((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp); + tb0 += tb_step; + tb1 += tb_step; + } + } + } } #endif // HMX_UTILS_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index d0926ded..f6cb0295 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -77,6 +77,12 @@ static inline int32_t hvx_vec_get_i32(HVX_Vector v) { return x; } +static inline _Float16 hvx_vec_get_f16(HVX_Vector v) { + _Float16 __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 2, v); + return x; +} + static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { // abs by clearing the fp16 sign bit HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index 851482e0..a3e33c3b 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -7,7 +7,8 @@ #include "hvx-base.h" -#define hvx_splat_loop_body(dst_type, vec_store) \ +#define hvx_splat_pragma(x) _Pragma(#x) +#define hvx_splat_loop_body(dst_type, vec_store, unroll_cnt) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ \ @@ -16,7 +17,7 @@ \ uint32_t i = 0; \ \ - _Pragma("unroll(4)") \ + hvx_splat_pragma(unroll(unroll_cnt)) \ for (; i < nvec; i++) { \ vdst[i] = src; \ } \ @@ -25,31 +26,47 @@ } \ } while(0) -static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { +static inline void hvx_splat_a(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { assert((unsigned long) dst % 128 == 0); - hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a); + hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a, 4); } -static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { - hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u); +static inline void hvx_splat_u(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { + hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u, 4); } -static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f32_a(void * restrict dst, float v, uint32_t n) { hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float)); } -static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f32_u(void * restrict dst, float v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float)); } -static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) { +static inline void hvx_splat_f16_a(void * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } -static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) { +static inline void hvx_splat_f16_u(void * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } +static inline void hvx_splat_u16_a(void * restrict dst, uint16_t v, uint32_t n) { + hvx_splat_a(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t)); +} + +static inline void hvx_splat_u16_u(void * restrict dst, uint16_t v, uint32_t n) { + hvx_splat_u(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t)); +} + +static inline void hvx_splat_u8_a(void * restrict dst, uint8_t v, uint32_t n) { + hvx_splat_a(dst, Q6_Vb_vsplat_R(v), n, 1); +} + +static inline void hvx_splat_u8_u(void * restrict dst, uint8_t v, uint32_t n) { + hvx_splat_u(dst, Q6_Vb_vsplat_R(v), n, 1); +} + #define hvx_copy_loop_body(dst_type, src_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ diff --git a/ggml/src/ggml-hexagon/htp/vtcm-utils.h b/ggml/src/ggml-hexagon/htp/vtcm-utils.h new file mode 100644 index 00000000..b129fb74 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/vtcm-utils.h @@ -0,0 +1,16 @@ +#ifndef VTCM_UTILS_H +#define VTCM_UTILS_H + +#include "hex-utils.h" + +#include +#include +#include + +static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { + uint8_t *p = *vtcm_ptr; + *vtcm_ptr += size; + return p; +} + +#endif // VTCM_UTILS_H