ggml-zendnn : add Q8_0 quantization support (llama/23414)
* ggml-zendnn : add Q8_0 quantization support * ggml-zendnn : sync with latest ZenDNN * ggml-zendnn : address review comments for Q8_0
This commit is contained in:
parent
ec183556c6
commit
2d629533a5
|
|
@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
|
|||
ExternalProject_Add(
|
||||
zendnn
|
||||
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
|
||||
GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17
|
||||
GIT_TAG 253b94ce0d7e9284c265fefb485714944caff9d3 # ZenDNN-2026-WW19
|
||||
PREFIX ${ZENDNN_PREFIX}
|
||||
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
|
||||
BINARY_DIR ${ZENDNN_BUILD_DIR}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,10 @@
|
|||
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#define GGML_COMMON_DECL_CPP
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include "zendnnl.hpp"
|
||||
|
||||
#include <cstring>
|
||||
|
|
@ -19,6 +23,8 @@ zendnnl::common::data_type_t ggml_to_zendnn_type() {
|
|||
return zendnnl::common::data_type_t::f32;
|
||||
} else if constexpr (std::is_same_v<T, ggml_bf16_t>) {
|
||||
return zendnnl::common::data_type_t::bf16;
|
||||
} else if constexpr (std::is_same_v<T, block_q8_0>) {
|
||||
return zendnnl::common::data_type_t::s8;
|
||||
} else {
|
||||
return zendnnl::common::data_type_t::none;
|
||||
}
|
||||
|
|
@ -48,6 +54,17 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
|
|||
params.num_threads = ctx->n_threads;
|
||||
|
||||
zendnnl::lowoha::matmul::matmul_batch_params_t batch_params;
|
||||
|
||||
if constexpr (std::is_same_v<TA, block_q8_0>) {
|
||||
params.dtypes.compute = zendnnl::common::data_type_t::s8;
|
||||
const int64_t num_groups = k / QK8_0;
|
||||
params.dynamic_quant = true;
|
||||
params.quant_params.src_scale.buff = nullptr;
|
||||
params.quant_params.src_scale.dt = zendnnl::common::data_type_t::bf16;
|
||||
params.quant_params.src_scale.dims = {n, num_groups};
|
||||
params.packing.pack_format_b = 1;
|
||||
}
|
||||
|
||||
zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct(
|
||||
'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major)
|
||||
n, // M: rows of B and C
|
||||
|
|
@ -108,6 +125,14 @@ static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int6
|
|||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc);
|
||||
return false;
|
||||
case GGML_TYPE_Q8_0:
|
||||
if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32)
|
||||
return false;
|
||||
return ggml_zendnn_matmul<block_q8_0, float, float>(
|
||||
ctx, m, n, k,
|
||||
(const block_q8_0 *)A, lda,
|
||||
(const float *)B, ldb,
|
||||
(float *)C, ldc);
|
||||
default:
|
||||
return false; // unsupported type
|
||||
}
|
||||
|
|
@ -145,7 +170,9 @@ static void ggml_zendnn_compute_forward_mul_mat(
|
|||
const int64_t r3 = ne13/ne03;
|
||||
|
||||
void * work_data = ctx->work_data.get();
|
||||
if (src1->type != vec_dot_type) {
|
||||
|
||||
// ZenDNN requires FP32 for dynamic quantization, so conversion is skipped
|
||||
if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) {
|
||||
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
|
||||
const size_t nbw2 = nbw1 * ne11;
|
||||
const size_t nbw3 = nbw2 * ne12;
|
||||
|
|
@ -171,7 +198,7 @@ static void ggml_zendnn_compute_forward_mul_mat(
|
|||
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
const void* wdata = src1->type == vec_dot_type ? src1->data : work_data;
|
||||
const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
if (!ggml_zendnn_sgemm(ctx,
|
||||
ne01, // m
|
||||
|
|
@ -184,7 +211,7 @@ static void ggml_zendnn_compute_forward_mul_mat(
|
|||
static_cast<char *>(dst->data) + i12*nb2 + i13*nb3,
|
||||
ne01, // ldc
|
||||
src0->type,
|
||||
vec_dot_type,
|
||||
src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
|
||||
dst->type))
|
||||
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
|
||||
}
|
||||
|
|
@ -261,10 +288,15 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
|
|||
const size_t nbw1 = row_size;
|
||||
const size_t nbw2 = nbw1 * ne11;
|
||||
const size_t nbw3 = nbw2 * ne12;
|
||||
const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0;
|
||||
const size_t src1_conv_size = (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) ? ne13 * nbw3 : 0;
|
||||
|
||||
// For Q8_0, src1 is always F32; the gather buffer must hold F32 rows (ne10*4 bytes),
|
||||
// not Q8_0-encoded rows (row_size ≈ ne10/32*34 bytes) — they differ by ~4x.
|
||||
const size_t f32_row_size = (size_t)ne10 * sizeof(float);
|
||||
const size_t gather_row_size = (src0->type == GGML_TYPE_Q8_0) ? f32_row_size : row_size;
|
||||
|
||||
// size for MoE gather/scatter buffers
|
||||
const size_t wdata_cur_size = max_rows * row_size;
|
||||
const size_t wdata_cur_size = max_rows * gather_row_size;
|
||||
const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01);
|
||||
|
||||
// allocate single buffer for all needs
|
||||
|
|
@ -279,7 +311,8 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
|
|||
char * wdata_cur = work_data + src1_conv_size;
|
||||
char * dst_cur = wdata_cur + wdata_cur_size;
|
||||
|
||||
if (src1->type != vec_dot_type) {
|
||||
// ZenDNN requires FP32 for dynamic quantization, so conversion is skipped
|
||||
if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) {
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
#pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)
|
||||
|
|
@ -294,7 +327,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
|
|||
}
|
||||
}
|
||||
|
||||
const void * wdata = src1->type == vec_dot_type ? src1->data : work_data;
|
||||
const void * wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data;
|
||||
|
||||
// process each expert with gather -> gemm -> scatter pattern
|
||||
for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) {
|
||||
|
|
@ -315,9 +348,9 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
|
|||
const int64_t i12 = row_mapping.i2;
|
||||
|
||||
std::memcpy(
|
||||
wdata_cur + ir1 * row_size,
|
||||
(const char *) wdata + (i11 + i12*ne11) * row_size,
|
||||
row_size
|
||||
wdata_cur + ir1 * gather_row_size,
|
||||
(const char *) wdata + (i11 + i12*ne11) * gather_row_size,
|
||||
gather_row_size
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -333,7 +366,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
|
|||
dst_cur,
|
||||
ne01, // ldc
|
||||
src0->type,
|
||||
vec_dot_type,
|
||||
src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
|
||||
dst->type)) {
|
||||
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
|
||||
}
|
||||
|
|
@ -577,6 +610,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
|
|||
switch (weights->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
|
|||
Loading…
Reference in New Issue