hexagon: MUL_MAT, MUL_MAT_ID, FLASH_ATTN and GDN cleanup and optimizations for latest models (llama/23989)

* hex-mm: initial support for F32 * F32 -> F32 matmuls

* hex-rms-norm: fix src1 stride use in fused rms_norm_mul

* hex-ops: clear spad pointers in the ops that clober it

This fixes an odd case where fused rms-norm-mul was failing but only in qwen3.5-2B and only at searth op-bath sizes.

* hmx-mm: add support for F32 * F32 -> F32 matmul_2d on HMX

Decided to use Q4_0 * F32 -> F32 matmul for this.
Q4_0 gets dequantized and tiled into F16, and here we quantize and tile F32 into F16.
Super simple and pretty efficient.

* hmx-mm: route f16 2D matmuls through the same kernel used for all other types

* hmx-mm: re-introduce pipelined vs non-pipelined mode that we used to have but is much more generic way

This update futher improves matmul performance and at the same time removes most of the redudant logic
we had in different paths.

* hmx-fa: slighlty improved pipeline simimar to matmul updates

* hmx-mm: initial version of MAT_MUL_ID support for HMX

* hmx-mm: fixed mxfp4 handling for MUL_MAT_ID

* hex-gdn: optimize GATED_DELTA_NET

DMA prefetch/double-buff, vectorize everything with HVX, in other words -- the usual :)

* hmx-mm: missed one more case where we can use fastmod

* hexagon: update DCVS settings for a slight perf bump

* hmx-fa: use fastdiv in hmx-flash-attn

* hmx-fa: precompute slope values to avoid disrupting the inner loop

* hvx-utils/fa: new HVX helpers for powf and logf and using those to speed up FA alibi

* hex-ops: fixed a bug in fusion logic that was messing up the order of the src tensors when some srcs are empty

* hex-fa: correctly fallback to HVX if we have sinks or the dims are not quite right
This commit is contained in:
Max Krasnyansky 2026-06-01 23:40:08 -07:00 committed by Georgi Gerganov
parent 754247f28b
commit 8d61a9edf0
20 changed files with 1859 additions and 626 deletions

View File

@ -1927,6 +1927,7 @@ struct ggml_hexagon_opbatch {
size_t extra_tens = 0;
auto fit_tensor = [&](const ggml_tensor *t) {
if (!t) return;
if (!t_map.count(t)) {
extra_tens++;
@ -2602,6 +2603,27 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n");
return false;
}
if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) {
GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n");
return false;
}
if (ggml_nrows(src1) > 1024) {
return false; // no huge batches (for now)
}
break;
case GGML_TYPE_F32:
if (src1->type != GGML_TYPE_F32) {
return false;
}
if (src0->nb[1] < src0->nb[0]) {
GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F32 src0 not supported\n");
return false;
}
if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) {
GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n");
return false;
}
if (ggml_nrows(src1) > 1024) {
return false; // no huge batches (for now)
}

View File

