whisper.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp

1028 lines
49 KiB
C++

#include "ggml-impl.h"
#include "ggml.h"
#include "ime_kernels.h"
#include "rvv_kernels.h"
#include <algorithm>
#include <cmath>
#include <stdexcept>
#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic)
# error "riscv v extension or v_intrinsic not enabled"
#else
# include <riscv_vector.h>
#endif
#if !defined(__riscv_zfh)
# error "riscv zfh extension not enabled"
#endif
#if defined(RISCV64_SPACEMIT_IME1)
#else
# error "RISCV64_SPACEMIT_IME1 not defined"
#endif
// clang-format off
#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Woverlength-strings"
#pragma GCC diagnostic ignored "-Wcast-qual"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#endif
// clang-format on
namespace spacemit_kernels {
#define QUANTIZEM4ROW_KERNEL \
"vmv.s.x v16, zero \n\t" \
"vfabs.v v8, v0 \n\t" \
"vfredmax.vs v16, v8, v16 \n\t" \
"vfmv.f.s f10, v16 \n\t" \
"fmul.s f10, f10, %[RMAXREC] \n\t" \
"fsw f10, (a1) \n\t" \
"fdiv.s f11, %[FONE], f10 \n\t" \
"vfmul.vf v16, v0, f11 \n\t" \
"vfcvt.x.f.v v16, v16 \n\t" \
"vsetvli t0, zero, e16, mf2 \n\t" \
"vnclip.wx v16, v16, zero \n\t" \
"vnclip.wx v17, v17, zero \n\t" \
"vnclip.wx v18, v18, zero \n\t" \
"vnclip.wx v19, v19, zero \n\t" \
"vnclip.wx v20, v20, zero \n\t" \
"vnclip.wx v21, v21, zero \n\t" \
"vnclip.wx v22, v22, zero \n\t" \
"vnclip.wx v23, v23, zero \n\t" \
"vsetvli t0, zero, e8, mf4 \n\t" \
"vnclip.wx v24, v16, zero \n\t" \
"vnclip.wx v25, v17, zero \n\t" \
"vnclip.wx v26, v18, zero \n\t" \
"vnclip.wx v27, v19, zero \n\t" \
"vnclip.wx v28, v20, zero \n\t" \
"vnclip.wx v29, v21, zero \n\t" \
"vnclip.wx v30, v22, zero \n\t" \
"vnclip.wx v31, v23, zero \n\t"
#define QUANTIZEM4ROW_STORE \
"addi t1, %[BlkLen], 0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v24, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v25, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v26, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v27, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v28, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v29, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v30, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v31, (s1) \n\t"
namespace ime1 {
void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) {
constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
const float fone = 1.0f;
for (size_t row_index = 0; row_index < 4; ++row_index) {
const float * SRC = A + row_index * CountK;
uint8_t * DST = QuantA + row_index * sizeof(float);
const size_t offset = (4 - row_index) * 4 + row_index * 8;
const size_t stride = 4 * (sizeof(float) + BlkLen);
__asm__ volatile(
"vsetvli t0, zero, e32, m8 \n\t"
"addi t2, %[CountK], 0 \n\t"
"addi a1, %[DST], 0 \n\t"
"blt t2, %[BlkLen], TAIL%= \n\t"
"LOOP%=: \n\t"
"vsetvli t0, %[BlkLen], e32, m8 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"sub t2, t2, t0 \n\t"
"slli t1, t0, 2 \n\t"
"add %[SRC], %[SRC], t1 \n\t"
"add s1, a1, %[OFFSET] \n\t"
QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
"add a1, a1, %[STRIDE] \n\t"
"bge t2, %[BlkLen], LOOP%= \n\t"
"TAIL%=: \n\t"
"blez t2, QUIT%= \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v24, v24, v24 \n\t"
"vsetvli t0, t2, e32, m8 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"add s1, a1, %[OFFSET] \n\t"
QUANTIZEM4ROW_KERNEL
"addi t3, %[BlkLen], 0 \n\t"
"addi s2, s1, 0 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vxor.vv v8, v8, v8 \n\t"
"SET_ZERO%=: \n\t"
"vse8.v v8, (s2) \n\t"
"addi s2, s2, 32 \n\t"
"addi t3, t3, -8 \n\t"
"bnez t3, SET_ZERO%= \n\t"
QUANTIZEM4ROW_STORE
"QUIT%=: \n\t"
: [SRC] "+r"(SRC)
: [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), [CountK] "r"(CountK),
[FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
: "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
}
}
void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) {
const float * SRC = A;
uint8_t * DST = QuantA;
constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
const float fone = 1.0f;
uint8_t * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;
__asm__ volatile(
"addi t3, zero, 32*4 \n\t"
"addi t2, zero, 32 \n\t"
"addi a1, %[SRC], 0 \n\t"
"addi a2, %[SRC], 128 \n\t"
"addi a3, %[SRC], 256 \n\t"
"addi a4, %[SRC], 384 \n\t"
"addi s1, %[DST], 0 \n\t"
"addi s2, %[DST], 36 \n\t"
"addi s3, %[DST], 72 \n\t"
"addi s4, %[DST], 108 \n\t"
"blt %[K], t3, LOOP_K%= \n\t"
"blt %[K], t2, TAIL%= \n\t"
"LOOP_MAIN%=: \n\t"
"vsetvli t1, zero, e32, m4 \n\t"
"addi %[K], %[K], -128 \n\t"
"vle32.v v0, (a1) \n\t"
"addi a1, a1, 512 \n\t"
"vle32.v v4, (a2) \n\t"
"addi a2, a2, 512 \n\t"
"vle32.v v8, (a3) \n\t"
"addi a3, a3, 512 \n\t"
"vle32.v v12, (a4) \n\t"
"addi a4, a4, 512 \n\t"
"vfabs.v v16, v0 \n\t"
"vfabs.v v20, v4 \n\t"
"vfabs.v v24, v8 \n\t"
"vfabs.v v28, v12 \n\t"
"vsetvli t0, zero, e32, m2 \n\t"
"vfmax.vv v16, v16, v18 \n\t"
"vfmax.vv v20, v20, v22 \n\t"
"vfmax.vv v24, v24, v26 \n\t"
"vfmax.vv v28, v28, v30 \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vfmax.vv v16, v16, v17 \n\t"
"vfmax.vv v20, v20, v21 \n\t"
"vfmax.vv v24, v24, v25 \n\t"
"vfmax.vv v28, v28, v29 \n\t"
"vfredmax.vs v17, v16, v17 \n\t"
"vfredmax.vs v21, v20, v21 \n\t"
"vfredmax.vs v25, v24, v25 \n\t"
"vfredmax.vs v29, v28, v29 \n\t"
"vfmv.f.s f10, v17 \n\t"
"vfmv.f.s f11, v21 \n\t"
"vfmv.f.s f12, v25 \n\t"
"vfmv.f.s f13, v29 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fmul.s f11, f11, %[RMAXREC] \n\t"
"fmul.s f12, f12, %[RMAXREC] \n\t"
"fmul.s f13, f13, %[RMAXREC] \n\t"
"fsw f10, (s1) \n\t"
"addi s1, s1, 4 \n\t"
"fsw f11, (s2) \n\t"
"addi s2, s2, 4 \n\t"
"fsw f12, (s3) \n\t"
"addi s3, s3, 4 \n\t"
"fsw f13, (s4) \n\t"
"addi s4, s4, 4 \n\t"
"fdiv.s f10, %[FONE], f10 \n\t"
"fdiv.s f11, %[FONE], f11 \n\t"
"fdiv.s f12, %[FONE], f12 \n\t"
"fdiv.s f13, %[FONE], f13 \n\t"
"vsetvli t0, zero, e32, m4 \n\t"
"vfmul.vf v16, v0, f10 \n\t"
"vfmul.vf v20, v4, f11 \n\t"
"vfmul.vf v24, v8, f12 \n\t"
"vfmul.vf v28, v12, f13 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vfcvt.x.f.v v20, v20 \n\t"
"vfcvt.x.f.v v24, v24 \n\t"
"vfcvt.x.f.v v28, v28 \n\t"
"vsetvli t0, zero, e16, m2 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vnclip.wx v20, v20, zero \n\t"
"vnclip.wx v24, v24, zero \n\t"
"vnclip.wx v28, v28, zero \n\t"
"vsetvli t0, t1, e8, m1 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vnclip.wx v20, v20, zero \n\t"
"vnclip.wx v24, v24, zero \n\t"
"vnclip.wx v28, v28, zero \n\t"
"vse8.v v16, (s1) \n\t"
"addi s1, s1, 140 \n\t"
"vse8.v v20, (s2) \n\t"
"addi s2, s2, 140 \n\t"
"vse8.v v24, (s3) \n\t"
"addi s3, s3, 140 \n\t"
"vse8.v v28, (s4) \n\t"
"addi s4, s4, 140 \n\t"
"bge %[K], t3, LOOP_MAIN%= \n\t"
"blt %[K], t2, TAIL%= \n\t"
"LOOP_K%=: \n\t"
"vsetvli t1, %[K], e32, m4 \n\t"
"vle32.v v0, (a1) \n\t"
"addi a1, a1, 128 \n\t"
"sub %[K], %[K], t1 \n\t"
"vfabs.v v16, v0 \n\t"
"vsetvli t0, zero, e32, m2 \n\t"
"vfmax.vv v16, v16, v18 \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vfmax.vv v16, v16, v17 \n\t"
"vfredmax.vs v17, v16, v17 \n\t"
"vfmv.f.s f10, v17 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fsw f10, (s1) \n\t"
"addi s1, s1, 4 \n\t"
"fdiv.s f11, %[FONE], f10 \n\t"
"vsetvli t0, zero, e32, m4 \n\t"
"vfmul.vf v16, v0, f11 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vsetvli t0, zero, e16, m2 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vse8.v v16, (s1) \n\t"
"addi s1, s1, 32 \n\t"
"bge %[K], t2, LOOP_K%= \n\t"
"TAIL%=: \n\t"
"blez %[K], END%= \n\t"
"vsetvli t0, t3, e32, m4 \n\t"
"vxor.vv v0, v0, v0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"jal x0, LOOP_K%= \n\t"
"END%=: \n\t"
: [K] "+r"(CountK)
: [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
: "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
}
} // namespace ime1
namespace {
#define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \
"vmadot v16, v14, v0 \n\t" \
"vmadot v18, v14, v1 \n\t" \
"vmadot v20, v14, v2 \n\t" \
"vmadot v22, v14, v3 \n\t" \
"vmadot v16, v15, v4 \n\t" \
"vmadot v18, v15, v5 \n\t" \
"vmadot v20, v15, v6 \n\t" \
"vmadot v22, v15, v7 \n\t"
#define SQ4BIT_KERNEL_ACC_1X4X4 \
"vfcvt.f.x.v v16, v16 \n\t" \
"vfcvt.f.x.v v18, v18 \n\t" \
"vfcvt.f.x.v v20, v20 \n\t" \
"vfcvt.f.x.v v22, v22 \n\t" \
"addi s2, s1, 16 \n\t" \
"addi s3, s1, 32 \n\t" \
"addi s4, s1, 48 \n\t" \
"addi s6, s5, 12 \n\t" \
"vfmacc.vv v28, v16, v24 \n\t" \
"vfmacc.vv v29, v18, v25 \n\t" \
"vfmacc.vv v30, v20, v26 \n\t" \
"vfmacc.vv v31, v22, v27 \n\t"
#define SQ4BIT_KERNEL_ACC_F16_1X4X4 \
"vfcvt.f.x.v v16, v16 \n\t" \
"vfcvt.f.x.v v18, v18 \n\t" \
"vfcvt.f.x.v v20, v20 \n\t" \
"vfcvt.f.x.v v22, v22 \n\t" \
"addi s2, s1, 8 \n\t" \
"addi s3, s1, 16 \n\t" \
"addi s4, s1, 24 \n\t" \
"addi s6, s5, 12 \n\t" \
"vfmacc.vv v28, v16, v24 \n\t" \
"vfmacc.vv v29, v18, v25 \n\t" \
"vfmacc.vv v30, v20, v26 \n\t" \
"vfmacc.vv v31, v22, v27 \n\t"
#define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \
"vle8.v v4, (s1) \n\t" \
"addi s1, s1, 128 \n\t" \
"vle8.v v5, (s2) \n\t" \
"addi s2, s2, 128 \n\t" \
"vle8.v v6, (s3) \n\t" \
"addi s3, s3, 128 \n\t" \
"vle8.v v7, (s4) \n\t" \
"addi s4, s4, 128 \n\t" \
"vsetvli t0, zero, e8, mf4 \n\t" \
"vle8.v v14, (s5) \n\t" \
"addi s5, s5, 16 \n\t" \
"vle8.v v15, (s6) \n\t" \
"addi s6, s6, 16 \n\t" \
"addi t5, t5, -1 \n\t" \
"vsetvli t0, zero, e8, m1 \n\t" \
"vand.vi v0, v4, 15 \n\t" \
"vand.vi v1, v5, 15 \n\t" \
"vand.vi v2, v6, 15 \n\t" \
"vand.vi v3, v7, 15 \n\t" \
"vsrl.vi v4, v4, 4 \n\t" \
"vsrl.vi v5, v5, 4 \n\t" \
"vsrl.vi v6, v6, 4 \n\t" \
"vsrl.vi v7, v7, 4 \n\t"
#define SQ4BIT_KERNEL_LOAD_ZP_16X1 \
"vsetvli t0, zero, e8, mf2 \n\t" \
"vle8.v v1, (s7) \n\t" \
"vsetvli t0, zero, e8, m1 \n\t" \
"vrgather.vv v8, v1, v13 \n\t" \
"vadd.vi v13, v13, 4 \n\t" \
"vrgather.vv v9, v1, v13 \n\t" \
"vadd.vi v13, v13, 4 \n\t" \
"vrgather.vv v10, v1, v13 \n\t" \
"vadd.vi v13, v13, 4 \n\t" \
"vrgather.vv v11, v1, v13 \n\t" \
"vadd.vi v13, v13, -12 \n\t"
// using for M4Kernel
#define LOAD_B_16x8x2 \
"vsetvli t0, zero, e8, m1 \n\t" \
"vle8.v v6, (s1) \n\t" \
"addi s1, s1, 32*4 \n\t" \
"vle8.v v7, (s2) \n\t" \
"addi s2, s2, 32*4 \n\t" \
"vle8.v v8, (s3) \n\t" \
"addi s3, s3, 32*4 \n\t" \
"vle8.v v9, (s4) \n\t" \
"addi s4, s4, 32*4 \n\t" \
\
"vand.vi v2, v6, 15 \n\t" \
"vand.vi v3, v7, 15 \n\t" \
"vand.vi v4, v8, 15 \n\t" \
"vand.vi v5, v9, 15 \n\t" \
\
"vsrl.vi v6, v6, 4 \n\t" \
"vsrl.vi v7, v7, 4 \n\t" \
"vsrl.vi v8, v8, 4 \n\t" \
"vsrl.vi v9, v9, 4 \n\t"
// [s2|s5, s3, s4, s6]
#define LOAD_SCALE_4x16_FP16 \
"addi s2, s5, -8 \n\t" \
"addi s3, s5, 8 \n\t" \
"addi s4, s5, 16 \n\t" \
"addi s6, s5, 24 \n\t" \
"li t1, 0xf0 \n\t" \
"vmv.s.x v0, t1 \n\t" \
"vsetvli t0, zero, e16, mf4 \n\t" \
"vle16.v v9, (s5) \n\t" \
"vle16.v v11, (s3) \n\t" \
"vle16.v v13, (s4) \n\t" \
"vle16.v v15, (s6) \n\t" \
"vsetvli t0, zero, e16, mf2 \n\t" \
"vle16.v v9, (s2), v0.t \n\t" \
"vle16.v v11, (s5), v0.t \n\t" \
"vle16.v v13, (s3), v0.t \n\t" \
"vle16.v v15, (s4), v0.t \n\t" \
"vfwcvt.f.f.v v8, v9 \n\t" \
"vfwcvt.f.f.v v10, v11 \n\t" \
"vfwcvt.f.f.v v12, v13 \n\t" \
"vfwcvt.f.f.v v14, v15 \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
"vmv.v.v v9, v8 \n\t" \
"vmv.v.v v11, v10 \n\t" \
"vmv.v.v v13, v12 \n\t" \
"vmv.v.v v15, v14 \n\t" \
"li t1, 0xf0 \n\t" \
"vmv.s.x v0, t1 \n\t" \
"vsetvli t0, zero, e32, mf2 \n\t" \
"vfmul.vf v8, v8, f1 \n\t" \
"vfmul.vf v10, v10, f1 \n\t" \
"vfmul.vf v12, v12, f1 \n\t" \
"vfmul.vf v14, v14, f1 \n\t" \
"vfmul.vf v9, v9, f3 \n\t" \
"vfmul.vf v11, v11, f3 \n\t" \
"vfmul.vf v13, v13, f3 \n\t" \
"vfmul.vf v15, v15, f3 \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
"vfmul.vf v8, v8, f2, v0.t \n\t" \
"vfmul.vf v10, v10, f2, v0.t \n\t" \
"vfmul.vf v12, v12, f2, v0.t \n\t" \
"vfmul.vf v14, v14, f2, v0.t \n\t" \
"vfmul.vf v9, v9, f4, v0.t \n\t" \
"vfmul.vf v11, v11, f4, v0.t \n\t" \
"vfmul.vf v13, v13, f4, v0.t \n\t" \
"vfmul.vf v15, v15, f4, v0.t \n\t"
// [s2|s5, s3, s4, s6]
#define LOAD_SCALE_4x16 \
"addi s2, s5, -16 \n\t" \
"addi s3, s5, 16 \n\t" \
"addi s4, s5, 32 \n\t" \
"addi s6, s5, 48 \n\t" \
"li t1, 0xf0 \n\t" \
"vmv.s.x v0, t1 \n\t" \
"vsetvli t0, zero, e32, mf2 \n\t" \
"vle32.v v8, (s5) \n\t" \
"vle32.v v10, (s3) \n\t" \
"vle32.v v12, (s4) \n\t" \
"vle32.v v14, (s6) \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
"vle32.v v8, (s2), v0.t \n\t" \
"vle32.v v10, (s5), v0.t \n\t" \
"vle32.v v12, (s3), v0.t \n\t" \
"vle32.v v14, (s4), v0.t \n\t" \
"vmv.v.v v9, v8 \n\t" \
"vmv.v.v v11, v10 \n\t" \
"vmv.v.v v13, v12 \n\t" \
"vmv.v.v v15, v14 \n\t" \
"vsetvli t0, zero, e32, mf2 \n\t" \
"vfmul.vf v8, v8, f1 \n\t" \
"vfmul.vf v10, v10, f1 \n\t" \
"vfmul.vf v12, v12, f1 \n\t" \
"vfmul.vf v14, v14, f1 \n\t" \
"vfmul.vf v9, v9, f3 \n\t" \
"vfmul.vf v11, v11, f3 \n\t" \
"vfmul.vf v13, v13, f3 \n\t" \
"vfmul.vf v15, v15, f3 \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
"vfmul.vf v8, v8, f2, v0.t \n\t" \
"vfmul.vf v10, v10, f2, v0.t \n\t" \
"vfmul.vf v12, v12, f2, v0.t \n\t" \
"vfmul.vf v14, v14, f2, v0.t \n\t" \
"vfmul.vf v9, v9, f4, v0.t \n\t" \
"vfmul.vf v11, v11, f4, v0.t \n\t" \
"vfmul.vf v13, v13, f4, v0.t \n\t" \
"vfmul.vf v15, v15, f4, v0.t \n\t"
//[s1| BIAS, s2, s3, s4]
#define LOAD_BIAS \
"vsetvli t0, zero, e32, mf2 \n\t" \
"li t1, 0xf0 \n\t" \
"vmv.s.x v0, t1 \n\t" \
"addi s1, %[BIAS], -16 \n\t" \
"addi s2, %[BIAS], 16 \n\t" \
"addi s3, %[BIAS], 32 \n\t" \
"addi s4, %[BIAS], 48 \n\t" \
\
"vle32.v v24, (%[BIAS]) \n\t" \
"vle32.v v26, (s2) \n\t" \
"vle32.v v28, (s3) \n\t" \
"vle32.v v30, (s4) \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
"vle32.v v24, (s1), v0.t \n\t" \
"vle32.v v26, (%[BIAS]), v0.t \n\t" \
"vle32.v v28, (s2), v0.t \n\t" \
"vle32.v v30, (s3), v0.t \n\t" \
"vmv.v.v v25, v24 \n\t" \
"vmv.v.v v27, v26 \n\t" \
"vmv.v.v v29, v28 \n\t" \
"vmv.v.v v31, v30 \n\t"
#define SQ4BIT_KERNEL_COMP_4x16x16 \
"vmadot v16, v10, v2 \n\t" \
"vmadot v18, v10, v3 \n\t" \
"vmadot v20, v10, v4 \n\t" \
"vmadot v22, v10, v5 \n\t" \
"vmadot v16, v11, v6 \n\t" \
"vmadot v18, v11, v7 \n\t" \
"vmadot v20, v11, v8 \n\t" \
"vmadot v22, v11, v9 \n\t"
#define SAVE_RESULT_4x16 \
"addi a1, %[C], 0 \n\t" \
"add a2, %[C], %[LDC] \n\t" \
"add a3, a2, %[LDC] \n\t" \
"add a4, a3, %[LDC] \n\t" \
"addi a2, a2, -16 \n\t" \
"addi a4, a4, -16 \n\t" \
"li t1, 0xf0 \n\t" \
"vmv.s.x v0, t1 \n\t" \
"vsetvli t0, zero, e32, mf2 \n\t" \
\
"vse32.v v24, (a1) \n\t" \
"addi a1, a1, 16 \n\t" \
"vse32.v v25, (a3) \n\t" \
"addi a3, a3, 16 \n\t" \
\
"vse32.v v26, (a1) \n\t" \
"addi a1, a1, 16 \n\t" \
"vse32.v v27, (a3) \n\t" \
"addi a3, a3, 16 \n\t" \
\
"vse32.v v28, (a1) \n\t" \
"addi a1, a1, 16 \n\t" \
"vse32.v v29, (a3) \n\t" \
"addi a3, a3, 16 \n\t" \
\
"vse32.v v30, (a1) \n\t" \
"vse32.v v31, (a3) \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
\
"vse32.v v24, (a2), v0.t \n\t" \
"addi a2, a2, 16 \n\t" \
"vse32.v v25, (a4), v0.t \n\t" \
"addi a4, a4, 16 \n\t" \
\
"vse32.v v26, (a2), v0.t \n\t" \
"addi a2, a2, 16 \n\t" \
"vse32.v v27, (a4), v0.t \n\t" \
"addi a4, a4, 16 \n\t" \
\
"vse32.v v28, (a2), v0.t \n\t" \
"addi a2, a2, 16 \n\t" \
"vse32.v v29, (a4), v0.t \n\t" \
"addi a4, a4, 16 \n\t" \
\
"vse32.v v30, (a2), v0.t \n\t" \
"vse32.v v31, (a4), v0.t \n\t"
#define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \
"vsetvli t0, zero, e8, mf2 \n\t" \
"vle8.v v11, (s6) \n\t" \
"vsetvli t0, zero, e8, m1 \n\t" \
"vrgather.vv v12, v11, v1 \n\t" \
"vadd.vi v1, v1, 4 \n\t" \
"vrgather.vv v13, v11, v1 \n\t" \
"vadd.vi v1, v1, 4 \n\t" \
"vrgather.vv v14, v11, v1 \n\t" \
"vadd.vi v1, v1, 4 \n\t" \
"vrgather.vv v15, v11, v1 \n\t" \
"vadd.vi v1, v1, -12 \n\t"
template <bool HasZeroPoint>
void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
const uint8_t * QuantA,
const uint8_t * QuantBData,
float * C,
size_t CountN,
size_t BlockCountK,
const size_t ldc) {
size_t LDC = ldc * sizeof(float);
const size_t INNER = BlkLen / 16;
float tmp[4 * 16];
if constexpr (HasZeroPoint) {
for (size_t n = 0; n < CountN; n += 16) {
size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(uint8_t) + // zp
n * BlockCountK * sizeof(_Float16); // scale
float * CPtr = C + n;
if (NBLKS < 16) {
CPtr = tmp;
LDC = 16 * sizeof(float);
}
__asm__ volatile(
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v24, v24, v24 \n\t"
"addi t3, %[BlockCountK], 0 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"li s1, 24 \n\t"
"vmv.v.i v1, 3 \n\t"
"vsetvli t0, s1, e8, m1 \n\t"
"vmv.v.i v1, 2 \n\t"
"vsetvli t0, zero, e8, mf2 \n\t"
"vmv.v.i v1, 1 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vmv.v.i v1, 0 \n\t"
"addi a1, %[A], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"BLOCK_COUNTK_LOOP%=: \n\t"
// scale offset
"addi s5, s1, 0 \n\t"
// zp offset
"addi s6, s1, 32 \n\t"
"addi s1, s6, 16 \n\t"
"addi s2, s1, 32 \n\t"
"addi s3, s1, 32*2 \n\t"
"addi s4, s1, 32*3 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
// load a scale
"flw f1, (a1) \n\t"
"flw f2, 4(a1) \n\t"
"flw f3, 8(a1) \n\t"
"flw f4, 12(a1) \n\t"
"addi a1, a1, 16 \n\t"
"addi t2, %[INNER], 0 \n\t"
SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
"BLOCK_INNER_LOOP%=: \n\t"
LOAD_B_16x8x2
"vle8.v v10, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vle8.v v11, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vsub.vv v2, v2, v12 \n\t"
"vsub.vv v6, v6, v12 \n\t"
"vsub.vv v3, v3, v13 \n\t"
"vsub.vv v7, v7, v13 \n\t"
"vsub.vv v4, v4, v14 \n\t"
"vsub.vv v8, v8, v14 \n\t"
"vsub.vv v5, v5, v15 \n\t"
"vsub.vv v9, v9, v15 \n\t"
SQ4BIT_KERNEL_COMP_4x16x16
"addi t2, t2, -1 \n\t"
"bnez t2, BLOCK_INNER_LOOP%= \n\t"
LOAD_SCALE_4x16_FP16
"vsetvli t0, zero, e32, m8 \n\t"
"vfcvt.f.x.v v16, v16 \n\t"
"vfmacc.vv v24, v16, v8 \n\t"
"addi t3, t3, -1 \n\t"
"bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
"RESULT_SAVE%=: \n\t"
SAVE_RESULT_4x16
:
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
[BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
: "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4",
"s5", "s6");
}
} else {
for (size_t n = 0; n < CountN; n += 16) {
size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(_Float16); // scale
float * CPtr = C + n;
if (NBLKS < 16) {
CPtr = tmp;
LDC = 16 * sizeof(float);
}
__asm__ volatile(
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v24, v24, v24 \n\t"
"addi t3, %[BlockCountK], 0 \n\t"
"addi a1, %[A], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"BLOCK_COUNTK_LOOP%=: \n\t"
"addi s5, s1, 0 \n\t"
"addi s1, s5, 32 \n\t"
"addi s2, s1, 32 \n\t"
"addi s3, s1, 32*2 \n\t"
"addi s4, s1, 32*3 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
// load a scale
"flw f1, (a1) \n\t"
"flw f2, 4(a1) \n\t"
"flw f3, 8(a1) \n\t"
"flw f4, 12(a1) \n\t"
"addi a1, a1, 16 \n\t"
"addi t2, %[INNER], 0 \n\t"
"BLOCK_INNER_LOOP%=: \n\t"
LOAD_B_16x8x2
"vsetvli t0, zero, e8, m1 \n\t"
"vle8.v v10, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vle8.v v11, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vadd.vi v2, v2, -8 \n\t"
"vadd.vi v3, v3, -8 \n\t"
"vadd.vi v4, v4, -8 \n\t"
"vadd.vi v5, v5, -8 \n\t"
"vadd.vi v6, v6, -8 \n\t"
"vadd.vi v7, v7, -8 \n\t"
"vadd.vi v8, v8, -8 \n\t"
"vadd.vi v9, v9, -8 \n\t"
SQ4BIT_KERNEL_COMP_4x16x16
"addi t2, t2, -1 \n\t"
"bnez t2, BLOCK_INNER_LOOP%= \n\t"
LOAD_SCALE_4x16_FP16
"vsetvli t0, zero, e32, m8 \n\t"
"vfcvt.f.x.v v16, v16 \n\t"
"vfmacc.vv v24, v16, v8 \n\t"
"addi t3, t3, -1 \n\t"
"bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
"RESULT_SAVE%=: \n\t"
SAVE_RESULT_4x16
:
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
[BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
: "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4",
"s5", "s6");
}
}
}
template <bool HasZeroPoint>
void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
const uint8_t * QuantA,
const uint8_t * QuantBData,
float * C,
size_t CountN,
size_t BlockCountK,
const size_t ldc) {
GGML_UNUSED(ldc);
size_t INNER = BlkLen / 16;
if constexpr (HasZeroPoint) {
for (size_t n = 0; n < CountN; n += 16) {
size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(uint8_t) + // zp
n * BlockCountK * sizeof(_Float16); // scale
float * CPtr = C + n;
size_t cnt = BlockCountK;
__asm__ volatile(
"vsetvli t0, zero, e32, m4 \n\t"
"vxor.vv v28, v28, v28 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"vmv.v.i v13, 3 \n\t"
"li s1, 24 \n\t"
"vsetvli t0, s1, e8, m1 \n\t"
"vmv.v.i v13, 2 \n\t"
"vsetvli t0, zero, e8, mf2 \n\t"
"vmv.v.i v13, 1 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vmv.v.i v13, 0 \n\t"
"addi s1, %[B], 0 \n\t"
"addi s2, %[B], 8 \n\t"
"addi s3, %[B], 16 \n\t"
"addi s4, %[B], 24 \n\t"
"addi s7, %[B], 32 \n\t"
"addi s5, %[A], 0 \n\t"
"addi s6, %[A], 12 \n\t"
"LOOP_K%=: \n\t"
"vsetvli t0, zero, e16, mf4 \n\t"
"vle16.v v4, (s1) \n\t"
"addi s1, s1, 48 \n\t"
"vle16.v v5, (s2) \n\t"
"addi s2, s2, 72 \n\t"
"vle16.v v6, (s3) \n\t"
"addi s3, s3, 96 \n\t"
"vle16.v v7, (s4) \n\t"
"addi s4, s4, 120 \n\t"
"flw f1, (s5) \n\t"
"addi s5, s5, 4 \n\t"
"vfwcvt.f.f.v v8, v4 \n\t"
"vfwcvt.f.f.v v9, v5 \n\t"
"vfwcvt.f.f.v v10, v6 \n\t"
"vfwcvt.f.f.v v11, v7 \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
"addi t5, %[INNER], 0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v18, v18, v18 \n\t"
"vxor.vv v20, v20, v20 \n\t"
"vxor.vv v22, v22, v22 \n\t"
"vfmul.vf v24, v8, f1 \n\t"
"vfmul.vf v25, v9, f1 \n\t"
"vfmul.vf v26, v10, f1 \n\t"
"vfmul.vf v27, v11, f1 \n\t"
"addi %[CNT], %[CNT], -1 \n\t"
SQ4BIT_KERNEL_LOAD_ZP_16X1
"LOOP_INNER%=: \n\t"
SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
"vsub.vv v0, v0, v8 \n\t"
"vsub.vv v4, v4, v8 \n\t"
"vsub.vv v1, v1, v9 \n\t"
"vsub.vv v5, v5, v9 \n\t"
"vsub.vv v2, v2, v10 \n\t"
"vsub.vv v6, v6, v10 \n\t"
"vsub.vv v3, v3, v11 \n\t"
"vsub.vv v7, v7, v11 \n\t"
SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
"bnez t5, LOOP_INNER%= \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
SQ4BIT_KERNEL_ACC_F16_1X4X4
"addi s7, s1, 32 \n\t"
"bnez %[CNT], LOOP_K%= \n\t"
"addi t3, zero, 16 \n\t"
"addi s1, %[C], 16 \n\t"
"addi s2, %[C], 32 \n\t"
"addi s3, %[C], 48 \n\t"
"blt %[NBLKS], t3, ST_TAIL%= \n\t"
"vse32.v v28, (%[C]) \n\t"
"vse32.v v29, (s1) \n\t"
"vse32.v v30, (s2) \n\t"
"vse32.v v31, (s3) \n\t"
"jal x0, END%= \n\t"
"ST_TAIL%=: \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v28, (%[C]) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v29, (s1) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v30, (s2) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v31, (s3) \n\t"
"END%=: \n\t"
: [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
: "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
}
} else {
for (size_t n = 0; n < CountN; n += 16) {
size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(_Float16); // scale
float * CPtr = C + n;
size_t cnt = BlockCountK;
__asm__ volatile(
"vsetvli t0, zero, e32, m4 \n\t"
"vxor.vv v28, v28, v28 \n\t"
"addi s1, %[B], 0 \n\t"
"addi s2, %[B], 8 \n\t"
"addi s3, %[B], 16 \n\t"
"addi s4, %[B], 24 \n\t"
"addi s5, %[A], 0 \n\t"
"addi s6, %[A], 12 \n\t"
"LOOP_K%=: \n\t"
"vsetvli t0, zero, e16, mf4 \n\t"
"vle16.v v4, (s1) \n\t"
"addi s1, s1, 32 \n\t"
"vle16.v v5, (s2) \n\t"
"addi s2, s2, 56 \n\t"
"vle16.v v6, (s3) \n\t"
"addi s3, s3, 80 \n\t"
"vle16.v v7, (s4) \n\t"
"addi s4, s4, 104 \n\t"
"flw f1, (s5) \n\t"
"addi s5, s5, 4 \n\t"
"vfwcvt.f.f.v v8, v4 \n\t"
"vfwcvt.f.f.v v9, v5 \n\t"
"vfwcvt.f.f.v v10, v6 \n\t"
"vfwcvt.f.f.v v11, v7 \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
"addi t5, %[INNER], 0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v18, v18, v18 \n\t"
"vxor.vv v20, v20, v20 \n\t"
"vxor.vv v22, v22, v22 \n\t"
"vfmul.vf v24, v8, f1 \n\t"
"vfmul.vf v25, v9, f1 \n\t"
"vfmul.vf v26, v10, f1 \n\t"
"vfmul.vf v27, v11, f1 \n\t"
"addi %[CNT], %[CNT], -1 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"LOOP_INNER%=: \n\t"
SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
"vadd.vi v0, v0, -8 \n\t"
"vadd.vi v1, v1, -8 \n\t"
"vadd.vi v2, v2, -8 \n\t"
"vadd.vi v3, v3, -8 \n\t"
"vadd.vi v4, v4, -8 \n\t"
"vadd.vi v5, v5, -8 \n\t"
"vadd.vi v6, v6, -8 \n\t"
"vadd.vi v7, v7, -8 \n\t"
SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
"bnez t5, LOOP_INNER%= \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
SQ4BIT_KERNEL_ACC_F16_1X4X4
"bnez %[CNT], LOOP_K%= \n\t"
"addi t3, zero, 16 \n\t"
"addi s1, %[C], 16 \n\t"
"addi s2, %[C], 32 \n\t"
"addi s3, %[C], 48 \n\t"
"blt %[NBLKS], t3, ST_TAIL%= \n\t"
"vse32.v v28, (%[C]) \n\t"
"vse32.v v29, (s1) \n\t"
"vse32.v v30, (s2) \n\t"
"vse32.v v31, (s3) \n\t"
"jal x0, END%= \n\t"
"ST_TAIL%=: \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v28, (%[C]) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v29, (s1) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v30, (s2) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v31, (s3) \n\t"
"END%=: \n\t"
: [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
: "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
}
}
}
} // namespace
namespace ime1 {
size_t gemm_kernel_i8i4(size_t blk_len,
const uint8_t * quant_a_ptr,
const uint8_t * quant_b_data,
const uint8_t * quant_b_zp,
float * c_ptr,
size_t count_m,
size_t count_n,
size_t k_blks,
size_t ldc) {
if (count_m >= 4) {
if (quant_b_zp != nullptr) {
SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<true>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks,
ldc);
} else {
SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<false>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n,
k_blks, ldc);
}
return 4;
} else {
if (quant_b_zp != nullptr) {
SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<true>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks,
ldc);
} else {
SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<false>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n,
k_blks, ldc);
}
return 1;
}
}
} // namespace ime1
} // namespace spacemit_kernels