From 84a6b5c03903504ccfd7bfad321d8c6dc9fbd708 Mon Sep 17 00:00:00 2001 From: Shreya Jain Date: Tue, 21 Apr 2026 14:16:04 -0700 Subject: [PATCH] Hexagon: DAIG op (llama/22195) * hexagon: Add DIAG op * hexagon: add HVX support and DMA double buffering * hexagon: fix fatal error * hexagon: remove as many pragma(s) as possible --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 28 +++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/diag-ops.c | 216 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + 6 files changed, 250 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/diag-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3d68b800..5e206c5e 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2596,6 +2596,29 @@ static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // diag only supports F32 currently + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + // Input must have ne[1] == 1 (vector input) + if (src0->ne[1] != 1) { + return false; + } + + // Output must be square in first two dimensions + if (dst->ne[0] != dst->ne[1] || dst->ne[0] != src0->ne[0]) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2632,6 +2655,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_ROPE: return HTP_OP_ROPE; case GGML_OP_REPEAT: return HTP_OP_REPEAT; case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; + case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { @@ -3159,6 +3183,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_DIAG: + supp = ggml_hexagon_supported_diag(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 9ca75945..82c10b57 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -34,6 +34,7 @@ add_library(${HTP_LIB} SHARED argsort-ops.c ssm-conv.c cumsum-ops.c + diag-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/diag-ops.c b/ggml/src/ggml-hexagon/htp/diag-ops.c new file mode 100644 index 00000000..9b3194d9 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/diag-ops.c @@ -0,0 +1,216 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hex-utils.h" +#include "hvx-copy.h" +#include "hex-dma.h" + +#define htp_diag_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne02 = src0->ne[2]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_diag_context { + struct htp_ops_context * octx; + size_t src_batch_size; + size_t dst_row_size; + size_t src_batch_size_aligned; + size_t dst_row_size_aligned; + uint32_t batches_per_thread; + uint32_t total_batches; +}; + +#define htp_diag_preamble \ + struct htp_diag_context * dctx = (struct htp_diag_context *) data; \ + struct htp_ops_context * octx = dctx->octx; \ + htp_diag_tensors_preamble; + +static inline void hvx_diag_row_f32(const float * restrict src, float * restrict dst, + uint32_t row_idx, uint32_t n) { + hvx_splat_f32_a((uint8_t *) dst, 0.0f, n); + dst[row_idx] = src[row_idx]; +} + +// --------------------------------------------------------------------------- +// Per thread worker: DMA src fetch, compute in VTCM, DMA dst writeback +// --------------------------------------------------------------------------- + +static void diag_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + dma_queue * dma_queue = octx->ctx->dma[ith]; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + if (ib0 >= ib1) { + return; + } + + const size_t src_batch_size = dctx->src_batch_size; + const size_t dst_row_size = dctx->dst_row_size; + const size_t src_batch_size_aligned = dctx->src_batch_size_aligned; + const size_t dst_row_size_aligned = dctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + // 1 src buffer + 1 dst row buffer per thread in VTCM + uint8_t * src_spad = octx->src0_spad.data + (ith * src_batch_size_aligned); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const uint8_t * src_batch = src_data + i3 * nb03 + i2 * nb02; + + // Fetch source vector into VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad, src_batch), + src_batch_size_aligned, src_batch_size, 1); + dma_queue_flush(dma_queue); + + const float * src_spad_f32 = (const float *) src_spad; + float * dst_spad_f32 = (float *) dst_spad; + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + // Compute row in VTCM + hvx_diag_row_f32(src_spad_f32, dst_spad_f32, i1, ne0); + + // Write completed row back to DDR + uint8_t * dst_row = dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1; + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_row, dst_spad), + dst_row_size, dst_row_size_aligned, 1); + dma_queue_flush(dma_queue); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void diag_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const float * restrict src_batch = (const float *)(src_data + i3 * nb03 + i2 * nb02); + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + float * restrict dst_row = (float *)(dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1); + hvx_diag_row_f32(src_batch, dst_row, i1, ne0); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_diag_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_batches); + + const size_t src_batch_size = src0->ne[0] * sizeof(float); + const size_t dst_row_size = dst->ne[0] * sizeof(float); + const size_t src_batch_size_aligned = hex_round_up(src_batch_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 1 src buffer + 1 dst row buffer per thread + const size_t spad_per_thread = src_batch_size_aligned + dst_row_size_aligned; + + octx->src0_spad.size_per_thread = src_batch_size_aligned; + octx->dst_spad.size_per_thread = dst_row_size_aligned; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; + + struct htp_diag_context dctx = { + .octx = octx, + .src_batch_size = src_batch_size, + .dst_row_size = dst_row_size, + .src_batch_size_aligned = src_batch_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .batches_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_batches = total_batches, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32, &dctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32_dma, &dctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_diag(struct htp_ops_context * octx) { + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_diag_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 8b5e47ad..038941af 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -98,5 +98,6 @@ int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); +int op_diag(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 79b5ecd2..002dd1c1 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -80,6 +80,7 @@ enum htp_op_code { HTP_OP_SSM_CONV, HTP_OP_REPEAT, HTP_OP_CUMSUM, + HTP_OP_DIAG, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 5091623a..d633145c 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -514,6 +514,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_CUMSUM: return op_cumsum(octx); + case HTP_OP_DIAG: + return op_diag(octx); + case HTP_OP_INVALID: break;