@ -56,7 +56,7 @@ struct htp_opnode {
}
std::vector<const ggml_tensor *> get_inputs() const {
std::vector<const ggml_tensor *> inputs;
std::vector<const ggml_tensor *> inputs(GGML_MAX_SRC, nullptr);
std::vector<const ggml_tensor *> outputs;
outputs.push_back(node);
for (const auto * f : fused) {
@ -70,20 +70,38 @@ struct htp_opnode {
return false;
};
int count = 0;
auto add_input = [&](const ggml_tensor * t) {
if (t && !contains(outputs, t) && !contains(inputs, t)) {
inputs.push_back(t);
if (count < (int)inputs.size()) {
inputs[count++] = t;
} else {
inputs.push_back(t);
}
}
};
for (int i = 0; i < GGML_MAX_SRC && node->src[i]; i++) {
add_input(node->src[i]);
}
for (const auto * f : fused) {
for (int i = 0; i < GGML_MAX_SRC && f->src[i]; i++) {
add_input(f->src[i]);
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (fused.empty()) {
inputs[i] = node->src[i];
} else {
if (node->src[i]) {
add_input(node->src[i]);
}
}
}
for (const auto * f : fused) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (f->src[i]) {
add_input(f->src[i]);
}
}
}
if (!fused.empty()) {
inputs.resize(count);
}
return inputs;
}
@ -108,6 +126,9 @@ struct htp_opformat {
char names[64 * GGML_MAX_SRC];
int format_tensor_dims(char * str, const struct ggml_tensor * t) {
if (!t) {
return sprintf(str, "NONE");
}
if (t->ne[2] == 1 && t->ne[3] == 1) {
return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
} else {
@ -136,6 +157,9 @@ struct htp_opformat {
}
int format_tensor_strides(char * str, const struct ggml_tensor * t) {
if (!t) {
return sprintf(str, "NONE");
}
const char * c = ggml_is_contiguous(t) ? "" : "!";
if (t->ne[2] == 1 && t->ne[3] == 1) {
@ -170,11 +194,11 @@ struct htp_opformat {
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += sprintf(p, "%s", ggml_type_name(inputs[0]->type));
p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE");
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += sprintf(p, "%s", ggml_type_name(inputs[i]->type));
p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE");
}
p += sprintf(p, " -> ");
@ -184,7 +208,7 @@ struct htp_opformat {
}
const char * tensor_buff_name(const struct ggml_tensor * t) {
if (t->buffer) {
if (t && t->buffer) {
return ggml_backend_buffer_name(t->buffer);
}
return "NONE";
@ -213,11 +237,11 @@ struct htp_opformat {
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += sprintf(p, "%s", inputs[0]->name);
p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE");
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += sprintf(p, "%s", inputs[i]->name);
p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE");
}
p += sprintf(p, " -> ");

View File

@ -19,6 +19,43 @@ add_library(${HTP_LIB} SHARED
htp_iface_skel.c
worker-pool.c
hex-dma.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
# HMX acceleration: available on v73+ architectures
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-matmul-ops.c
hmx-flash-attn-ops.c
hmx-queue.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
target_sources(${HTP_LIB} PRIVATE
matmul-ops.c
binary-ops.c
unary-ops.c
@ -42,40 +79,6 @@ add_library(${HTP_LIB} SHARED
pad-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
# HMX acceleration: available on v73+ architectures
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
install(TARGETS ${HTP_LIB})

View File

@ -276,6 +276,7 @@ int op_argsort(struct htp_ops_context * octx) {
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src0_spad.size = total_spad_size;
octx->src0_spad.size_per_thread = spad_per_thread;
octx->src0_spad.src = NULL;
FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3],

View File

@ -262,6 +262,8 @@ int op_concat(struct htp_ops_context * octx) {
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->src0_spad.src = NULL;
octx->src1_spad.src = NULL;
if (type_size == 4) {
worker_func = concat_2d_f32_transposed;

View File

@ -11,6 +11,7 @@
#include "hex-dma.h"
#include "hvx-utils.h"
#include "hvx-dump.h"
#include "hvx-flash-attn.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
@ -245,6 +246,7 @@ struct htp_fa_context {
uint32_t n_head_log2;
float m0;
float m1;
float slopes[512];
uint32_t n_blocks;
@ -412,7 +414,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
}
const uint32_t h = iq2; // head index
const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
const float slope = factx->slopes[h];
HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
@ -628,8 +630,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
}
#ifdef HTP_HAS_HMX
// HMX path: head_dim multiple of 32, F16 KV
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0) {
// HMX path: head_dim multiple of 64, F16 KV, and no sinks
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) {
int ret = hmx_flash_attn_ext(octx);
if (ret == HTP_STATUS_OK) {
return ret;
@ -689,6 +691,13 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
if (n_head > 512) {
return HTP_STATUS_NO_SUPPORT;
}
for (uint32_t h = 0; h < n_head; ++h) {
factx.slopes[h] = (max_bias > 0.0f) ? alibi_slope(h, factx.n_head_log2, factx.m0, factx.m1) : 1.0f;
}
// total rows in q
const uint32_t neq0 = q->ne[0];
const uint32_t neq1 = q->ne[1];

View File

@ -3,6 +3,7 @@
#include <string.h>
#include "hvx-utils.h"
#include "hex-fastdiv.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
@ -14,106 +15,103 @@
#define HTP_GDN_MAX_SV 128
struct htp_gdn_context {
struct htp_ops_context * octx;
uint32_t rows_per_thread;
size_t state_bytes;
bool use_vtcm;
uint8_t * vtcm_state_base;
size_t vtcm_state_per_thread;
size_t state_bytes;
uint8_t * vtcm_base;
size_t vtcm_per_thread;
};
static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul,
const float * restrict dot, uint32_t n) {
static inline HVX_Vector gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, const float * restrict dot, uint32_t n) {
HVX_Vector acc = Q6_V_vzero();
const uint32_t epv = 128 / sizeof(float);
const uint32_t epv = 128 / sizeof(float);
const uint32_t nvec = n / epv;
const uint32_t tail = n % epv;
const uint32_t nloe = n % epv;
for (uint32_t i = 0; i < nvec; ++i) {
HVX_Vector vd = hvx_vmemu(dst + i * epv);
HVX_Vector vm = hvx_vmem(mul + i * epv);
HVX_Vector vd = hvx_vmemu(dst + i * epv);
HVX_Vector vm = hvx_vmem(mul + i * epv);
HVX_Vector vdot = hvx_vmem(dot + i * epv);
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
hvx_vmemu(dst + i * epv) = out;
acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
}
if (tail) {
if (nloe) {
const uint32_t off = nvec * epv;
HVX_Vector vd = hvx_vmemu(dst + off);
HVX_Vector vm = hvx_vmem(mul + off);
HVX_Vector vd = hvx_vmemu(dst + off);
HVX_Vector vm = hvx_vmem(mul + off);
HVX_Vector vdot = hvx_vmem(dot + off);
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
hvx_vec_store_u(dst + off, tail * sizeof(float), out);
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
hvx_vec_store_u(dst + off, nloe * sizeof(float), out);
HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
}
return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
return hvx_vec_reduce_sum_f32(acc);
}
static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul,
const float * restrict dot, uint32_t n) {
static inline HVX_Vector gdn_mul_scalar_dot_f32(float * restrict dst, float mul, const float * restrict dot, uint32_t n) {
HVX_Vector acc = Q6_V_vzero();
const HVX_Vector vmul = hvx_vec_splat_f32(mul);
const uint32_t epv = 128 / sizeof(float);
const uint32_t epv = 128 / sizeof(float);
const uint32_t nvec = n / epv;
const uint32_t tail = n % epv;
const uint32_t nloe = n % epv;
for (uint32_t i = 0; i < nvec; ++i) {
HVX_Vector vd = hvx_vmemu(dst + i * epv);
HVX_Vector vd = hvx_vmemu(dst + i * epv);
HVX_Vector vdot = hvx_vmem(dot + i * epv);
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
hvx_vmemu(dst + i * epv) = out;
acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
}
if (tail) {
if (nloe) {
const uint32_t off = nvec * epv;
HVX_Vector vd = hvx_vmemu(dst + off);
HVX_Vector vd = hvx_vmemu(dst + off);
HVX_Vector vdot = hvx_vmem(dot + off);
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
hvx_vec_store_u(dst + off, tail * sizeof(float), out);
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
hvx_vec_store_u(dst + off, nloe * sizeof(float), out);
HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
}
return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
return hvx_vec_reduce_sum_f32(acc);
}
static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src,
float scale, const float * restrict dot, uint32_t n) {
static inline HVX_Vector gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src,
HVX_Vector vscale, const float * restrict dot, uint32_t n) {
HVX_Vector acc = Q6_V_vzero();
const HVX_Vector vscale = hvx_vec_splat_f32(scale);
const uint32_t epv = 128 / sizeof(float);
const uint32_t epv = 128 / sizeof(float);
const uint32_t nvec = n / epv;
const uint32_t tail = n % epv;
const uint32_t nloe = n % epv;
for (uint32_t i = 0; i < nvec; ++i) {
HVX_Vector vd = hvx_vmemu(dst + i * epv);
HVX_Vector vs = hvx_vmem(src + i * epv);
HVX_Vector vd = hvx_vmemu(dst + i * epv);
HVX_Vector vs = hvx_vmem(src + i * epv);
HVX_Vector vdot = hvx_vmem(dot + i * epv);
HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
hvx_vmemu(dst + i * epv) = out;
acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
}
if (tail) {
if (nloe) {
const uint32_t off = nvec * epv;
HVX_Vector vd = hvx_vmemu(dst + off);
HVX_Vector vs = hvx_vmem(src + off);
HVX_Vector vd = hvx_vmemu(dst + off);
HVX_Vector vs = hvx_vmem(src + off);
HVX_Vector vdot = hvx_vmem(dot + off);
HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
hvx_vec_store_u(dst + off, tail * sizeof(float), out);
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
hvx_vec_store_u(dst + off, nloe * sizeof(float), out);
HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
}
return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
return hvx_vec_reduce_sum_f32(acc);
}
static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1,
@ -126,7 +124,7 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1
const uint32_t epv = 128 / sizeof(float);
const uint32_t nvec = n / epv;
const uint32_t tail = n % epv;
const uint32_t nloe = n % epv;
for (uint32_t i = 0; i < nvec; ++i) {
HVX_Vector vm = hvx_vmem(mul + i * epv);
HVX_Vector vdot = hvx_vmem(dot + i * epv);
@ -147,11 +145,11 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
}
if (tail) {
if (nloe) {
const uint32_t off = nvec * epv;
HVX_Vector vm = hvx_vmem(mul + off);
HVX_Vector vm = hvx_vmem(mul + off);
HVX_Vector vdot = hvx_vmem(dot + off);
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
HVX_Vector zero = Q6_V_vzero();
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm);
@ -159,10 +157,10 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm);
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm);
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
@ -185,7 +183,7 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri
const uint32_t epv = 128 / sizeof(float);
const uint32_t nvec = n / epv;
const uint32_t tail = n % epv;
const uint32_t nloe = n % epv;
for (uint32_t i = 0; i < nvec; ++i) {
HVX_Vector vdot = hvx_vmem(dot + i * epv);
@ -205,10 +203,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
}
if (tail) {
if (nloe) {
const uint32_t off = nvec * epv;
HVX_Vector vdot = hvx_vmem(dot + off);
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
HVX_Vector zero = Q6_V_vzero();
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul);
@ -216,10 +214,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul);
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul);
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
@ -246,7 +244,7 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri
const uint32_t epv = 128 / sizeof(float);
const uint32_t nvec = n / epv;
const uint32_t tail = n % epv;
const uint32_t nloe = n % epv;
for (uint32_t i = 0; i < nvec; ++i) {
HVX_Vector vs = hvx_vmem(src + i * epv);
HVX_Vector vdot = hvx_vmem(dot + i * epv);
@ -267,11 +265,11 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
}
if (tail) {
if (nloe) {
const uint32_t off = nvec * epv;
HVX_Vector vs = hvx_vmem(src + off);
HVX_Vector vdot = hvx_vmem(dot + off);
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
HVX_Vector zero = Q6_V_vzero();
HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0));
@ -279,10 +277,10 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri
HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2));
HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3));
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
@ -310,7 +308,7 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1
const uint32_t epv = 128 / sizeof(float);
const uint32_t nvec = n / epv;
const uint32_t tail = n % epv;
const uint32_t nloe = n % epv;
for (uint32_t i = 0; i < nvec; ++i) {
HVX_Vector vm = hvx_vmem(mul + i * epv);
HVX_Vector vdot = hvx_vmem(dot + i * epv);
@ -343,11 +341,11 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1
acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
}
if (tail) {
if (nloe) {
const uint32_t off = nvec * epv;
HVX_Vector vm = hvx_vmem(mul + off);
HVX_Vector vdot = hvx_vmem(dot + off);
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
HVX_Vector zero = Q6_V_vzero();
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm);
@ -359,14 +357,14 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1
HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm);
HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm);
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4);
hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5);
hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6);
hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7);
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
@ -400,7 +398,7 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri
const uint32_t epv = 128 / sizeof(float);
const uint32_t nvec = n / epv;
const uint32_t tail = n % epv;
const uint32_t nloe = n % epv;
for (uint32_t i = 0; i < nvec; ++i) {
HVX_Vector vdot = hvx_vmem(dot + i * epv);
@ -432,10 +430,10 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri
acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
}
if (tail) {
if (nloe) {
const uint32_t off = nvec * epv;
HVX_Vector vdot = hvx_vmem(dot + off);
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
HVX_Vector zero = Q6_V_vzero();
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul);
@ -447,14 +445,14 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri
HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul);
HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul);
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4);
hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5);
hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6);
hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7);
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
@ -496,7 +494,7 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri
const uint32_t epv = 128 / sizeof(float);
const uint32_t nvec = n / epv;
const uint32_t tail = n % epv;
const uint32_t nloe = n % epv;
for (uint32_t i = 0; i < nvec; ++i) {
HVX_Vector vs = hvx_vmem(src + i * epv);
HVX_Vector vdot = hvx_vmem(dot + i * epv);
@ -529,11 +527,11 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri
acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
}
if (tail) {
if (nloe) {
const uint32_t off = nvec * epv;
HVX_Vector vs = hvx_vmem(src + off);
HVX_Vector vdot = hvx_vmem(dot + off);
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
HVX_Vector zero = Q6_V_vzero();
HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0));
@ -545,14 +543,14 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri
HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6));
HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7));
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4);
hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5);
hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6);
hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7);
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
@ -605,26 +603,65 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
float local_sums[4] __attribute__((aligned(128)));
float local_sums[32] __attribute__((aligned(128)));
dma_queue * dma = octx->ctx->dma[ith];
size_t state_aligned = (size_t) S_v * S_v * sizeof(float);
state_aligned = (state_aligned + 127) & ~(size_t)127;
float * s_work[2];
s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith);
s_work[1] = s_work[0] + state_aligned / sizeof(float);
struct fastdiv_values fd_H = init_fastdiv_values(H);
struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]);
struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]);
struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3);
struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3);
const uint64_t state_seq_stride = state->nb[2] / sizeof(float);
const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs;
const int64_t shift = (int64_t) n_tokens - (int64_t) K;
for (uint32_t ir = ith; ir < total_rows; ir += nth) {
const uint32_t iv1 = ir % H;
const uint32_t iv3 = ir / H;
uint32_t ir_prefetch = ith;
int spad_idx = 0;
const uint32_t iq1 = iv1 % q->ne[1];
const uint32_t ik1 = iv1 % k->ne[1];
const uint32_t iq3 = iv3 / rq3;
const uint32_t ik3 = iv3 / rk3;
// Prefetch preamble (up to 2 steps)
for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) {
const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
// Push dummy write-back
dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]),
S_v * sizeof(float), S_v * sizeof(float),
S_v * sizeof(float), 0);
// Push fetch
dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in),
S_v * sizeof(float), S_v * sizeof(float),
S_v * sizeof(float), S_v);
ir_prefetch += nth;
spad_idx ^= 1;
}
int curr_spad_idx = 0;
for (uint32_t ir = ith; ir < total_rows; ir += nth) {
dma_queue_pop(dma);
dma_queue_pop(dma);
float * s_work_curr = s_work[curr_spad_idx];
const uint32_t iv1 = fastmodulo(ir, H, &fd_H);
const uint32_t iv3 = fastdiv(ir, &fd_H);
const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1);
const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1);
const uint32_t iq3 = fastdiv(iv3, &fd_rq3);
const uint32_t ik3 = fastdiv(iv3, &fd_rk3);
float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v;
memcpy(s_out, s_in, gctx->state_bytes);
float * s_work = s_out;
float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v;
@ -640,57 +677,117 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data +
(uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]);
memcpy(local_q, q_t, (size_t) S_v * sizeof(float));
memcpy(local_k, k_t, (size_t) S_v * sizeof(float));
hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v);
hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v);
if (kda) {
hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false);
uint32_t j = 0;
for (; j + 4 <= S_v; j += 4) {
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums);
float local_delta_b[4] __attribute__((aligned(128)));
for (uint32_t r = 0; r < 4; ++r) {
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
}
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
for (uint32_t r = 0; r < 4; ++r) {
attn_data[j + r] = local_sums[r] * scale;
}
for (; j + 8 <= S_v; j += 8) {
float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v;
float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v;
float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v;
float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v;
gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
local_gate, local_k, S_v, local_sums);
float local_delta_b[32] __attribute__((aligned(128)));
HVX_Vector vv_t = hvx_vmemu(v_t + j);
HVX_Vector v_local_sums = hvx_vmem(local_sums);
HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
local_k, local_delta_b, local_q, S_v, local_sums);
HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn);
}
for (; j + 4 <= S_v; j += 4) {
float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums);
float local_delta_b[32] __attribute__((aligned(128)));
HVX_Vector vv_t = hvx_vmemu(v_t + j);
HVX_Vector v_local_sums = hvx_vmem(local_sums);
HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn);
}
HVX_Vector vscale_splat = hvx_vec_splat_f32(scale);
for (; j < S_v; ++j) {
float * row = s_work + (uint64_t) j * S_v;
const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
const float dj = (v_t[j] - sum) * beta_val;
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
float * row = s_work_curr + (uint64_t) j * S_v;
HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]);
HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val));
HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v);
attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat));
}
} else {
const float gate = expf(g_t[0]);
uint32_t j = 0;
for (; j + 4 <= S_v; j += 4) {
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums);
float local_delta_b[4] __attribute__((aligned(128)));
for (uint32_t r = 0; r < 4; ++r) {
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
}
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
for (uint32_t r = 0; r < 4; ++r) {
attn_data[j + r] = local_sums[r] * scale;
}
for (; j + 8 <= S_v; j += 8) {
float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v;
float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v;
float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v;
float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v;
gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
gate, local_k, S_v, local_sums);
float local_delta_b[32] __attribute__((aligned(128)));
HVX_Vector vv_t = hvx_vmemu(v_t + j);
HVX_Vector v_local_sums = hvx_vmem(local_sums);
HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
local_k, local_delta_b, local_q, S_v, local_sums);
HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn);
}
for (; j + 4 <= S_v; j += 4) {
float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums);
float local_delta_b[32] __attribute__((aligned(128)));
HVX_Vector vv_t = hvx_vmemu(v_t + j);
HVX_Vector v_local_sums = hvx_vmem(local_sums);
HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn);
}
HVX_Vector vscale_splat = hvx_vec_splat_f32(scale);
for (; j < S_v; ++j) {
float * row = s_work + (uint64_t) j * S_v;
const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
const float dj = (v_t[j] - sum) * beta_val;
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
float * row = s_work_curr + (uint64_t) j * S_v;
HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]);
HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val));
HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v);
attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat));
}
}
@ -698,17 +795,40 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
const int64_t target_slot = (int64_t) t - shift;
if (target_slot >= 0 && target_slot < (int64_t) K) {
float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
if (curr_state_o != s_work) {
memcpy(curr_state_o, s_work, gctx->state_bytes);
if (curr_state_o != s_out) {
hvx_copy_f32_uu((uint8_t *) curr_state_o, (const uint8_t *) s_work_curr, S_v * S_v);
}
}
}
attn_data += (uint64_t) S_v * H;
}
// Push real write-back
dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr),
S_v * sizeof(float), S_v * sizeof(float),
S_v * sizeof(float), S_v);
// Prefetch next block (if any)
if (ir_prefetch < total_rows) {
const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in),
S_v * sizeof(float), S_v * sizeof(float),
S_v * sizeof(float), S_v);
ir_prefetch += nth;
spad_idx ^= 1;
}
curr_spad_idx ^= 1;
}
dma_queue_flush(dma);
}
static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_gdn_context * gctx = (struct htp_gdn_context *) data;
struct htp_ops_context * octx = gctx->octx;
@ -743,41 +863,64 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
float local_sums[8] __attribute__((aligned(128)));
float local_sums[32] __attribute__((aligned(128)));
dma_queue * dma = octx->ctx->dma[ith];
size_t state_aligned = (size_t) S_v * S_v * sizeof(float);
state_aligned = (state_aligned + 127) & ~(size_t)127;
float * s_work[2];
s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith);
s_work[1] = s_work[0] + state_aligned / sizeof(float);
uint8_t * spad = NULL;
if (gctx->use_vtcm) {
spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith;
}
struct fastdiv_values fd_H = init_fastdiv_values(H);
struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]);
struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]);
struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3);
struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3);
const uint64_t state_seq_stride = state->nb[2] / sizeof(float);
const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs;
for (uint32_t ir = ith; ir < total_rows; ir += nth) {
const uint32_t iv1 = ir % H;
const uint32_t iv3 = ir / H;
uint32_t ir_prefetch = ith;
int spad_idx = 0;
const uint32_t iq1 = iv1 % q->ne[1];
const uint32_t ik1 = iv1 % k->ne[1];
const uint32_t iq3 = iv3 / rq3;
const uint32_t ik3 = iv3 / rk3;
// Prefetch preamble (up to 2 steps)
for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) {
const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
// Push dummy write-back
dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]),
S_v * sizeof(float), S_v * sizeof(float),
S_v * sizeof(float), 0);
// Push fetch
dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in),
S_v * sizeof(float), S_v * sizeof(float),
S_v * sizeof(float), S_v);
ir_prefetch += nth;
spad_idx ^= 1;
}
int curr_spad_idx = 0;
for (uint32_t ir = ith; ir < total_rows; ir += nth) {
dma_queue_pop(dma);
dma_queue_pop(dma);
float * s_work_curr = s_work[curr_spad_idx];
const uint32_t iv1 = fastmodulo(ir, H, &fd_H);
const uint32_t iv3 = fastdiv(ir, &fd_H);
const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1);
const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1);
const uint32_t iq3 = fastdiv(iv3, &fd_rq3);
const uint32_t ik3 = fastdiv(iv3, &fd_rk3);
float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v;
float * s_work;
if (spad) {
dma_queue_push(dma, dma_make_ptr(spad, s_in),
S_v * sizeof(float), S_v * sizeof(float),
S_v * sizeof(float), S_v);
dma_queue_pop(dma);
s_work = (float *) spad;
} else {
s_work = s_out;
memcpy(s_work, s_in, gctx->state_bytes);
}
float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v;
@ -792,111 +935,145 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data +
(uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]);
memcpy(local_q, q_t, (size_t) S_v * sizeof(float));
memcpy(local_k, k_t, (size_t) S_v * sizeof(float));
hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v);
hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v);
if (kda) {
hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false);
uint32_t j = 0;
for (; j + 8 <= S_v; j += 8) {
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
float * row4 = s_work + (uint64_t) (j + 4) * S_v;
float * row5 = s_work + (uint64_t) (j + 5) * S_v;
float * row6 = s_work + (uint64_t) (j + 6) * S_v;
float * row7 = s_work + (uint64_t) (j + 7) * S_v;
float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v;
float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v;
float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v;
float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v;
gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
local_gate, local_k, S_v, local_sums);
float local_delta_b[8] __attribute__((aligned(128)));
for (uint32_t r = 0; r < 8; ++r) {
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
}
float local_delta_b[32] __attribute__((aligned(128)));
HVX_Vector vv_t = hvx_vmemu(v_t + j);
HVX_Vector v_local_sums = hvx_vmem(local_sums);
HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
local_k, local_delta_b, local_q, S_v, local_sums);
for (uint32_t r = 0; r < 8; ++r) {
attn_data[j + r] = local_sums[r] * scale;
}
HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn);
}
for (; j + 4 <= S_v; j += 4) {
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums);
float local_delta_b[4] __attribute__((aligned(128)));
for (uint32_t r = 0; r < 4; ++r) {
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
}
float local_delta_b[32] __attribute__((aligned(128)));
HVX_Vector vv_t = hvx_vmemu(v_t + j);
HVX_Vector v_local_sums = hvx_vmem(local_sums);
HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
for (uint32_t r = 0; r < 4; ++r) {
attn_data[j + r] = local_sums[r] * scale;
}
HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn);
}
HVX_Vector vscale_splat = hvx_vec_splat_f32(scale);
for (; j < S_v; ++j) {
float * row = s_work + (uint64_t) j * S_v;
const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
const float dj = (v_t[j] - sum) * beta_val;
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
float * row = s_work_curr + (uint64_t) j * S_v;
HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]);
HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val));
HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v);
attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat));
}
} else {
const float gate = expf(g_t[0]);
uint32_t j = 0;
for (; j + 8 <= S_v; j += 8) {
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
float * row4 = s_work + (uint64_t) (j + 4) * S_v;
float * row5 = s_work + (uint64_t) (j + 5) * S_v;
float * row6 = s_work + (uint64_t) (j + 6) * S_v;
float * row7 = s_work + (uint64_t) (j + 7) * S_v;
float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v;
float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v;
float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v;
float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v;
gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
gate, local_k, S_v, local_sums);
float local_delta_b[8] __attribute__((aligned(128)));
for (uint32_t r = 0; r < 8; ++r) {
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
}
float local_delta_b[32] __attribute__((aligned(128)));
HVX_Vector vv_t = hvx_vmemu(v_t + j);
HVX_Vector v_local_sums = hvx_vmem(local_sums);
HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
local_k, local_delta_b, local_q, S_v, local_sums);
for (uint32_t r = 0; r < 8; ++r) {
attn_data[j + r] = local_sums[r] * scale;
}
HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn);
}
for (; j + 4 <= S_v; j += 4) {
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums);
float local_delta_b[4] __attribute__((aligned(128)));
for (uint32_t r = 0; r < 4; ++r) {
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
}
float local_delta_b[32] __attribute__((aligned(128)));
HVX_Vector vv_t = hvx_vmemu(v_t + j);
HVX_Vector v_local_sums = hvx_vmem(local_sums);
HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
for (uint32_t r = 0; r < 4; ++r) {
attn_data[j + r] = local_sums[r] * scale;
}
HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn);
}
HVX_Vector vscale_splat = hvx_vec_splat_f32(scale);
for (; j < S_v; ++j) {
float * row = s_work + (uint64_t) j * S_v;
const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
const float dj = (v_t[j] - sum) * beta_val;
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
float * row = s_work_curr + (uint64_t) j * S_v;
HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]);
HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val));
HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v);
attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat));
}
}
if (spad) {
dma_queue_push(dma, dma_make_ptr(s_out, spad),
// Push real write-back
dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr),
S_v * sizeof(float), S_v * sizeof(float),
S_v * sizeof(float), S_v);
// Prefetch next block (if any)
if (ir_prefetch < total_rows) {
const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in),
S_v * sizeof(float), S_v * sizeof(float),
S_v * sizeof(float), S_v);
dma_queue_pop(dma);
ir_prefetch += nth;
spad_idx ^= 1;
}
curr_spad_idx ^= 1;
}
dma_queue_flush(dma);
}
int op_gated_delta_net(struct htp_ops_context * octx) {
const struct htp_tensor * q = octx->src[0];
const struct htp_tensor * k = octx->src[1];
@ -952,18 +1129,11 @@ int op_gated_delta_net(struct htp_ops_context * octx) {
size_t state_aligned = (size_t) S_v * S_v * sizeof(float);
state_aligned = (state_aligned + 127) & ~(size_t)127;
gctx.use_vtcm = false;
gctx.vtcm_state_base = NULL;
gctx.vtcm_state_per_thread = 0;
assert(octx->ctx->vtcm_base != NULL);
assert(octx->ctx->vtcm_size >= 2 * state_aligned * octx->n_threads);
if (n_tokens == 1 && octx->ctx->vtcm_base) {
size_t vtcm_total = state_aligned * octx->n_threads;
if (octx->ctx->vtcm_size >= vtcm_total) {
gctx.use_vtcm = true;
gctx.vtcm_state_base = octx->ctx->vtcm_base;
gctx.vtcm_state_per_thread = state_aligned;
}
}
gctx.vtcm_base = octx->ctx->vtcm_base;
gctx.vtcm_per_thread = 2 * state_aligned;
if (n_tokens == 1) {
worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads);

View File

@ -17,14 +17,17 @@
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "hex-dma.h"
#include "hex-fastdiv.h"
#include "hmx-profile.h"
#include "hmx-queue.h"
#include "hmx-utils.h"
#include "htp-ctx.h"
#include "htp-ops.h"
#include "hvx-dump.h"
#include "hvx-copy.h"
#include "hvx-reduce.h"
#include "hvx-utils.h"
#include "hvx-flash-attn.h"
#include "vtcm-utils.h"
#include "worker-pool.h"
@ -46,7 +49,7 @@
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) {
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) {
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
@ -67,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV,
+ k_dma_size * 2 // K DMA x2
+ v_dma_size * 2 // V DMA x2
+ k_tile_size * 1 // K tiles
+ v_tile_size * 1 // V tiles
+ v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
+ s_tile_size * 2 // S + P
+ d_tile_size * 1 // D (diagonal matrix)
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
@ -144,12 +147,13 @@ static int hmx_fa_find_chunk_size(size_t * Br_out,
// See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off.
const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64
const size_t fp16 = sizeof(__fp16);
const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2);
// Approximate per-unit VTCM costs (without per-buffer alignment padding).
const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors
const size_t per_gbr2 = fp16; // D diagonal matrix
const size_t per_bc =
3 * (DK + DV) * fp16 + 2 * n_threads * fp16; // K_dma×2 + V_dma×2 + K_tile + V_tile + row bufs
3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs
const size_t per_gbr_bc = 2 * fp16; // S + P
const size_t overhead = 256 * 2 + 13 * 4096;
@ -164,7 +168,6 @@ static int hmx_fa_find_chunk_size(size_t * Br_out,
// Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS.
// Only relax when kv_len is too short to form enough blocks.
const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2);
const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) :
(kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit);
// Cost coefficients calibrated from profiling
@ -200,7 +203,7 @@ static int hmx_fa_find_chunk_size(size_t * Br_out,
}
// Exact VTCM verification (alignment padding may push over budget)
while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) {
while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) {
Bc -= bc_unit;
}
if (Bc < bc_unit) {
@ -303,6 +306,7 @@ struct hmx_fa_context {
uint32_t n_kv_heads; // number of KV heads
uint32_t n_heads; // number of Q heads
uint32_t G; // GQA factor = n_heads / n_kv_heads
struct fastdiv_values div_G;
uint32_t n_kv_blocks;
uint32_t neq1; // Q token count
@ -321,7 +325,7 @@ struct hmx_fa_context {
__fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D]
__fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D]
__fp16 * vtcm_k_tiles; // K tiles (transposed)
__fp16 * vtcm_v_tiles; // V tiles (column-major)
__fp16 * vtcm_v_tiles[2]; // V tiles (column-major, double-buffered)
__fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc]
__fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc]
__fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br]
@ -402,7 +406,9 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
return;
}
hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV,
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV,
(int) args->src_stride, (int) args->n_col_tiles, start, end);
}
@ -464,10 +470,10 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
for (size_t r = start; r < end; r += 2) {
const bool next_row_valid = (r + 1) < n_rows_g;
const size_t q_idx0 = (r + 0) / G;
const size_t h_idx0 = (r + 0) % G;
const size_t q_idx1 = (r + 1) / G;
const size_t h_idx1 = (r + 1) % G;
const size_t q_idx0 = fastdiv(r + 0, &factx->div_G);
const size_t h_idx0 = fastmodulo(r + 0, G, &factx->div_G);
const size_t q_idx1 = fastdiv(r + 1, &factx->div_G);
const size_t h_idx1 = fastmodulo(r + 1, G, &factx->div_G);
const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] +
(kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3];
@ -567,8 +573,8 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
const uint32_t ib3 = args->ib3;
for (size_t r = start; r < end; ++r) {
const size_t q_idx = r / G;
const size_t h_idx = r % G;
const size_t q_idx = fastdiv(r, &factx->div_G);
const size_t h_idx = fastmodulo(r, G, &factx->div_G);
// FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) ->
// [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2].
@ -780,11 +786,11 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
if (args->mask_vtcm) {
// Read mask from VTCM buffer (DMA'd per KV block).
// GQA dedup (scheme B): skip load when qi unchanged.
const size_t qi0 = (r + 0) / G;
const size_t qi0 = fastdiv(r + 0, &factx->div_G);
v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c);
v_mask1 = v_neg_inf;
if (r + 1 < (int) n_rows_g) {
const size_t qi1 = (r + 1) / G;
const size_t qi1 = fastdiv(r + 1, &factx->div_G);
if (qi1 == qi0) {
v_mask1 = v_mask0; // scheme B: reuse — same mask row
} else {
@ -794,8 +800,8 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
} else {
// Fallback: read mask directly from DDR (when mask->ne[2] > 1).
const struct htp_tensor * mask = args->mask;
const size_t q_idx0 = args->q_start + ((r + 0) / G);
const size_t h_idx0 = args->kv_head * G + (r + 0) % G;
const size_t q_idx0 = args->q_start + fastdiv(r + 0, &factx->div_G);
const size_t h_idx0 = args->kv_head * G + fastmodulo(r + 0, G, &factx->div_G);
const uint32_t im2_0 = h_idx0 % mask->ne[2];
const uint32_t im3_0 = args->ib3 % mask->ne[3];
@ -805,12 +811,12 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
v_mask1 = v_neg_inf;
if (r + 1 < (int) n_rows_g) {
const size_t q_idx1 = args->q_start + ((r + 1) / G);
const size_t q_idx1 = args->q_start + fastdiv(r + 1, &factx->div_G);
if (q_idx1 == q_idx0) {
// scheme B: same mask row in DDR path
v_mask1 = v_mask0;
} else {
const size_t h_idx1 = args->kv_head * G + (r + 1) % G;
const size_t h_idx1 = args->kv_head * G + fastmodulo(r + 1, G, &factx->div_G);
const uint32_t im2_1 = h_idx1 % mask->ne[2];
const uint32_t im3_1 = args->ib3 % mask->ne[3];
const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] +
@ -1191,14 +1197,13 @@ static void hmx_fa_o_norm_worker(void * data) {
// Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G.
// slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1).
// When max_bias == 0, all slopes are 1.0 (no ALiBi).
static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs,
static __attribute__((noinline)) void fa_compute_slopes(
const struct hmx_fa_context * factx,
uint32_t kv_head,
size_t n_rows_g) {
__fp16 * slopes = factx->vtcm_slopes;
if (factx->max_bias == 0.0f) {
for (size_t r = 0; r < n_rows_g; ++r) {
sargs->slopes[r] = 1.0f;
}
hvx_splat_f16_a(slopes, 1.0f, n_rows_g);
return;
}
@ -1207,10 +1212,32 @@ static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sarg
const float m0 = factx->m0;
const float m1 = factx->m1;
for (size_t r = 0; r < n_rows_g; ++r) {
const uint32_t h = kv_head * G + r % G;
sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1);
__fp16 temp_slopes[512] __attribute__((aligned(128)));
if (G <= 32) {
// Fast path: Compute G unique slope values in vector registers
HVX_Vector v_val = hvx_alibi_slopes(kv_head, G, n_head_log2, m0, m1);
__fp16 temp_slopes_aligned[64] __attribute__((aligned(128)));
hvx_vmem(temp_slopes_aligned) = hvx_vec_f32_to_f16(v_val, Q6_V_vzero());
for (uint32_t i = 0; i < G; ++i) {
temp_slopes[i] = temp_slopes_aligned[i];
}
} else {
// Fallback path: G > 32 (rare configurations)
for (uint32_t i = 0; i < G; ++i) {
temp_slopes[i] = (__fp16)alibi_slope(kv_head * G + i, n_head_log2, m0, m1);
}
}
// Allocate stack buffer to avoid scalar writes to VTCM (which generates L2 misses)
__fp16 local_slopes[n_rows_g] __attribute__((aligned(128)));
for (size_t r = 0; r < n_rows_g; ++r) {
local_slopes[r] = temp_slopes[fastmodulo(r, G, &factx->div_G)];
}
// Copy to VTCM slopes using HVX block copy (both are aligned to 128 bytes)
hvx_copy_f16_aa((uint8_t *)slopes, (const uint8_t *)local_slopes, n_rows_g);
}
// ============================================================================
@ -1254,19 +1281,22 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const uint32_t G = neq2 / n_kv_heads;
// Thread count for multi-thread HVX phases
const uint32_t n_threads = octx->n_threads;
const uint32_t n_threads_init = octx->n_threads;
// Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs)
size_t Br, Bc;
const size_t vtcm_budget = ctx->vtcm_size;
if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) {
if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads_init) != 0) {
return HTP_STATUS_VTCM_TOO_SMALL;
}
const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS);
const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc;
const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2);
const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
// Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1
const uint32_t n_threads = use_pipeline ? n_threads_init : 1;
FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu",
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget);
@ -1282,6 +1312,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
factx.n_kv_heads = n_kv_heads;
factx.n_heads = neq2;
factx.G = G;
factx.div_G = init_fastdiv_values(G);
factx.neq1 = neq1;
factx.Br = (uint32_t) Br;
factx.Bc = (uint32_t) Bc;
@ -1354,7 +1385,12 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes);
factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes);
factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes);
factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
if (use_pipeline) {
factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
} else {
factx.vtcm_v_tiles[1] = NULL;
}
factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes);
factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes);
factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes);
@ -1457,6 +1493,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
// ---- KV block loop with DMA double-buffering ----
size_t buf_idx = 0;
fa_compute_slopes(&factx, kv_head, n_rows_g);
// Prefetch first KV block
if (factx.n_kv_blocks > 0) {
const uint32_t kv_rows0 = hex_smin(Bc, nek1);
@ -1535,7 +1573,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
ou_job.o_curr = o_tile_curr;
ou_job.o_prev = o_tile_prev;
ou_job.p_tiles = factx.vtcm_p_tiles;
ou_job.v_tiles = factx.vtcm_v_tiles;
ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx];
ou_job.d_tiles = factx.vtcm_d_tiles;
ou_job.hmx_scales = factx.vtcm_hmx_scales_id;
ou_job.n_row_tiles = n_row_tiles;
@ -1550,11 +1588,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
TIMER_STOP(k_interleave);
if (kv_blk > 0) {
hmx_queue_pop(hmx_q);
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
}
// ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ----
qk_job.q_tiles = factx.vtcm_q_tiles;
qk_job.k_tiles = factx.vtcm_k_tiles;
@ -1574,6 +1607,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
TIMER_STOP(v_interleave);
// Pop and swap previous block's output update (deferred HMX pop)
if (kv_blk > 0) {
hmx_queue_pop(hmx_q);
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
}
// Pop current block's dot product job
hmx_queue_pop(hmx_q);
TIMER_STOP(qk_dot);
@ -1601,7 +1641,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
sargs.slopes = factx.vtcm_slopes;
fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g);
TIMER_START(softmax);
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
@ -1617,7 +1656,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
ou_job.o_curr = o_tile_curr;
ou_job.o_prev = o_tile_prev;
ou_job.p_tiles = factx.vtcm_p_tiles;
ou_job.v_tiles = factx.vtcm_v_tiles;
ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx];
ou_job.d_tiles = factx.vtcm_d_tiles;
ou_job.hmx_scales = factx.vtcm_hmx_scales_id;
ou_job.n_row_tiles = n_row_tiles;
@ -1712,7 +1751,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
sargs.slopes = factx.vtcm_slopes;
fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g);
TIMER_START(softmax);
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
@ -1732,7 +1770,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t DV_tiles = (size_t) (DV / 32);
const __fp16 * restrict d_base = factx.vtcm_d_tiles;
const __fp16 * restrict p_base = factx.vtcm_p_tiles;
const __fp16 * restrict v_base = factx.vtcm_v_tiles;
const __fp16 * restrict v_base = factx.vtcm_v_tiles[0];
const __fp16 * restrict op_base = o_tile_prev;
__fp16 * restrict oc_base = o_tile_curr;
__builtin_assume(n_row_tiles > 0);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,6 @@
// HMX operations compiled as a single translation unit.
// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO.
#include "hmx-queue.c"
#include "hmx-matmul-ops.c"
#include "hmx-flash-attn-ops.c"

