diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 869c7b238..f3eccff7d 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -450,12 +450,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/repack.cpp ) if (GGML_CPU_RISCV64_SPACEMIT) + include(ggml-cpu/cmake/FindSMTIME.cmake) target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC}) list(APPEND GGML_CPU_SOURCES ggml-cpu/spacemit/ime.cpp ggml-cpu/spacemit/ime.h + ggml-cpu/spacemit/spine_mem_pool.cpp + ggml-cpu/spacemit/spine_mem_pool.h + ggml-cpu/spacemit/repack.cpp + ggml-cpu/spacemit/repack.h + ggml-cpu/spacemit/ime_env.cpp + ggml-cpu/spacemit/ime_env.h ggml-cpu/spacemit/ime1_kernels.cpp + ggml-cpu/spacemit/ime2_kernels.cpp ggml-cpu/spacemit/ime_kernels.h + ggml-cpu/spacemit/rvv_kernels.cpp + ggml-cpu/spacemit/rvv_kernels.h ) endif() if(NOT GGML_CPU_ALL_VARIANTS) @@ -485,6 +495,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZIHINTPAUSE) string(APPEND MARCH_STR "_zihintpause") endif() + if (GGML_RV_ZBA) + string(APPEND MARCH_STR "_zba") + endif() if (GGML_CPU_RISCV64_SPACEMIT) # `xsmtvdotii' is only required for GCC >= 15. if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND diff --git a/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake new file mode 100644 index 000000000..c8a4d4b4e --- /dev/null +++ b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake @@ -0,0 +1,32 @@ +include(CheckCSourceRuns) + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(riscv)" AND GGML_CPU_RISCV64_SPACEMIT) + set(SMT_MARCH_STR "-march=rv64gcv_zfh_zvfh_zba_zicbop") + if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND + CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 15) + string(APPEND SMT_MARCH_STR "_xsmtvdotii") + endif() + set(CMAKE_REQUIRED_FLAGS "${SMT_MARCH_STR}") + + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vfwmadot v2, v0, v1, fp16\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFWMADOT_FP16) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot1 v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOTN) + check_c_source_compiles("int main() {__asm__ volatile(\"vpack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK) + check_c_source_compiles("int main() {__asm__ volatile(\"vnspack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + unset(CMAKE_REQUIRED_FLAGS) + + list(APPEND RISCV64_SPACEMIT_IME_SPEC "") + if (SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + set(RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME1") + endif() + + if (SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4 AND SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK AND SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + list(APPEND RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME2") + endif() + + message("RISCV64_SPACEMIT_IME_SPEC: ${RISCV64_SPACEMIT_IME_SPEC}") +endif() diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8b7acafda..7b05edf6b 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -50,6 +50,10 @@ #include "llamafile/sgemm.h" #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT +# include "spacemit/ime.h" +#endif + // Note: once we move threading into a separate C++ file // will use std::hardware_destructive_interference_size instead of hardcoding it here // and we'll use C++ attribute syntax. @@ -3011,7 +3015,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { const struct ggml_cgraph * cgraph = tp->cgraph; const struct ggml_cplan * cplan = tp->cplan; +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(state->ith); +#else set_numa_thread_affinity(state->ith); +#endif struct ggml_compute_params params = { /*.ith =*/ state->ith, @@ -3068,6 +3076,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_barrier(state->threadpool); +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(state->ith); +#endif + return 0; } diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp index 91fe1925e..9563ea3e4 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -3,19 +3,32 @@ #include "ime.h" +#include "binary-ops.h" +#include "common.h" #include "ggml-backend-impl.h" #include "ggml-common.h" #include "ggml-cpu.h" +#include "ime_env.h" #include "ime_kernels.h" +#include "ops.h" +#include "repack.h" +#include "rvv_kernels.h" +#include "spine_mem_pool.h" #include "traits.h" +#include "vec.h" + +#include +#include +#include #include +#include #include +#include #include #include // for GGML_ASSERT #include #include - // clang-format off #if defined(__riscv) @@ -25,13 +38,17 @@ #include #endif -#if !defined(__riscv_zfh) -#error "riscv zfh extension not enabled" +#if !defined(__riscv_zfh) || !defined(__riscv_zvfh) +#error "riscv zfh extension not enabled, GGML_RV_ZFH and GGML_RV_ZVFH must be defined to 1" #endif -#if defined(RISCV64_SPACEMIT_IME1) +#if !defined(__riscv_zba) +#error "riscv zba extension not enabled, GGML_RV_ZBA must be defined to 1" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) || defined(RISCV64_SPACEMIT_IME2) #else -#error "RISCV64_SPACEMIT_IME1 not defined" +#error "RISCV64_SPACEMIT_IME1 or RISCV64_SPACEMIT_IME2 not defined" #endif #else @@ -46,370 +63,118 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif -#if defined(RISCV64_SPACEMIT_IME1) -#define QGEMM_STRIDEN_THREAD_ALIGN 16 -#else -#define QGEMM_STRIDEN_THREAD_ALIGN 32 -#endif - // clang-format on -struct qnbitgemm_spacemit_ime_args { - const float * a_ptr = nullptr; - size_t lda = 0; - const std::byte * packed_quant_b_data = nullptr; - const float * quant_b_scale = nullptr; - const void * quant_b_zp = nullptr; - const float * quant_b_blksum = nullptr; - const float * bias = nullptr; - float * c_ptr = nullptr; - size_t ldc = 0; -}; - -constexpr size_t div_round_up(size_t up, size_t down) { - return (up + down - 1) / down; -} - -constexpr size_t q8_blk_size(size_t blk_len) { - const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t); - // Currently, the strictest alignment requirement of a block is for a float. - // Ensure contiguous blocks are suitably aligned. - assert(blk_size % alignof(float) == 0); - return blk_size; +extern "C" { +extern void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value); +extern int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value); } namespace ggml::cpu::riscv64_spacemit { -const int num_ai_cores = std::thread::hardware_concurrency() / 2; - -} // namespace ggml::cpu::riscv64_spacemit - -static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len, - const size_t gemm_k, - const qnbitgemm_spacemit_ime_args * gemm_args, - void * const per_gemm_ws, - const size_t m_start, - const size_t m_count, - const size_t n_start, - const size_t n_count) { - constexpr size_t scale_stride = sizeof(uint16_t); - constexpr size_t blk_bitwidth = 4; - - const size_t k_blks = div_round_up(gemm_k, blk_len); - - const size_t lda = k_blks * q8_blk_size(blk_len); - const size_t ldc = gemm_args->ldc; - const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8); - const std::byte * quant_a_ptr = static_cast(per_gemm_ws) + m_start * lda; - - const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0; - const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); - const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride; - - float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start; - - size_t count_n = 0; - const size_t compute_block_count_n = m_count == 1 ? n_count : 16; - for (size_t n = 0; n < n_count; n += count_n) { - count_n = std::min(n_count - n, compute_block_count_n); - - const std::byte * a_row = quant_a_ptr; - const std::byte * b_col = packed_quant_b_data + n * packed_b_stride; - const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr; - float * c_blk = c_ptr + n; - - int32_t rows_remaining = m_count; - - while (rows_remaining > 0) { - const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4( - blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr, - scale_stride); - - c_blk += rows_handled * ldc; - a_row += rows_handled * lda; - - rows_remaining -= rows_handled; - } - } -} - -template constexpr int QK_0() { - if constexpr (K == 4) { - return QK4_0; - } - if constexpr (K == 8) { - return QK8_0; - } - return -1; -} - -template struct block { - ggml_half d[N]; // deltas for N qK_0 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +struct TLSContext { + int cpu_id{ -1 }; + cpu_set_t cpuset; + void * tcm_buffer{ nullptr }; + size_t tcm_buffer_size{ 0 }; }; -template struct block_with_zp { - ggml_half d[N]; // deltas for N qK_1 blocks - uint8_t zp[N]; // zero points for N qK_1 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks -}; +thread_local TLSContext tls_context; -// control size -static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); -static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), - "wrong block_with_zp<4,16> size/padding"); -static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); - -using block_q4_0x16 = block<4, 16>; -using block_q4_1x16 = block_with_zp<4, 16>; -using block_q8_0x16 = block<8, 16>; - -static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x16 out; - GGML_ASSERT(QK4_0 / blck_size_interleave == 2); - - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].d; - } - - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); - } - } - - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); - } - } - - return out; -} - -static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { - block_q4_1x16 out; - GGML_ASSERT(QK4_1 / blck_size_interleave == 2); - - for (int i = 0; i < 16; i++) { - float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); - float mid = -std::nearbyintf(m / d); - mid = std::min(15.0f, std::max(0.0f, mid)); - out.d[i] = GGML_FP32_TO_FP16(d); - out.zp[i] = static_cast(mid); - } - - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); - } - } - - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); - } - } - - return out; -} - -static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - GGML_ASSERT(interleave_block == 16); - - constexpr int nrows_interleaved = 16; - - block_q4_0x16 * dst = (block_q4_0x16 *) t->data; - const block_q4_0 * src = (const block_q4_0 *) data; - block_q4_0 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; - - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); - - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; - } - *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; - - GGML_UNUSED(data_size); -} - -static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_1); - GGML_ASSERT(interleave_block == 16); - - constexpr int nrows_interleaved = 16; - - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_1 * src = (const block_q4_1 *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_1; - - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); - - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; - } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; - - GGML_UNUSED(data_size); -} - -static inline void get_scale_min_k4(int j, - const uint8_t * GGML_RESTRICT q, - uint8_t * GGML_RESTRICT d, - uint8_t * GGML_RESTRICT m) { - if (j < 4) { - *d = q[j] & 63; - *m = q[j + 4] & 63; +template constexpr size_t get_repacked_block_type_size() { + if constexpr (std::is_same_v || std::is_same_v) { + return sizeof(block_q8_0); + } else if constexpr (std::is_same_v) { + return sizeof(block_q4_0) * INTER_SIZE / QK4_0; + } else if constexpr (std::is_same_v || std::is_same_v) { + return (sizeof(block_q4_0) + sizeof(uint8_t)) * INTER_SIZE / QK4_1; + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q2_k<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q3_k<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_mxfp4<1>); + } else if constexpr (std::is_same_v || std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q5_1<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q5_0<1>); } else { - *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); - *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + assert(false); + return 0; } } -static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 16); - GGML_ASSERT(QK_K / QK4_1 == 8); - - constexpr int nrows_interleaved = 16; - - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_K * src = (const block_q4_K *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; - - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { - return -1; +template constexpr bool block_type_has_zp() { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return false; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return true; + } else { + assert(false); + return false; } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int j = 0; j < 8; j++) { - for (int i = 0; i < nrows_interleaved; i++) { - uint8_t sc, m; - const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - const float min = - GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); - get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); - const float d1 = d * sc; - const float m1 = min * m; - - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); - // src -> [b0, b32] [b1, b33] ... [b31, b63] - // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] - const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; - if (j % 2 == 0) { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); - } - } else { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); - } - } - } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); - } - } - src += nrows_interleaved * nblocks; - } - return 0; - - GGML_UNUSED(data_size); -} - -namespace ggml::cpu::riscv64_spacemit { - -template -int repack(struct ggml_tensor *, const void *, size_t); - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); -} - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); -} - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); } class tensor_traits_base : public ggml::cpu::tensor_traits { public: - virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; + virtual int repack(ggml_tensor * t, const void * data, size_t data_size) = 0; }; template class tensor_traits : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + bool work_size(int /* n_threads */, const ggml_tensor * op, size_t & size) override { switch (op->op) { case GGML_OP_MUL_MAT: - size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; - size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); - return true; - default: - // GGML_ABORT("fatal error"); - break; - } - return false; - } + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v || std::is_same_v) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } + + size = GGML_PAD(size, sizeof(int64_t)); + + return true; + } + case GGML_OP_MUL_MAT_ID: + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v || std::is_same_v) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } + + size = GGML_PAD(size, sizeof(int64_t)); + + const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert + const int64_t ne12 = op->src[1]->ne[2]; // n_tokens + + const size_t sizeof_mmid_row_mapping = sizeof(int64_t); + size += sizeof_mmid_row_mapping * ne02 * (ne12 + 1) + (ne02 + 1) * sizeof(int64_t); + + size = GGML_PAD(size, sizeof(int64_t)); - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - if (op->src[0]->type == GGML_TYPE_Q4_0 || // - op->src[0]->type == GGML_TYPE_Q4_1 || // - op->src[0]->type == GGML_TYPE_Q4_K) { - forward_mul_mat_q4(params, op); return true; } default: @@ -419,7 +184,57 @@ template class tensor_ return false; } - void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) { + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT"); + return false; + } + break; + case GGML_OP_MUL_MAT_ID: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat_id(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT_ID"); + return false; + } + break; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; + const ggml_tensor * src0 = op->src[0]; const ggml_tensor * src1 = op->src[1]; ggml_tensor * dst = op; @@ -435,73 +250,137 @@ template class tensor_ const float * feature = (const float *) src1->data; float * output = (float *) dst->data; - const size_t batch_feature = ne12 * ne13; - [[maybe_unused]] const size_t batch_weight = ne02 * ne03; - const size_t gemm_m = ne11; - const size_t gemm_k = ne10; - const size_t gemm_n = ne01; + const int64_t gemm_m = ne11 * ne12 * ne13; + const int64_t gemm_k = ne10; + const int64_t gemm_n = ne01; - GGML_ASSERT(batch_weight == 1); + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::quantize_a_row_def quantize_a_4row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + bool set_kernel_impl = false; - const size_t block_count_k = div_round_up(gemm_k, QK4_0); - const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0); - const size_t per_gemm_workspace_stride = - div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t); - const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride; - const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1; + int64_t block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len); - if (ith == 0 && params->wsize < desired_wsize) { - throw std::runtime_error("wsize less than desired_wsize"); - } +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); - std::vector qnbitgemm_args(batch_feature); + if constexpr (std::is_same_v || std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); - for (size_t i = 0; i < batch_feature; i++) { - qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i; - qnbitgemm_args[i].lda = gemm_k; - qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data; - qnbitgemm_args[i].quant_b_scale = nullptr; + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); - if constexpr (std::is_same_v) { - qnbitgemm_args[i].quant_b_zp = nullptr; - } else { - qnbitgemm_args[i].quant_b_zp = w_data; + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + set_kernel_impl = true; } + } +#endif - qnbitgemm_args[i].bias = nullptr; - qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i; - qnbitgemm_args[i].ldc = gemm_n; +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::ime1::quantize_a_4row_i8; + + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); } - const uintptr_t ws_ptr = reinterpret_cast(params->wdata); - void * ws = reinterpret_cast((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1))); - const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0); + const int64_t a_k_blks = spacemit_kernels::div_round_up(gemm_k, a_blk_len); + const int64_t b_k_blks = spacemit_kernels::div_round_up(gemm_k, b_blk_len); - { - constexpr size_t block_size_m = 4; - size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m); - int32_t task_count = batch_feature * per_gemm_block_count_m; - int32_t task_per_thread = (task_count + nth - 1) / nth; - int32_t start = ith * task_per_thread; - int32_t end = std::min((ith + 1) * task_per_thread, task_count); - for (int32_t compute_idx = start; compute_idx < end; compute_idx++) { - int32_t gemm_idx = compute_idx / per_gemm_block_count_m; - int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m; - int32_t m_idx = block_idx_in_gemm * block_size_m; - const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx]; - int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx); + const int64_t row_stride_a = a_k_blks * block_stride_a; + const int64_t gemm_workspace_size = GGML_PAD(gemm_m * row_stride_a, alignof(int64_t)); - if (rows_tobe_handled == block_size_m) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = - static_cast(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); + if (ith == 0 && params->wsize < gemm_workspace_size) { + GGML_ABORT("wsize less than gemm_workspace_size"); + } + + uintptr_t ws_ptr = reinterpret_cast(params->wdata); + + void * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const int64_t tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; + + auto * quant_a_buffer = reinterpret_cast(ws_ptr); + + constexpr int64_t row_align = 4; + const int64_t row_blks = spacemit_kernels::div_round_up(gemm_m, row_align); + + const int64_t row_stride_b = b_k_blks * get_repacked_block_type_size(); + const int64_t per_mb_rows_wsize = row_align * row_stride_a; + const int64_t per_nb_cols_wsize = NB_COLS * row_stride_b; + + const int64_t barrier_idx = static_cast(ith / 2); + + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; + + if (gemm_m == 1) { + int task_per_thread = spacemit_kernels::div_round_up(a_k_blks, nth); + int a_blk_start = ith * task_per_thread; + int a_blk_end = std::min(a_blk_start + task_per_thread, (int) a_k_blks); + if (a_blk_start < a_blk_end) { + quantize_a_row_i8(a_blk_len, feature + a_blk_start * a_blk_len, (a_blk_end - a_blk_start) * a_blk_len, + quant_a_buffer + a_blk_start * block_stride_a); + } + } else { + int task_per_thread = spacemit_kernels::div_round_up(row_blks, nth); + int m_row_blk_start = ith * task_per_thread; + int m_row_blk_end = std::min(m_row_blk_start + task_per_thread, (int) row_blks); + for (int m_row_blk = m_row_blk_start; m_row_blk < m_row_blk_end; m_row_blk++) { + int m_idx = m_row_blk * row_align; + int rows_tobe_handled = (gemm_m - m_idx) > row_align ? row_align : (gemm_m - m_idx); + + if (rows_tobe_handled == row_align && quantize_a_4row_i8 != nullptr) { + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_4row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); } else { while (rows_tobe_handled) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = static_cast(ws) + - gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); rows_tobe_handled -= 1; m_idx += 1; } @@ -511,51 +390,545 @@ template class tensor_ ggml_barrier(params->threadpool); - if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) { - return; + const int64_t gemm_m_stride = gemm_n / gemm_m > 64 ? gemm_m : 16; + const int64_t gemm_m_blocked = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t max_gemm_n_stride = spacemit_kernels::div_round_up(gemm_n * gemm_m_blocked, nth); + + int64_t gemm_n_stride = gemm_n; + if (max_gemm_n_stride < gemm_n) { + gemm_n_stride = + std::min(gemm_n_stride, spacemit_kernels::div_round_up(max_gemm_n_stride, NB_COLS) * NB_COLS); } - nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores }); - size_t threads_per_gemm = nth / batch_feature; - constexpr size_t gemm_m_stride = 128; - size_t nc = gemm_n; - const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride); - const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm); - if (max_nc < nc) { - nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN); - } - const size_t gemm_n_stride = nc; - const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride); - const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride); - threads_per_gemm = thread_count_m * thread_count_n; + if (gemm_n_stride == gemm_n && tcm_buffer != nullptr && per_mb_rows_wsize <= tcm_buffer_size) { + for (int64_t m_start = ith * row_align; m_start < gemm_m; m_start += row_align * nth) { + uint8_t * b_col = reinterpret_cast(w_data); + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; - { - int task_count = batch_feature * threads_per_gemm; - int task_per_thread = (task_count + nth - 1) / nth; - int start = ith * task_per_thread; - int end = std::min((ith + 1) * task_per_thread, task_count); - for (int compute_idx = start; compute_idx < end; compute_idx++) { - const auto gemm_i = compute_idx / threads_per_gemm; - const auto blk_i = compute_idx % threads_per_gemm; - const auto * data = &qnbitgemm_args[gemm_i]; + int64_t m_row_real = std::min(gemm_m - m_start, row_align); - const auto tid_n = blk_i / thread_count_m; - const auto tid_m = blk_i % thread_count_m; + spacemit_kernels::rvv::memcpy1d(tcm_buffer, quant_a_buffer + m_start * row_stride_a, + m_row_real * row_stride_a); - const size_t m_start = tid_m * gemm_m_stride; - const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride); + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < gemm_n; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(gemm_n - ni, (int64_t) NB_COLS); - const size_t n_start = tid_n * gemm_n_stride; - const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride); + uint8_t * a_row_ptr = (uint8_t *) tcm_buffer; + float * c_blk = output + m_start * gemm_n + ni; - void * per_gemm_ws = reinterpret_cast(ws) + gemm_i * per_gemm_workspace_stride; + int32_t rows_remaining = m_row_real; - sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count); + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_ptr, b_col, b_col_zp, c_blk, rows_remaining, + n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_ptr += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + } + } + } else if (tcm_buffer != nullptr && per_nb_cols_wsize <= tcm_buffer_size) { + uint8_t * a_row = quant_a_buffer; + uint8_t * b_col = reinterpret_cast(tcm_buffer); + if ((gemm_workspace_size + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast(tcm_buffer) + gemm_workspace_size; + } + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; + + int64_t ni = ith * NB_COLS; + int64_t nb_real = std::min(gemm_n - ni, NB_COLS); + + if (ith % 2 == 0 && nb_real > 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + } + + spine_barrier_wait(cur_barrier); + + if (ith % 2 != 0 && nb_real > 0) { + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + } + + for (; ni < gemm_n; ni += NB_COLS * nth) { + int64_t rows_remaining = gemm_m; + float * c_blk = output + ni; + auto * a_row_cur = a_row; + + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_col, b_col_zp, c_blk, rows_remaining, + nb_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_cur += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS * nth; + if (next_ni < gemm_n) { + nb_real = std::min(gemm_n - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + next_ni * row_stride_b, + nb_real * row_stride_b); + } + } + } else { + const int64_t task_count_m = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t task_count_n = spacemit_kernels::div_round_up(gemm_n, gemm_n_stride); + + int64_t task_count = task_count_m * task_count_n; + int64_t task_per_thread = (task_count + nth - 1) / nth; + int64_t start = ith * task_per_thread; + int64_t end = std::min((ith + 1) * task_per_thread, task_count); + for (int64_t compute_idx = start; compute_idx < end; compute_idx++) { + const auto tid_n = compute_idx / task_count_m; + const auto tid_m = compute_idx % task_count_m; + + const int64_t m_start = tid_m * gemm_m_stride; + const int64_t m_count = std::min(gemm_m - m_start, (int64_t) gemm_m_stride); + + const int64_t n_start = tid_n * gemm_n_stride; + const int64_t n_count = std::min(gemm_n - n_start, (int64_t) gemm_n_stride); + + const int64_t n_blk = m_count == 1 ? n_count : NB_COLS; + + uint8_t * b_col = reinterpret_cast(w_data) + n_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; + + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < n_count; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(n_count - ni, n_blk); + + uint8_t * a_row = quant_a_buffer + m_start * row_stride_a; + + float * c_blk = output + m_start * gemm_n + n_start + ni; + + int64_t rows_remaining = m_count; + + uint8_t * b_col_cur = b_col; + uint8_t * b_col_zp_cur = b_col_zp; + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row, b_col_cur, b_col_zp_cur, c_blk, + rows_remaining, n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + } } } } - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; + + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + const ggml_tensor * ids = op->src[2]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + int ith = params->ith; + int nth = params->nth; + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + spacemit_kernels::moe_gemm_kernel_quantize_def moe_gemm_kernel_m2; + bool set_kernel_impl = false; + size_t block_stride_a = spacemit_kernels::q8_blk_size(QK4_0); + +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(QK4_0, true); + + if constexpr (std::is_same_v || std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i5; + set_kernel_impl = true; + } + } +#endif + +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); + } + + const size_t a_k_blks = spacemit_kernels::div_round_up(ne10, a_blk_len); + const size_t b_k_blks = spacemit_kernels::div_round_up(ne10, b_blk_len); + + const size_t nbw1 = a_k_blks * block_stride_a; + const size_t nbw2 = ne11 * nbw1; + const size_t nbw3 = nbw2 * ne12; + const size_t gemm_workspace_size = GGML_PAD(nbw3, alignof(int64_t)); + + const uintptr_t ws_ptr = reinterpret_cast(params->wdata); + auto * quant_a_buffer = reinterpret_cast(ws_ptr); + + if (ne11 == 1) { + for (int64_t ii = ith; ii < ne12 * a_k_blks; ii += nth) { + int64_t i12 = ii / a_k_blks; + int64_t ak_blk_id = ii % a_k_blks; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12) + ak_blk_id * a_blk_len, + a_blk_len, quant_a_buffer + i12 * nbw2 + ak_blk_id * block_stride_a); + } + } else { + for (int64_t ii = ith; ii < ne12 * ne11; ii += nth) { + int64_t i12 = ii / ne11; + int64_t i11 = ii % ne11; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12 + i11 * nb11), ne10, + quant_a_buffer + i12 * nbw2 + i11 * nbw1); + } + } + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) *ne12 + (i1)] + + int64_t * matrix_row_counts = (int64_t *) (ws_ptr + gemm_workspace_size); + int32_t * valid_ep_count = (int32_t *) (matrix_row_counts + n_as); + int32_t * valid_act_count = (int32_t *) (valid_ep_count + 1); + int64_t * valid_matrix_row_counts = (int64_t *) (valid_act_count + 1); + mmid_row_mapping * matrix_rows = (mmid_row_mapping *) (valid_matrix_row_counts + n_as); + + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + + // group rows by src0 matrix + for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int32_t id = 0; id < n_ids; ++id) { + const int32_t i02 = + *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 }; + matrix_row_counts[i02] += 1; + } + } + + int32_t valid_ep_count_t = 0; + int32_t valid_act_count_t = 0; + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) { + continue; + } + valid_matrix_row_counts[valid_ep_count_t] = cur_a; + valid_act_count_t += cne1; + valid_ep_count_t += 1; + } + valid_ep_count[0] = valid_ep_count_t; + valid_act_count[0] = valid_act_count_t; + } + + const int64_t barrier_idx = static_cast(ith / 2); + + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; + + ggml_barrier(params->threadpool); + + const size_t row_stride_b = b_k_blks * get_repacked_block_type_size(); + const size_t expert_b_stride = ne01 * row_stride_b; + const size_t per_nb_cols_wsize = NB_COLS * row_stride_b; + + std::array src_workspaces; + std::array dst_workspaces; + + auto * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const auto tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; + + const auto valid_ep_count_t = valid_ep_count[0]; + const auto valid_act_count_t = valid_act_count[0]; + + int nth_es = 1; + int nth_n = nth; + + int ith_es = ith % nth_es; + int ith_n = (ith / nth_es) % nth_n; + + if (valid_ep_count_t % nth == 0 && tcm_buffer != nullptr && valid_ep_count_t == n_as && + valid_act_count_t == n_as && per_nb_cols_wsize <= tcm_buffer_size) { + for (int64_t valid_id = ith; valid_id < valid_ep_count_t; valid_id += nth) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride; + + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, 0); + const int id = row_mapping.i1; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; + const int64_t i1 = id; + const int64_t i2 = i12; + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + float * c_blk = (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)); + + uint8_t * a_row = src1_col; + uint8_t * b_col = reinterpret_cast(tcm_buffer); + if ((nbw1 + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast(tcm_buffer) + nbw1; + } + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; + + if (ith % 2 == 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(src0_cur), per_nb_cols_wsize); + + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + } + + spine_barrier_wait(cur_barrier); + + if (ith % 2 != 0) { + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(src0_cur), per_nb_cols_wsize); + } + + int64_t nb_real = std::min(ne01, NB_COLS); + for (int64_t ni = 0; ni < ne01; ni += NB_COLS) { + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } + + gemm_kernel(b_blk_len, a_row, b_col, b_col_zp, c_blk + ni, 1, nb_real, b_k_blks, ne01); + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS; + if (next_ni < ne01) { + nb_real = std::min(ne01 - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d( + b_col, reinterpret_cast(src0_cur) + next_ni * row_stride_b, per_nb_cols_wsize); + } + } + } + } else { + for (int64_t valid_id = ith_es; valid_id < valid_ep_count_t; valid_id += nth_es) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + const int64_t cne1 = matrix_row_counts[cur_a]; + + int64_t src1_cur_start = 0; + int64_t src1_cur_end = cne1; + + int64_t src0_cur_start = (ith_n * ne01) / nth_n; + int64_t src0_cur_end = MIN(((ith_n + 1) * ne01) / nth_n, ne01); + + if (src1_cur_start >= src1_cur_end || src0_cur_start >= src0_cur_end) { + continue; + } + + src0_cur_start = + (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; + src0_cur_end = + (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride + src0_cur_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp() ? src0_cur : nullptr; + + size_t extra_tcm_buffer_size = tcm_buffer_size; + void * extra_tcm_buffer = tcm_buffer; + if (tcm_buffer != nullptr && (src1_cur_end - src1_cur_start) >= 4 && + (src0_cur_end - src0_cur_start) * row_stride_b <= tcm_buffer_size) { + spacemit_kernels::rvv::memcpy1d(tcm_buffer, src0_cur, + (src0_cur_end - src0_cur_start) * row_stride_b); + src0_cur = reinterpret_cast(tcm_buffer); + b_col_zp = block_type_has_zp() ? src0_cur : nullptr; + extra_tcm_buffer_size -= (src0_cur_end - src0_cur_start) * row_stride_b; + extra_tcm_buffer = reinterpret_cast(reinterpret_cast(tcm_buffer) + + (src0_cur_end - src0_cur_start) * row_stride_b); + } + + int ir1 = src1_cur_start; + + if (extra_tcm_buffer_size >= nbw1 && extra_tcm_buffer != nullptr) { + int64_t quant_a_tile_size = extra_tcm_buffer_size / nbw1; + do { + quant_a_tile_size = MIN(quant_a_tile_size, src1_cur_end - ir1); + + uint8_t * quant_a_tile_buffer = reinterpret_cast(extra_tcm_buffer); + + int iir1 = ir1; + for (; iir1 < (ir1 + quant_a_tile_size); ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, iir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + spacemit_kernels::rvv::memcpy1d(quant_a_tile_buffer, src1_col, nbw1); + quant_a_tile_buffer = quant_a_tile_buffer + nbw1; + } + + quant_a_tile_buffer = reinterpret_cast(extra_tcm_buffer); + iir1 = ir1; + + if (moe_gemm_kernel_m2 != nullptr) { + for (; iir1 < (ir1 + quant_a_tile_size - 1); iir1 += 2, quant_a_tile_buffer += 2 * nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + mmid_row_mapping row_mapping_1 = MMID_MATRIX_ROW(cur_a, iir1 + 1); + + src_workspaces[0] = quant_a_tile_buffer; + src_workspaces[1] = quant_a_tile_buffer + nbw1; + + dst_workspaces[0] = + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start; + dst_workspaces[1] = (float *) ((char *) dst->data + + ((row_mapping_1.i1) * nb1 + (row_mapping_1.i2) * nb2)) + + src0_cur_start; + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, + ne01); + } + } + + for (; iir1 < (ir1 + quant_a_tile_size); iir1++, quant_a_tile_buffer += nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + + gemm_kernel( + b_blk_len, quant_a_tile_buffer, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start, + 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + + ir1 += quant_a_tile_size; + } while (ir1 < src1_cur_end); + } else { + if (moe_gemm_kernel_m2 != nullptr) { + for (; ir1 < src1_cur_end - 1; ir1 += 2) { + for (int iir1 = 0; iir1 < 2; ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1 + iir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + src_workspaces[iir1] = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + dst_workspaces[iir1] = + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start; + } + + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } + + for (; ir1 < src1_cur_end; ir1++) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + gemm_kernel(b_blk_len, src1_col, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, 1, + src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } + } + } +#undef MMID_MATRIX_ROW + } + + int repack(ggml_tensor * t, const void * data, size_t data_size) override { GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), (int) NB_COLS, (int) INTER_SIZE); return ggml::cpu::riscv64_spacemit::repack(t, data, data_size); @@ -563,309 +936,464 @@ template class tensor_ }; class tensor_traits_common : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + bool work_size(int n_threads, const ggml_tensor * op, size_t & size) override { switch (op->op) { - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - size = 0; + case GGML_OP_FLASH_ATTN_EXT: + { + const int n_tasks = n_threads; + const int64_t neq2 = op->src[0]->ne[2]; // number of query heads + const int64_t DK = op->src[1]->ne[0]; + const int64_t DV = op->src[2]->ne[0]; // DV + + // Tiled flash attention scratch (tile sizes defined in common.h) + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding + size_t prefill = sizeof(float) * + (GGML_FA_TILE_Q * DK + 2 * GGML_FA_TILE_Q * GGML_FA_TILE_KV + GGML_FA_TILE_Q * DV + + GGML_FA_TILE_KV * DV + GGML_FA_TILE_KV * DK) * + n_tasks; + + // Decode path: n_kv_chunks = n_tasks (one chunk per thread) + // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ + size_t n_chunks = n_tasks; + size_t decode = sizeof(float) * (neq2 * n_chunks * (2 + DV) + n_tasks * (DK + 2 * DV)); + + size = MAX(prefill, decode); + } return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { switch (op->op) { case GGML_OP_NORM: - forward_norm_f32(params, op); - return true; + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } case GGML_OP_RMS_NORM: - forward_rms_norm_f32(params, op); + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_rms_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } + case GGML_OP_ADD: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_add(params, op); + return true; + } + case GGML_OP_SUB: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_sub(params, op); + return true; + } + case GGML_OP_MUL: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_mul(params, op); + return true; + } + case GGML_OP_DIV: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_div(params, op); + return true; + } + case GGML_OP_FLASH_ATTN_EXT: + forward_flash_attn_ext_f16(params, op); return true; + case GGML_OP_CONT: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] != src0->nb[0] && op->nb[0] == src0->nb[1] && + op->ne[3] * op->ne[2] * op->nb[2] == src0->ne[3] * src0->ne[2] * src0->nb[2]) { + spacemit_kernels::rvv::forward_cont_with_permute(params, op); + } else { + ggml_compute_forward_cont(params, op); + } + return true; + } + case GGML_OP_CPY: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] == src0->nb[1] && src0->nb[0] != src0->nb[1] && + ggml_nelements(src0) == ggml_nelements(op)) { + spacemit_kernels::rvv::forward_cpy_with_permute(params, op); + } else { + ggml_compute_forward_cpy(params, op); + } + return true; + } + case GGML_OP_REPEAT: + { + const bool rows_equal = ggml_nrows(op->src[0]) == ggml_nrows(op); + const bool broadcast_or_equal = op->src[0]->ne[0] == 1 || op->src[0]->ne[0] == op->ne[0]; + + if (rows_equal && broadcast_or_equal) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_nrows(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_nrows(params, op); + return true; + default: + break; + } + } + + if (op->src[0]->ne[1] == 1 && op->src[0]->ne[0] == op->ne[0]) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_dim1(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_dim1(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_repeat(params, op); + } + return true; + case GGML_OP_SUM_ROWS: + { + if (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) { + spacemit_kernels::rvv::forward_sum_rows(params, op); + } else { + ggml_compute_forward_sum_rows(params, op); + } + } + return true; + case GGML_OP_GET_ROWS: + { + if (op->src[0]->type == op->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_get_rows(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_get_rows(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_get_rows(params, op); + } + return true; + case GGML_OP_CONCAT: + { + const int32_t dim = ggml_get_op_params_i32(op, 0); + if (dim == 0 && op->type == op->src[0]->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_concat(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_concat(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_concat(params, op); + } + return true; + // TODO For GGML_OP_GATED_DELTA_NET + // case GGML_OP_GATED_DELTA_NET: + // return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); + void forward_flash_attn_ext_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + + const bool supported_prec = (dst->op_params[3] == GGML_PREC_F32 || dst->op_params[3] == GGML_PREC_DEFAULT); + const bool supported_types = (q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16); + const bool supported_shape = (DK > 0 && DK <= 128 && DV > 0 && DV <= 128); + const bool supported_vlen = (__riscv_vlenb() == 128); + + if (!(supported_prec && supported_types && supported_shape && supported_vlen)) { + ggml_compute_forward_flash_attn_ext(params, dst); + return; + } + + // total rows in q + const int64_t nr = neq1 * neq2 * neq3; + + // rows per thread const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + const bool use_tiled = !params->use_ref && (neq1 >= Q_TILE_SZ); - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); + // 4x chunks per thread + // int nth_scaled = nth * 4; + // int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + // int64_t nchunk = (nr + chunk_size - 1) / chunk_size; - GGML_ASSERT(epsilon > 0.0f); + // if (nth == 1 || nchunk < nth) { + // nchunk = nth; + // } - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; + int64_t nchunk = nth; - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + ggml_threadpool_chunk_set(params->threadpool, nth); + } - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + ggml_barrier(params->threadpool); - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast(input + offset); + // The number of elements in each chunk + const int64_t dr = (nr + nchunk - 1) / nchunk; - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - size_t gvl = __riscv_vsetvlmax_e32m4(); - vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; - sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); - - p_input += gvl; - p_temp_output += gvl; - length -= gvl; + if (use_tiled) { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_tiled_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); + } else { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); } - gvl = __riscv_vsetvlmax_e32m1(); - - float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - vfloat32m1_t mean_v = - __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); - mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); - mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); - mean /= hidden_size; - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; - mean_square = sqrt(mean_square - mean * mean + epsilon); - - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; - - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; - } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } } - void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); - - GGML_ASSERT(epsilon > 0.0f); - - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; - - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); - - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast(input + offset); - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - - size_t gvl = __riscv_vsetvlmax_e32m4(); - // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); - - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); - - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); - - p_input += gvl; - p_temp_output += gvl; - length -= gvl; - } - - gvl = __riscv_vsetvlmax_e32m1(); - - // float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; - - mean_square = sqrt(mean_square + epsilon); - - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; - - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; - } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } - } - } - - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + int repack(ggml_tensor * t, const void * data, size_t data_size) override { memcpy(t->data, data, data_size); return 0; } }; -static const tensor_traits q4_0_16x8_q8_0; -static const tensor_traits q4_1_16x8_q8_0; -static const tensor_traits q4_k_16x8_q8_0; -static const tensor_traits_common rvv_impl; +// Impl By IME1 +static const tensor_traits q4_0_16x32_q8_0; +static const tensor_traits q4_1_16x32_q8_0; +static const tensor_traits q4_k_16x32_q8_0; +// Impl By IME2 +static const tensor_traits q2_k_32x256_q8_0; +static const tensor_traits q3_k_32x256_q8_0; +static const tensor_traits q4_0_32x32_q8_0; +static const tensor_traits q4_1_32x32_q8_0; +static const tensor_traits q4_0_32x256_q8_0; +static const tensor_traits q4_1_32x256_q8_0; +static const tensor_traits q4_k_32x32_q8_0; +static const tensor_traits q6_k_32x32_q8_0; +static const tensor_traits q8_0_32x32_q8_0; +static const tensor_traits mxfp4_32x32_q8_0; +static const tensor_traits q5_k_32x32_q8_0; +static const tensor_traits q5_1_32x32_q8_0; +static const tensor_traits q5_0_32x32_q8_0; +// Impl By RVV +static const tensor_traits_common rvv_impl; } // namespace ggml::cpu::riscv64_spacemit -static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) { - if (cur->type == GGML_TYPE_Q4_0) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_1) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_K) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_F32) { - return &ggml::cpu::riscv64_spacemit::rvv_impl; +static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const ggml_tensor * cur) { + switch (cur->type) { + case GGML_TYPE_Q2_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q2_k_32x256_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q3_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q3_k_32x256_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x256_q8_0; + } + + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x32_q8_0; + } +#endif + +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_0_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + // (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::q4_1_32x256_q8_0; + // } + + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_1_32x32_q8_0; + } +#endif + +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_1_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_k_32x32_q8_0; + } +#endif + +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_k_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q6_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q6_k_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q8_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q8_0_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_MXFP4: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::mxfp4_32x32_q8_0; + // } +#endif + } + break; + case GGML_TYPE_Q5_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_k_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_1_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_0_32x32_q8_0; + } +#endif + } + break; + default: + break; } return nullptr; } static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor) { + ggml_tensor * tensor) { tensor->extra = (void *) const_cast(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); @@ -874,8 +1402,46 @@ static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_ba return GGML_STATUS_SUCCESS; } +static void ggml_backend_riscv64_spacemit_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + if (base == nullptr) { + return; + } + + ggml::cpu::riscv64_spacemit::spine_mem_pool_free(base); +} + +static void * ggml_backend_riscv64_spacemit_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + return base; +} + +static void ggml_backend_riscv64_spacemit_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { + GGML_ASSERT(tensor); + memset((char *) tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_riscv64_spacemit_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + memset(base, value, buffer->size); +} + static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor, + ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -891,6 +1457,20 @@ static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_ GGML_UNUSED(buffer); } +static const ggml_backend_buffer_i ggml_backend_riscv64_spacemit_buffer_i = { + /* .free_buffer = */ ggml_backend_riscv64_spacemit_buffer_free_buffer, + /* .get_base = */ ggml_backend_riscv64_spacemit_buffer_get_base, + /* .init_tensor = */ ggml_backend_riscv64_spacemit_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_riscv64_spacemit_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_riscv64_spacemit_buffer_set_tensor, + /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_riscv64_spacemit_buffer_clear, + /* .reset = */ nullptr, +}; + static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { return "CPU_RISCV64_SPACEMIT"; @@ -899,18 +1479,12 @@ static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_ static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); - - if (buffer == nullptr) { + void * base = ggml::cpu::riscv64_spacemit::spine_mem_pool_alloc(size, 64); + if (base == nullptr) { return nullptr; } - buffer->buft = buft; - buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; - buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; - buffer->iface.get_tensor = nullptr; - buffer->iface.cpy_tensor = nullptr; - return buffer; + return ggml_backend_buffer_init(buft, ggml_backend_riscv64_spacemit_buffer_i, base, size); } static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { @@ -919,44 +1493,91 @@ static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_b GGML_UNUSED(buft); } -static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, - const struct ggml_tensor * tensor) { +static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { if (tensor->ne[i] <= 0) { return 0; } } - size_t nbytes; + GGML_UNUSED(buft); + + const auto plain_nbytes = [&]() { + size_t total = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * tensor->nb[i]; + } + return total; + }; + const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { - nbytes = ggml_type_size(tensor->type); - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; - } - } else { - nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; - if (tensor->type == GGML_TYPE_Q4_K) { - GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); - nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - } - } else { - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; - } - } + return plain_nbytes(); + } + + const size_t row_nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; + + const auto add_strided_nbytes = [&](size_t total, size_t src_block_size, size_t dst_block_size) { + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * (tensor->nb[i] / src_block_size) * dst_block_size; + } + return total; + }; + + const auto remap_block_nbytes = [&](size_t src_block_size, size_t dst_block_size, int64_t padded_rows = 0) { + GGML_ASSERT(row_nbytes % src_block_size == 0); + + size_t total = + add_strided_nbytes((row_nbytes / src_block_size) * dst_block_size, src_block_size, dst_block_size); + + if (padded_rows > 0 && tensor->ne[1] % padded_rows != 0) { + total += (padded_rows - tensor->ne[1] % padded_rows) * (tensor->nb[1] / src_block_size) * dst_block_size; + } + + return total; + }; + + size_t nbytes = row_nbytes; + switch (tensor->type) { + case GGML_TYPE_Q4_K: + nbytes = remap_block_nbytes(sizeof(block_q4_K), sizeof(block_q4_1) * 8); + break; + case GGML_TYPE_Q6_K: + nbytes = remap_block_nbytes(sizeof(block_q6_K), sizeof(block_q8_0) * 8, 32); + break; + case GGML_TYPE_Q8_0: + nbytes = remap_block_nbytes(sizeof(block_q8_0), sizeof(block_q8_0), 32); + break; + case GGML_TYPE_Q2_K: + nbytes = remap_block_nbytes(sizeof(block_q2_K), sizeof(spacemit_kernels::nrow_block_q2_k<1>)); + break; + case GGML_TYPE_Q3_K: + nbytes = remap_block_nbytes(sizeof(block_q3_K), sizeof(spacemit_kernels::nrow_block_q3_k<1>)); + break; + case GGML_TYPE_MXFP4: + nbytes = remap_block_nbytes(sizeof(block_mxfp4), sizeof(spacemit_kernels::nrow_block_mxfp4<1>)); + break; + case GGML_TYPE_Q5_K: + nbytes = remap_block_nbytes(sizeof(block_q5_K), sizeof(spacemit_kernels::nrow_block_q5_1<1>) * 8); + break; + case GGML_TYPE_Q5_1: + nbytes = remap_block_nbytes(sizeof(block_q5_1), sizeof(spacemit_kernels::nrow_block_q5_1<1>)); + break; + case GGML_TYPE_Q5_0: + nbytes = remap_block_nbytes(sizeof(block_q5_0), sizeof(spacemit_kernels::nrow_block_q5_0<1>)); + break; + default: + nbytes = add_strided_nbytes(row_nbytes, 1, 1); + break; } - GGML_UNUSED(buft); return nbytes; } namespace ggml::cpu::riscv64_spacemit { class extra_buffer_type : ggml::cpu::extra_buffer_type { - bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + bool supports_op(ggml_backend_dev_t, const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && @@ -970,10 +1591,16 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } } break; - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - if (op->src[0]->type == GGML_TYPE_F32) { - return true; + case GGML_OP_MUL_MAT_ID: + if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 3) && + op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && + ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } } break; default: @@ -983,15 +1610,28 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { return false; } - ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + ggml::cpu::tensor_traits * get_tensor_traits(const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_REPEAT: + case GGML_OP_SUM_ROWS: + case GGML_OP_GET_ROWS: + case GGML_OP_CONCAT: + // case GGML_OP_GATED_DELTA_NET: return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl); default: // GGML_ABORT("fatal error"); @@ -1005,7 +1645,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } // namespace ggml::cpu::riscv64_spacemit ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { + static ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { /* .iface = */ { /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, @@ -1023,3 +1663,78 @@ ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { return &ggml_backend_cpu_buffer_type_riscv64_spacemit; } + +extern "C" { +static int bind_ai_thread() { + int fd, bytes; + char str[32]; + + fd = open("/proc/set_ai_thread", O_WRONLY); + if (fd < 0) { + GGML_LOG_ERROR("try open /proc/set_ai_thread failed\n"); + return -1; + } + + snprintf(str, 16, "%d", 0); + bytes = write(fd, str, strlen(str)); + if (bytes < 0) { + GGML_LOG_ERROR("try write /proc/set_ai_thread failed\n"); + close(fd); + return -1; + } + + close(fd); + return 0; +} + +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n) { + int cpu_id = sched_getcpu(); + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2 && + !((1 << cpu_id) & ggml::cpu::riscv64_spacemit::global_spine_env_info.cpu_mask)) { + GGML_PRINT_DEBUG("bind_ai_thread for thread %d, pid %d\n", thread_n, getpid()); + bind_ai_thread(); + } + + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_tcm && + ggml::cpu::riscv64_spacemit::tls_context.cpu_id == -1) { + CPU_ZERO(&(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + pthread_t main_thread = pthread_self(); + const auto & perfer_core_ids = ggml::cpu::riscv64_spacemit::global_spine_env_info.perfer_core_ids; + if (thread_n < 0 || static_cast(thread_n) >= perfer_core_ids.size()) { + GGML_ABORT("thread_n %d exceeds perfer_core_ids size %zu\n", thread_n, perfer_core_ids.size()); + } + auto perfer_cpu_id = perfer_core_ids[static_cast(thread_n)]; + CPU_SET(perfer_cpu_id, &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + int s = + pthread_setaffinity_np(main_thread, sizeof(cpu_set_t), &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + if (s != 0) { + GGML_ABORT("set thread affinity error for thread_n %d, cpu_id %d\n", thread_n, perfer_cpu_id); + } + + int ai_cpu_id = perfer_cpu_id - ggml::cpu::riscv64_spacemit::global_spine_env_info.aicpu_id_offset; + ggml::cpu::riscv64_spacemit::tls_context.cpu_id = ai_cpu_id; + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_get(ai_cpu_id); + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size = + ggml::cpu::riscv64_spacemit::global_spine_env_info.tcm_blk_size; + } + + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + void * rt = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_wait(ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt == nullptr) { + GGML_ABORT("wait tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n) { + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + auto rt = ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_release( + ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt != 0) { + GGML_ABORT("release tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} +} diff --git a/ggml/src/ggml-cpu/spacemit/ime.h b/ggml/src/ggml-cpu/spacemit/ime.h index 800d91acd..6849dd95e 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.h +++ b/ggml/src/ggml-cpu/spacemit/ime.h @@ -8,6 +8,14 @@ extern "C" { ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n); + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n); + +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment); + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp index cbbb6cd91..6acc6819d 100644 --- a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp @@ -1,8 +1,26 @@ +#include "ggml-impl.h" #include "ggml.h" #include "ime_kernels.h" +#include "rvv_kernels.h" #include #include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#else +# error "RISCV64_SPACEMIT_IME1 not defined" +#endif // clang-format off #if defined(__GNUC__) @@ -11,7 +29,7 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif // clang-format on -namespace sqnbitgemm_spacemit_ime { +namespace spacemit_kernels { #define QUANTIZEM4ROW_KERNEL \ "vmv.s.x v16, zero \n\t" \ @@ -76,1093 +94,208 @@ namespace sqnbitgemm_spacemit_ime { "vse8.v v31, (s1) \n\t" namespace ime1 { -void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + uint8_t * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" - "LOOP%=: \n\t" - "vsetvli t0, %[BlkLen], e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "sub t2, t2, t0 \n\t" - "slli t1, t0, 2 \n\t" - "add %[SRC], %[SRC], t1 \n\t" - "add s1, a1, %[OFFSET] \n\t" + "LOOP%=: \n\t" + "vsetvli t0, %[BlkLen], e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "sub t2, t2, t0 \n\t" + "slli t1, t0, 2 \n\t" + "add %[SRC], %[SRC], t1 \n\t" + "add s1, a1, %[OFFSET] \n\t" - QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE + QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "add s1, a1, %[OFFSET] \n\t" + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "add s1, a1, %[OFFSET] \n\t" - QUANTIZEM4ROW_KERNEL + QUANTIZEM4ROW_KERNEL - "addi t3, %[BlkLen], 0 \n\t" - "addi s2, s1, 0 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "SET_ZERO%=: \n\t" - "vse8.v v8, (s2) \n\t" - "addi s2, s2, 32 \n\t" - "addi t3, t3, -8 \n\t" - "bnez t3, SET_ZERO%= \n\t" + "addi t3, %[BlkLen], 0 \n\t" + "addi s2, s1, 0 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "SET_ZERO%=: \n\t" + "vse8.v v8, (s2) \n\t" + "addi s2, s2, 32 \n\t" + "addi t3, t3, -8 \n\t" + "bnez t3, SET_ZERO%= \n\t" - QUANTIZEM4ROW_STORE + QUANTIZEM4ROW_STORE - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 128) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "addi t2, t2, -128 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v16, v24, v16 \n\t" - "vfredmax.vs v24, v16, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e64, m4 \n\t" - "vsse64.v v16, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "sub t2, t2, t2 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "jal x0, QUANTIZE%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 256) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "addi t2, t2, -256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v24, v16 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t1, t2, 0 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - - "TAIL_LOOP%=: \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t2, e32, m1 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 32 \n\t" - "vfmul.vf v1, v0, f11 \n\t" - "vfcvt.x.f.v v2, v1 \n\t" - "vsetvli t0, zero, e16, mf2 \n\t" - "vnclip.wx v3, v2, zero \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vnclip.wx v3, v3, zero \n\t" - "vse8.v v3, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bnez t2, TAIL_LOOP%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), [CountK] "r"(CountK), + [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); } } -void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { const float * SRC = A; - std::byte * DST = QuantA; + uint8_t * DST = QuantA; constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); + uint8_t * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; - if (CountK <= BlkLen) { - float max_abs_A = 0.0f; - for (size_t k = 0; k < CountK; k++) { - max_abs_A = std::max(max_abs_A, fabsf(A[k])); - } - float scale_A = max_abs_A * range_max_reciprocal; + __asm__ volatile( + "addi t3, zero, 32*4 \n\t" + "addi t2, zero, 32 \n\t" - ((float *) QuantA)[0] = scale_A; + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 128 \n\t" + "addi a3, %[SRC], 256 \n\t" + "addi a4, %[SRC], 384 \n\t" - auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float)); + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 36 \n\t" + "addi s3, %[DST], 72 \n\t" + "addi s4, %[DST], 108 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" - for (size_t k = 0; k < CountK; k++) { - QuantAData_offset[k] = - (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits::lowest(), - (float) std::numeric_limits::max()); - } - for (size_t k = CountK; k < BlkLen; k++) { - QuantAData_offset[k] = 0; - } + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v4, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vle32.v v8, (a3) \n\t" + "addi a3, a3, 512 \n\t" + "vle32.v v12, (a4) \n\t" + "addi a4, a4, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v28, v12 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v20, v20, v22 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vfmax.vv v28, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v28, v28, v29 \n\t" - return; - } + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v21, v20, v21 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfredmax.vs v29, v28, v29 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v21 \n\t" + "vfmv.f.s f12, v25 \n\t" + "vfmv.f.s f13, v29 \n\t" - if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) { - __asm__ volatile( - "vsetvli t0, zero, e8, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "LOOP%=: \n\t" - "vsetvli t0, %[CNT], e8, m8 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "sub %[CNT], %[CNT], t0 \n\t" - "bnez %[CNT], LOOP%= \n\t" - : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset) - : - : "cc", "t0"); - } - if (BlkLen == 16) { - float buffer[64] = { 0.0f }; - __asm__ volatile( - "addi t3, zero, 16*8 \n\t" - "addi t2, zero, 16 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m2 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v2, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v4, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v6, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v10, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v12, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v14, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v18, v2 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v22, v6 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v26, v10 \n\t" - "vfabs.v v28, v12 \n\t" - "vfabs.v v30, v14 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v18, v18, v19 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v22, v22, v23 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v26, v26, v27 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vse32.v v16, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v18, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v20, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v22, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v24, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v26, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v28, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v30, (a1) \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f11, f3, f7 \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f11, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f12, f3, f7 \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fsw f12, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f13, f3, f7 \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f13, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f14, f3, f7 \n\t" - "fmul.s f14, f14, %[RMAXREC] \n\t" - "fsw f14, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f14, %[FONE], f14 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f15, f3, f7 \n\t" - "fmul.s f15, f15, %[RMAXREC] \n\t" - "fsw f15, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f15, %[FONE], f15 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f16, f3, f7 \n\t" - "fmul.s f16, f16, %[RMAXREC] \n\t" - "fsw f16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f16, %[FONE], f16 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f17, f3, f7 \n\t" - "fmul.s f17, f17, %[RMAXREC] \n\t" - "fsw f17, (%[DST]) \n\t" - "addi %[DST], %[DST], -136 \n\t" - "fdiv.s f17, %[FONE], f17 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v18, v2, f11 \n\t" - "vfmul.vf v20, v4, f12 \n\t" - "vfmul.vf v22, v6, f13 \n\t" - "vfmul.vf v24, v8, f14 \n\t" - "vfmul.vf v26, v10, f15 \n\t" - "vfmul.vf v28, v12, f16 \n\t" - "vfmul.vf v30, v14, f17 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v18, v18 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v22, v22 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v26, v26 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vfcvt.x.f.v v30, v30 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v18, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v20, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v22, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v26, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v28, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v30, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vse32.v v16, (%[BUFFER]) \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m2 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer) - : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12", - "f13", "f14", "f15", "f16", "f17"); - } else if (BlkLen == 32) { - __asm__ volatile( - "addi t3, zero, 32*4 \n\t" - "addi t2, zero, 32 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 128 \n\t" - "addi a3, %[SRC], 256 \n\t" - "addi a4, %[SRC], 384 \n\t" + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fsw f12, (s3) \n\t" + "addi s3, s3, 4 \n\t" + "fsw f13, (s4) \n\t" + "addi s4, s4, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v20, v4, f11 \n\t" + "vfmul.vf v24, v8, f12 \n\t" + "vfmul.vf v28, v12, f13 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vsetvli t0, t1, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 140 \n\t" + "vse8.v v20, (s2) \n\t" + "addi s2, s2, 140 \n\t" + "vse8.v v24, (s3) \n\t" + "addi s3, s3, 140 \n\t" + "vse8.v v28, (s4) \n\t" + "addi s4, s4, 140 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m4 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 128 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 36 \n\t" - "addi s3, %[DST], 72 \n\t" - "addi s4, %[DST], 108 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v4, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vle32.v v8, (a3) \n\t" - "addi a3, a3, 512 \n\t" - "vle32.v v12, (a4) \n\t" - "addi a4, a4, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v28, v12 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v20, v20, v22 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vfmax.vv v28, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v21, v20, v21 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfredmax.vs v29, v28, v29 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v21 \n\t" - "vfmv.f.s f12, v25 \n\t" - "vfmv.f.s f13, v29 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fsw f12, (s3) \n\t" - "addi s3, s3, 4 \n\t" - "fsw f13, (s4) \n\t" - "addi s4, s4, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v20, v4, f11 \n\t" - "vfmul.vf v24, v8, f12 \n\t" - "vfmul.vf v28, v12, f13 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vsetvli t0, t1, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 140 \n\t" - "vse8.v v20, (s2) \n\t" - "addi s2, s2, 140 \n\t" - "vse8.v v24, (s3) \n\t" - "addi s3, s3, 140 \n\t" - "vse8.v v28, (s4) \n\t" - "addi s4, s4, 140 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m4 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 128 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); - } else if (BlkLen == 64) { - __asm__ volatile( - "addi t3, zero, 64*2 \n\t" - "addi t2, zero, 64 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 68 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vfmax.vv v24, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v25 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 132 \n\t" - "vse8.v v24, (s2) \n\t" - "addi s2, s2, 132 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 256 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11"); - } else if (BlkLen == 128) { - __asm__ volatile( - "addi t2, zero, 128 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "sub %[K], %[K], t2 \n\t" - "QUANT%=: \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "vfmax.vv v28, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v30, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vfredmax.vs v31, v30, v31 \n\t" - "vfmv.f.s f10, v31 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v8, (a2) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "jal x0, QUANT%= \n\t" - "END%=: \n\t" - - : [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC) - : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11"); - } else { - float buffer[8] = { 0.0f }; - size_t cnt = BlkLen / 256; - - __asm__ volatile( - "slli t3, %[BLK], 2 \n\t" - "blt %[K], %[BLK], LOOP_TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_CMP%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v16, v16, v24 \n\t" - "vfmax.vv v0, v0, v16 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, LOOP_CMP%= \n\t" - "sub %[SRC], %[SRC], t3 \n\t" - "addi t6, %[CNT], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_QUANT%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "vse8.v v4, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bnez t6, LOOP_QUANT%= \n\t" - "sub %[K], %[K], %[BLK] \n\t" - "bge %[K], %[BLK], LOOP_MAIN%= \n\t" - "blez %[K], END%= \n\t" - "LOOP_TAIL%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[K], 0 \n\t" - "addi s1, %[SRC], 0 \n\t" - "TAIL_CMP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t6, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "sub t6, t6, t0 \n\t" - "vfabs.v v0, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, TAIL_CMP%= \n\t" - "addi t6, %[K], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[K], 0 \n\t" - "TAIL_QUANT%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t1, t6, e32, m8 \n\t" - "vle32.v v0, (s1) \n\t" - "addi s1, s1, 256 \n\t" - "sub t6, t6, t1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 64 \n\t" - "bnez t6, TAIL_QUANT%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer), - [CNT] "r"(cnt) - : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6"); - } + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); } } // namespace ime1 @@ -1451,1746 +584,444 @@ namespace { "vadd.vi v1, v1, -12 \n\t" template -void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { size_t LDC = ldc * sizeof(float); const size_t INNER = BlkLen / 16; float tmp[4 * 16]; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" + LOAD_B_16x8x2 - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" + SQ4BIT_KERNEL_COMP_4x16x16 - LOAD_B_16x8x2 + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" + LOAD_SCALE_4x16_FP16 - SQ4BIT_KERNEL_COMP_4x16x16 + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" - LOAD_SCALE_4x16_FP16 + SAVE_RESULT_4x16 - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" - LOAD_B_16x8x2 + LOAD_B_16x8x2 - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" - SQ4BIT_KERNEL_COMP_4x16x16 + SQ4BIT_KERNEL_COMP_4x16x16 - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" - LOAD_SCALE_4x16_FP16 + LOAD_SCALE_4x16_FP16 - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" - SAVE_RESULT_4x16 + SAVE_RESULT_4x16 - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } } template -void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - size_t LDC = ldc * sizeof(float); - const size_t INNER = BlkLen / 16; - float tmp[4 * 16]; - - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } -} - -template -void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { + GGML_UNUSED(ldc); size_t INNER = BlkLen / 16; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - // zp offset - "addi s7, %[B], 32 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" + "addi s7, %[B], 32 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" - SQ4BIT_KERNEL_LOAD_ZP_16X1 + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_INNER%=: \n\t" + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + SQ4BIT_KERNEL_LOAD_ZP_16X1 - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" + "LOOP_INNER%=: \n\t" - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - - "addi s7, %[B], 32 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - SQ4BIT_KERNEL_ACC_F16_1X4X4 + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" + SQ4BIT_KERNEL_ACC_F16_1X4X4 - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); } } } - -template -void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - const size_t INNER = BlkLen / 16; - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0 - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - // zp offset - "addi s7, %[B], 64 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - - // load scale - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - // load a scale - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - // a scale * b scale - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s7, %[B], 64 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } - } - } -} - -template -inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM4Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias, ldc); - - } else if (scalestride == 2) { - SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl( - BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc); - } -} - -template -inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM1Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias); - } else if (scalestride == 2) { - SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias); - } -} - } // namespace namespace ime1 { -size_t gemm_kernel_i8i4(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - size_t ldc, - const float * Bias, - const size_t ScaleStride) { - GGML_UNUSED(CountM); - GGML_UNUSED(CountK); - GGML_UNUSED(ldc); - if (CountM >= 4) { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { + if (quant_b_zp != nullptr) { + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 4; } else { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + if (quant_b_zp != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 1; } } } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp new file mode 100644 index 000000000..0c7a036a9 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp @@ -0,0 +1,5768 @@ +#include "ggml-impl.h" +#include "ggml.h" +#include "ime_kernels.h" +#include "rvv_kernels.h" +#include "string.h" + +#include +#include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME2) +#else +# error "RISCV64_SPACEMIT_IME2 not defined" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels { +namespace ime2 { + +template +void gemm_kernel_i8i2k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + uint8_t * zeros16 = (uint8_t *) (quant_b_blk_data->zeros16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + _Float16 * zeros_fp16 = (_Float16 *) zeros16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + uint8_t * scales_temp = scales; + uint8_t * zps_temp = scales; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, zps_temp++) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + int16_t a_sum = a_sum_row[mi * 16 + kii]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0.0; + + uint8_t b_zp = zps_temp[ci * 16] >> 4; + uint8_t b_scale = scales_temp[ci] & 0x0F; + for (size_t bi = 0; bi < 16; bi++) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + uint8_t b0 = b_data_col[ci * 16 + bi]; + acc_0 += static_cast(a0) * static_cast((b0 >> b_shift) & 0x03); + } + + _Float16 scale_item = + static_cast<_Float16>(b_scale) * static_cast<_Float16>(factor_scale) * scales_fp16[ci]; + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + output[ci + mi * NB_COLS] += b_zp * a_sum * a_scale_row[mi] * zeros_fp16[ci]; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i3k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * b_hmask = quant_b_blk_data->hmask; + int8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + int8_t * scales_temp = scales; + uint16_t * b_mask_col = (uint16_t *) b_hmask; + + float acc_0_max = 0.0f; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, b_mask_col += NB_COLS) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0; + // blk 2 * kii + 0 + uint16_t b_shift_mask = 1; + for (size_t bi = 0; bi < 16; bi++, b_shift_mask <<= 1) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + int8_t b0 = static_cast((b_data_col[ci * 16 + bi] >> b_shift) & 0x03); + b0 -= b_mask_col[ci] & b_shift_mask ? 0 : 4; + acc_0 += static_cast(a0) * static_cast(b0); + } + + _Float16 scale_item = static_cast<_Float16>(scales_temp[ci]) * scales_fp16[ci] * + static_cast<_Float16>(factor_scale); + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t kblks_per_blk = 16; + GGML_ASSERT(k_blks % kblks_per_blk == 0); + + int64_t b_blk_stride = (sizeof(_Float16) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(_Float16); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + output_f16[ci + mi * NB_COLS] = static_cast<_Float16>(0.0f); + } + } + + size_t kii = 0; + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + _Float16 * b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + _Float16 a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_scale_fp16[ci]; + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output_f16[ci + mi * NB_COLS] += + static_cast(acc) * static_cast(a_scale) * static_cast(b_scale); + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } else { + kii++; + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i4_hp_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * MB_ROWS; + const size_t a_subblk_stride = q8_hp_blk_size(32, false, false) * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const uint8_t * b_tile_base = quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride) { + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + const uint8_t * b_superblk_ptr = b_tile_base + ki * b_superblk_stride; + const block_q4_0x32_layout * b_blocks = reinterpret_cast(b_superblk_ptr); + const uint8_t * b_zps = + quant_b_zp ? b_superblk_ptr + sizeof(block_q4_0x32_layout) * k_subblks_per_superblk : nullptr; + + _Float16 * a_sum_row = (_Float16 *) (a_data + a_subblk_stride * k_subblks_per_superblk); + _Float16 * a_scale_avg_row = (_Float16 *) (a_data + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + _Float16 scale_factor = a_scale_avg_row[0]; + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast(a_data + a_subblk_stride * ksi); + int8_t * a_subblk = a_data + a_subblk_stride * ksi + MB_ROWS * sizeof(_Float16); + const _Float16 a_scale = a_scale_row[0]; + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + const uint8_t * b_qs = b_block.qs + ci * 16; + _Float16 b_scale = b_block.d[ci] * a_scale; + + int16_t acc = 0; + for (size_t bi = 0; bi < 16; bi++) { + uint8_t b = b_qs[bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + + acc += static_cast(a_subblk[mi * 32 + 2 * bi]) * static_cast(b0) + + static_cast(a_subblk[mi * 32 + 2 * bi + 1]) * static_cast(b1); + } + + const _Float16 scaled_acc = static_cast<_Float16>(acc) * b_scale; + output_f16[ci + mi * NB_COLS] += scaled_acc; + } + } + } + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast(a_data + a_subblk_stride * ksi); + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + const uint8_t * b_zp_row = b_zps ? b_zps + ksi * NB_COLS : nullptr; + const _Float16 a_scale = a_scale_row[0]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + const _Float16 a_sum = a_sum_row[mi * k_subblks_per_superblk + ksi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_block.d[ci] * a_scale; + _Float16 a_sum_bzp = a_sum; + if (b_zp_row) { + a_sum_bzp = a_sum * static_cast<_Float16>(0.125f) * static_cast<_Float16>(b_zp_row[ci]); + } + + const _Float16 scaled_acc = a_sum_bzp * b_scale; + output[ci + mi * NB_COLS] += scaled_acc * scale_factor; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + auto val = static_cast(output_f16[ci + mi * NB_COLS]) * static_cast(scale_factor); + output[ci + mi * NB_COLS] += val; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void moe_gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array a_data; + std::array c_data; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] = c_ptr[mi]; + } + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = (a_data[mi])[2 * bi]; + int8_t a1 = (a_data[mi])[2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + (c_data[mi])[ci] = output[mi * NB_COLS + ci]; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] += NB_COLS; + } + } +} + +template +void moe_gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + // blk_len is expected to be 32 for Q5 types. + int64_t a_blk_stride = q8_blk_size(blk_len, true); + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array a_data; + std::array c_data; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + c_data[mi] = c_ptr[mi]; + } + + if (quant_b_zp) { + using blk_type = nrow_block_q5_1; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } else { + using blk_type = nrow_block_q5_0; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } +} + +template +void gemm_kernel_i8i8_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + blk_len); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + int8_t * b_data = (int8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len; bi++) { + int8_t a0 = a_data[mi * blk_len + bi]; + int8_t b0 = b_data[ci * blk_len + bi]; + acc += static_cast(a0) * static_cast(b0); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 for Q5 types + // quant_b_zp != nullptr => nrow_block_q5_1 (has zp) + // quant_b_zp == nullptr => nrow_block_q5_0 (no zp) + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + if (quant_b_zp) { + // nrow_block_q5_1: scales16[NB_COLS] + zp[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_1; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + // qh is packed as 4 bytes per column (32 bits for 32 elements) + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } else { + // nrow_block_q5_0: scales16[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_0; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + // Q5_0 has no zp, use default offset 16 (midpoint of 5-bit unsigned range) + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } +} + +template +void gemm_kernel_i8mxfp4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 (QK_MXFP4) + // quant_b_zp is unused for MXFP4 (symmetric quantization) + GGML_UNUSED(quant_b_zp); + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + using blk_type = nrow_block_mxfp4; + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = GGML_E8M0_TO_FP32_HALF(quant_b_blk_data->e[ci]); + + // Read 32 sign bits for this column + uint32_t sign_bits; + memcpy(&sign_bits, &quant_b_blk_data->qh[ci * 4], 4); + + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + + // qs[ci*16 + bi] stores abs(vals[bi*2]) in low 4 bits + // and abs(vals[bi*2+1]) in high 4 bits + uint8_t qs_byte = quant_b_blk_data->qs[ci * 16 + bi]; + int8_t b_abs0 = static_cast(qs_byte & 0x0F); + int8_t b_abs1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract sign bits: bit (2*bi) for vals[2*bi], bit (2*bi+1) for vals[2*bi+1] + int8_t b0 = (sign_bits >> (2 * bi)) & 1 ? -b_abs0 : b_abs0; + int8_t b1 = (sign_bits >> (2 * bi + 1)) & 1 ? -b_abs1 : b_abs1; + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +void gemm_kernel_i8i2k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "addi %[A], %[A], 4 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 32 \n\t" // A data addr + "addi s3, %[B], 0 \n\t" + "vxor.vv v30, v29, v29 \n\t" // tmp result + + "INNER_K_LOOP%=: \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vxor.vv v2, v2, v2 \n\t" + "vxor.vv v3, v3, v3 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + + // A data, 1x64@i8 + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v4, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v5, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v6, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetvli t0, x0, e64, mf2 \n\t" + "vslideup.vi v3, v4, 2 \n\t" + "vslideup.vi v28, v5, 4 \n\t" + "vslideup.vi v29, v6, 6 \n\t" + + // init the accumu to zero + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v18, v18 \n\t" + "vxor.vv v22, v18, v18 \n\t" + "vxor.vv v24, v18, v18 \n\t" + "vxor.vv v26, v18, v18 \n\t" + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu v20, v2, v8, i8 \n\t" + "vmadotsu v22, v2, v12, i8 \n\t" + "vmadotsu v24, v2, v16, i8 \n\t" + "vmadotsu v26, v2, v4, i8 \n\t" + + "vmadotsu v20, v3, v9, i8 \n\t" + "vmadotsu v22, v3, v13, i8 \n\t" + "vmadotsu v24, v3, v17, i8 \n\t" + "vmadotsu v26, v3, v5, i8 \n\t" + + "vmadotsu v20, v28, v10, i8 \n\t" + "vmadotsu v22, v28, v14, i8 \n\t" + "vmadotsu v24, v28, v18, i8 \n\t" + "vmadotsu v26, v28, v6, i8 \n\t" + + "vmadotsu v20, v29, v11, i8 \n\t" + "vmadotsu v22, v29, v15, i8 \n\t" + "vmadotsu v24, v29, v19, i8 \n\t" + "vmadotsu v26, v29, v7, i8 \n\t" + + "vand.vi v10, v0, 0xf \n\t" // scale + "vwadd.vx v12, v10, x0 \n\t" + "vsetvli t0, x0, e16, m2 \n\t" + "vwadd.vx v16, v12, x0 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v6, v2, v4, 3 \n\t" // 0,1 + "vpack.vv v8, v3, v5, 3 \n\t" // 2,3 + + // mul scale + "vmacc.vv v30, v6, v16 \n\t" + "vmacc.vv v30, v7, v17 \n\t" + "vmacc.vv v30, v8, v18 \n\t" + "vmacc.vv v30, v9, v19 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + // load zp B + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v4, (s3) \n\t" + "vsrl.vi v8, v4, 4 \n\t" // zp + + // asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + + "vsetvli t0, x0, e16, mf4 \n\t" + "vle16.v v2, (%[A]) \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vnsrl.wi v12, v2, 0 \n\t" // low 8 + "vnsra.wi v13, v2, 8 \n\t" // high 8 + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v20, v13, v8, i8 \n\t" + "vmadotsu v22, v13, v9, i8 \n\t" + "vmadotsu v24, v13, v10, i8 \n\t" + "vmadotsu v26, v13, v11, i8 \n\t" + + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vsll.vi v26, v26, 8 \n\t" + + "vmadotu v20, v12, v8, i8 \n\t" + "vmadotu v22, v12, v9, i8 \n\t" + "vmadotu v24, v12, v10, i8 \n\t" + "vmadotu v26, v12, v11, i8 \n\t" + + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v28, v2, v4, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v0, (t2) \n\t" // scale16 + "addi t2, t2, 64 \n\t" + "vle16.v v1, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v2, v0 \n\t" + "vfwcvt.f.f.v v4, v1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v30, v30 \n\t" + "vfcvt.f.x.v v28, v28 \n\t" + "addi %[B], t2, 64 \n\t" + "mv %[A], t3 \n\t" + + "vfmul.vv v30, v30, v2 \n\t" // mul scale16 + "vfmacc.vv v30, v28, v4 \n\t" // + mul zero16 + "vfmacc.vf v31, fa0, v30 \n\t" + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3"); + } +} + +void gemm_kernel_i8i2k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + _Float16 scale = 0.0625f; + _Float16 scale_1 = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v31, v31 \n\t" // init result + "vxor.vv v29, v31, v31 \n\t" + "vxor.vv v30, v31, v31 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 128 \n\t" // A data addr + "addi s4, t2, 1024 \n\t" // scale16 addr + "addi s4, s4, 1024 \n\t" // TODO + "addi s3, %[B], 0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s4) \n\t" // load scale16 + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v22, v1, v1, 3 \n\t" + + "addi s4, t3, 256 \n\t" // addr 1 + "addi s5, t3, 512 \n\t" // addr 2 + "addi s6, t3, 768 \n\t" // addr 3 + + // init the accu to 0 + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v25, v25, v25 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v27, v27, v27 \n\t" + + "INNER_K_LOOP%=: \n\t" + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + "vand.vi v1, v1, 0xf \n\t" + + "vfwcvt.f.x.v v20, v1 \n\t" // f16 scale B + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v0, v20, v22 \n\t" // mul scale16 + "vfmul.vv v1, v21, v22 \n\t" // mul scale16 + "vfmul.vf v0, v0, %[SCALE] \n\t" // mul magic + "vfmul.vf v1, v1, %[SCALE] \n\t" // mul magic + + // A data, 4x64@i8 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 64 \n\t" + "vle8.v v3, (s4) \n\t" + "addi s4, s4, 64 \n\t" + "vle8.v v4, (s5) \n\t" + "addi s5, s5, 64 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 64 \n\t" + + // 4x64 => 4x16x4 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v6, v2, v3, 1 \n\t" + "vpack.vv v8, v4, v5, 1 \n\t" + "vpack.vv v2, v6, v8, 2 \n\t" // 0, 2 + + "vpack.vv v20, v2, v2, 3 \n\t" // 1 + "vor.vv v23, v21, v21 \n\t" + "vpack.vv v20, v3, v3, 3 \n\t" // 3 + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu.hp v24, v2, v8, v0, 0, i8 \n\t" + "vmadotsu.hp v25, v2, v12, v0, 1, i8 \n\t" + "vmadotsu.hp v26, v2, v16, v0, 2, i8 \n\t" + "vmadotsu.hp v27, v2, v4, v0, 3, i8 \n\t" + + "vmadotsu.hp v24, v23, v9, v0, 4, i8 \n\t" + "vmadotsu.hp v25, v23, v13, v0, 5, i8\n\t" + "vmadotsu.hp v26, v23, v17, v0, 6, i8\n\t" + "vmadotsu.hp v27, v23, v5, v0, 7, i8 \n\t" + + "vmadotsu.hp v24, v3, v10, v1, 0, i8 \n\t" + "vmadotsu.hp v25, v3, v14, v1, 1, i8 \n\t" + "vmadotsu.hp v26, v3, v18, v1, 2, i8 \n\t" + "vmadotsu.hp v27, v3, v6, v1, 3, i8 \n\t" + + "vmadotsu.hp v24, v21, v11, v1, 4, i8\n\t" + "vmadotsu.hp v25, v21, v15, v1, 5, i8\n\t" + "vmadotsu.hp v26, v21, v19, v1, 6, i8\n\t" + "vmadotsu.hp v27, v21, v7, v1, 7, i8 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v2, v24, v25, 1 \n\t" + "vpack.vv v4, v26, v27, 1 \n\t" + "vpack.vv v6, v2, v4, 2 \n\t" // 0,1,2,3 + + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + // load zp B, 16x8x4@int4 + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v0, (s3) \n\t" + "vsrl.vi v0, v0, 4 \n\t" // zp + + // 4x16@int16 + "vsetvli t0, x0, e16, m1 \n\t" // a sum + "vle16.v v12, (%[A]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnsrl.wi v10, v12, 0 \n\t" // low 8 + "vnsra.wi v11, v12, 8 \n\t" // high 8 + + // asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v18, v11, v0, i8 \n\t" + "vmadotsu v20, v11, v1, i8 \n\t" + "vmadotsu v22, v11, v2, i8 \n\t" + "vmadotsu v24, v11, v3, i8 \n\t" + "vsll.vi v18, v18, 8 \n\t" + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vmadotu v18, v10, v0, i8 \n\t" + "vmadotu v20, v10, v1, i8 \n\t" + "vmadotu v22, v10, v2, i8 \n\t" + "vmadotu v24, v10, v3, i8 \n\t" + + "vpack.vv v10, v18, v20, 2 \n\t" + "vpack.vv v12, v22, v24, 2 \n\t" + "vpack.vv v14, v10, v12, 3 \n\t" + "vpack.vv v16, v11, v13, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t2, t2, 64 \n\t" + "vle16.v v20, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v22, v20 \n\t" + + // mul 1/magic + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmul.vf v0, v6, %[SCALE_1] \n\t" + "vfwmul.vf v2, v7, %[SCALE_1] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v14, v14 \n\t" + "vfcvt.f.x.v v15, v15 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + + "addi %[B], t2, 64 \n\t" + "mv %[A], s6 \n\t" + + "vfmacc.vv v0, v14, v22 \n\t" // + mul zero16 + "vfmacc.vv v1, v15, v22 \n\t" + "vfmacc.vv v2, v16, v22 \n\t" + "vfmacc.vv v3, v17, v22 \n\t" + + "vfmacc.vf v28, fa0, v0 \n\t" // mul a scale + "vfmacc.vf v29, fa1, v1 \n\t" + "vfmacc.vf v30, fa2, v2 \n\t" + "vfmacc.vf v31, fa3, v3 \n\t" + + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "vse32.v v29, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v30, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v31, (t1) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks), [LDC] "r"(ldc * 4), [SCALE] "f"(scale), [SCALE_1] "f"(scale_1) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i3k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; //only support 32 in ASM + using blk_type = nrow_block_q3_k; + + const blk_type * b_base = reinterpret_cast(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride; + int64_t b_ncol_block_stride = sizeof(blk_type); + + // Constants used by q3_k scaling in HP branch: + // - k_q3k_scale_step: per-nibble scale factor (1/16). + // - k_a_scale_post_mul: A_scale needs an extra *16 at the end (pairs with 1/16 above). + const _Float16 k_q3k_scale_step = (_Float16) 0.0625f; // 1 / 16 + const float k_a_scale_post_mul = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; +#if 0 + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1 32bit + // Asum int16 * 16 256bit + // A M1K256 int8 2048bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+32 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" // clear acc + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + + // ordinary vmadot: vle*10 vecIns*78 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // K0-15 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // load B qs chunk (128B per K16, 16 times => 2048B) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (s3) \n\t" + "addi s3, s3, 64 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K16-31 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // load A scale (fp32) and advance A to next superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // t4 = next B blk base + "addi s3, s2, 4+32 \n\t" + + // load B scales16[32] (fp16) at end of qs region + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v2, (s6) \n\t" + + // pointer modify + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v30 \n\t" + "vfmul.vf v1, v24, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v31, v1, v26 \n\t" + + // next K-superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "s2", "s3", "s4", "s5", "s6", "s7"); +#else + + __asm__ volatile( + // ========================= + // Kernel overview (M1 x N32) + // ========================= + // Process one output row (M=1) and 32 columns (N=32) per call. + // + // Loop structure: + // - Outer loop: K superblocks of size K=256 (k_blks times) + // - Each K256 superblock is broken into 4 x K64 + // - Each K64 is processed as 4 x K16 "sub-blocks" (via unpack+dot) + // + // Data layout (high level): + // A (q8k K=256, per superblock): + // [ fp32 a_scale ][ int16 a_sum[16] ][ int8 a_qs[256] ] + // B (nrow_block_q3_k<32>, per superblock): + // [ int8 scales[32*16] ][ hmask[1024] ][ qs[2048] ][ fp16 scales16[32] ] + // + // Registers/pointers: + // s2: pA (points at A superblock header; used to load fp32 a_scale) + // s3: pA_qs (points at A int8 data within the current superblock) + // s4: pB_scales (points at B int8 per-K16 scales) + // s5: pB_hmask (points at B sign mask area) + // s6: pB_qs (points at B 2-bit packed qs area) + // s8: pB_scales16 (points at B fp16 scales16[32] at the end of block) + // s7: pB_base (base pointer to current B block; used for block-to-block stride) + + // t2 = number of K256 superblocks + "mv t2, %[KBLKS] \n\t" + // t3 = number of K64 chunks per K256 superblock (256 / 64) + "li t3, 4 \n\t" + + // A pointers + "mv s2, %[pA] \n\t" // s2 = pA_superblock (a_scale at +0) + "addi s3, %[pA], 4+32 \n\t" // s3 = pA_qs (skip a_scale + a_sum[16]) + + // B pointers for nrow_block_q3_k<32> + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales[32*16]) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + // scales16 is at the end of the block: qs(2048) after hmask + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" // s8 = pB_scales16 (fp16 scales16[32]) + "mv s7, %[pB] \n\t" // s7 = pB_base (for next-block address calc) + + // v31: final FP32 accumulator for N=32 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ---- Preload B scales16[32] and build FP16 scale vector used by vmadot.hp ---- + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" // load fp16 scales16[32] + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" // broadcast/pack to match lanes + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" // v30 = scales16 * (1/16) + + // v24-v27: fp16 partial accumulators for a K64 chunk (vmadot.hp outputs) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + // HP vmadot: vle*10 vecIns*38 vmadot.hp*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" // loop over K256 superblocks + "K64_LPST%=: \n\t" // loop over 4 x K64 chunks + + // ------------------------------------------------------------ + // K0-15: load B scales + {hmask, qs} + A data; unpack and dot + // ------------------------------------------------------------ + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" // B int8 scales for this K16 + "addi s4, s4, 128 \n\t" + + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" // B hmask for this K16 + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v3, (s3) \n\t" // A int8 data for this K16 + "addi s3, s3, 64 \n\t" + + // Convert B int8 scales to FP16 and apply scales16*(1/16) + "vsetvli t0, x0, e8, m1 \n\t" + "vfwcvt.f.x.v v28, v2 \n\t" // int8 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v1, v28, v30 \n\t" // v1: FP16 scale vector for vmadot.hp + "vfmul.vv v29, v29, v30 \n\t" + + // Unpack B 2-bit qs + hmask -> signed int8 in v12..v15 + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + // (Next K16 unpack path uses a fresh hmask load) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Prepare another group from packed qs (bit shifts) + apply sign from hmask + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vsrl.vi v16, v8, 6 \n\t" + "vsrl.vi v17, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v18, v10, 6 \n\t" + "vsrl.vi v19, v11, 6 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v16, v16, -4, v0.t \n\t" + + // A shift for the second dot within this K64 + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + // Dot products with FP16 scaling (accumulate into v24..v27) + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v12, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v13, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v14, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v15, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v16, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v17, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v18, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v19, v1, 7, i8 \n\t" + + // (K32-47 / K48-63 blocks continue unchanged...) + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vmv.v.v v1, v29 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v3, v3, 4 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + + "vsrl.vi v20, v8, 6 \n\t" + "vsrl.vi v21, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v22, v10, 6 \n\t" + "vsrl.vi v23, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v20, v20, -4, v0.t \n\t" + + // K48-63 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v8, v4, 6 \n\t" + "vsrl.vi v9, v5, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v10, v6, 6 \n\t" + "vsrl.vi v11, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v8, v8, -4, v0.t \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v20, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v21, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v22, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v23, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v8, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v9, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v10, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v11, v1, 7, i8 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ---- End of K64 chunk: reduce fp16 accumulators -> fp32 and scale by A ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v24, v25, 1 \n\t" + "vpack.vv v14, v26, v27, 1 \n\t" + "vpack.vv v16, v12, v14, 2 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v26, v16 \n\t" // fp16 -> fp32 vector (qsum * b_scales) + + // Load A scale and advance A pointer to next K256 superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // next B block base + "addi s3, s2, 4+32 \n\t" // reset A data pointer for next block + + // Advance B pointers to next K256 superblock + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" + "addi s7, t4, 0 \n\t" + "addi t2, t2, -1 \n\t" + + // Final per-block scaling: a_scale * 16.0f + "fmul.s f0, f0, %[a_post_mul] \n\t" + // acc += (qsum * b_scales) * (a_scale*16) + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vf v31, f0, v26 \n\t" + + "beqz t2, BLK_LPND%= \n\t" + + // Preload next block's scales16 and rebuild v30 for vmadot.hp + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" + + // Reset fp16 partial accumulators for next K64 loop(s) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [q3_step] "f"(k_q3k_scale_step), + [a_post_mul] "f"(k_a_scale_post_mul) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "s2", "s3", "s4", "s5", "s6", "s7", "s8"); +#endif + } +} + +void gemm_kernel_i8i3k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q3_k<32>; + constexpr size_t NB_COLS = 32; //only support 32 in ASM + + const blk_type * b_base = reinterpret_cast(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t b_ncol_block_stride = sizeof(blk_type); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; + + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1* 4row 128bit + // Asum int16 * 16 4row 1024bit + // A M1K256 int8 4row 8192bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 16+128 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // v24-v27: K256 temp accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "vxor.vv v28, v0, v0 \n\t" // v28-v31: final accumulator + "vxor.vv v29, v0, v0 \n\t" + "vxor.vv v30, v0, v0 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ordinary vmadot: vle*13 vecIns*96 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // ========== K0-15: First K16 sub-block ========== + // Load B INT8 scale factors (32 cols × 16 K16 blocks) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v8, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // Load B quantized data (32 cols × 16 elements × 2bit, stored in 4 groups) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // Load B hmask (32 cols × 16bit sign mask) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Load A data (4 rows × 16 elements × INT8) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v12, (s3) \n\t" + "addi s3, s3, 256 \n\t" // Jump to next row + "vle8.v v13, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v14, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v15, (s3) \n\t" + "addi s3, s3, -768+64 \n\t" // Back to first row, advance 16 elements + + // Pack A data: merge 4 rows into 2 vectors + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v13, 1 \n\t" + "vpack.vv v18, v14, v15, 1 \n\t" + "vpack.vv v2, v16, v18, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" // 4 rows × cols 0-7 + "vmadot v18, v2, v13, i8 \n\t" // 4 rows × cols 8-15 + "vmadot v20, v2, v14, i8 \n\t" // 4 rows × cols 16-23 + "vmadot v22, v2, v15, i8 \n\t" // 4 rows × cols 24-31 + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" // Merge cols 0-15 + "vpack.vv v14, v20, v22, 2 \n\t" // Merge cols 16-31 + "vpack.vv v16, v12, v14, 3 \n\t" // Inter-row results (INT16) + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // INT8 → INT16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // INT16 → INT32 + + // Accumulate to K256 accumulator: qsum * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" // Row 0 + "vmacc.vv v25, v17, v23 \n\t" // Row 1 + "vmacc.vv v26, v18, v23 \n\t" // Row 2 + "vmacc.vv v27, v19, v23 \n\t" + + // ========== K16-31, K32-47, K48-63: Similar processing ========== + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 8 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 4 \n\t" + "vsll.vi v13, v5, 4 \n\t" + "vsll.vi v14, v6, 4 \n\t" + "vsll.vi v15, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" + "vmadot v18, v2, v13, i8 \n\t" + "vmadot v20, v2, v14, i8 \n\t" + "vmadot v22, v2, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 2 \n\t" + "vsll.vi v13, v5, 2 \n\t" + "vsll.vi v14, v6, 2 \n\t" + "vsll.vi v15, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v3, v3, 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ========== K256 superblock complete, apply scale factors ========== + // Load A's 4 row scale factors (FP32) + "flw f0, (s2) \n\t" + "flw f1, 4(s2) \n\t" + "flw f2, 8(s2) \n\t" + "flw f3, 12(s2) \n\t" + "add s2, s2, %[A_STR] \n\t" // Advance to next superblock + "add t4, s7, %[B_STR] \n\t" // t4 = next B block address + "addi s3, s2, (4+32)*4 \n\t" + + // Load B FP16 global scale factors (32 cols) + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (s6) \n\t" + + // Update B pointers to next block + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // ========== Type conversion and final scaling ========== + // FP16 → FP32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v9, v8 \n\t" + + // INT32 → FP32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + + // Compute a_scale * b_scale (4 rows) + "vfmul.vf v12, v9, f0 \n\t" + "vfmul.vf v13, v9, f1 \n\t" + "vfmul.vf v14, v9, f2 \n\t" + "vfmul.vf v15, v9, f3 \n\t" + + // Final accumulation: result += qsum * a_scale * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vv v28, v12, v24 \n\t" + "vfmacc.vv v29, v13, v25 \n\t" + "vfmacc.vv v30, v14, v26 \n\t" + "vfmacc.vv v31, v15, v27 \n\t" + + // Prepare for next K superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // Clear K256 accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + + // ========== Store results (4 rows × 32 cols) ========== + "mv t5, %[pC] \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v28, (%[pC]) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v29, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v30, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v31, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [A_STR] "r"(a_nrow_block_stride), [LDC] "r"(ldc * 4) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "f2", "f3", "s2", "s3", "s4", "s5", "s6", "s7"); + } +} + +void gemm_kernel_i8i4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... +#if 0 + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*7 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" // Bzp u8 -> u16 + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vfcvt.f.x.v v16, v26 \n\t" // zp i16 -> fp16 + "vadd.vi v18, v16, 0 \n\t" + "vadd.vi v20, v16, 0 \n\t" + "vadd.vi v22, v16, 0 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + // mac result * b_scale; f16*f16->f32 + "vfwmul.vv v31, v30, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + +#endif + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(uint8_t) + // b zp + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // Bzp uint8_t * N32 256bit; + // || scl*32..(fp16) | zp*32(uint8) | blk0 blk1 ... blk31 || scl*32..(fp16) ... + // || Element || Element ... + + //bias always be nullptr +#if 0 + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*6 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v31, (s4) \n\t" // B zp 32Row*uint8 = 256bit + "addi s4, s4, 32+128*4 \n\t" + + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v31, x0 \n\t" // Bzp u8 -> u16 + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vfwcvt.f.f.v v22, v30 \n\t" // b_scale fp16 -> fp32 + "vfcvt.f.x.v v18, v26 \n\t" // zp i16 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfwadd.vv v20, v18, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // mac result * b_scale; f32*f32->f32 + "vfmul.vv v31, v22, v20 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#endif + } + } +} + +void gemm_kernel_i8i4_hp_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" // init acc to zero + "mv t4, %[BK] \n\t" + "li t0, 0x4c00 \n\t" // 16 in fp16 + "fmv.h.x fa0, t0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + "li t5, 8 \n\t" + "addi t6, %[A], 288 \n\t" // point to blk scale + "flh ft1, (t6) \n\t" + "addi t6, %[A], 272 \n\t" // point to asum + + // init the acc fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v16, v18, v18 \n\t" + "vxor.vv v17, v18, v18 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v18, v18 \n\t" + + "INNER_BLK_LOOP%=: \n\t" + // load a sum and scale + "flh fa1, (t6) \n\t" + "addi t6, t6, 2 \n\t" + "flh ft0, (%[A]) \n\t" + "addi %[A], %[A], 2 \n\t" + // load A + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (%[A]) \n\t" // 1x32@i8 + "addi %[A], %[A], 32 \n\t" + + // load scale B and B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (%[B]) \n\t" // b_scale fp16 + "addi %[B], %[B], 64 \n\t" + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vfmul.vf v8, v8, ft0 \n\t" // scale b * scale a + "vfmul.vf v9, v8, fa0 \n\t" + "vfmul.vf v10, v8, fa1 \n\t" // scale b * scale a * asm + "vfwmacc.vf v31, ft1, v10 \n\t" // asum * scale a * scale b * blk scale + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v0, v8, v9, 3 \n\t" + "vsrl.vi v28, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v2, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v3, v28, v28, 3 \n\t" // hi4 of A + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v16, v3, v4, v0, 4, i4 \n\t" // high 4 + "vmadotsu.hp v17, v3, v5, v0, 5, i4 \n\t" + "vmadotsu.hp v18, v3, v6, v0, 6, i4 \n\t" + "vmadotsu.hp v19, v3, v7, v0, 7, i4 \n\t" + "vmadotu.hp v16, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v17, v2, v5, v0, 1, i4 \n\t" + "vmadotu.hp v18, v2, v6, v0, 2, i4 \n\t" + "vmadotu.hp v19, v2, v7, v0, 3, i4 \n\t" + + "addi t5, t5, -1 \n\t" + "bgtz t5, INNER_BLK_LOOP%= \n\t" + + "vpack.vv v8, v16, v17, 1 \n\t" + "vpack.vv v12, v18, v19, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t4, t4, -1 \n\t" + "vfwmacc.vf v31, ft1, v20 \n\t" + //"vsetvli t0, x0, e32, m1 \n\t" + //"vfmul.vf v31, v31, ft1 \n\t" // blk scale + + // update A ptr + "addi %[A], t6, 2 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "ft0", "ft1"); + } + } else { + // TODO: support quant_b_zp for i8i4 hp kernel + GGML_ABORT("gemm_kernel_i8i4_hp_m1 with quant_b_zp is not supported yet"); + } +} + +void gemm_kernel_i8i4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; +#if 0 + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vwmul.vx v10, v8, t1 \n\t" // 8*asum + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vrgather.vi v0, v10, 0 \n\t" + "vrgather.vi v1, v10, 1 \n\t" + "vrgather.vi v2, v10, 2 \n\t" + "vrgather.vi v3, v10, 3 \n\t" + + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v1 \n\t" + "vadd.vv v18, v18, v2 \n\t" + "vadd.vv v19, v19, v3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc*4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#else + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vsll.vi v8, v8, 3 \n\t" // asum * 8 + "vfcvt.f.x.v v9, v8 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vrgather.vi v10, v9, 0 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v16, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v17, v16, 4 \n\t" + "vnpack4.vv v12, v16, v17, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v16, v10, v10,0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v20, v16, v16,0 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vpack.vv v18, v20, v20, 0 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + "vfwmul.vv v18, v21, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#endif + } + } else { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2\n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + // load a sum + "lh s1, (%[A]) \n\t" + "lh s2, 2(%[A]) \n\t" + "lh s3, 4(%[A]) \n\t" + "lh s4, 6(%[A]) \n\t" + "addi %[A], %[A], 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v0, v10, s1 \n\t" + "vwmul.vx v2, v10, s2 \n\t" + "vwmul.vx v4, v10, s3 \n\t" + "vwmul.vx v6, v10, s4 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v2 \n\t" + "vadd.vv v18, v18, v4 \n\t" + "vadd.vv v19, v19, v6 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST]\n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3", "s1", "s2", "s3", "s4"); + } + } +} + +void gemm_kernel_i8i4_hp_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_SUBBLKS_PER_SUPERBLK = 8; + constexpr size_t K_SUBBLK_LEN = 32; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + GGML_ASSERT(count_m >= 4); + + // Contract: + // - computes a 4-row x 32-col tile per inner invocation + // - A is q8 HP packed in m4 layout, one logical K256 block at a time + // - B is q4 HP packed in N32 tiles, optionally with a separate zp area + // - tail-N is currently not handled here; the caller must provide full N32 tiles + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * K_SUBBLKS_PER_SUPERBLK + + (quant_b_zp ? NB_COLS * K_SUBBLKS_PER_SUPERBLK * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * 4; + const size_t a_subblk_stride = q8_hp_blk_size(K_SUBBLK_LEN, false, false) * 4; + + if (quant_b_zp != nullptr) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the with-zp path. + // + // A: M4 x K256 q8 HP block + // - split into 8 x K32 subblocks + // - each K32 subblock is 136B: + // 8B = 4 x fp16 row scales + // 128B = 4 x int8[32] row payloads + // - trailer after 8 subblocks is 72B: + // 4 rows x fp16[8] a_sum values, indexed as [row][ksi] + // 4 rows x fp16 scale_avg tail + // + // B: N32 x K256 q4 HP block with explicit zp area + // - each K32 subblock is 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload for 32 columns x 32 k-elements + // - zp is stored separately, not interleaved with the 576B payload block + // - one K256 superblock is laid out as: + // 8 x (scale + qs) blocks = 4608B + // 8 x zp[32] = 256B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // row1/row2/row3 are at +16/+32/+48 bytes + // - s5: current B (scale + qs) K32 subblock base + // - s6: current B zp[32] base for this ksi + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576, B_zp += 32 + // - per ki : skip the 72B A trailer and advance B to the next 4864B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + const _Float16 hp_scale_0125 = (_Float16) 0.125f; + + // VPR grouping used below: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : zp u8 / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "li t1, 4608 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + // load Bzp[32] for the current ksi from the dedicated zp area + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (s6) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + "vfwcvt.f.xu.v v10, v8 \n\t" // uint8 -> fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + "fmul.h ft1, ft1, %[HP0125] \n\t" + "fmul.h ft2, ft2, %[HP0125] \n\t" + "fmul.h ft3, ft3, %[HP0125] \n\t" + "fmul.h ft4, ft4, %[HP0125] \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> mf2 + "vfmul.vv v10, v10, v16 \n\t" // zp * ascale * bscale; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v10, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v10, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v10, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v10, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "addi s6, s6, 32 \n\t" // next zp[32] + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + "mv s5, s6 \n\t" // s6 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), + [HP1] "f"(hp_scale_1), [HP0125] "f"(hp_scale_0125) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s5", "s6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v10", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", + "memory"); + } + return; + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the no-zp path. + // + // A layout is identical to the with-zp branch. + // + // B: N32 x K256 q4 HP block without explicit zp storage + // - each K32 subblock is still 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload + // - zp is implicit and treated as a constant value 8 in the kernel + // - one K256 superblock therefore contains only: + // 8 x (scale + qs) blocks = 4608B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // - s5: current B (scale + qs) K32 subblock base + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576 + // - per ki : skip the 72B A trailer and advance B to the next 4608B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + + // VPR grouping used below matches the with-zp path: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : implicit zp lane / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v16, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v16, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v16, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v16, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" //N32in1register + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + // s5 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), [HP1] "f"(hp_scale_1) + : "t0", "t2", "t3", "t4", "t5", "t6", "s5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v10", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", "memory"); + } + return; + } +} + +void gemm_kernel_i8mxfp4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 1); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // MXFP4 no-zp: per column per k-block stride = scale_e8m0(1B) + qs(16B) + qh(4B) = 21B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // qh sign/high-bit mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // qs packed 4-bit magnitudes: n×k_blks×16 + n * k_blks * sizeof(uint8_t); // scale: n×k_blks×1 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (q8 block with per-block scale and stored sum field): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (u8×32) + // s5 = pB qh (sign/high-bit mask, 128B) + // s6 = pB qs (packed 4-bit magnitudes, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales u8 (loaded as bytes; later widened) + // v0 = qh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = qs unpack/pack temporaries (build signed vmadot lanes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32 \n\t" // s5 = pBh (pB + 32B scale) + "addi s6, %[pB], 32+128 \n\t" // s6 = pBs (pB + 32 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 672B = 32(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load qs (512B = 4 VRF) from s6, advance s6 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = qs N32K32 packed 4-bit magnitudes + "addi s6, s6, 128*4+128+32 \n\t" // s6 += 672 (512+128+32) + + // ---- load B scale (32B = 32×u8) from s4, advance s4 by 672 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_u8 × 32 + "addi s4, s4, 32+128*4+128 \n\t" // s4 += 672 (32+512+128) + + // ---- load qh (128B = 1 VRF) from s5, advance s5 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = qh N32K32 sign/high-bit packed + "addi s5, s5, 128+32+128*4 \n\t" // s5 += 672 (128+32+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + // ---- Decode packed MXFP4 payload into a vmadot-friendly signed-lane layout ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to 0 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- int8 dot products over the decoded MXFP4 lane groups ---- + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v2, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v2, 1 \n\t" + "vadd.vi v28, v2, -1 \n\t" + "vsll.vi v28, v28, 23 \n\t" + "vsll.vv v28, v30, v2, v0.t \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8mxfp4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 4); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + // MXFP4 block layout per K32/N32 tile: + // [scale_e8m0 x 32][qh sign/high-bit mask x 128B][qs packed 4-bit magnitudes x 512B] + // There is no explicit zp stream; qh is combined with qs to reconstruct signed MXFP4 values. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 32 \n\t" + "addi t2, %[B], 160 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "addi %[B], %[B], 672 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + + // Decode the packed MXFP4 payload once for the whole tile and expand it + // into a vmadot-friendly layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v26, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v26, 1 \n\t" + "vadd.vi v24, v26, -1 \n\t" + "vsll.vi v18, v24, 23 \n\t" + "vsll.vv v18, v30, v26, v0.t \n\t" + + // Row 0: dot(A0, decoded MXFP4 lane groups), accumulate in int32 and + // then apply A/B scaling. + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i5_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // ========================================================================= + // i8i5: 8-bit activation × 5-bit weight (4-bit low + 1-bit high mask) + // + // B layout per N32K32 k-block (no-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 191] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [192.. 703] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 704B per k-block stride + // + // B layout per N32K32 k-block (with-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 95 ] : zp_uint8 × 32 (32B) + // [96 .. 223] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [224.. 735] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 736B per k-block stride + // + // Bh format per N8K32 sub-block (32B): + // K rows × N cols × 1bit packed as bytes (8 cols per byte, K groups of 4B) + // Byte k gives 8 mask bits for columns N7..N0 at k-th K-element. + // + // Computation: + // B5bit_signed = (Bs | (Bh << 4)) - zp + // dot(A, B5) = dot(A, Bs_u4) + 16*dot(A, Bh_u1) - zp*asum + // No-zp: implicit zp = 16 (unsigned [0..31] centered at 16) + // With-zp: explicit zp from data + // + // ========================================================================= + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 no-zp: per column per k-block stride = fp16(2B) + i4(16B) + i1(4B) = 22B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // Bh i1 mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // Bs i4 data: n×k_blks×16 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32) + // s5 = pB Bh (i1 mask, 128B) + // s6 = pB Bs (i4 packed, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBh (pB + 64B scale) + "addi s6, %[pB], 32*2+128 \n\t" // s6 = pBs (pB + 64 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 704B = 64(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+64 \n\t" // s6 += 704 (512+128+64) + + // ---- load B scale (64B = 32×fp16) from s4, advance s4 by 704 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64+128*4+128 \n\t" // s4 += 704 (64+512+128) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+64+128*4 \n\t" // s5 += 704 (128+64+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + //// vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 with-zp: per column per k-block stride = fp16(2B)+zp(1B)+i4(16B)+i1(4B)=23B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // Bs i4: n×k_blks×16 + n * k_blks * (blk_len / 8) + // Bh i1: n×k_blks×4 + n * k_blks * sizeof(uint8_t) + // zp: n×k_blks×1 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32); 每个 k-block 先 +64 指向 zp,再 +672 到下一个 block + // s5 = pB Bh (i1 mask, 128B) (offset +96) + // s6 = pB Bs (i4 packed, 512B) (offset +224) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) / later reused to hold Bzp bytes + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBh (pB + 64B scale + 32B zp = pB+96) + "addi s6, %[pB], 32*3+128 \n\t" // s6 = pBs (pB + 96 + 128 = pB+224) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 736B = 64(scale) + 32(zp) + 128(Bh) + 512(Bs) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+96 \n\t" // s6 += 736 (512+128+96) + + // ---- load B scale (64B = 32×fp16) from s4; then s4 points to zp[32] ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64 \n\t" // s4 += 64 (now points to zp) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+96+128*4 \n\t" // s5 += 736 (128+96+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4+128 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8i5_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + + GGML_UNUSED(count_m); + GGML_UNUSED(blk_len); + + // This kernel computes a 4x32 output tile. For each K32 block we decode the + // packed Q5 weights once and reuse the decoded vectors across the 4 A rows. + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1); + + if (quant_b_zp) { + // Q5_1 block layout per K32/N32 tile: + // [scale_fp16 x 32][zp_u8 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> zp stream, t2 -> qh bitmask stream, s5 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 96 \n\t" + "addi s5, %[B], 224 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t2) \n\t" + "vl4r.v v8, (s5) \n\t" + "addi %[B], %[B], 736 \n\t" + + // Decode Q5 payload once for the whole tile: + // 1) split `qs` low/high nibbles, + // 2) repack into bytes, + // 3) use the `qh` mask to inject bit4 (+16) where needed, + // 4) expand into the vmadot-friendly layout reused by all 4 rows. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v3, (t1) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * zp, then scale by A/B scales. + // The widen/mul correction sequence intentionally matches the proven m1 Q5_1 path. + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vwaddu.vx v28, v3, x0 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v12, v28, s1 \n\t" + "vwmul.vx v14, v28, s2 \n\t" + "vwmul.vx v20, v28, s3 \n\t" + "vwmul.vx v22, v28, s4 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v12 \n\t" + "vadd.vv v25, v25, v14 \n\t" + "vadd.vv v26, v26, v20 \n\t" + "vadd.vv v27, v27, v22 \n\t" + "vfcvt.f.x.v v12, v24 \n\t" + "vfcvt.f.x.v v14, v25 \n\t" + "vfcvt.f.x.v v20, v26 \n\t" + "vfcvt.f.x.v v22, v27 \n\t" + "vfmul.vv v12, v12, v18 \n\t" + "vfmul.vv v14, v14, v18 \n\t" + "vfmul.vv v20, v20, v18 \n\t" + "vfmul.vv v22, v22, v18 \n\t" + "vfmacc.vf v4, fa0, v12 \n\t" + "vfmacc.vf v5, fa1, v14 \n\t" + "vfmacc.vf v6, fa2, v20 \n\t" + "vfmacc.vf v7, fa3, v22 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { + // Q5_0 block layout per K32/N32 tile: + // [scale_fp16 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + // There is no explicit zp stream; the implicit midpoint correction is +16. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 192 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + "addi %[B], %[B], 704 \n\t" + + // Decode Q5 payload once for the whole tile and expand it into the vmadot layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * 16 (implicit Q5_0 midpoint correction). + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "slli s1, s1, 4 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "slli s2, s2, 4 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "slli s3, s3, 4 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "slli s4, s4, 4 \n\t" + "vadd.vx v24, v24, s1 \n\t" + "vadd.vx v25, v25, s2 \n\t" + "vadd.vx v26, v26, s3 \n\t" + "vadd.vx v27, v27, s4 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i8_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 2048bit + // 4VRF, N32K32, 8192bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... + + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*64 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4 \n\t" + "vl4r.v v8, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*8 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "addi s2, s2, 6+32 \n\t" // AScale + Asum(FP32+i16) + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v28, v8, v9, 1 \n\t" + "vupack.vv v30, v10, v11, 1 \n\t" + + "vslidedown.vi v4, v3, 4 \n\t" + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v24, i8 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadot v18, v3, v26, i8 \n\t" // M0 N8 - N15 + "vmadot v20, v3, v28, i8 \n\t" // M0 N16 - N23 + "vmadot v22, v3, v30, i8 \n\t" // M0 N24 - N31 + + "vmadot v16, v4, v25, i8 \n\t" + "vmadot v18, v4, v27, i8 \n\t" + "vmadot v20, v4, v29, i8 \n\t" + "vmadot v22, v4, v31, i8 \n\t" + + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i8_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = k_blks * sizeof(ggml_fp16_t) + k_blks * blk_len; + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16+8 \n\t" // Ascl+Asum; FP32*4+i16*4 + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + "vl4r.v v8, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v2, v0, v0, 1 \n\t" + + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v4, v8, v9, 1 \n\t" + "vupack.vv v6, v10, v11, 1 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot v16, v2, v24, i8 \n\t" + "vmadot v18, v2, v26, i8 \n\t" + "vmadot v20, v2, v4, i8 \n\t" + "vmadot v22, v2, v6, i8 \n\t" + "vmadot v16, v3, v25, i8 \n\t" + "vmadot v18, v3, v27, i8 \n\t" + "vmadot v20, v3, v5, i8 \n\t" + "vmadot v22, v3, v7, i8 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz %[BK], BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +} + +void moe_m2_gemm_kernel_i8i4_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, + ldc); +#else + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.x v10, t1 \n\t" + "vmv.v.x v11, t2 \n\t" + + "vpack.vv v18, v10, v11, 1 \n\t" + "vsll.vi v18, v18, 3 \n\t" // mul 8 + "vfcvt.f.x.v v18, v18 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vor.vv v19, v18, v18 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { +# if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +# else + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v19, v19 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v21, v21, v21 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + // asum*zp + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v2, v10, t1 \n\t" + "vwmul.vx v4, v10, t2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "vfcvt.f.x.v v2, v2 \n\t" + "vfcvt.f.x.v v4, v4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwcvt.f.f.v v16, v20 \n\t" + + "vfwcvt.f.f.v v18, v14 \n\t" + + // +asum*zp + "vsetvli t0, x0, e32, m1 \n\t" + "vfadd.vv v16, v16, v2 \n\t" + "vfadd.vv v17, v17, v4 \n\t" + "vfmul.vv v16, v16, v18 \n\t" + "vfmul.vv v17, v17, v18 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +# endif + } +#endif +} + +void moe_m2_gemm_kernel_i8i5_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1); + + GGML_UNUSED(blk_len); + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/Bh/Bs and advance to the next q5_0 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" // v0 = Bh N32K32 1-bit packed + "addi %[B], %[B], 128 \n\t" + "vl4r.v v8, (%[B]) \n\t" // v8..v11 = Bs N32K32 i4 + "addi %[B], %[B], 512 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + "slli t3, t3, 4 \n\t" + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + "vadd.vx v25, v25, t3 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v1 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + "addi t5, %[B], 64 \n\t" // t5 = zp (32B) + "addi t6, %[B], 96 \n\t" // t6 = qh (128B) + "addi s1, %[B], 224 \n\t" // s1 = qs (512B) + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/zp/Bh/Bs and advance to the next q5_1 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 736 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t6) \n\t" // v0 = Bh N32K32 1-bit packed + "addi t6, t6, 736 \n\t" + "vl4r.v v8, (s1) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s1, s1, 736 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // q5_1 uses explicit zp, so keep a_sum unshifted here. + // [4*32]*2 + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (t5) \n\t" // v4 = Bzp N32 uint8 + "addi t5, t5, 736 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v4, x0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + "vwmul.vx v31, v28, t3 \n\t" + + // b_scale fp16 -> fp32 + "vfwcvt.f.f.v v28, v1 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + "vadd.vv v25, v25, v31 \n\t" + + // a_scale * b_scale; + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } +} + +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i2k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i2k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i2k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, + ldc); +#else + gemm_kernel_i8i2k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i3k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i3k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + moe_m2_gemm_kernel_i8i4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i8_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i8_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + //moe_m2_gemm_kernel_i8mxfp4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i5_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i5_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i5_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + moe_m2_gemm_kernel_i8i5_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 2; +} + +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.cpp b/ggml/src/ggml-cpu/spacemit/ime_env.cpp new file mode 100644 index 000000000..a13ba391d --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.cpp @@ -0,0 +1,320 @@ +#include "ime_env.h" + +#include "ggml-impl.h" +#include "spine_mem_pool.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { +bool spine_core_info::get_spine_core_info(std::vector & result) { + static std::unordered_map spine_march_mapping_ = { + {0x8000000058000001, spine_core_arch_id::core_arch_x60 }, + { 0x8000000041000001, spine_core_arch_id::core_arch_a60 }, + { 0x8000000058000002, spine_core_arch_id::core_arch_x100}, + { 0x8000000041000002, spine_core_arch_id::core_arch_a100}, + }; + + result.clear(); + std::ifstream file("/proc/cpuinfo"); + std::string line; + + std::vector> cpu_info_list; + + uint64_t current_processor = spine_invalid_core_id; + uint64_t current_marchid = 0; + bool has_processor = false; + bool has_marchid = false; + + if (!file.is_open()) { + return false; + } + + while (std::getline(file, line)) { + if (line.substr(0, 9) == "processor") { + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + current_processor = std::stoi(line.substr(colon_pos + 1)); + has_processor = true; + } + + has_marchid = false; + } else if (line.substr(0, 7) == "marchid") { + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + std::string marchid_str = line.substr(colon_pos + 1); + marchid_str.erase(std::remove_if(marchid_str.begin(), marchid_str.end(), isspace), marchid_str.end()); + current_marchid = std::stoull(marchid_str, nullptr, 16); + has_marchid = true; + } + } + } + + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + if (has_processor && has_marchid) { + for (auto & cpu_info : cpu_info_list) { + if (cpu_info[0] != spine_invalid_core_id && + spine_march_mapping_.find(cpu_info[1]) != spine_march_mapping_.end()) { + auto core_info = spine_core_info(); + core_info.core_id = cpu_info[0]; + core_info.arch_id = spine_core_arch_id(spine_march_mapping_[cpu_info[1]]); + + result.push_back(core_info); + } + } + } + + return has_processor && has_marchid; +} + +namespace { +uint16_t hex_string_to_u16(const std::string & hex_str) { + try { + size_t pos = 0; + if (hex_str.substr(0, 2) == "0x" || hex_str.substr(0, 2) == "0X") { + pos = 2; + } + unsigned long result = std::stoul(hex_str.substr(pos), nullptr, 16); + if (result > std::numeric_limits::max()) { + throw std::out_of_range("Converted value is out of range for uint16_t"); + } + return static_cast(result); + } catch (const std::invalid_argument & e) { + throw std::invalid_argument("Invalid hexadecimal string"); + } catch (const std::out_of_range & e) { + throw; + } +} + +const char * spine_mem_pool_backend_to_string(spine_mem_pool_backend backend) { + switch (backend) { + case spine_mem_pool_backend::none: + return "NONE"; + case spine_mem_pool_backend::posix_memalign: + return "POSIX"; + case spine_mem_pool_backend::transparent_hugepage: + return "HPAGE"; + case spine_mem_pool_backend::hugetlb_1g: + return "HPAGE1GB"; + } + + return "unknown"; +} + +spine_mem_pool_backend parse_mem_backend(const char * mem_backend_str) { + if (mem_backend_str == nullptr || mem_backend_str[0] == '\0') { + return spine_mem_pool_backend::transparent_hugepage; + } + + std::string value(mem_backend_str); + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char ch) { return static_cast(std::tolower(ch)); }); + + if (value == "none") { + return spine_mem_pool_backend::none; + } + + if (value == "posix") { + return spine_mem_pool_backend::posix_memalign; + } + + if (value == "hpage") { + return spine_mem_pool_backend::transparent_hugepage; + } + + if (value == "hpage1gb") { + return spine_mem_pool_backend::hugetlb_1g; + } + + throw std::runtime_error("invalid SPACEMIT_MEM_BACKEND: " + value + ", expected NONE, POSIX, HPAGE or HPAGE1GB"); +} +} // namespace + +spine_env_info::spine_env_info() { + num_cores = static_cast(std::thread::hardware_concurrency()); + spine_core_info::get_spine_core_info(core_info_list); + + // special for x60 K1 + if (core_info_list.size() == 8 && core_info_list[0].arch_id == spine_core_arch_id::core_arch_x60) { + for (int i = 0; i < 4; i++) { + core_info_list[i].arch_id = spine_core_arch_id::core_arch_a60; + } + } + + // special for qemu + if (core_info_list.size() == 0) { + char * spine_core_arch_str = getenv("SPACEMIT_CORE_ARCH"); + if (spine_core_arch_str != nullptr) { + auto arch_id = hex_string_to_u16(spine_core_arch_str); + for (int i = 0; i < num_cores; i++) { + auto core_info = spine_core_info(); + core_info.core_id = i; + core_info.arch_id = spine_core_arch_id{ arch_id }; + core_info_list.push_back(core_info); + } + } + } + + if (core_info_list.size() == 0) { + throw std::runtime_error( + "Failed to get SPACEMIT_CORE_ARCH from environment or failed to parse it from /proc/cpuinfo"); + } + + char * spine_perfer_core_arch_str = getenv("SPACEMIT_PERFER_CORE_ARCH"); + if (spine_perfer_core_arch_str != nullptr && spine_perfer_core_arch_str != "") { + perfer_core_arch_id = spine_core_arch_id{ hex_string_to_u16(spine_perfer_core_arch_str) }; + } + + char * spine_perfer_core_id_str = getenv("SPACEMIT_PERFER_CORE_ID"); + std::vector perfer_core_id_vec; + if (spine_perfer_core_id_str != nullptr && spine_perfer_core_id_str != "") { + std::string perfer_core_id_str(spine_perfer_core_id_str); + size_t start = 0; + size_t end = 0; + while ((end = perfer_core_id_str.find(',', start)) != std::string::npos) { + std::string core_id_substr = perfer_core_id_str.substr(start, end - start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + start = end + 1; + } + std::string core_id_substr = perfer_core_id_str.substr(start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + } + + perfer_core_ids.reserve(num_cores); + if (perfer_core_arch_id == spine_core_arch_id::core_arch_none) { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + num_perfer_cores++; + perfer_core_arch_id = core_arch_id; + cpu_mask |= (1ULL << core_info.core_id); + perfer_core_ids.push_back(core_info.core_id); + } + } + } else { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + num_perfer_cores++; + cpu_mask |= (1ULL << core_info.core_id); + + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + perfer_core_ids.push_back(core_info.core_id); + } + } + } + if (num_perfer_cores == 0) { + GGML_ABORT("can not find core with arch id %x for SPACEMIT_PERFER_CORE_ARCH in core info list\n", + (uint16_t) perfer_core_arch_id); + } + } + + if (perfer_core_id_vec.size() > 0) { + perfer_core_ids.clear(); + cpu_mask = 0; + num_perfer_cores = 0; + for (int core_id : perfer_core_id_vec) { + if (core_id < 0 || core_id >= num_cores) { + GGML_ABORT("invalid core id in SPACEMIT_PERFER_CORE_ID: %d, should be between 0 and %d\n", core_id, + num_cores - 1); + } + auto core_info = core_info_list[core_id]; + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + cpu_mask |= (1ULL << core_id); + perfer_core_ids.push_back(core_id); + } else { + GGML_ABORT( + "core id %d in SPACEMIT_PERFER_CORE_ID has arch id %x which does not match " + "SPACEMIT_PERFER_CORE_ARCH %x\n", + core_id, (uint16_t) core_arch_id, (uint16_t) perfer_core_arch_id); + } + } + std::string perfer_core_id_vec_str; + for (int core_id : perfer_core_id_vec) { + perfer_core_id_vec_str += std::to_string(core_id) + ","; + } + perfer_core_id_vec_str.pop_back(); + GGML_LOG_DEBUG("SPACEMIT_PERFER_CORE_ID is set, perferred core ids: %s\n", perfer_core_id_vec_str.c_str()); + num_perfer_cores = static_cast(perfer_core_id_vec.size()); + } + + use_ime1 = perfer_core_arch_id == spine_core_arch_id::core_arch_a60 || + perfer_core_arch_id == spine_core_arch_id::core_arch_x100; + + use_ime2 = perfer_core_arch_id == spine_core_arch_id::core_arch_a100; + + mem_backend = parse_mem_backend(getenv("SPACEMIT_MEM_BACKEND")); + char * spine_disable_tcm_str = getenv("SPACEMIT_DISABLE_TCM"); + auto user_disable_tcm = spine_disable_tcm_str != nullptr && strcmp(spine_disable_tcm_str, "0") != 0; + + if (!user_disable_tcm) { + spine_mem_pool_tcm_info tcm_info; + if (spine_mem_pool_tcm_init(&tcm_info)) { + use_tcm = tcm_info.available; + tcm_blk_size = tcm_info.blk_size; + GGML_LOG_DEBUG("CPU_RISCV64_SPACEMIT: tcm is available, blk_size: %zu, blk_num: %zu, is_fake_tcm: %d\n", + tcm_info.blk_size, tcm_info.blk_num, tcm_info.is_fake_tcm); + + for (auto & core_info : core_info_list) { + auto core_arch_head = (uint16_t) (core_info.arch_id) >> 12; + if (core_arch_head != 0xA) { + aicpu_id_offset++; + } else { + break; + } + } + } + } + + GGML_LOG_DEBUG( + "CPU_RISCV64_SPACEMIT: num_cores: %d, num_perfer_cores: %d, perfer_core_arch_id: %x, exclude_main_thread: %d, " + "use_ime1: %d, use_ime2: %d, mem_backend: %s, cpu_mask: %lx, aicpu_id_offset: %d\n", + num_cores, num_perfer_cores, (uint16_t) perfer_core_arch_id, exclude_main_thread, use_ime1, use_ime2, + spine_mem_pool_backend_to_string(mem_backend), cpu_mask, aicpu_id_offset); + + const size_t init_barrier_size = sizeof(spine_barrier_t) * spine_init_barrier_count; + init_barrier = + static_cast(spine_mem_pool_shared_mem_alloc(init_barrier_size, alignof(spine_barrier_t))); + if (init_barrier != nullptr) { + init_barrier_is_shared_mem = true; + } else { + GGML_LOG_WARN("CPU_RISCV64_SPACEMIT: failed to allocate init_barrier from shared mem, falling back to heap\n", + __func__); + init_barrier = new spine_barrier_t[spine_init_barrier_count]; + } + + spine_barrier_init(init_barrier, spine_init_barrier_count, 2); +} + +spine_env_info::~spine_env_info() { + if (init_barrier_is_shared_mem) { + spine_mem_pool_shared_mem_free(init_barrier); + } else { + delete[] init_barrier; + } + + init_barrier = nullptr; + init_barrier_is_shared_mem = false; +} + +spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.h b/ggml/src/ggml-cpu/spacemit/ime_env.h new file mode 100644 index 000000000..a6ca06d26 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.h @@ -0,0 +1,55 @@ +#pragma once + +#include "spine_barrier.h" +#include "spine_mem_pool.h" + +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +constexpr uint64_t spine_invalid_core_id = 0xFFFFFFFF; +constexpr size_t spine_init_barrier_count = 16; + +enum class spine_core_arch_id : uint16_t { + core_arch_none = 0, + core_arch_x60 = 0x503C, + core_arch_x100 = 0x5064, + core_arch_x200 = 0x50C8, + core_arch_a60 = 0xA03C, + core_arch_a100 = 0xA064, + core_arch_a200 = 0xA0C8, +}; + +struct spine_core_info { + uint64_t core_id{ spine_invalid_core_id }; + spine_core_arch_id arch_id{ spine_core_arch_id::core_arch_none }; + + static bool get_spine_core_info(std::vector & result); +}; + +struct spine_env_info { + std::vector core_info_list; + std::vector perfer_core_ids; + int aicpu_id_offset{ 0 }; + int num_cores{ 0 }; + int num_perfer_cores{ 0 }; + spine_core_arch_id perfer_core_arch_id{ spine_core_arch_id::core_arch_none }; + bool exclude_main_thread{ false }; + bool use_ime2{ false }; + bool use_ime1{ false }; + bool use_tcm{ false }; + spine_mem_pool_backend mem_backend{ spine_mem_pool_backend::transparent_hugepage }; + uint64_t tcm_blk_size{ 0 }; + uint64_t cpu_mask{ 0 }; + spine_barrier_t * init_barrier{ nullptr }; + bool init_barrier_is_shared_mem{ false }; + + spine_env_info(); + ~spine_env_info(); +}; + +extern spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ime_kernels.h index 757063415..0a1fafffb 100644 --- a/ggml/src/ggml-cpu/spacemit/ime_kernels.h +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.h @@ -1,26 +1,189 @@ #pragma once +#include #include +#include + +namespace spacemit_kernels { + +#define BLOCK_QNK_LEN 256 + +template struct nrow_block_q2_k { + // [4bit scale + 4bit zp] * N * 16 + uint8_t scales[N * BLOCK_QNK_LEN / 16]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; + uint16_t zeros16[N]; +}; + +template struct nrow_block_q3_k { + // [8bit scale] * N * 16 + int8_t scales[N * 16]; + // [b0, b1, b2, b3, b4, b5, b6, b7] ... [b248, b249, b250, b251, b252, b253, b254, b255] + uint8_t hmask[N * BLOCK_QNK_LEN / 8]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; +}; + +template struct nrow_block_mxfp4 { + uint8_t e[N]; + uint8_t qh[4 * N]; + uint8_t qs[16 * N]; +}; + +template struct __attribute__((packed)) nrow_block_q5_1 { + uint16_t scales16[N]; + uint8_t zp[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_1<1>) == sizeof(uint8_t) + 22, "wrong nrow_block_q5_1 block size/padding"); + +template struct __attribute__((packed)) nrow_block_q5_0 { + uint16_t scales16[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_0<1>) == 22, "wrong nrow_block_q5_0 block size/padding"); + +using gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t *, const uint8_t *, const uint8_t *, float *, size_t, size_t, size_t, size_t)>; + +using moe_gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t **, const uint8_t *, const uint8_t *, float **, size_t, size_t, size_t, size_t)>; -namespace sqnbitgemm_spacemit_ime { namespace ime1 { -size_t gemm_kernel_i8i4(size_t blk_len, - const std::byte * quant_a_ptr, - const std::byte * quant_b_data, - const float * quant_b_scale, - const std::byte * quant_b_zp, - float * c_ptr, - size_t count_m, - size_t count_n, - size_t count_k, - size_t block_count_k, - size_t ldc, - const float * bias, - const size_t scale_stride); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); -void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); -void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime + +namespace ime2 { +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/repack.cpp b/ggml/src/ggml-cpu/spacemit/repack.cpp new file mode 100644 index 000000000..3c879c4b7 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.cpp @@ -0,0 +1,1795 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP + +#include "repack.h" + +#include "ggml-common.h" +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include "ime_kernels.h" + +#include +#include +#include +#include + +// clang-format off +#if defined(__riscv) + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +#error "riscv v extension or v_intrinsic not enabled" +#else +#include +#endif + +#if !defined(__riscv_zfh) +#error "riscv zfh extension not enabled" +#endif + +#else +#error "riscv not enabled in this build" +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +// clang-format on + +template constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +template struct block_with_zp { + ggml_half d[N]; // deltas for N qK_1 blocks + uint8_t zp[N]; // zero points for N qK_1 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks +}; + +// control size +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), + "wrong block_with_zp<4,16> size/padding"); + +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + +static_assert(sizeof(block<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<4,32> size/padding"); +static_assert(sizeof(block_with_zp<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16 + 32 * sizeof(uint8_t), + "wrong block_with_zp<4,32> size/padding"); + +using block_q4_0x16 = block<4, 16>; +using block_q4_1x16 = block_with_zp<4, 16>; +using block_q8_0x16 = block<8, 16>; + +using block_q4_0x32 = block<4, 32>; +using block_q4_1x32 = block_with_zp<4, 32>; +using block_q8_0x32 = block<8, 32>; + +struct block_q4_0x32x256 { + block_q4_0x32 blocks[8]; // [f16 * 32 | i4 * 32 * 32] * 8 +}; + +struct block_q4_1x32x256 { + block_q4_0x32 blocks[8]; + uint8_t zps[32 * 8]; +}; + +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + } + } + + return out; +} + +static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x16 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + } + } + + return out; +} + +static int repack_q4_0_to_q4_0_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static inline void get_scale_min_k4(int j, + const uint8_t * GGML_RESTRICT q, + uint8_t * GGML_RESTRICT d, + uint8_t * GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +static int repack_q4_k_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 16); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_q4_0x32 make_block_q4_0x32(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x32 out; + assert(QK4_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b17] ......... [b30 b31] + out.qs[i * QK4_0 / 2 + QK4_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q4_1x32 make_block_q4_1x32(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x32 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[i * QK4_1 / 2 + QK4_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q8_0x32 make_block_q8_0x32(block_q8_0 * in, unsigned int blck_size_interleave) { + block_q8_0x32 out; + GGML_ASSERT(QK8_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * QK8_0, in[i].qs, QK8_0); + } + + return out; +} + +static int repack_q2_k_to_q2_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const block_q2_K * src = (const block_q2_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q2_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint8_t qs_aux[256] = { 0 }; + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q2_K * src_block = &src[(b + i) * nblocks + x]; + + // scale for [16, N] + for (int j = 0; j < 16; j++) { + auto zp_aux = (dst->scales[j * nrows_interleaved + i]) & 0xF0; + + dst->scales[j * nrows_interleaved + i] = (src_block->scales[j] & 0x0F) | zp_aux; + } + + // zp for [N, 16] + for (int j = 0; j < 16; j++) { + auto scale_aux = (dst->scales[16 * i + j]) & 0x0F; + + dst->scales[16 * i + j] = (src_block->scales[j] & 0xF0) | scale_aux; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + dst->scales16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + dst->zeros16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + dst++; + } + } + + return 0; +} + +static int repack_q3_k_to_q3_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q3_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * src = (const block_q3_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q3_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q3_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint32_t b_scale_aux[4] = { 0 }; + uint8_t qs_aux[256] = { 0 }; + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q3_K * src_block = &src[(b + i) * nblocks + x]; + + uint32_t * auxs = b_scale_aux; + int8_t * scale = (int8_t *) auxs; + memcpy(auxs, src_block->scales, 12); + + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + for (int j = 0; j < 16; j++) { + dst->scales[j * nrows_interleaved + i] = scale[j] - 32; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + //memcpy(dst->hmask + i * 32, src_block->hmask, 32); + + // from nrows_interleaved * [32byte] + // to 16 * [nrows_interleaved * uint16_t] + uint16_t * dst_mask = ((uint16_t *) dst->hmask) + i; + for (int j = 0; j < 16; j++, dst_mask += nrows_interleaved) { + uint8_t b_shift = j / 2; + uint8_t * b_mask_col = (uint8_t *) (src_block->hmask + (j % 2) * 16); + // b0 - b15 + uint16_t msk_out_0 = 0; + + for (int k = 0; k < 8; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + for (int k = 8; k < 16; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + + dst_mask[0] = msk_out_0; + } + + dst->scales16[i] = src_block->d; + } + + dst++; + } + } + + return 0; +} + +static int repack_q4_0_to_q4_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_0_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32x256 * dst = (block_q4_0x32x256 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + dst->blocks[j] = make_block_q4_0x32(dst_tmp, interleave_block); + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_1_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32x256 * dst = (block_q4_1x32x256 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + + block_q4_0x32 * dst_block = &dst->blocks[j]; + uint8_t * dst_zp = dst->zps + j * nrows_interleaved; + + for (int i = 0; i < nrows_interleaved; i++) { + float d = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + + dst_block->d[i] = GGML_FP32_TO_FP16(d); + dst_zp[i] = static_cast(mid); + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + k] = + (dst_tmp[i].qs[k * 2] & 0x0F) | ((dst_tmp[i].qs[k * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + QK4_1 / 4 + k] = + ((dst_tmp[i].qs[k * 2] & 0xF0) >> 4) | (dst_tmp[i].qs[k * 2 + 1] & 0xF0); + } + } + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_0_to_q4_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack. +static int repack_q4_0_to_q4_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_0 / 2; // 16 + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + // d is at offset 0 of each block_q4_0, stride between rows = row_stride + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + // For each row i: + // src qs[16]: [b0|b16] [b1|b17] ... [b15|b31] (lo nibble = b_j, hi nibble = b_{j+16}) + // dst qs low 8B: (qs[2j] & 0x0F) | ((qs[2j+1] & 0x0F) << 4) for j=0..7 + // dst qs high 8B: ((qs[2j] >> 4)) | (qs[2j+1] & 0xF0) for j=0..7 + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); // qs[0], qs[2], ..., qs[14] + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); // qs[1], qs[3], ..., qs[15] + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_1_to_q4_1_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack + zp computation. +static int repack_q4_1_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_1 / 2; // 16 + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_1); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_1 * col_src = src + x; + + // --- 1) Gather d and m, compute zp = clamp(nearbyint(-m/d), 0, 15) --- + // block_q4_1 layout: [d(f16), m(f16), qs[16]] + // d is at byte offset 0, m is at byte offset 2 from each block start + { + const uint8_t * dm_base = (const uint8_t *) &col_src->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + ggml_half * d_dst = dst->d; + uint8_t * zp_dst = dst->zp; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + + // stride load d (f16) from each row + vuint16m1_t vd_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd_raw, vl); + + // stride load m (f16) from each row (offset +2 bytes from d) + vuint16m1_t vm_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + 2 + offset * row_stride), row_stride, vl); + + // convert to f32 for zp computation: zp = nearbyint(-m / d) + vfloat16m1_t vd_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vd_raw); + vfloat16m1_t vm_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vm_raw); + + // -m / d in f16 directly (SpaceMIT X60 supports f16 arithmetic) + vfloat16m1_t v_neg_m = __riscv_vfneg_v_f16m1(vm_f16, vl); + vfloat16m1_t v_ratio = __riscv_vfdiv_vv_f16m1(v_neg_m, vd_f16, vl); + + // Convert to f32 for nearbyint, then clamp + vfloat32m2_t v_ratio_f32 = __riscv_vfwcvt_f_f_v_f32m2(v_ratio, vl); + + // Use integer rounding: convert f32 -> int (rounds to nearest) + vint32m2_t v_zp_i32 = __riscv_vfcvt_x_f_v_i32m2(v_ratio_f32, vl); + + // clamp to [0, 15] + v_zp_i32 = __riscv_vmax_vx_i32m2(v_zp_i32, 0, vl); + v_zp_i32 = __riscv_vmin_vx_i32m2(v_zp_i32, 15, vl); + + // narrow i32 -> u8 + vint16m1_t v_zp_i16 = __riscv_vncvt_x_x_w_i16m1(v_zp_i32, vl); + vint8mf2_t v_zp_i8 = __riscv_vncvt_x_x_w_i8mf2(v_zp_i16, vl); + vuint8mf2_t v_zp_u8 = __riscv_vreinterpret_v_i8mf2_u8mf2(v_zp_i8); + __riscv_vse8_v_u8mf2(zp_dst + offset, v_zp_u8, vl); + + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_k_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q6_k_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + block_q8_0 dst_tmp[32]; + int8_t aux8[QK4_1]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrow_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + int i = 0; + for (; i < nrow_real; i++) { + const uint8_t * q4 = src[x + i * nblocks].ql; + const uint8_t * qh = src[x + i * nblocks].qh; + const int8_t * scales = src[x + i * nblocks].scales; + float d = GGML_FP16_TO_FP32(src[x + i * nblocks].d); + + q4 += 64 * (bi / 4); + qh += 32 * (bi / 4); + int8_t * GGML_RESTRICT a = aux8; + + int8_t bi_idx = bi % 4; + + if (bi_idx == 0) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + } + } else if (bi_idx == 1) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + } + } else if (bi_idx == 2) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + } + } else if (bi_idx == 3) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + } + a = aux8; + + float a_max_abs = 0.0f; + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + for (int l = 0; l < 16; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_0)); + } + + for (int l = 16; l < 32; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_1)); + } + + float reflect_scale = a_max_abs / ((1 << 7) - 1); + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + for (int l = 0; l < 16; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_0), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + for (int l = 16; l < 32; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_1), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + dst_tmp[i].d = GGML_FP32_TO_FP16(reflect_scale); + + memcpy(dst_tmp[i].qs, a, 32 * sizeof(int8_t)); + } + + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q6_k_to_q8_0_32_bl +// Vectorizes the Q6_K dequant -> requant pipeline using RVV intrinsics. +// For each sub-block (bi), dequant 32 Q6_K values to int6 -> apply two sub-block scales -> +// find max abs -> compute reflect_scale -> requant to int8 -> gather d with stride load. +static int repack_q6_k_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q6_K); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + // --- 1) Gather 32 d values with stride load --- + // We need to compute reflect_scale per row first, so gather d later. + // Process each row: dequant Q6_K sub-block -> requant to Q8_0 + for (int i = 0; i < nrows_interleaved; i++) { + const block_q6_K * src_blk = &src[x + i * nblocks]; + const uint8_t * q4 = src_blk->ql + 64 * (bi / 4); + const uint8_t * qh = src_blk->qh + 32 * (bi / 4); + const int8_t * scales = src_blk->scales; + float d = GGML_FP16_TO_FP32(src_blk->d); + + int8_t bi_idx = bi % 4; + + // --- Dequant 32 Q6_K values to int6 (range [-32, 31]) using RVV --- + // vl = 32 for e8m2 (VLEN=256) or loop for smaller VLEN + const size_t vl16 = __riscv_vsetvl_e8m1(16); + + vint8m1_t va_lo, va_hi; // 16 elements each + + if (bi_idx == 0) { + // a[l] = (q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_lo, 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_hi, 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 1) { + // a[l] = (q4[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 2, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 2, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 2) { + // a[l] = (q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 4, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 4, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else { // bi_idx == 3 + // a[l] = (q4[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 6, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 6, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } + + // --- Widen to i16 for scaled abs computation --- + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + + // Widen i8 -> i16 -> f32 for abs*scale computation + vint16m2_t va_lo_w = __riscv_vsext_vf2_i16m2(va_lo, vl16); + vint16m2_t va_hi_w = __riscv_vsext_vf2_i16m2(va_hi, vl16); + + // Compute |a[l] * scale_0| for lo half, |a[l] * scale_1| for hi half + vfloat32m4_t vf_lo = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_lo_w, vl16), vl16); + vfloat32m4_t vf_hi = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_hi_w, vl16), vl16); + + vfloat32m4_t vabs_lo = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_lo, scale_0, vl16), vl16); + vfloat32m4_t vabs_hi = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_hi, scale_1, vl16), vl16); + + // Find max abs across both halves + vfloat32m4_t vabs_max = __riscv_vfmax_vv_f32m4(vabs_lo, vabs_hi, vl16); + + // Reduce to scalar max + vfloat32m1_t vzero = __riscv_vfmv_v_f_f32m1(0.0f, 1); + vfloat32m1_t vmax_red = __riscv_vfredmax_vs_f32m4_f32m1(vabs_max, vzero, vl16); + float a_max_abs = __riscv_vfmv_f_s_f32m1_f32(vmax_red); + + float reflect_scale = a_max_abs / 127.0f; + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + // --- Requant: a[l] = clamp(nearbyint(a[l] * reflect_scale_x), -128, 127) --- + vfloat32m4_t vscaled_lo = __riscv_vfmul_vf_f32m4(vf_lo, reflect_scale_0, vl16); + vfloat32m4_t vscaled_hi = __riscv_vfmul_vf_f32m4(vf_hi, reflect_scale_1, vl16); + + // fcvt.x rounds to nearest (using current rounding mode) + vint32m4_t vi_lo = __riscv_vfcvt_x_f_v_i32m4(vscaled_lo, vl16); + vint32m4_t vi_hi = __riscv_vfcvt_x_f_v_i32m4(vscaled_hi, vl16); + + // Clamp to [-128, 127] + vi_lo = __riscv_vmax_vx_i32m4(vi_lo, -128, vl16); + vi_lo = __riscv_vmin_vx_i32m4(vi_lo, 127, vl16); + vi_hi = __riscv_vmax_vx_i32m4(vi_hi, -128, vl16); + vi_hi = __riscv_vmin_vx_i32m4(vi_hi, 127, vl16); + + // Narrow i32 -> i16 -> i8 + vint16m2_t vi16_lo = __riscv_vncvt_x_x_w_i16m2(vi_lo, vl16); + vint16m2_t vi16_hi = __riscv_vncvt_x_x_w_i16m2(vi_hi, vl16); + vint8m1_t vi8_lo = __riscv_vncvt_x_x_w_i8m1(vi16_lo, vl16); + vint8m1_t vi8_hi = __riscv_vncvt_x_x_w_i8m1(vi16_hi, vl16); + + // Store d and qs directly into dst block + dst->d[i] = GGML_FP32_TO_FP16(reflect_scale); + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + __riscv_vse8_v_i8m1(dq, vi8_lo, vl16); + __riscv_vse8_v_i8m1(dq + 16, vi8_hi, vl16); + } + dst++; + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q8_0_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[0] % QK8_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrows_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + int i = 0; + for (; i < nrows_real; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q8_0_to_q8_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes scale gather + qs copy. +static int repack_q8_0_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK8_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q8_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q8_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Copy qs for each of the 32 rows (32 bytes per row) --- + { + for (int i = 0; i < 32; i++) { + const int8_t * sq = col_src[i * nblocks].qs; + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + + size_t len = QK8_0; + size_t idx = 0; + while (len > 0) { + size_t vl = __riscv_vsetvl_e8m2(len); + vint8m2_t vs = __riscv_vle8_v_i8m2(sq + idx, vl); + __riscv_vse8_v_i8m2(dq + idx, vs, vl); + idx += vl; + len -= vl; + } + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static void convert_mxfp4_to_5bit(const block_mxfp4 & src, spacemit_kernels::nrow_block_mxfp4<1> & dst) { + dst.e[0] = src.e; + + // Decode all 32 mxfp4 values to signed integers via kvalues_mxfp4 + int8_t vals[32]; + for (int j = 0; j < QK_MXFP4 / 2; j++) { + vals[j] = kvalues_mxfp4[src.qs[j] & 0xF]; + vals[j + QK_MXFP4 / 2] = kvalues_mxfp4[src.qs[j] >> 4]; + } + + // vals [b0, b1, b2, b3, ..., b30, b31] + // Pack abs into qs with reorder: [b0,b1]..[b14,b15]..[b30,b31] + for (int j = 0; j < QK_MXFP4 / 2; j++) { + uint8_t lo0 = static_cast(std::abs(vals[j * 2])); + uint8_t lo1 = static_cast(std::abs(vals[j * 2 + 1])); + dst.qs[j] = (lo0 & 0x0F) | ((lo1 & 0x0F) << 4); + } + + // Pack sign bits into qh[4] (32 bits total, 1 bit per weight) + // reorder: [0,1,2,...,15,16,17,...,31] after the qs reorder above + uint32_t sign_bits = 0; + for (int j = 0; j < 32; j++) { + if (vals[j] < 0) { + sign_bits |= (1u << j); + } + } + memcpy(dst.qh, &sign_bits, 4); +} + +static spacemit_kernels::nrow_block_mxfp4<32> make_block_mxfp4x32(spacemit_kernels::nrow_block_mxfp4<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_mxfp4<32> out; + GGML_ASSERT(QK_MXFP4 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.e[i] = in[i].e[0]; + } + + // qs: copy per-row 16 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * 16, in[i].qs, 16); + } + + // qh: copy per-row 4 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qh + i * 4, in[i].qh, 4); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_mxfp4<32> * dst = (spacemit_kernels::nrow_block_mxfp4<32> *) t->data; + const block_mxfp4 * src = (const block_mxfp4 *) data; + spacemit_kernels::nrow_block_mxfp4<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_MXFP4 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + convert_mxfp4_to_5bit(src[x + i * nblocks], dst_tmp[i]); + } + *dst++ = make_block_mxfp4x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static spacemit_kernels::nrow_block_q5_1<32> make_block_q5_1x32(spacemit_kernels::nrow_block_q5_1<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_1<32> out; + GGML_ASSERT(QK5_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + out.zp[i] = in[i].zp[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + QK5_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static spacemit_kernels::nrow_block_q5_0<32> make_block_q5_0x32(spacemit_kernels::nrow_block_q5_0<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_0<32> out; + GGML_ASSERT(QK5_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + QK5_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static int repack_q5_0_to_q5_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_0<32> * dst = (spacemit_kernels::nrow_block_q5_0<32> *) t->data; + const block_q5_0 * src = (const block_q5_0 *) data; + spacemit_kernels::nrow_block_q5_0<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_0 & s = src[x + i * nblocks]; + + dst_tmp[i].scales16[0] = s.d; + memcpy(dst_tmp[i].qs, s.qs, sizeof(dst_tmp[i].qs)); + memcpy(dst_tmp[i].qh, s.qh, sizeof(dst_tmp[i].qh)); + } + *dst++ = make_block_q5_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_1_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_1 * src = (const block_q5_1 *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_1 & s = src[x + i * nblocks]; + + float d = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + + if (d == 0.0f) { + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(std::fabs(m)); + dst_tmp[i].zp[0] = m < 0.0f ? 1 : 0; + memset(dst_tmp[i].qh, 0, sizeof(dst_tmp[i].qh)); + memset(dst_tmp[i].qs, m > 0.0f ? 0x11 : 0x00, sizeof(dst_tmp[i].qs)); + continue; + } + + float mid = std::nearbyintf(-m / d); + mid = std::min(31.0f, std::max(0.0f, mid)); + + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d); + dst_tmp[i].zp[0] = static_cast(mid); + + // qs: copy low 4 bits directly (same nibble packing) + memcpy(dst_tmp[i].qs, s.qs, QK5_1 / 2); + + // qh: copy 5th bit directly + memcpy(dst_tmp[i].qh, s.qh, 4); + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_k_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK5_1 == 8); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + + float d1 = d * sc; + float m1 = min * m; + + float mid = std::nearbyintf(m1 / d1); + mid = std::min(31.0f, std::max(0.0f, mid)); + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d1); + dst_tmp[i].zp[0] = static_cast(mid); + + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK5_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + + // Extract the 5th bit (qh) for this sub-block + // block_q5_K.qh[32]: for sub-block j, the 5th bit is at bit position j in qh[l] + // qs was reordered: dst_qs maps to src weights [0,16,1,17,...,15,31] + // So qh must follow the same reorder to stay aligned with qs + // dst qh[4] = 32 bits for 32 weights in the reordered layout: + // byte 0: weights 0..7 (from src_qh[0..7]) + // byte 1: weights 8..15 (from src_qh[8..15]) + // byte 2: weights 16..23 (from src_qh[16..23]) + // byte 3: weights 24..31 (from src_qh[24..31]) + const uint8_t * src_qh = src[x + i * nblocks].qh; + for (int bi = 0; bi < 4; bi++) { + uint8_t qh_byte = 0; + for (int k = 0; k < 8; k++) { + int src_idx = bi * 8 + k; + qh_byte |= ((src_qh[src_idx] >> j) & 1) << k; + } + dst_tmp[i].qh[bi] = qh_byte; + } + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +namespace ggml::cpu::riscv64_spacemit { + +template int repack(ggml_tensor *, const void *, size_t); + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_k_to_q2_k_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_k_to_q3_k_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_0_to_q4_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_0_to_q4_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_0_256_32_bl_ref(t, 32, data, data_size); +#else + //return repack_q4_0_to_q4_0_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_1_to_q4_1_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_1_256_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q6_k_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q6_k_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q8_0_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q8_0_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_0_to_q5_0_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_1_to_q5_1_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_k_to_q5_1_32_bl(t, 32, data, data_size); +} + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/repack.h b/ggml/src/ggml-cpu/spacemit/repack.h new file mode 100644 index 000000000..950cbde75 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.h @@ -0,0 +1,14 @@ +#pragma once + +#include "ggml-common.h" +#include "ggml.h" + +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +template +int repack(ggml_tensor * t, const void * data, size_t data_size); + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp new file mode 100644 index 000000000..d2f897436 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp @@ -0,0 +1,3178 @@ +#include "rvv_kernels.h" + +#include "common.h" +#include "ggml.h" +#include "ops.h" +#include "string.h" + +#include +#include +#include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels::rvv { + +namespace { + +auto align_up(size_t value, size_t alignment) { + return (value + alignment - 1) / alignment * alignment; +} + +static inline bool flash_attn_ext_supported_d_vlen1024_vf16(int64_t d) { + return d > 0 && d <= 128; +} + +static inline bool flash_attn_ext_supported_shape_vlen1024_vf16(int64_t DK, int64_t DV) { + return flash_attn_ext_supported_d_vlen1024_vf16(DK) && flash_attn_ext_supported_d_vlen1024_vf16(DV); +} + +static inline float reduce_sum_f32m4_vlen1024(vfloat32m4_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m4_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +static inline float reduce_sum_f32m2_vlen1024(vfloat32m2_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m2_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +// Adapted from ggml_v_expf_m2 in vec.h. This is accurate enough for softmax. +static inline vfloat32m2_t rvv_expf_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl); + const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl); + const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl); + const vfloat32m2_t b = + __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl), 0x1.7f7d1cp-20f, n, vl); + const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl); + const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); + const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl); + const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl); + const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2( + __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl), + __riscv_vfmacc_vv_f32m2( + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl), + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl), u, vl), + u, vl); + + if (!__riscv_vcpop_m_b16(c, vl)) { + return __riscv_vfmacc_vv_f32m2(k, j, k, vl); + } + + const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl); + const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl); + const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl)); + const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl)); + const vfloat32m2_t r1 = + __riscv_vmerge_vvm_f32m2(__riscv_vfmacc_vv_f32m2(k, k, j, vl), + __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl), c, vl); + return __riscv_vmerge_vvm_f32m2(r1, __riscv_vfmul_vv_f32m2(s1, s1, vl), + __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl), vl); +} + +static inline vfloat32m2_t rvv_tanh_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t abs_x = __riscv_vfabs_v_f32m2(x, vl); + const vfloat32m2_t neg_2_abs = __riscv_vfmul_vf_f32m2(abs_x, -2.0f, vl); + const vfloat32m2_t exp_term = rvv_expf_approx_f32m2(neg_2_abs, vl); + const vfloat32m2_t numerator = __riscv_vfsub_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t denominator = __riscv_vfadd_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t tanh_abs = __riscv_vfneg_v_f32m2(__riscv_vfdiv_vv_f32m2(numerator, denominator, vl), vl); + const vbool16_t neg_mask = __riscv_vmflt_vf_f32m2_b16(x, 0.0f, vl); + const vfloat32m2_t tanh_neg = __riscv_vfneg_v_f32m2(tanh_abs, vl); + return __riscv_vmerge_vvm_f32m2(tanh_abs, tanh_neg, neg_mask, vl); +} + +static void rvv_softcap_tanh_inplace_f32(float * dst, int64_t dst_stride, int64_t tile_rows, int64_t n, float softcap) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride) { + float * dst_row = dst; + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m2(remaining); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst_row, vl); + v = rvv_tanh_approx_f32m2(v, vl); + v = __riscv_vfmul_vf_f32m2(v, softcap, vl); + __riscv_vse32_v_f32m2(dst_row, v, vl); + dst_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_softmax_exp_inplace_f32(float * dst, int64_t n, float max_value) { + float row_sum = 0.0f; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst, vl); + v = __riscv_vfsub_vf_f32m2(v, max_value, vl); + v = rvv_expf_approx_f32m2(v, vl); + __riscv_vse32_v_f32m2(dst, v, vl); + row_sum += reduce_sum_f32m2_vlen1024(v, vl); + dst += vl; + n -= vl; + } + return row_sum; +} + +static inline float rvv_add_max_inplace_f32(float * dst, const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline float rvv_softcap_add_max_inplace_f32(float * dst, const float * src, int64_t n, float softcap) { + if (softcap == 0.0f) { + return rvv_add_max_inplace_f32(dst, src, n); + } + + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t vdst = __riscv_vle32_v_f32m2(dst, vl); + vfloat32m2_t vsrc = __riscv_vle32_v_f32m2(src, vl); + vdst = rvv_tanh_approx_f32m2(vdst, vl); + vdst = __riscv_vfmul_vf_f32m2(vdst, softcap, vl); + vdst = __riscv_vfadd_vv_f32m2(vdst, vsrc, vl); + __riscv_vse32_v_f32m2(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m2_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline void rvv_zero_f32(float * dst, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t z = __riscv_vfmv_v_f_f32m4(0.0f, vl); + __riscv_vse32_v_f32m4(dst, z, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_scale_f32(float * dst, float scale, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t v = __riscv_vle32_v_f32m4(dst, vl); + v = __riscv_vfmul_vf_f32m4(v, scale, vl); + __riscv_vse32_v_f32m4(dst, v, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_add_inplace_f32(float * dst, + int64_t dst_stride, + const float * src, + int64_t src_stride, + int64_t tile_rows, + int64_t n) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride, src += src_stride) { + int64_t remaining = n; + float * dst_row = dst; + const float * src_row = src; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst_row, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src_row, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst_row, vdst, vl); + dst_row += vl; + src_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_max_f32(const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t v = __riscv_vle32_v_f32m4(src, vl); + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(v, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + src += vl; + n -= vl; + } + return max_val; +} + +static void rvv_pack_f32_as_scaled_f16(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + _Float16 * dst_row_ptr = (_Float16 *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + const vfloat16m2_t v16 = __riscv_vfncvt_f_f_w_f16m2(v32, vl); + __riscv_vse16_v_f16m2(dst_row_ptr, v16, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f16_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const _Float16 * row_ptr = (const _Float16 *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e16m2(remaining); + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(row_ptr, vl); + vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f32_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float * scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale[tq], vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static inline void rvv_transposed_s32_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e32, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vle32.v v1, (s2) \n\t" + "sh2add s2, t0, s2 \n\t" + "vle32.v v2, (s3) \n\t" + "sh2add s3, t0, s3 \n\t" + "vle32.v v3, (s4) \n\t" + "sh2add s4, t0, s4 \n\t" + "vle32.v v4, (s5) \n\t" + "sh2add s5, t0, s5 \n\t" + "vle32.v v5, (s6) \n\t" + "sh2add s6, t0, s6 \n\t" + "vle32.v v6, (s7) \n\t" + "sh2add s7, t0, s7 \n\t" + "vle32.v v7, (s8) \n\t" + "sh2add s8, t0, s8 \n\t" + "vssseg8e32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 32 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vsse32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 4 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_transposed_s16_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e16, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vle16.v v1, (s2) \n\t" + "sh1add s2, t0, s2 \n\t" + "vle16.v v2, (s3) \n\t" + "sh1add s3, t0, s3 \n\t" + "vle16.v v3, (s4) \n\t" + "sh1add s4, t0, s4 \n\t" + "vle16.v v4, (s5) \n\t" + "sh1add s5, t0, s5 \n\t" + "vle16.v v5, (s6) \n\t" + "sh1add s6, t0, s6 \n\t" + "vle16.v v6, (s7) \n\t" + "sh1add s7, t0, s7 \n\t" + "vle16.v v7, (s8) \n\t" + "sh1add s8, t0, s8 \n\t" + "vssseg8e16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 16 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vsse16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 2 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_qk_dot_tile_f16_x1(float * dst, + const _Float16 * q_row, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc = __riscv_vfwmacc_vf_f32m2(acc, q_row[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst, acc, vl); +} + +static inline void rvv_qk_dot_tile_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const _Float16 * q0, + const _Float16 * q1, + const _Float16 * q2, + const _Float16 * q3, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc0 = __riscv_vfwmacc_vf_f32m2(acc0, q0[d], k_vec, vl); + acc1 = __riscv_vfwmacc_vf_f32m2(acc1, q1[d], k_vec, vl); + acc2 = __riscv_vfwmacc_vf_f32m2(acc2, q2[d], k_vec, vl); + acc3 = __riscv_vfwmacc_vf_f32m2(acc3, q3[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst0, acc0, vl); + __riscv_vse32_v_f32m2(dst1, acc1, vl); + __riscv_vse32_v_f32m2(dst2, acc2, vl); + __riscv_vse32_v_f32m2(dst3, acc3, vl); +} + +static inline void rvv_pv_accumulate_f16_x1(float * dst, + const float * prob, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_pv_accumulate_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const float * prob0, + const float * prob1, + const float * prob2, + const float * prob3, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc0 = __riscv_vle32_v_f32m4(dst0 + d_off, vl); + vfloat32m4_t acc1 = __riscv_vle32_v_f32m4(dst1 + d_off, vl); + vfloat32m4_t acc2 = __riscv_vle32_v_f32m4(dst2 + d_off, vl); + vfloat32m4_t acc3 = __riscv_vle32_v_f32m4(dst3 + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc0 = __riscv_vfmacc_vf_f32m4(acc0, prob0[tk], v32, vl); + acc1 = __riscv_vfmacc_vf_f32m4(acc1, prob1[tk], v32, vl); + acc2 = __riscv_vfmacc_vf_f32m4(acc2, prob2[tk], v32, vl); + acc3 = __riscv_vfmacc_vf_f32m4(acc3, prob3[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst0 + d_off, acc0, vl); + __riscv_vse32_v_f32m4(dst1 + d_off, acc1, vl); + __riscv_vse32_v_f32m4(dst2 + d_off, acc2, vl); + __riscv_vse32_v_f32m4(dst3 + d_off, acc3, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_qk_dot_tile(float * dst, + const float * q_row, + const float * k_pack, + int64_t dk, + int64_t kv_tile, + float scale) { + const size_t vl = __riscv_vsetvl_e32m4(kv_tile); + vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat32m4_t k_vec = __riscv_vle32_v_f32m4(k_pack + d * kv_tile, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, q_row[d] * scale, k_vec, vl); + } + + __riscv_vse32_v_f32m4(dst, acc, vl); +} + +static inline void rvv_pv_accumulate(float * dst, + const float * prob, + const float * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e32m4(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat32m4_t v_vec = __riscv_vle32_v_f32m4(v_pack + tk * dv + d_off, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v_vec, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static void permute_transpose_impl(const ggml_tensor * src0, + ggml_tensor * dst, + int64_t batch, + int64_t m, + int64_t n, + int64_t batch_stride, + int64_t m_src_stride, + int64_t n_src_stride, + int64_t n_dst_stride, + int ith, + int nth) { + GGML_ASSERT(n_src_stride == sizeof(int32_t) || n_src_stride == sizeof(int16_t)); + + if (n_src_stride == sizeof(int32_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else if (n_src_stride == sizeof(int16_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else { + GGML_ABORT("not implemented"); + } +} + +template +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow(float ** pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + float ** sinks, + float ** dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK, + void * tcm_buffer, + size_t tcm_buffer_size) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + float S[QLEN] = { 0.0f }; // sum + float M[QLEN] = { -INFINITY }; // maximum KQ value + + _Float16 * kq16_buffer = (_Float16 *) tcm_buffer; + _Float16 * qv_buffer = kq16_buffer + QLEN * DV; + const size_t qkv_temp_buffer_size = (QLEN * DV + QLEN * DK) * sizeof(_Float16); + char * kv_tile_buffer = (char *) (qv_buffer + QLEN * DK); + + { + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + for (int64_t i = 0; i < QLEN; ++i) { + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq[i], DK), DK); + __riscv_vse16_v_f16m2(qv_buffer + i * DK, Q_q_v, DK); + } + } + + const uintptr_t scratch_addr = reinterpret_cast(kv_tile_buffer); + const size_t scratch_size = tcm_buffer_size > qkv_temp_buffer_size ? tcm_buffer_size - qkv_temp_buffer_size : 0; + const uintptr_t kq_tile_addr = align_up(scratch_addr, alignof(float)); + const size_t scratch_prefix = kq_tile_addr - scratch_addr; + const size_t packed_tile_size = + QLEN * sizeof(float) + DK * sizeof(_Float16) + DV * sizeof(_Float16) + sizeof(float); + const int64_t max_ic_tile_step = ((int64_t) __riscv_vsetvlmax_e16m1()) & ~((int64_t) 7); + const int64_t max_fit_by_tcm = + scratch_size > scratch_prefix ? (int64_t) ((scratch_size - scratch_prefix) / packed_tile_size) : 0; + const int64_t ic_tile_step = std::min(max_ic_tile_step, max_fit_by_tcm) & ~((int64_t) 7); + + const uintptr_t k_tile_addr = kq_tile_addr + QLEN * ic_tile_step * sizeof(float); + const uintptr_t v_tile_addr = k_tile_addr + DK * ic_tile_step * sizeof(_Float16); + const uintptr_t mv_tile_addr = v_tile_addr + ic_tile_step * DV * sizeof(_Float16); + + if (ic_tile_step >= 8) { + float * kq_tile_buffer = reinterpret_cast(kq_tile_addr); + _Float16 * k_tile_pack = reinterpret_cast<_Float16 *>(k_tile_addr); + _Float16 * v_tile_pack = reinterpret_cast<_Float16 *>(v_tile_addr); + float * mv_tile_pack = reinterpret_cast(mv_tile_addr); + + const int64_t k_tile_byte_stride = ic_tile_step * (int64_t) sizeof(_Float16); + + int64_t ic_step = 0; + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + if (mv != -INFINITY) { + const _Float16 * k_data = (const _Float16 *) (k_data_row + ic * nbk1); + const _Float16 * v_data = (const _Float16 *) (v_data_row + ic * nbv1); + + const vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2(k_data, DK); + const vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2(v_data, DV); + __riscv_vsse16_v_f16m2(k_tile_pack + ic_step, k_tile_byte_stride, k_data_v, DK); + __riscv_vse16_v_f16m2(v_tile_pack + ic_step * DV, v_data_v, DV); + mv_tile_pack[ic_step] = mv; + ic_step++; + } + + if (ic_step > 0 && (ic_step == ic_tile_step || ic == (nek1 - 1))) { + if constexpr (QLEN == 4) { + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc2 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc3 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + qk_acc2 = __riscv_vfwmacc_vf_f32m2(qk_acc2, qv_buffer[2 * DK + d], k_vec, qk_vl); + qk_acc3 = __riscv_vfwmacc_vf_f32m2(qk_acc3, qv_buffer[3 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + qk_acc2 = __riscv_vfmul_vf_f32m2(qk_acc2, scale, qk_vl); + qk_acc3 = __riscv_vfmul_vf_f32m2(qk_acc3, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 2 * ic_tile_step, qk_acc2, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 3 * ic_tile_step, qk_acc3, qk_vl); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + } + + for (int i = 0; i < QLEN; ++i) { + float * row_ptr = kq_tile_buffer + i * ic_tile_step; + const float tile_max = + rvv_softcap_add_max_inplace_f32(row_ptr, mv_tile_pack, ic_step, logit_softcap); + + const float Mold = M[i]; + + if (tile_max > Mold) { + const float ms = expf(Mold - tile_max); + M[i] = tile_max; + S[i] *= ms; + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, (_Float16) ms, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + } + + S[i] += rvv_softmax_exp_inplace_f32(row_ptr, ic_step, M[i]); + } + + if constexpr (QLEN == 4) { + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + vfloat16m2_t pv_acc2 = __riscv_vle16_v_f16m2(kq16_buffer + 2 * DV, DV); + vfloat16m2_t pv_acc3 = __riscv_vle16_v_f16m2(kq16_buffer + 3 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + pv_acc2 = + __riscv_vfmacc_vf_f16m2(pv_acc2, (_Float16) kq_tile_buffer[2 * ic_tile_step + tk], v16, DV); + pv_acc3 = + __riscv_vfmacc_vf_f16m2(pv_acc3, (_Float16) kq_tile_buffer[3 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 2 * DV, pv_acc2, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 3 * DV, pv_acc3, DV); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + } + + ic_step = 0; + } + } + } else { + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + const char * k_data = k_data_row + ic * nbk1; + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t k_data_v; + vfloat16m2_t v_data_v; + + if (mv != -INFINITY) { + k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + } else { + continue; + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t Q_q_v = __riscv_vle16_v_f16m2(qv_buffer + i * DK, DK); + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + s = s * scale; + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + s += mv; + + const float Mold = M[i]; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + if (s > M[i]) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M[i] = s; + ms = expf(Mold - M[i]); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M[i]); + } + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + S[i] = S[i] * ms + vs; // scale and increment sum with partial sum + } + } + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks[i]) { + const float s = *(sinks[i]); + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[i]) { + ms = expf(M[i] - s); + M[i] = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M[i]); + } + + S[i] = S[i] * ms + vs; + } + + // V /= S + const float S_inv = S[i] == 0.0f ? 0.0f : 1.0f / S[i]; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst[i], VKQ32_v, DV); + } +} + +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1(const float * pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + const float * sinks, + float * dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq, DK), DK); + + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + if (mv == -INFINITY) { + continue; + } + + const char * k_data = k_data_row + ic * nbk1; + + vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + + s = s * scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + + s += mv; // apply mask + + const float Mold = M; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + + S = S * ms + vs; // scale and increment sum with partial sum + } + + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks) { + const float s = *sinks; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + M = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + + // V /= S + const float S_inv = S == 0.0f ? 0.0f : 1.0f / S; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst, VKQ32_v, DV); +} + +} // namespace + +void memcpy1d(void * dst, const void * src, int64_t size) { + size_t byte_size_all = size; + size_t vlen = __riscv_vlenb() * 8; + if (vlen == 256) { + // 1024 bytes + __asm__ volatile( + // + "srli t0, %[size], 10 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v8, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v16, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v24, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v8, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v16, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v24, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 1023 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1"); + } else if (vlen == 1024) { + // 2048 bytes + __asm__ volatile( + // + "srli t0, %[size], 11 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "addi t2, %[s], 1024 \n\t" + "addi t3, %[d], 1024 \n\t" + "li t5, 2048 \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t5 \n\t" + "vle8.v v8, (t2) \n\t" + "add t2, t2, t5 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t5 \n\t" + "vse8.v v8, (t3) \n\t" + "add t3, t3, t5 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 2047 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m2, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t3", "t5"); + } else { + __asm__ volatile( + // + "add t1, %[size], zero \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t4", "t3"); + } +} + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size) { + for (int64_t i = 0; i < tile_rows; ++i) { + memcpy1d((char *) dst + i * dst_stride, (const char *) src + i * src_stride, size); + } +} + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + // parallelize by q rows using ggml_vec_dot_f32 + + float scale = *((float *) dst->op_params + 0); + float max_bias = *((float *) dst->op_params + 1); + float logit_softcap = *((float *) dst->op_params + 2); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int KV_row_size = DK * sizeof(_Float16) + DV * sizeof(_Float16); + + int ith = params->ith; + int ir_step = 1; + for (int ir = ir0; ir < ir1; ir += ir_step) { + // q indices + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + const int iq3_1 = (ir + 1) / (neq2 * neq1); + const int iq2_1 = (ir + 1 - iq3_1 * neq2 * neq1) / neq1; + const int iq1_1 = (ir + 1 - iq3_1 * neq2 * neq1 - iq2_1 * neq1); + + const int iq3_2 = (ir + 2) / (neq2 * neq1); + const int iq2_2 = (ir + 2 - iq3_2 * neq2 * neq1) / neq1; + const int iq1_2 = (ir + 2 - iq3_2 * neq2 * neq1 - iq2_2 * neq1); + + const int iq3_3 = (ir + 3) / (neq2 * neq1); + const int iq2_3 = (ir + 3 - iq3_3 * neq2 * neq1) / neq1; + const int iq1_3 = (ir + 3 - iq3_3 * neq2 * neq1 - iq2_3 * neq1); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + const ggml_fp16_t * mp = + mask ? (ggml_fp16_t *) ((char *) mask->data + iq1 * mask->nb[1] + (iq2 % mask->ne[2]) * mask->nb[2] + + (iq3 % mask->ne[3]) * mask->nb[3]) : + NULL; + + const bool mp_equal_2 = iq1_1 == iq1 && (iq2 % mask->ne[2]) == (iq2_1 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_1 % mask->ne[3]); + + const bool mp_equal_4 = mp_equal_2 && iq1_2 == iq1 && (iq2 % mask->ne[2]) == (iq2_2 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_2 % mask->ne[3]) && iq1_3 == iq1 && + (iq2 % mask->ne[2]) == (iq2_3 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_3 % mask->ne[3]); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + const int ik3_1 = iq3_1 / rk3; + const int ik2_1 = iq2_1 / rk2; + + const int ik3_2 = iq3_2 / rk3; + const int ik2_2 = iq2_2 / rk2; + + const int ik3_3 = iq3_3 / rk3; + const int ik2_3 = iq2_3 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const int iv3_1 = iq3_1 / rv3; + const int iv2_1 = iq2_1 / rv2; + + const int iv3_2 = iq3_2 / rv3; + const int iv2_2 = iq2_2 / rv2; + + const int iv3_3 = iq3_3 / rv3; + const int iv2_3 = iq2_3 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + + std::array pq_buffer; + std::array sinks_buffer; + std::array dst_buffer; + + if (tcm_buffer != nullptr && 4 * KV_row_size < tcm_buffer_size && ir < (ir1 - 3) && mp_equal_4 && + ik3_3 == ik3 && ik2_3 == ik2 && iv3_3 == iv3 && iv2_3 == iv2 && ik3_2 == ik3 && ik2_2 == ik2 && + iv3_2 == iv3 && iv2_2 == iv2 && ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 4; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + pq_buffer[2] = (float *) ((char *) q->data + (iq1_2 * nbq1 + iq2_2 * nbq2 + iq3_2 * nbq3)); + pq_buffer[3] = (float *) ((char *) q->data + (iq1_3 * nbq1 + iq2_3 * nbq2 + iq3_3 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + sinks_buffer[2] = sinks ? ((float *) ((char *) sinks->data)) + iq2_2 : nullptr; + sinks_buffer[3] = sinks ? ((float *) ((char *) sinks->data)) + iq2_3 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + dst_buffer[2] = (float *) ((char *) dst->data + (iq3_2 * ne2 * ne1 + iq2_2 + iq1_2 * ne1) * nb1); + dst_buffer[3] = (float *) ((char *) dst->data + (iq3_3 * ne2 * ne1 + iq2_3 + iq1_3 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<4>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else if (tcm_buffer != nullptr && 2 * KV_row_size < tcm_buffer_size && ir < (ir1 - 1) && mp_equal_2 && + ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 2; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<2>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else { + ir_step = 1; + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1( // + pq, // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks ? ((float *) ((char *) sinks->data)) + h : nullptr, // + (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK); + } + } +} + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + GGML_ASSERT(ne0 == DV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); + + GGML_ASSERT(neq1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(k->type == v->type); + const ggml_type kv_type = k->type; + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + float * param_list = (float *) dst->op_params; + float scale = param_list[0]; + float max_bias = param_list[1]; + float logit_softcap = param_list[2]; + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + int ith = params->ith; + + static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; + static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + + // Per-thread scratch layout: + // Q_f32: Q_TILE_SZ * DK + // KQ: Q_TILE_SZ * KV_TILE_SZ + // mask32: Q_TILE_SZ * KV_TILE_SZ + // VKQ32: Q_TILE_SZ * DV + // V32: KV_TILE_SZ * DV + // K_f32: DK * KV_TILE_SZ (transposed K tile) + float * base = (float *) params->wdata + ith * (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + + KV_TILE_SZ * DV + KV_TILE_SZ * DK + CACHE_LINE_SIZE_F32); + const size_t base_size = + (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + KV_TILE_SZ * DV + KV_TILE_SZ * DK) * + sizeof(float) + + CACHE_LINE_SIZE_F32; + + if (base_size <= tcm_buffer_size && tcm_buffer != nullptr) { + base = (float *) tcm_buffer; + } + + float S_M_Buf[Q_TILE_SZ * 2]; // buffer to hold S, M, bias for one tile to reduce register pressure in main loop + float * S = S_M_Buf; + float * M = S_M_Buf + Q_TILE_SZ; + + int ir = ir0; + while (ir < ir1) { + // q indices for the start of this tile + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + // Number of valid rows in this tile: + // - limited by tile size (Q_TILE_SZ) + // - limited by chunk boundary (ir1 - ir) + // - limited by head boundary (neq1 - iq1) to avoid crossing into next head + const int tile_rows = MIN(Q_TILE_SZ, MIN((int) (ir1 - ir), (int) (neq1 - iq1))); + GGML_ASSERT(tile_rows > 0); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + for (int i = 0; i < Q_TILE_SZ; ++i) { + S[i] = 0.; + M[i] = -INFINITY; + } + + float * Q_f32 = base; + float * KQ = (float *) ((char *) base + Q_TILE_SZ * DK * sizeof(float)); + float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; + float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; + float * V32 = VKQ32 + Q_TILE_SZ * DV; + float * K_f32 = V32 + KV_TILE_SZ * DV; + _Float16 * Q_f16 = (_Float16 *) Q_f32; + _Float16 * V_f16 = (_Float16 *) V32; + _Float16 * K_f16 = (_Float16 *) K_f32; + + rvv_zero_f32(VKQ32, Q_TILE_SZ * DV); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + if (kv_type == GGML_TYPE_F16) { + rvv_pack_f32_as_scaled_f16((uint8_t *) Q_f16, DK * sizeof(_Float16), (uint8_t *) pq, nbq1, tile_rows, DK, + scale); + } else { + memcpy2d(Q_f32, DK * sizeof(float), pq, nbq1, tile_rows, DK * sizeof(float)); + } + + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + const int kv_tile = (int) std::min((int64_t) KV_TILE_SZ, nek1 - ic); + + rvv_zero_f32(K_f32, DK * KV_TILE_SZ); + rvv_zero_f32(V32, KV_TILE_SZ * DV); + + // skip the tile entirely if all the masks are -inf + if (mask) { + bool can_skip = true; + const ggml_fp16_t * mp_row = + (const ggml_fp16_t *) ((const char *) mask->data + iq1 * mask->nb[1] + + (iq2 % mask->ne[2]) * mask->nb[2] + (iq3 % mask->ne[3]) * mask->nb[3]); + rvv_pack_scaled_f16_as_f32(mask32, KV_TILE_SZ * sizeof(float), mp_row + ic, mask->nb[1], tile_rows, + kv_tile, slope); + + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = 0; tk < kv_tile; tk++) { + if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { + can_skip = false; + } + } + // Pad remaining mask entries with -inf + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + + if (can_skip) { + continue; + } + } + + if (kv_type == GGML_TYPE_F16) { + rvv_transposed_s16_mn_to_nm((int8_t *) K_f16, KV_TILE_SZ * sizeof(_Float16), + (int8_t *) k->data + ic * nbk1 + ik2 * nbk2 + ik3 * nbk3, nbk1, kv_tile, + DK); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + rvv_qk_dot_tile_f16_x4(KQ + (tq + 0) * KV_TILE_SZ, KQ + (tq + 1) * KV_TILE_SZ, + KQ + (tq + 2) * KV_TILE_SZ, KQ + (tq + 3) * KV_TILE_SZ, + Q_f16 + (tq + 0) * DK, Q_f16 + (tq + 1) * DK, Q_f16 + (tq + 2) * DK, + Q_f16 + (tq + 3) * DK, K_f16, DK, kv_tile); + } + for (; tq < tile_rows; ++tq) { + rvv_qk_dot_tile_f16_x1(KQ + tq * KV_TILE_SZ, Q_f16 + tq * DK, K_f16, DK, kv_tile); + } + } else { + for (int tk = 0; tk < kv_tile; tk++) { + const char * k_data = (const char *) k->data + (ic + tk) * nbk1 + ik2 * nbk2 + ik3 * nbk3; + float * k_col = K_f32 + tk; + const float * k_src = (const float *) k_data; + for (int64_t dk = 0; dk < DK; ++dk) { + k_col[dk * KV_TILE_SZ] = k_src[dk]; + } + } + + for (int tq = 0; tq < tile_rows; ++tq) { + rvv_qk_dot_tile(KQ + tq * KV_TILE_SZ, Q_f32 + tq * DK, K_f32, DK, KV_TILE_SZ, scale); + } + } + + // Set padded KQ entries to -inf so softmax gives them zero weight + if (kv_tile < KV_TILE_SZ) { + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + KQ[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + } + + if (logit_softcap != 0.0f) { + rvv_softcap_tanh_inplace_f32(KQ, KV_TILE_SZ, tile_rows, KV_TILE_SZ, logit_softcap); + } + + if (mask) { + rvv_add_inplace_f32(KQ, KV_TILE_SZ, mask32, KV_TILE_SZ, tile_rows, KV_TILE_SZ); + } + + bool skip[Q_TILE_SZ] = {}; + + for (int tq = 0; tq < tile_rows; tq++) { + float * kq_row = KQ + tq * KV_TILE_SZ; + + const float tile_max = rvv_max_f32(kq_row, KV_TILE_SZ); + + if (tile_max == -INFINITY) { + skip[tq] = true; + continue; + } + + const float Mold = M[tq]; + const float Mnew = fmaxf(Mold, tile_max); + + if (Mnew > Mold) { + const float ms = expf(Mold - Mnew); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + S[tq] *= ms; + } + M[tq] = Mnew; + + S[tq] += rvv_softmax_exp_inplace_f32(kq_row, KV_TILE_SZ, Mnew); + } + + // Pack V as contiguous [KV_TILE_SZ][DV]. + if (kv_type == GGML_TYPE_F16) { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V_f16, DV * sizeof(_Float16), v_data, nbv1, kv_tile, DV * sizeof(_Float16)); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + if (skip[tq + 0] || skip[tq + 1] || skip[tq + 2] || skip[tq + 3]) { + for (int i = 0; i < 4; ++i) { + if (!skip[tq + i]) { + rvv_pv_accumulate_f16_x1(VKQ32 + (tq + i) * DV, KQ + (tq + i) * KV_TILE_SZ, V_f16, + KV_TILE_SZ, DV); + } + } + continue; + } + + rvv_pv_accumulate_f16_x4(VKQ32 + (tq + 0) * DV, VKQ32 + (tq + 1) * DV, VKQ32 + (tq + 2) * DV, + VKQ32 + (tq + 3) * DV, KQ + (tq + 0) * KV_TILE_SZ, + KQ + (tq + 1) * KV_TILE_SZ, KQ + (tq + 2) * KV_TILE_SZ, + KQ + (tq + 3) * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + for (; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate_f16_x1(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + } + } else { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V32, DV * sizeof(float), v_data, nbv1, kv_tile, DV * sizeof(float)); + + for (int tq = 0; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V32, KV_TILE_SZ, DV); + } + } + } + } + + // sinks (apply only to valid rows in the tile) + if (sinks) { + const float s = ((float *) ((char *) sinks->data))[h]; + + for (int tq = 0; tq < tile_rows; tq++) { + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[tq]) { + ms = expf(M[tq] - s); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + } else { + vs = expf(s - M[tq]); + } + + float S_temp = S[tq] * ms + vs; + S[tq] = S_temp == 0.0f ? 0.0f : 1.0f / S_temp; + } + } else { + for (int tq = 0; tq < tile_rows; tq++) { + const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq]; + S[tq] = S_inv; + } + } + + float * dst_ptr = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + (iq1) *ne1) * nb1); + rvv_pack_scaled_f32_as_f32(dst_ptr, nb1 * ne1, VKQ32, DV * sizeof(float), tile_rows, DV, S); + + ir += tile_rows; + } +} + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + + mean_square = sqrt(mean_square + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template +void quantize_a_nrow_i8_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk])); + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + int16_t a_sum = 0; + for (size_t bk = 0; bk < blk_len; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row] = -a_sum; + } + } +} + +template +void quantize_a_nrow_i8_hp_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + const size_t subblk_count = blk_len / k_subblk_len; + + GGML_ASSERT(blk_len == 256); + + float scale_temp[8] = { 0.0f }; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * MB_ROWS; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + + float scale_avg = 0.0f; + for (size_t kk = 0; kk < subblk_count; kk++) { + float max_abs_a = 0.0f; + for (size_t row = 0; row < MB_ROWS; row++) { + for (size_t bk = 0; bk < k_subblk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk + kk * k_subblk_len])); + } + } + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + float scale_factor = 1.0f / scale_avg; + + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + scale_avg_ptr[0] = scale_avg; + + for (size_t kk = 0; kk < subblk_count; kk++) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * MB_ROWS); + + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + const float rep_scale_a = 1.0f / scale_temp[kk]; + + for (size_t row = 0; row < MB_ROWS; row++) { + int16_t a_sum = 0; + for (size_t bk = 0; bk < k_subblk_len; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk + kk * k_subblk_len] * rep_scale_a), + -128.0f, 127.0f)); + quant_a_blk[row * k_subblk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * subblk_count + kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + } + } + } +} + +template +void quantize_a_nrow_i8k_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_sum_size = 256 / 16; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * a_sum_size * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_a = 0.0f; + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + float ax = std::abs(a_ptr[row * count_k + k + bk]); + if (ax > max_abs_a) { + max_abs_a = ax; + max_a = a_ptr[row * count_k + k + bk]; + } + } + + if (!max_abs_a) { + scale_a_ptr[row] = 0; + for (size_t bki = 0; bki < a_sum_size; bki++) { + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + quant_a_blk[row * blk_len + bk] = 0; + } + a_sum_ptr[row * a_sum_size + bki] = 0; + } + continue; + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + for (size_t bki = 0; bki < a_sum_size; bki++) { + int16_t a_sum = 0; + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * a_sum_size + bki] = -a_sum; + } + } + } +} + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } else { + quantize_a_nrow_i8_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + mi * count_k + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false); + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * 4; + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + vfloat32m1_t v_a0_abs = __riscv_vfabs_v_f32m1(v_a0, vl); + vfloat32m1_t v_a1_abs = __riscv_vfabs_v_f32m1(v_a1, vl); + vfloat32m1_t v_a2_abs = __riscv_vfabs_v_f32m1(v_a2, vl); + vfloat32m1_t v_a3_abs = __riscv_vfabs_v_f32m1(v_a3, vl); + + vfloat32m1_t v_max_abs = __riscv_vfmax_vv_f32m1(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a0_scale = __riscv_vfmul_vf_f32m1(v_a0, rep_scale_a, vl); + vfloat32m1_t v_a1_scale = __riscv_vfmul_vf_f32m1(v_a1, rep_scale_a, vl); + vfloat32m1_t v_a2_scale = __riscv_vfmul_vf_f32m1(v_a2, rep_scale_a, vl); + vfloat32m1_t v_a3_scale = __riscv_vfmul_vf_f32m1(v_a3, rep_scale_a, vl); + vint16mf2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a0_scale, vl); + vint16mf2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a1_scale, vl); + vint16mf2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a2_scale, vl); + vint16mf2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a3_scale, vl); + vint8mf4_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a0_quant, vl); + vint8mf4_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a1_quant, vl); + vint8mf4_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a2_quant, vl); + vint8mf4_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + vfloat32m4_t v_a0_abs = __riscv_vfabs_v_f32m4(v_a0, vl); + vfloat32m4_t v_a1_abs = __riscv_vfabs_v_f32m4(v_a1, vl); + vfloat32m4_t v_a2_abs = __riscv_vfabs_v_f32m4(v_a2, vl); + vfloat32m4_t v_a3_abs = __riscv_vfabs_v_f32m4(v_a3, vl); + + vfloat32m4_t v_max_abs = __riscv_vfmax_vv_f32m4(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a0_scale = __riscv_vfmul_vf_f32m4(v_a0, rep_scale_a, vl); + vfloat32m4_t v_a1_scale = __riscv_vfmul_vf_f32m4(v_a1, rep_scale_a, vl); + vfloat32m4_t v_a2_scale = __riscv_vfmul_vf_f32m4(v_a2, rep_scale_a, vl); + vfloat32m4_t v_a3_scale = __riscv_vfmul_vf_f32m4(v_a3, rep_scale_a, vl); + vint16m2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16m2(v_a0_scale, vl); + vint16m2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16m2(v_a1_scale, vl); + vint16m2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16m2(v_a2_scale, vl); + vint16m2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16m2(v_a3_scale, vl); + vint8m1_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a0_quant, vl); + vint8m1_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a1_quant, vl); + vint8m1_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a2_quant, vl); + vint8m1_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits, can process 32 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + // vlen = 256 bits, can process 8 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8k_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_nrow_block_stride = a_blk_stride * 4; + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else if (vlenb == 32) { + // vlen = 256 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else { + quantize_a_nrow_i8k_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = src0->ne[2] * src0->ne[3]; + int64_t m = src0->ne[1]; + int64_t n = src0->ne[0]; + + int64_t batch_stride = src0->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = n_src_stride * m; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = dst->ne[2] * dst->ne[3]; + int64_t n = dst->ne[1]; + int64_t m = dst->ne[0]; + + int64_t batch_stride = dst->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = dst->nb[1]; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); + mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); + mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); + mean /= hidden_size; + + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + mean_square = sqrt(mean_square - mean * mean + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + auto src0_rows = ggml_nrows(src0); + auto src1_rows = ggml_nrows(src1); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(T)); + GGML_ASSERT(nb00 == sizeof(T)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + auto compute_func_vv = [&](int64_t blk_len, int64_t r, T * src0_ptr, T * src1_ptr, T * dst_ptr) { + int64_t idx = 0; + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfadd_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfadd_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfsub_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfsub_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfmul_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfmul_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfdiv_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfdiv_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + }; + + if (src0_rows == src1_rows && src0_rows == 1 && ne00 == ne10) { + int64_t task_per_thread = (ne00 + nth - 1) / nth; + int64_t task_begin = ith * task_per_thread; + int64_t task_end = std::min((ith + 1) * task_per_thread, ne00); + + T * dst_ptr = ((T *) dst->data) + task_begin; + T * src0_ptr = ((T *) src0->data) + task_begin; + T * src1_ptr = ((T *) src1->data) + task_begin; + + compute_func_vv(task_end - task_begin, 0, src0_ptr, src1_ptr, dst_ptr); + } else if (ne10 > 1) { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + // src1 is broadcastable across src0 and dst in i1, i2, i3 + for (int64_t r = 0; r < ne00; r += ne10) { + compute_func_vv(ne10, r, src0_ptr, src1_ptr, dst_ptr); + } + } + } else { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + T rhs_scalar = src1_ptr[0]; + int64_t blk_len = ne00; + int64_t r = 0; + + for (size_t vl; blk_len > 0; blk_len -= vl, r += vl) { + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfadd_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfadd_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfsub_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfsub_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfmul_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfmul_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfdiv_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfdiv_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == 1); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + int64_t n_task = ne01 * ne02 * ne03; + int64_t task_per_thread = (n_task + nth - 1) / nth; + int64_t ir_start = ith * task_per_thread; + int64_t ir_end = std::min(ir_start + task_per_thread, n_task); + + for (int64_t ir = ir_start; ir < ir_end; ir++) { + const int64_t i3 = ir / (ne02 * ne01); + const int64_t i2 = (ir - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (ir - i3 * ne02 * ne01 - i2 * ne01); + + T * src_row = (T *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03); + T * dst_row = (T *) ((char *) op->data + i1 * nb1 + i2 * nb2 + i3 * nb3); + + float row_sum = 0; + + if constexpr (std::is_same_v) { + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const float * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t vec = __riscv_vle32_v_f32m4(p_data, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else if constexpr (std::is_same_v) { + size_t gvl = __riscv_vsetvlmax_e16m2(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const _Float16 * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e16m2(length); + vfloat16m2_t vec_f16 = __riscv_vle16_v_f16m2(p_data, vl); + vfloat32m4_t vec_f32 = __riscv_vfwcvt_f_f_v_f32m4(vec_f16, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec_f32, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else { + GGML_ABORT("fatal error"); + } + + dst_row[0] = row_sum; + } +} + +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + int64_t nrows = ggml_nrows(src0); + int64_t nrows_per_thread = (nrows + nth - 1) / nth; + int64_t ir_start = ith * nrows_per_thread; + int64_t ir_end = std::min(ir_start + nrows_per_thread, nrows); + + if (src0->ne[0] == 1) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + T src_scalar = src_row[0]; + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vmv_v_x_i32m4(src_scalar, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vmv_v_x_i16m4(src_scalar, vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else if (src0->ne[0] == dst->ne[0]) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_row + idx, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_row + idx), vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else { + GGML_ABORT("fatal error"); + } +} + +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int64_t total_batches = ne2 * ne3; + const int64_t batches_per_thread = (total_batches + nth - 1) / nth; + const int64_t batch_start = ith * batches_per_thread; + const int64_t batch_end = std::min(batch_start + batches_per_thread, total_batches); + + for (int64_t b = batch_start; b < batch_end; b++) { + const int64_t i3 = b / ne2; + const int64_t i2 = b % ne2; + + T * src_base = (T *) ((char *) src0->data + i2 * src0->nb[2] + i3 * src0->nb[3]); + T * dst_batch = (T *) ((char *) dst->data + i2 * dst->nb[2] + i3 * dst->nb[3]); + + for (int64_t i1 = 0; i1 < ne1; i1++) { + T * dst_ptr = (T *) ((char *) dst_batch + i1 * dst->nb[1]); + int64_t length = ne0; + int64_t idx = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + size_t vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_base + idx, vl); + __riscv_vse32_v_i32m4(dst_ptr + idx, vec, vl); + idx += vl; + length -= vl; + } else if constexpr (std::is_same_v) { + size_t vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_base + idx), vl); + __riscv_vse16_v_i16m4((dst_ptr + idx), vec, vl); + idx += vl; + length -= vl; + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(ggml_nrows(op) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + // rows per thread + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i / (ne11 * ne10); + const int64_t i11 = (i - i12 * ne11 * ne10) / ne10; + const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); + + GGML_ASSERT(i01 >= 0 && i01 < ne01); + + memcpy1d(((char *) dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3) + cr0 * sizeof(T), + ((char *) src0->data + i01 * nb01 + i11 * nb02 + i12 * nb03) + cr0 * sizeof(T), + (cr1 - cr0) * sizeof(T)); + } +} + +template void forward_concat(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float)); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim == 0 && nb0 == sizeof(float) && nb1 == sizeof(float) * (ne00 + ne10)); + + const int64_t nr = ggml_nrows(dst); + const int64_t nc = ne0; + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + int64_t o[4] = { 0, 0, 0, 0 }; + o[dim] = src0->ne[dim]; + const float * x; + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i3 = i / (ne02 * ne01); + const int64_t i2 = (i - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (i - i3 * ne02 * ne01 - i2 * ne01); + + for (int i0 = cr0; i0 < cr1; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *) ((const char *) src0->data + (i0) *nb00 + (i1) *nb01 + (i2) *nb02 + (i3) *nb03); + } else { + x = (const float *) ((const char *) src1->data + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13); + } + + float * y = (float *) ((char *) dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + + *y = *x; + } + } +} + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows<_Float16>(const ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); + +} // namespace spacemit_kernels::rvv diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.h b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h new file mode 100644 index 000000000..edddf957c --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h @@ -0,0 +1,95 @@ +#pragma once + +#include "ggml-cpu-impl.h" + +#include +#include +#include +#include + +namespace spacemit_kernels { + +constexpr auto div_round_up(auto up, auto down) { + return (up + down - 1) / down; +} + +// Q8 Blk [f32] [s16] [int8 * blk_len] +// Q8 Blk N [f32 * N] [s16 * N] [int8 * blk_len * N] +constexpr size_t q8_blk_size(size_t blk_len, bool with_blk_sum = false) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + (with_blk_sum ? sizeof(int16_t) : 0); + return blk_size; +} + +// Q8 HP row block: K is split into K32 subblocks. +// Each subblock stores [f32 scale] [int8 * 32], with an optional fp16 sum trailer per subblock. +constexpr size_t q8_hp_blk_size(size_t blk_len, bool with_blk_sum = false, bool with_blk_scale = false) { + const size_t subblk_count = div_round_up(blk_len, size_t(32)); + const size_t blk_size = blk_len * sizeof(int8_t) + subblk_count * sizeof(_Float16) + + (with_blk_sum ? subblk_count * sizeof(_Float16) : 0) + + (with_blk_scale ? sizeof(_Float16) : 0); + return blk_size; +} + +// Q8K Blk [f32] [s16 * (blk_len / 16)] [int8 * blk_len] +// Q8K Blk N [f32 * N] [s16 * (blk_len / 16) * N] [int8 * blk_len * N] +constexpr size_t q8k_blk_size(size_t blk_len) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + sizeof(int16_t) * blk_len / 16; + return blk_size; +} + +using quantize_a_row_def = std::function; + +namespace rvv { +void memcpy1d(void * dst, const void * src, int64_t size); + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size); + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op); + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op); + +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); + +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); + +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op); + +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); + +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +} // namespace rvv + +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/spine_barrier.h b/ggml/src/ggml-cpu/spacemit/spine_barrier.h new file mode 100644 index 000000000..f897dad4b --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_barrier.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +#define SPINE_CACHE_LINE 64 +#define SPINE_CACHE_ALIGN __attribute__((aligned(SPINE_CACHE_LINE))) + +struct spine_barrier_t { + SPINE_CACHE_ALIGN std::atomic pending_; + SPINE_CACHE_ALIGN std::atomic rounds_; + SPINE_CACHE_ALIGN int64_t total_; +}; + +inline void spine_barrier_wait(spine_barrier_t * b) { + auto cur_round = b->rounds_.load(std::memory_order_acquire); + auto cnt = --b->pending_; + if (cnt == 0) { + b->pending_.store(b->total_); + b->rounds_.store(cur_round + 1); + } else { + while (cur_round == b->rounds_.load(std::memory_order_relaxed)) { + __asm__ volatile("pause " ::: "memory"); + } + } +} + +inline void spine_barrier_init(spine_barrier_t * b, int num_barriers, uint64_t thread_count) { + for (int i = 0; i < num_barriers; i++) { + b[i].total_ = thread_count; + b[i].pending_.store(thread_count); + b[i].rounds_.store(0); + } +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp new file mode 100644 index 000000000..1409423b1 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp @@ -0,0 +1,760 @@ +#include "spine_mem_pool.h" + +#include "common.h" +#include "ime_env.h" +#include "spine_tcm.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { +namespace { + +constexpr size_t SPINE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull * 1024ull; +constexpr size_t SPINE_SHARE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull; +constexpr size_t SPINE_MEM_POOL_1G_REGION_SIZE = 1ull << 30; +constexpr uint64_t HUGETLB_1G_FLAG_REQUIRE_PUD = 1ull << 0; +constexpr char SPINE_MEM_POOL_HUGETLB_1G_DEV[] = "/dev/hugetlb_1g"; +constexpr char SPINE_MEM_POOL_TCM_SYNC_MEM_DEV[] = "/dev/tcm_sync_mem"; + +struct hugetlb_1g_region { + uint64_t size{ 0 }; + uint64_t dma_addr{ 0 }; + uint64_t flags{ 0 }; + uint64_t reserved{ 0 }; +}; + +#define HUGETLB_1G_IOC_MAGIC 'M' +#define HUGETLB_1G_IOC_ALLOC _IOWR(HUGETLB_1G_IOC_MAGIC, 0x00, struct hugetlb_1g_region) +#define HUGETLB_1G_IOC_FREE _IO(HUGETLB_1G_IOC_MAGIC, 0x01) + +struct free_block { + size_t offset{ 0 }; + size_t size{ 0 }; +}; + +struct pool_chunk { + uint8_t * base{ nullptr }; + size_t size{ 0 }; + int fd{ -1 }; + std::vector free_blocks; +}; + +struct pool_allocation { + void * chunk_base{ nullptr }; + size_t chunk_size{ 0 }; + void * base{ nullptr }; + size_t size{ 0 }; +}; + +bool is_power_of_two(size_t value) { + return value != 0 && (value & (value - 1)) == 0; +} + +bool align_up(size_t value, size_t alignment, size_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const size_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const size_t padding = alignment - remainder; + if (value > std::numeric_limits::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +bool align_up_uintptr(uintptr_t value, size_t alignment, uintptr_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const uintptr_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const uintptr_t padding = alignment - remainder; + if (value > std::numeric_limits::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +class spine_mem_pool_manager { + public: + explicit spine_mem_pool_manager(size_t default_chunk_size) : default_chunk_size_(default_chunk_size) {} + + virtual ~spine_mem_pool_manager() = default; + + void * alloc(size_t size, size_t alignment) { + if (size == 0 || !is_power_of_two(alignment)) { + return nullptr; + } + + size_t aligned_size = 0; + if (!align_up(size, alignment, &aligned_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: align_up failed for size %zu alignment %zu\n", __func__, size, + alignment); + return nullptr; + } + + pool_allocation allocation; + + std::lock_guard lock(mutex_); + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + if (!add_chunk_locked(aligned_size, alignment)) { + return nullptr; + } + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation retry failed for size %zu alignment %zu\n", + __func__, aligned_size, alignment); + return nullptr; + } + } + + try { + const auto [allocation_it, inserted] = allocations_.emplace(allocation.base, allocation); + if (!inserted) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: duplicate allocation key %p\n", __func__, allocation.base); + rollback_allocation_locked(allocation); + return nullptr; + } + } catch (const std::bad_alloc &) { + rollback_allocation_locked(allocation); + throw; + } + + return allocation.base; + } + + void free(void * base) { + if (base == nullptr) { + return; + } + + std::lock_guard lock(mutex_); + + auto allocation_it = allocations_.find(base); + if (allocation_it == allocations_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown allocation %p\n", __func__, base); + return; + } + + pool_allocation allocation = allocation_it->second; + allocations_.erase(allocation_it); + + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown chunk for allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p out of chunk range %p..%p\n", __func__, + allocation.base, chunk_base, chunk_base + chunk_it->size); + return; + } + + const size_t offset = static_cast(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p size %zu exceeds chunk size %zu\n", __func__, + allocation.base, allocation.size, chunk_it->size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + protected: + void release_chunks() { + std::lock_guard lock(mutex_); + + allocations_.clear(); + for (auto & chunk : chunks_) { + dealloc_chunk(&chunk); + } + chunks_.clear(); + } + + size_t default_chunk_size() const { return default_chunk_size_; } + + static void clear_chunk(pool_chunk * chunk) { + chunk->base = nullptr; + chunk->size = 0; + chunk->fd = -1; + chunk->free_blocks.clear(); + } + + virtual bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) = 0; + virtual void dealloc_chunk(pool_chunk * chunk) = 0; + + private: + struct alloc_candidate { + size_t chunk_index{ 0 }; + size_t block_index{ 0 }; + size_t aligned_offset{ 0 }; + uintptr_t address{ std::numeric_limits::max() }; + bool valid{ false }; + }; + + std::vector::iterator find_chunk_locked(const pool_allocation & allocation) { + return std::find_if(chunks_.begin(), chunks_.end(), [&](const pool_chunk & chunk) { + return chunk.base == allocation.chunk_base && chunk.size == allocation.chunk_size; + }); + } + + bool add_chunk_locked(size_t min_size, size_t alignment) { + pool_chunk chunk; + const size_t chunk_request = default_chunk_size_ == 0 ? min_size : std::max(min_size, default_chunk_size_); + void * hint_addr = nullptr; + + for (const auto & existing_chunk : chunks_) { + auto * chunk_end = existing_chunk.base + existing_chunk.size; + if (hint_addr == nullptr || chunk_end > hint_addr) { + hint_addr = chunk_end; + } + } + + if (!alloc_chunk(chunk_request, alignment, hint_addr, &chunk)) { + return false; + } + + if (chunk.base == nullptr || chunk.size < min_size) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: invalid chunk returned for request size %zu, chunk_base=%p chunk_size=%zu\n", + __func__, min_size, chunk.base, chunk.size); + dealloc_chunk(&chunk); + return false; + } + + try { + chunk.free_blocks.push_back({ 0, chunk.size }); + chunks_.push_back(std::move(chunk)); + } catch (const std::bad_alloc &) { + dealloc_chunk(&chunk); + throw; + } + + return true; + } + + void rollback_allocation_locked(const pool_allocation & allocation) { + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, owning chunk not found\n", + __func__, allocation.base); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, chunk range is invalid\n", + __func__, allocation.base); + return; + } + + const size_t offset = static_cast(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + bool try_alloc_locked(size_t size, size_t alignment, pool_allocation * allocation) { + alloc_candidate best; + + for (size_t chunk_index = 0; chunk_index < chunks_.size(); ++chunk_index) { + const auto & chunk = chunks_[chunk_index]; + for (size_t block_index = 0; block_index < chunk.free_blocks.size(); ++block_index) { + const auto & block = chunk.free_blocks[block_index]; + + uintptr_t aligned_addr = 0; + const auto block_addr = reinterpret_cast(chunk.base + block.offset); + if (!align_up_uintptr(block_addr, alignment, &aligned_addr)) { + continue; + } + + if (aligned_addr < block_addr) { + continue; + } + + const size_t aligned_offset = block.offset + static_cast(aligned_addr - block_addr); + const size_t padding = aligned_offset - block.offset; + if (padding > block.size || size > block.size - padding) { + continue; + } + + if (!best.valid || aligned_addr < best.address) { + best.chunk_index = chunk_index; + best.block_index = block_index; + best.aligned_offset = aligned_offset; + best.address = aligned_addr; + best.valid = true; + } + } + } + + if (!best.valid) { + return false; + } + + auto & chunk = chunks_[best.chunk_index]; + const free_block block = chunk.free_blocks[best.block_index]; + const size_t padding = best.aligned_offset - block.offset; + const size_t alloc_end = best.aligned_offset + size; + const size_t block_end = block.offset + block.size; + + chunk.free_blocks.erase(chunk.free_blocks.begin() + best.block_index); + auto insert_it = chunk.free_blocks.begin() + best.block_index; + if (padding != 0) { + insert_it = chunk.free_blocks.insert(insert_it, { block.offset, padding }); + ++insert_it; + } + if (alloc_end < block_end) { + chunk.free_blocks.insert(insert_it, { alloc_end, block_end - alloc_end }); + } + + allocation->chunk_base = chunk.base; + allocation->chunk_size = chunk.size; + allocation->base = chunk.base + best.aligned_offset; + allocation->size = size; + return true; + } + + void maybe_release_empty_chunk_locked(std::vector::iterator chunk_it) { + if (chunk_it->free_blocks.size() != 1) { + return; + } + + const auto & block = chunk_it->free_blocks.front(); + if (block.offset != 0 || block.size != chunk_it->size) { + return; + } + + dealloc_chunk(&*chunk_it); + chunks_.erase(chunk_it); + } + + void insert_free_block_locked(pool_chunk & chunk, free_block block) { + auto it = chunk.free_blocks.begin(); + while (it != chunk.free_blocks.end() && it->offset < block.offset) { + ++it; + } + + if (it != chunk.free_blocks.begin()) { + const auto & prev = *(it - 1); + if (prev.offset + prev.size > block.offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + } + + if (it != chunk.free_blocks.end() && block.offset + block.size > it->offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping next free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + + it = chunk.free_blocks.insert(it, block); + + if (it != chunk.free_blocks.begin()) { + auto prev = it - 1; + if (prev->offset + prev->size == it->offset) { + it->offset = prev->offset; + it->size += prev->size; + it = chunk.free_blocks.erase(prev); + } + } + + if (it + 1 != chunk.free_blocks.end() && it->offset + it->size == (it + 1)->offset) { + it->size += (it + 1)->size; + chunk.free_blocks.erase(it + 1); + } + } + + std::mutex mutex_; + std::vector chunks_; + std::unordered_map allocations_; + size_t default_chunk_size_{ 0 }; +}; + +class spine_mem_pool_posix final : public spine_mem_pool_manager { + public: + spine_mem_pool_posix() : spine_mem_pool_manager(0) {} + + ~spine_mem_pool_posix() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) hint_addr; + + const size_t alloc_alignment = std::max(alignment, sizeof(void *)); + void * base = nullptr; + const int rc = posix_memalign(&base, alloc_alignment, min_size); + if (rc != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: posix_memalign failed for size %zu alignment %zu, rc=%d\n", + __func__, min_size, alloc_alignment, rc); + return false; + } + + chunk->base = static_cast(base); + chunk->size = min_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + std::free(chunk->base); + clear_chunk(chunk); + } +}; + +class spine_mem_pool_transparent_hugepage final : public spine_mem_pool_manager { + public: + spine_mem_pool_transparent_hugepage() : spine_mem_pool_manager(SPINE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_transparent_hugepage() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + size_t chunk_size = 0; + if (!align_up(min_size, default_chunk_size(), &chunk_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round chunk size for %zu\n", __func__, min_size); + return false; + } + + void * map_addr = mmap(hint_addr, chunk_size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for chunk size %zu, errno=%d\n", __func__, chunk_size, + errno); + return false; + } + + if (madvise(map_addr, chunk_size, MADV_HUGEPAGE) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: madvise(MADV_HUGEPAGE) failed for chunk size %zu, errno=%d\n", + __func__, chunk_size, errno); + munmap(map_addr, chunk_size); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = chunk_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for chunk %p size %zu, errno=%d\n", __func__, + chunk->base, chunk->size, errno); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_hugetlb_1g final : public spine_mem_pool_manager { + public: + spine_mem_pool_hugetlb_1g() : spine_mem_pool_manager(SPINE_MEM_POOL_1G_REGION_SIZE) {} + + ~spine_mem_pool_hugetlb_1g() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + (void) hint_addr; + + size_t region_size = 0; + if (!align_up(min_size, SPINE_MEM_POOL_1G_REGION_SIZE, ®ion_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round hugetlb_1g size for %zu\n", __func__, min_size); + return false; + } + + const int fd = open(SPINE_MEM_POOL_HUGETLB_1G_DEV, O_RDWR); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_HUGETLB_1G_DEV, errno); + return false; + } + + hugetlb_1g_region region; + region.size = region_size; + region.flags = HUGETLB_1G_FLAG_REQUIRE_PUD; + if (ioctl(fd, HUGETLB_1G_IOC_ALLOC, ®ion) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_ALLOC failed for size %zu, errno=%d\n", __func__, + region_size, errno); + close(fd); + return false; + } + + void * map_addr = mmap(nullptr, region.size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for hugetlb_1g size %llu, errno=%d\n", __func__, + static_cast(region.size), errno); + ioctl(fd, HUGETLB_1G_IOC_FREE); + close(fd); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = region.size; + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for hugetlb_1g chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + if (ioctl(chunk->fd, HUGETLB_1G_IOC_FREE) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_FREE failed for chunk %p, errno=%d\n", + __func__, chunk->base, errno); + } + + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_shared_mem final : public spine_mem_pool_manager { + public: + spine_mem_pool_shared_mem() : spine_mem_pool_manager(SPINE_SHARE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_shared_mem() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + if (hint_addr != nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem does not support multiple active chunks\n", __func__); + return false; + } + + if (min_size > default_chunk_size()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem request %zu exceeds chunk size %zu\n", __func__, + min_size, default_chunk_size()); + return false; + } + + const int fd = open(SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, O_RDWR | O_SYNC); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, errno); + return false; + } + + void * map_addr = mmap(nullptr, default_chunk_size(), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for %s size %zu, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, default_chunk_size(), errno); + close(fd); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = default_chunk_size(); + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for shared_mem chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +spine_mem_pool_manager & get_spine_mem_pool_manager() { + static std::once_flag pool_once; + static std::unique_ptr selected_pool; + static spine_mem_pool_backend selected_backend = spine_mem_pool_backend::none; + + spine_mem_pool_backend backend = global_spine_env_info.mem_backend; + if (backend == spine_mem_pool_backend::none) { + backend = spine_mem_pool_backend::transparent_hugepage; + } + + std::call_once(pool_once, [&]() { + selected_backend = backend; + + switch (selected_backend) { + case spine_mem_pool_backend::posix_memalign: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::transparent_hugepage: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::hugetlb_1g: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::none: + selected_backend = spine_mem_pool_backend::transparent_hugepage; + selected_pool = std::make_unique(); + break; + } + }); + + if (backend != selected_backend) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: mem pool backend is process-global and mutually exclusive, requested=%d but " + "selected=%d\n", + __func__, static_cast(backend), static_cast(selected_backend)); + } + + if (selected_pool) { + return *selected_pool; + } + + throw std::bad_alloc(); +} + +spine_mem_pool_manager & get_spine_mem_pool_shared_mem_manager() { + static std::once_flag shared_mem_pool_once; + static std::unique_ptr shared_mem_pool; + + std::call_once(shared_mem_pool_once, [&]() { shared_mem_pool = std::make_unique(); }); + + if (shared_mem_pool) { + return *shared_mem_pool; + } + + throw std::bad_alloc(); +} + +} // namespace + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept { + if (info == nullptr) { + return false; + } + + *info = {}; + + if (spine_tcm_open_handle(NULL) != 0 || !spine_tcm_is_available()) { + return false; + } + + spine_tcm_mem_info_t mem_info; + if (spine_tcm_mem_info(&mem_info) != 0) { + return false; + } + + info->available = true; + info->blk_size = mem_info.blk_size; + info->blk_num = mem_info.blk_num; + info->is_fake_tcm = mem_info.is_fake_tcm != 0; + return true; +} + +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept { + return spine_tcm_mem_get(cpu_id); +} + +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept { + return spine_tcm_mem_try_wait(cpu_id, 1000 * 1000); +} + +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept { + return spine_tcm_mem_release(cpu_id); +} + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating size %zu\n", __func__, size); + return nullptr; + } +} + +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_shared_mem_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating shared memory size %zu\n", __func__, size); + return nullptr; + } +} + +void spine_mem_pool_free(void * base) noexcept { + try { + get_spine_mem_pool_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing allocation %p\n", __func__, base); + } +} + +void spine_mem_pool_shared_mem_free(void * base) noexcept { + try { + get_spine_mem_pool_shared_mem_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing shared allocation %p\n", __func__, base); + } +} + +} // namespace ggml::cpu::riscv64_spacemit + +extern "C" { +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment) { + void * result = ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_alloc(size, alignment); + if (result == nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to allocate shared memory size %zu alignment %zu\n", __func__, + size, alignment); + } + return result; +} + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr) { + ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_free(ptr); +} +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h new file mode 100644 index 000000000..8740d2c99 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +enum class spine_mem_pool_backend : uint8_t { + none, + posix_memalign, + transparent_hugepage, + hugetlb_1g, +}; + +struct spine_mem_pool_tcm_info { + bool available{ false }; + size_t blk_size{ 0 }; + size_t blk_num{ 0 }; + bool is_fake_tcm{ false }; +}; + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept; +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept; +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept; +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept; + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept; +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept; +void spine_mem_pool_free(void * base) noexcept; +void spine_mem_pool_shared_mem_free(void * base) noexcept; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/spine_tcm.h b/ggml/src/ggml-cpu/spacemit/spine_tcm.h new file mode 100644 index 000000000..f300d7d5c --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_tcm.h @@ -0,0 +1,409 @@ +#ifndef SPINE_TCM_PUBLIC_H_ +#define SPINE_TCM_PUBLIC_H_ + +/* + * spine_tcm public API + * + * Usage: + * 1. Direct link mode + * Define SPINE_TCM_DIRECT_LINK and link against libspine_tcm.so. + * + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + * + * 2. Header-only loader mode + * Include this header without linking libspine_tcm.so. The loader first + * tries to reuse a process-global spine_tcm instance and falls back to + * dlopen("libspine_tcm.so") when needed. + * + * spine_tcm_open_handle(NULL); // optional pre-bind + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + */ + +#include +#include +#include + +#if !defined(SPINE_TCM_BUILD_SHARED) && !defined(SPINE_TCM_DIRECT_LINK) +# include +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(_WIN32) +# if defined(SPINE_TCM_BUILD_SHARED) +# define SPINE_TCM_API __declspec(dllexport) +# else +# define SPINE_TCM_API __declspec(dllimport) +# endif +#else +# define SPINE_TCM_API __attribute__((visibility("default"))) +#endif + +typedef struct spine_tcm_mem_info { + size_t blk_size; + size_t blk_num; + int is_fake_tcm; +} spine_tcm_mem_info_t; + +typedef struct spine_tcm_block_info { + int id; + void * va; + size_t size; + uint64_t phys_addr; + uint64_t cpu_affinity_mask; + int owner_tid; + int is_acquired; +} spine_tcm_block_info_t; + +/* Shared-library runtime ABI exported by libspine_tcm.so. */ +SPINE_TCM_API const char * spine_tcm_runtime_version(void); +SPINE_TCM_API int spine_tcm_runtime_is_available(void); +SPINE_TCM_API int spine_tcm_runtime_layout_info(spine_tcm_mem_info_t * info); +SPINE_TCM_API int spine_tcm_runtime_mem_info(int id, spine_tcm_block_info_t * info); +SPINE_TCM_API void * spine_tcm_runtime_mem_get(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_free(int id); +SPINE_TCM_API void * spine_tcm_runtime_mem_try_wait(int id, size_t timeout_us); +SPINE_TCM_API int spine_tcm_runtime_mem_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_force_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_query(int id); + +#if defined(SPINE_TCM_DIRECT_LINK) +/* Optional no-op in direct-link mode. */ +static inline int spine_tcm_open_handle(const char * so_path) { + (void) so_path; + return 0; +} + +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + return spine_tcm_runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + return spine_tcm_runtime_layout_info(info); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + return spine_tcm_runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + return spine_tcm_runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + return spine_tcm_runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + return spine_tcm_runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + return spine_tcm_runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + return spine_tcm_runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + return spine_tcm_runtime_mem_query(id); +} +#elif !defined(SPINE_TCM_BUILD_SHARED) +typedef struct spine_tcm_handle { + void * module_handle; + int use_global_scope; + int owns_module_handle; + const char * (*runtime_version)(void); + int (*runtime_is_available)(void); + int (*runtime_layout_info)(spine_tcm_mem_info_t * info); + int (*runtime_mem_info)(int id, spine_tcm_block_info_t * info); + void * (*runtime_mem_get)(int id); + int (*runtime_mem_free)(int id); + void * (*runtime_mem_try_wait)(int id, size_t over_time); + int (*runtime_mem_release)(int id); + int (*runtime_mem_force_release)(int id); + int (*runtime_mem_query)(int id); +} spine_tcm_handle_t; + +static inline spine_tcm_handle_t * spine_tcm_default_handle(void) { + static spine_tcm_handle_t handle = { 0 }; + return &handle; +} + +static inline void spine_tcm_handle_reset(spine_tcm_handle_t * handle) { + if (handle != NULL) { + memset(handle, 0, sizeof(*handle)); + } +} + +static inline int spine_tcm_handle_bind(spine_tcm_handle_t * handle) { + void * symbol_scope = handle->use_global_scope ? RTLD_DEFAULT : handle->module_handle; + + handle->runtime_version = (const char * (*) (void) ) dlsym(symbol_scope, "spine_tcm_runtime_version"); + handle->runtime_is_available = (int (*)(void)) dlsym(symbol_scope, "spine_tcm_runtime_is_available"); + handle->runtime_layout_info = + (int (*)(spine_tcm_mem_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_layout_info"); + handle->runtime_mem_info = + (int (*)(int, spine_tcm_block_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_mem_info"); + handle->runtime_mem_get = (void * (*) (int) ) dlsym(symbol_scope, "spine_tcm_runtime_mem_get"); + handle->runtime_mem_free = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_free"); + handle->runtime_mem_try_wait = (void * (*) (int, size_t)) dlsym(symbol_scope, "spine_tcm_runtime_mem_try_wait"); + handle->runtime_mem_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_release"); + handle->runtime_mem_force_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_force_release"); + handle->runtime_mem_query = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_query"); + + return handle->runtime_version != NULL && handle->runtime_is_available != NULL && + handle->runtime_layout_info != NULL && handle->runtime_mem_info != NULL && + handle->runtime_mem_get != NULL && handle->runtime_mem_free != NULL && + handle->runtime_mem_try_wait != NULL && handle->runtime_mem_release != NULL && + handle->runtime_mem_force_release != NULL && handle->runtime_mem_query != NULL ? + 0 : + -1; +} + +/* + * Try to bind against an already-loaded process-global spine_tcm instance. + * The shared library exports spine_tcm_runtime_marker only for this probe. + */ +static inline int spine_tcm_try_bind_global(spine_tcm_handle_t * handle) { + if (dlsym(RTLD_DEFAULT, "spine_tcm_runtime_marker") == NULL) { + return -1; + } + + handle->use_global_scope = 1; + return spine_tcm_handle_bind(handle); +} + +/* + * Optional pre-bind entry point. + * + * Behavior: + * - Reuses an already-loaded global spine_tcm instance when available. + * - Otherwise loads the shared library from so_path or the default soname. + * - Repeated calls are safe and return 0 after the first successful bind. + */ +static inline int spine_tcm_open_handle(const char * so_path) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + const char * library = (so_path != NULL && so_path[0] != '\0') ? so_path : "libspine_tcm.so"; + + if (resolved->module_handle != NULL || resolved->use_global_scope) { + return 0; + } + + if (spine_tcm_try_bind_global(resolved) == 0) { + return 0; + } + + spine_tcm_handle_reset(resolved); + + resolved->module_handle = dlopen(library, RTLD_LAZY | RTLD_GLOBAL); + resolved->owns_module_handle = resolved->module_handle != NULL ? 1 : 0; + + if (resolved->module_handle == NULL) { + spine_tcm_handle_reset(resolved); + return -1; + } + + if (spine_tcm_handle_bind(resolved) != 0) { + if (resolved->owns_module_handle) { + dlclose(resolved->module_handle); + } + spine_tcm_handle_reset(resolved); + return -1; + } + + return 0; +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_is_available == NULL) { + return 0; + } + + return resolved->runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_layout_info == NULL) { + return -1; + } + + return resolved->runtime_layout_info(info); +} + +static inline const char * spine_tcm_version(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_version == NULL) { + return "unknown"; + } + + return resolved->runtime_version(); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_info == NULL) { + return -1; + } + + return resolved->runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_get == NULL) { + return NULL; + } + + return resolved->runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_free == NULL) { + return -1; + } + + return resolved->runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_try_wait == NULL) { + return NULL; + } + + return resolved->runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_release == NULL) { + return -1; + } + + return resolved->runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || + resolved->runtime_mem_force_release == NULL) { + return -1; + } + + return resolved->runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_query == NULL) { + return -1; + } + + return resolved->runtime_mem_query(id); +} +#else +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} +#endif + +#define SPINE_TCM_VERSION (spine_tcm_version()) + +#ifdef __cplusplus +} +#endif + +#endif