From 0c2fbd4703a7a64a71dc07e60a17d89dac81d57b Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Mon, 6 Apr 2026 11:55:21 -0700 Subject: [PATCH] ggml: add Q1_0 1-bit quantization support (CPU) (llama/21273) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml: add Q1_0 and Q1_0_g128 1-bit quantization support (CPU) * add generic fallback for x86 * remove Q1_0 (group size 32) * rename Q1_0_g128 => Q1_0 * fix Q1_0 LlamaFileType Enum * Fix trailing spaces; add generic fallback for othre backends * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret * fix /r/n spacing + arch-fallback --------- Co-authored-by: Sigbjørn Skjæret --- ggml/include/ggml.h | 4 +- ggml/src/ggml-common.h | 11 +++ ggml/src/ggml-cpu/arch-fallback.h | 7 ++ ggml/src/ggml-cpu/arch/arm/quants.c | 103 ++++++++++++++++++++++ ggml/src/ggml-cpu/arch/loongarch/quants.c | 1 - ggml/src/ggml-cpu/arch/powerpc/quants.c | 1 - ggml/src/ggml-cpu/arch/s390/quants.c | 1 - ggml/src/ggml-cpu/arch/wasm/quants.c | 1 - ggml/src/ggml-cpu/ggml-cpu.c | 6 ++ ggml/src/ggml-cpu/ops.cpp | 2 + ggml/src/ggml-cpu/quants.c | 49 ++++++++++ ggml/src/ggml-cpu/quants.h | 3 + ggml/src/ggml-quants.c | 75 ++++++++++++++++ ggml/src/ggml-quants.h | 3 + ggml/src/ggml.c | 10 +++ 15 files changed, 272 insertions(+), 5 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 669f66b6..3bb2faa2 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -428,7 +428,8 @@ extern "C" { // GGML_TYPE_IQ4_NL_8_8 = 38, GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) - GGML_TYPE_COUNT = 41, + GGML_TYPE_Q1_0 = 41, + GGML_TYPE_COUNT = 42, }; // precision @@ -465,6 +466,7 @@ extern "C" { GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors + GGML_FTYPE_MOSTLY_Q1_0 = 27, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 92cf739e..f05683b4 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -93,6 +93,10 @@ typedef sycl::half2 ggml_half2; // QR = QK / number of values before dequantization // QI = number of 32 bit integers before dequantization +#define QI1_0 (QK1_0 / 32) +#define QR1_0 1 + + #define QI4_0 (QK4_0 / (4 * QR4_0)) #define QR4_0 2 @@ -170,6 +174,13 @@ typedef sycl::half2 ggml_half2; #define GGML_EXTENSION __extension__ #endif // _MSC_VER +#define QK1_0 128 +typedef struct { + ggml_half d; // delta + uint8_t qs[QK1_0 / 8]; // bits / quants +} block_q1_0; +static_assert(sizeof(block_q1_0) == sizeof(ggml_half) + QK1_0 / 8, "wrong q1_0 block size/padding"); + #define QK4_0 32 typedef struct { ggml_half d; // delta diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 41da8293..c589a213 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -16,6 +16,7 @@ #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -82,6 +83,7 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 @@ -112,6 +114,7 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K @@ -160,6 +163,7 @@ #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -200,6 +204,7 @@ #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -240,6 +245,7 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -303,6 +309,7 @@ #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 82b048bb..e09db59c 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -137,6 +137,109 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in //===================================== Dot products ================================= +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; // 128 + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + float sumf = 0.0f; + +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); + + // Process 4 Q8_0 blocks (each has 32 elements) + for (int k = 0; k < 4; k++) { + const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); + + // Get the 4 bytes of bits for this Q8_0 block (32 bits = 4 bytes) + // Bits are at offset k*4 bytes in x[i].qs + const uint8_t * bits = &x[i].qs[k * 4]; + + // Load 32 int8 values from y + const int8x16_t y0 = vld1q_s8(yb->qs); + const int8x16_t y1 = vld1q_s8(yb->qs + 16); + + // Byte 0-1: bits for y0[0..15] + const uint64_t expand0 = table_b2b_0[bits[0]]; + const uint64_t expand1 = table_b2b_0[bits[1]]; + // Byte 2-3: bits for y1[0..15] + const uint64_t expand2 = table_b2b_0[bits[2]]; + const uint64_t expand3 = table_b2b_0[bits[3]]; + + // Build the sign vectors by reinterpreting the table values + uint8x8_t e0 = vcreate_u8(expand0); + uint8x8_t e1 = vcreate_u8(expand1); + uint8x8_t e2 = vcreate_u8(expand2); + uint8x8_t e3 = vcreate_u8(expand3); + + // Shift right by 4 to get 0 or 1 + int8x8_t s0 = vreinterpret_s8_u8(vshr_n_u8(e0, 4)); + int8x8_t s1 = vreinterpret_s8_u8(vshr_n_u8(e1, 4)); + int8x8_t s2 = vreinterpret_s8_u8(vshr_n_u8(e2, 4)); + int8x8_t s3 = vreinterpret_s8_u8(vshr_n_u8(e3, 4)); + + // Convert 0/1 to -1/+1: sign = 2*val - 1 + int8x8_t one = vdup_n_s8(1); + s0 = vsub_s8(vadd_s8(s0, s0), one); // 2*s0 - 1 + s1 = vsub_s8(vadd_s8(s1, s1), one); + s2 = vsub_s8(vadd_s8(s2, s2), one); + s3 = vsub_s8(vadd_s8(s3, s3), one); + + // Combine into 16-element vectors + int8x16_t signs0 = vcombine_s8(s0, s1); + int8x16_t signs1 = vcombine_s8(s2, s3); + + // Multiply signs with y values and accumulate + // dot(signs, y) where signs are +1/-1 + int32x4_t p0 = ggml_vdotq_s32(vdupq_n_s32(0), signs0, y0); + int32x4_t p1 = ggml_vdotq_s32(p0, signs1, y1); + + // Scale by d1 and accumulate + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(p1), d0 * d1); + } + } + + sumf = vaddvq_f32(sumv); +#else + // Scalar fallback + for (int i = 0; i < nb; i++) { + const float d0 = GGML_FP16_TO_FP32(x[i].d); + + // Process 4 Q8_0 blocks + for (int k = 0; k < 4; k++) { + const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); + + int sumi = 0; + for (int j = 0; j < QK8_0; j++) { + const int bit_index = k * QK8_0 + j; + const int byte_index = bit_index / 8; + const int bit_offset = bit_index % 8; + + const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; + sumi += xi * y[i*4 + k].qs[j]; + } + sumf += d0 * d1 * sumi; + } + } +#endif + + *s = sumf; +} + + void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index f531e916..74e0c086 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -2156,4 +2156,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index d3dfd049..644c380c 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -2302,4 +2302,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 34184ed8..50085757 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -1463,4 +1463,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index 74a359e6..648c6fca 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -1218,4 +1218,3 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7486acc2..2b3eb5b5 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -217,6 +217,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_F16, .nrows = 1, }, + [GGML_TYPE_Q1_0] = { + .from_float = quantize_row_q1_0, + .vec_dot = ggml_vec_dot_q1_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q4_0] = { .from_float = quantize_row_q4_0, .vec_dot = ggml_vec_dot_q4_0_q8_0, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 765ce07f..0b5d6c6d 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4829,6 +4829,7 @@ void ggml_compute_forward_get_rows( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5554,6 +5555,7 @@ void ggml_compute_forward_clamp( ggml_compute_forward_clamp_f16(params, dst); } break; case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 7ebbb9c6..f66127c2 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -22,6 +22,10 @@ #define UNUSED GGML_UNUSED +void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q1_0_ref(x, y, k); +} + void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { quantize_row_q4_0_ref(x, y, k); } @@ -116,6 +120,51 @@ void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI //===================================== Dot products ================================= +void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_FP16_TO_FP32(x[i].d); + + float sumi = 0.0f; + + for (int k = 0; k < 4; k++) { + const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); + + int sumi_block = 0; + + for (int j = 0; j < QK8_0; j++) { + const int bit_index = k * QK8_0 + j; + const int byte_index = bit_index / 8; + const int bit_offset = bit_index % 8; + + const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; + sumi_block += xi * y[i*4 + k].qs[j]; + } + + sumi += d1 * sumi_block; + } + + sumf += d0 * sumi; + } + + *s = sumf; +} + + void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 3584aaa4..d4bc87a1 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -12,6 +12,7 @@ extern "C" { #endif // Quantization +void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -36,6 +37,7 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dot product +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -68,6 +70,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 48695a61..15443aa5 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -32,6 +32,41 @@ static inline int best_index_int8(int n, const int8_t * val, float x) { return x - val[mu-1] < val[mu] - x ? mu-1 : mu; } +// reference implementation for deterministic creation of model files +void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k) { + static const int qk = QK1_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float sum_abs = 0.0f; + for (int j = 0; j < qk; j++) { + sum_abs += fabsf(x[i*qk + j]); + } + const float d = sum_abs / qk; + + y[i].d = GGML_FP32_TO_FP16(d); + + // Clear all bits first + for (int j = 0; j < qk / 8; ++j) { + y[i].qs[j] = 0; + } + + // Just store sign of each weight directly (no normalization) + for (int j = 0; j < qk; ++j) { + const int bit_index = j; + const int byte_index = bit_index / 8; + const int bit_offset = bit_index % 8; + + if (x[i*qk + j] >= 0.0f) { + y[i].qs[byte_index] |= (1 << bit_offset); + } + } + } +} + // reference implementation for deterministic creation of model files void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -339,6 +374,26 @@ void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RE } } +void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK1_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float neg_d = -d; + + for (int j = 0; j < qk; ++j) { + const int byte_index = j / 8; + const int bit_offset = j % 8; + const uint8_t bit = (x[i].qs[byte_index] >> bit_offset) & 1; + y[i*qk + j] = bit ? d : neg_d; + } + } +} + void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -1978,6 +2033,22 @@ static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * G } } +size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q1_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q1_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q1_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q1_0_ref(src, (block_q1_0*)qrow, n_per_row); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + + size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row); @@ -5286,6 +5357,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte } } } break; + case GGML_TYPE_Q1_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q1_0, data, nb); + } break; case GGML_TYPE_Q4_0: { VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 00604f75..d56c86da 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -14,6 +14,7 @@ extern "C" { // NOTE: these functions are defined as GGML_API because they used by the CPU backend // Quantization +GGML_API void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); @@ -41,6 +42,7 @@ GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_ GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); // Dequantization +GGML_API void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -90,6 +92,7 @@ GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e9b6720c..0142498d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -651,6 +651,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row, }, + [GGML_TYPE_Q1_0] = { + .type_name = "q1_0", + .blck_size = QK1_0, + .type_size = sizeof(block_q1_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q1_0, + .from_float_ref = (ggml_from_float_t) quantize_row_q1_0_ref, + }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", .blck_size = QK4_0, @@ -1384,6 +1392,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q1_0: wtype = GGML_TYPE_Q1_0; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; @@ -7652,6 +7661,7 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { + case GGML_TYPE_Q1_0: result = quantize_q1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;