View File

@ -52,14 +52,32 @@ int hmx_matmul_f16_f32(struct htp_context *ctx,
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params);
// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4)
int hmx_matmul_q_f32(struct htp_context *ctx,
// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4)
int hmx_matmul_2d_f32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int act_stride,
int weight_stride,
int weight_type);
struct mmid_row_mapping;
int hmx_matmul_id_2d_f32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int ne11,
size_t act_nb1, size_t act_nb2,
size_t dst_nb1, size_t dst_nb2,
int weight_stride,
int weight_type,
const struct mmid_row_mapping *matrix_rows,
int cur_a,
int mapping_stride);
// HMX flash attention
int hmx_flash_attn_ext(struct htp_ops_context * octx);

View File

@ -79,6 +79,10 @@ struct htp_context {
uint64_t max_vmem;
// Persistent DDR scratchpad for MUL_MAT_ID mappings
void * ddr_spad_base;
size_t ddr_spad_size;
struct htp_ops_context octx;
#ifdef HTP_HAS_HMX

View File

@ -0,0 +1,47 @@
#ifndef HVX_FLASH_ATTN_H
#define HVX_FLASH_ATTN_H
#include <math.h>
#include "hvx-utils.h"
// Scalar helper to compute a single ALiBi slope.
static inline float alibi_slope(uint32_t h, uint32_t n_head_log2, float m0, float m1) {
return (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1);
}
// Vectorized helper to compute 32 ALiBi slopes starting from (kv_head * G).
static inline HVX_Vector hvx_alibi_slopes(
uint32_t kv_head,
uint32_t G,
uint32_t n_head_log2,
float m0,
float m1
) {
static const float ramp_32[32] __attribute__((aligned(128))) = {
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f,
24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f
};
HVX_Vector v_ramp = hvx_vmem(ramp_32);
HVX_Vector v_h_base = hvx_vec_splat_f32((float)(kv_head * G));
HVX_Vector v_h = hvx_vec_add_f32_f32(v_h_base, v_ramp);
// Compute exponent_m0: h + 1
HVX_Vector v_exp_m0 = hvx_vec_add_f32_f32(v_h, hvx_vec_splat_f32(1.0f));
// Compute exponent_m1: 2 * (h - n_head_log2) + 1
HVX_Vector v_n_head_log2 = hvx_vec_splat_f32((float)n_head_log2);
HVX_Vector v_h_minus = hvx_vec_sub_f32_f32(v_h, v_n_head_log2);
HVX_Vector v_exp_m1 = hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(hvx_vec_splat_f32(2.0f), v_h_minus), hvx_vec_splat_f32(1.0f));
// Compute powers
HVX_Vector v_pow_m0 = hvx_vec_pow_const_base_f32(m0, v_exp_m0);
HVX_Vector v_pow_m1 = hvx_vec_pow_const_base_f32(m1, v_exp_m1);
// Select based on h < n_head_log2
HVX_VectorPred p_cond = Q6_Q_vcmp_gt_VsfVsf(v_n_head_log2, v_h); // v_n_head_log2 > v_h <=> h < n_head_log2
return Q6_V_vmux_QVV(p_cond, v_pow_m0, v_pow_m1);
}
#endif /* HVX_FLASH_ATTN_H */

View File

@ -0,0 +1,65 @@
#ifndef HVX_LOG_H
#define HVX_LOG_H
#include "hvx-base.h"
// Approximates ln(x) element-wise for float vectors.
// x must contain positive float elements.
// Uses Abramowitz & Stegun polynomial approximation 4.1.44 for ln(1+y) over [0, 1].
static inline HVX_Vector hvx_vec_log_f32(HVX_Vector x) {
// x = m * 2^e, where m in [1, 2)
HVX_Vector biased_e = Q6_Vuw_vlsr_VuwR(x, 23);
HVX_Vector e_int = Q6_Vw_vsub_VwVw(biased_e, Q6_V_vsplat_R(127));
HVX_Vector e_float = Q6_Vsf_equals_Vw(e_int);
// Extract mantissa and set exponent to 127 (which represents float value in [1.0, 2.0))
HVX_Vector mant_mask = Q6_V_vsplat_R(0x007FFFFF);
HVX_Vector exp_127 = Q6_V_vsplat_R(0x3F800000);
HVX_Vector m = Q6_V_vor_VV(Q6_V_vand_VV(x, mant_mask), exp_127);
// y = m - 1.0f, y in [0, 1)
HVX_Vector y = hvx_vec_sub_f32_f32(m, hvx_vec_splat_f32(1.0f));
// Abramowitz & Stegun 4.1.44 polynomial approximation of ln(1+y)
HVX_Vector c;
HVX_Vector res;
c = hvx_vec_splat_f32(-0.0064535442f);
res = hvx_vec_mul_f32_f32(y, c);
c = hvx_vec_splat_f32(0.0360884937f);
res = hvx_vec_add_f32_f32(res, c);
res = hvx_vec_mul_f32_f32(y, res);
c = hvx_vec_splat_f32(-0.0953293897f);
res = hvx_vec_add_f32_f32(res, c);
res = hvx_vec_mul_f32_f32(y, res);
c = hvx_vec_splat_f32(0.1676540711f);
res = hvx_vec_add_f32_f32(res, c);
res = hvx_vec_mul_f32_f32(y, res);
c = hvx_vec_splat_f32(-0.2407338084f);
res = hvx_vec_add_f32_f32(res, c);
res = hvx_vec_mul_f32_f32(y, res);
c = hvx_vec_splat_f32(0.3317990258f);
res = hvx_vec_add_f32_f32(res, c);
res = hvx_vec_mul_f32_f32(y, res);
c = hvx_vec_splat_f32(-0.4998741238f);
res = hvx_vec_add_f32_f32(res, c);
res = hvx_vec_mul_f32_f32(y, res);
c = hvx_vec_splat_f32(0.9999964239f);
res = hvx_vec_add_f32_f32(res, c);
res = hvx_vec_mul_f32_f32(y, res);
// ln(x) = e * ln(2) + ln(1+y)
HVX_Vector ln2 = hvx_vec_splat_f32(0.69314718056f);
HVX_Vector term_e = hvx_vec_mul_f32_f32(e_float, ln2);
return hvx_vec_add_f32_f32(term_e, res);
}
#endif /* HVX_LOG_H */

View File

@ -0,0 +1,42 @@
#ifndef HVX_POW_H
#define HVX_POW_H
#include <math.h>
#include "hvx-base.h"
#include "hvx-exp.h"
#include "hvx-log.h"
// Approximates base^exponent element-wise for float vectors.
// base must be a positive constant. exponent is an HVX f32 vector.
// Uses base^x = exp(x * ln(base)).
static inline HVX_Vector hvx_vec_pow_const_base_f32(float base, HVX_Vector exponent) {
float ln_base = logf(base);
HVX_Vector ln_base_v = hvx_vec_splat_f32(ln_base);
HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base_v);
static const float kInf = INFINITY;
static const float kMaxExp = 88.7228f;
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
return hvx_vec_exp_f32_guard(x, max_exp, inf);
}
// Approximates base^exponent element-wise for float vectors.
// base and exponent are HVX f32 vectors. base elements must be positive.
// Uses base^exponent = exp(exponent * ln(base)).
static inline HVX_Vector hvx_vec_pow_f32(HVX_Vector base, HVX_Vector exponent) {
HVX_Vector ln_base = hvx_vec_log_f32(base);
HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base);
static const float kInf = INFINITY;
static const float kMaxExp = 88.7228f;
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
return hvx_vec_exp_f32_guard(x, max_exp, inf);
}
#endif /* HVX_POW_H */

View File

@ -17,5 +17,7 @@
#include "hvx-floor.h"
#include "hvx-sin-cos.h"
#include "hvx-base.h"
#include "hvx-pow.h"
#include "hvx-log.h"
#endif /* HVX_UTILS_H */

View File

@ -12,6 +12,7 @@
#include <HAP_mem.h>
#include <HAP_power.h>
#include <HAP_ps.h>
#include <HAP_dcvs.h>
#include <qurt.h>
#include <qurt_thread.h>
#include <qurt_memory.h>
@ -63,8 +64,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
request.type = HAP_power_set_DCVS_v3;
request.dcvs_v3.set_dcvs_enable = TRUE;
request.dcvs_v3.dcvs_enable = TRUE;
request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE;
request.dcvs_v3.dcvs_enable = FALSE;
request.dcvs_v3.set_bus_params = TRUE;
request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX;
request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX;
@ -75,6 +75,10 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX;
request.dcvs_v3.set_sleep_disable = TRUE;
request.dcvs_v3.sleep_disable = TRUE;
#if (__HEXAGON_ARCH__ >= 79)
HAP_set_dcvs_v3_protected_bus_corners(&request, 1);
#endif
if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
return err;
}
@ -103,7 +107,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
FARF(ALWAYS, "Setting HMX clock\n");
err = HAP_power_set((void *) ctx, &request);
if (err != AEE_SUCCESS) {
FARF(ERROR, "Error setting HMX clock.");
FARF(ERROR, "ggml-hex: error setting HMX clock.");
return err;
}
}
@ -117,7 +121,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
FARF(ALWAYS, "Powering HMX on\n");
err = HAP_power_set((void *) ctx, &request);
if (err != AEE_SUCCESS) {
FARF(ERROR, "Error powering on HMX.");
FARF(ERROR, "ggml-hex: error powering on HMX.");
return err;
}
}
@ -423,10 +427,18 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
ctx->dma[i] = dma_queue_create(256); // queue depth
}
ctx->ddr_spad_size = 512 * 1024; // 512 KB
ctx->ddr_spad_base = memalign(128, ctx->ddr_spad_size);
// init worker pool
err = worker_pool_init(&ctx->worker_pool, n_hvx);
if (err != AEE_SUCCESS) {
FARF(ERROR, "Unable to create worker pool");
if (ctx->ddr_spad_base) {
free(ctx->ddr_spad_base);
ctx->ddr_spad_base = NULL;
ctx->ddr_spad_size = 0;
}
return err;
}
@ -474,6 +486,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
vtcm_free(ctx);
if (ctx->ddr_spad_base) {
free(ctx->ddr_spad_base);
ctx->ddr_spad_base = NULL;
ctx->ddr_spad_size = 0;
}
return AEE_SUCCESS;
}

View File

@ -53,6 +53,11 @@ struct htp_matmul_context {
struct fastdiv_values mm_div_ne1;
struct fastdiv_values mm_div_r2;
struct fastdiv_values mm_div_r3;
// Fields for scattered mapping & HMX support in MUL_MAT_ID
const uint32_t * matrix_row_counts;
const struct mmid_row_mapping * matrix_rows;
bool hmx_eligible;
};
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
@ -2913,6 +2918,176 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
}
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
uint32_t nloe = n % VLEN_FP32; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
*s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum));
}
static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector rsum0 = Q6_V_vzero();
HVX_Vector rsum1 = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_sf = y[i];
HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
HVX_VectorAlias va;
va.v = rsum;
s0[0] = va.fp32[0];
s0[1] = va.fp32[1];
}
static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector r0_c0_sum = Q6_V_vzero();
HVX_Vector r0_c1_sum = Q6_V_vzero();
HVX_Vector r1_c0_sum = Q6_V_vzero();
HVX_Vector r1_c1_sum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector r0_sf = x0[i];
HVX_Vector r1_sf = x1[i];
HVX_Vector c0_sf = y0[i];
HVX_Vector c1_sf = y1[i];
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]);
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
HVX_VectorAlias va0, va1;
va0.v = r0_r1_c0_sum;
va1.v = r0_r1_c1_sum;
s0[0] = va0.fp32[0];
s0[1] = va0.fp32[1];
s1[0] = va1.fp32[0];
s1[1] = va1.fp32[1];
}
static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
uint32_t nloe = n % VLEN_FP32; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
if (nloe) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
x_sf = Q6_V_vand_QV(bmask, x_sf);
y_sf = Q6_V_vand_QV(bmask, y_sf);
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
rsum = hvx_vec_reduce_sum_f32(rsum);
hvx_vec_store_u(&s[0], 4, rsum);
}
static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
@ -3331,7 +3506,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
// Process the last row (if any)
if (src0_end_row != src0_end_row_x2) {
uint32_t ir0 = src0_end_row_x2;
const int is0 = (ir0 - src0_start_row);
const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
@ -3466,7 +3641,7 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
// Process the last row (if any)
if (src0_end_row != src0_end_row_x2) {
const uint32_t ir0 = src0_end_row_x2;
const uint32_t is0 = (ir0 - src0_start_row);
const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
@ -3516,11 +3691,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
const uint32_t n_ids = ids->ne[0]; // n_expert_used
const uint32_t n_as = ne02; // n_expert
const size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0;
const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size;
const uint32_t * matrix_row_counts = mmctx->matrix_row_counts;
const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows;
const size_t dst_row_size = nb1;
const size_t src0_row_size = nb01;
@ -3542,6 +3714,10 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
continue;
}
if (mmctx->hmx_eligible) {
continue;
}
const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0);
// Prefill spad with src0 rows
@ -3583,7 +3759,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
// Process the last row (if any)
if (src0_end_row != src0_end_row_x2) {
uint32_t ir0 = src0_end_row_x2;
const uint32_t is0 = (ir0 - src0_start_row);
const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
src0_row_size_padded, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
@ -3685,7 +3861,7 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
// Process the last row (if any)
if (src0_end_row != src0_end_row_x2) {
uint32_t ir0 = src0_end_row_x2;
const uint32_t is0 = (ir0 - src0_start_row);
const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
src0_row_size_padded, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
@ -4086,6 +4262,47 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
uint32_t dst_stride = octx->src1_spad.stride;
uint64_t t1 = HAP_perf_get_qtimer_count();
const uint32_t ne0 = src->ne[0];
const uint32_t ne1 = src->ne[1];
const uint32_t ne2 = src->ne[2];
const uint32_t ne3 = src->ne[3];
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = ne0 * sizeof(float);
const size_t src_stride = src->nb[1];
uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
for (uint32_t i = ir_first; i < ir_last; ++i) {
hex_l2fetch(src_data, src_row_size, src_stride, 2);
hvx_copy_f32_au(dst_data, src_data, ne0);
dst_data += dst_stride;
src_data += src_stride;
}
uint64_t t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
@ -4328,6 +4545,60 @@ static int op_matmul_hvx(struct htp_ops_context * octx) {
mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
need_quant = false;
}
} else if (src0->type == HTP_TYPE_F32) {
// Try optimized f32-f32 path first (src1 in VTCM)
const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128);
const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256);
const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size;
const bool is_batched = (ne02 > 1) || (ne03 > 1);
const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]);
if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) {
// Optimized path
quant_job_func = quantize_f32_f32;
mmctx->type = "f32-f32";
mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1;
mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1;
mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2;
src1_row_size = f32_src1_row_size;
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
octx->src1_spad.size = octx->src1_spad.size_per_thread;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
} else {
// Fallback to DDR / broadcasting
quant_job_func = NULL;
mmctx->type = "f32-f32";
mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1;
matmul_job_func = matmul_4d;
src1_row_size = nb11;
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
// Init fastdiv for matmul_4d (supports broadcasting)
mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
need_quant = false;
}
} else {
@ -4405,20 +4676,20 @@ int op_matmul(struct htp_ops_context * octx) {
return op_matmul_hvx(octx);
}
// HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights.
// HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights.
// Other types fall back to HVX.
uint32_t wtype = src0->type;
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) {
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) {
return op_matmul_hvx(octx);
}
// Quantised HMX path requires K aligned to 256 (x4x2 super-block).
// F16 HMX path requires K aligned to 32 (tile width).
if (wtype != HTP_TYPE_F16 && src0->ne[0] % 256 != 0) {
// F16 and F32 HMX paths require K aligned to 32 (tile width).
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) {
return op_matmul_hvx(octx);
}
if (wtype == HTP_TYPE_F16 && src0->ne[0] % 32 != 0) {
if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) {
return op_matmul_hvx(octx);
}
@ -4463,8 +4734,8 @@ int op_matmul(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
if (src0->type == HTP_TYPE_F16) {
if (is_batched) {
if (is_batched) {
if (src0->type == HTP_TYPE_F16) {
hmx_matmul_f16_f32_batched_params_t batch_params = {
.dst = (float *) dst->data,
.activation = (float *) src1->data,
@ -4488,13 +4759,11 @@ int op_matmul(struct htp_ops_context * octx) {
};
ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params);
} else {
ret = hmx_matmul_f16_f32(octx->ctx,
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
m_total, k, n, act_stride, wgt_stride);
return op_matmul_hvx(octx);
}
} else {
ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
m_total, k, n, (int) src0->type);
ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type);
}
if (ret != 0) {
@ -4539,8 +4808,30 @@ int op_matmul_id(struct htp_ops_context * octx) {
size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
const size_t total_map_size = matrix_row_counts_size + matrix_row_map_size;
void * mapping_buf = NULL;
bool must_free_mapping = false;
if (octx->ctx->ddr_spad_base && total_map_size <= octx->ctx->ddr_spad_size) {
mapping_buf = octx->ctx->ddr_spad_base;
} else {
mapping_buf = memalign(128, total_map_size);
if (mapping_buf) {
must_free_mapping = true;
} else {
return HTP_STATUS_INTERNAL_ERR;
}
}
uint32_t * matrix_row_counts = (uint32_t *) mapping_buf;
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) ((uint8_t *) mapping_buf + matrix_row_counts_size);
mmctx->matrix_row_counts = matrix_row_counts;
mmctx->matrix_rows = matrix_rows;
if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
if (must_free_mapping) free(mapping_buf);
return HTP_STATUS_NO_SUPPORT;
}
@ -4552,7 +4843,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
src1_row_size = q8x4x2_row_size(ne10);
}
const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR!
htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);
size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
@ -4568,6 +4859,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
// Make sure the reserved vtcm size is sufficient
if (octx->ctx->vtcm_size < spad_size) {
FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size);
if (must_free_mapping) free(mapping_buf);
return HTP_STATUS_VTCM_TOO_SMALL;
}
@ -4587,9 +4879,6 @@ int op_matmul_id(struct htp_ops_context * octx) {
if (src1_nrows > 1) {
// initialize matrix_row_counts and map
uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0;
struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size;
memset(matrix_row_counts, 0, n_as * sizeof(uint32_t));
// group rows by src0 matrix
@ -4599,14 +4888,60 @@ int op_matmul_id(struct htp_ops_context * octx) {
assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 };
matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 };
matrix_row_counts[i02] += 1;
}
}
}
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
if (must_free_mapping) free(mapping_buf);
return HTP_STATUS_OK;
}
bool hmx_eligible = false;
#ifdef HTP_HAS_HMX
if (octx->ctx->hmx_enabled && src1_nrows > 1) {
uint32_t wtype = src0->type;
if (ne01 % 32 == 0 &&
(wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) {
if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) {
hmx_eligible = true;
} else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) {
hmx_eligible = true;
}
}
}
#endif
mmctx->hmx_eligible = hmx_eligible;
if (hmx_eligible) {
for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) {
const int32_t cne1 = matrix_row_counts[cur_a];
if (cne1 == 0) continue;
int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data,
(const uint8_t *) src0->data + cur_a * nb02,
cne1, ne00, ne01,
ne11,
nb11, nb12,
nb1, nb2,
(int) src0->nb[1], (int) src0->type,
matrix_rows, cur_a, n_ids * ids->ne[1]);
if (ret != 0) {
FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret);
if (must_free_mapping) free(mapping_buf);
return HTP_STATUS_NO_SUPPORT;
}
}
// HMX has overwritten VTCM, so force dynamic quantization cache to clear
octx->src1_spad.src = NULL;
if (must_free_mapping) free(mapping_buf);
return HTP_STATUS_OK;
}
if (octx->src1_spad.src != src1) {
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
@ -4618,5 +4953,6 @@ int op_matmul_id(struct htp_ops_context * octx) {
const uint32_t n_matmul_jobs = octx->n_threads;
worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
if (must_free_mapping) free(mapping_buf);
return HTP_STATUS_OK;
}

View File

@ -511,6 +511,8 @@ int op_pad(struct htp_ops_context * octx) {
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->src0_spad.src = NULL;
octx->dst_spad.src = NULL;
}
struct htp_pad_context pctx = {

View File

@ -692,6 +692,11 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
const uint8_t * restrict data_src1 = uctx->data_src1;
uint8_t * restrict data_dst = uctx->data_dst;
const struct htp_tensor * src1 = (htp_op == HTP_OP_RMS_NORM_MUL) ? octx->src[1] : NULL;
const uint32_t nb11 = src1 ? src1->nb[1] : 0;
const uint32_t nb12 = src1 ? src1->nb[2] : 0;
const uint32_t nb13 = src1 ? src1->nb[3] : 0;
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
@ -738,10 +743,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
src0_row_size_aligned, nb01, src0_data_row_size, block_size);
if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) {
const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03);
const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb11, nb12, nb13);
dma_queue_push(dma_queue,
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off),
uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, block_size);
uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, block_size);
}
ir += block_size;
@ -823,10 +828,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size);
if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) {
const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03);
const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb11, nb12, nb13);
dma_queue_push(dma_queue,
dma_make_ptr(src1_spad, data_src1 + src1_pref_off),
uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, pref_block_size);
uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, pref_block_size);
}
}
}
@ -977,6 +982,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
}
octx->src0_spad.src = NULL;
octx->src1_spad.src = NULL;
octx->dst_spad.src = NULL;
FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);