diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index 9bdb4e83..4f321a25 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08 + GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index c8760304..37730372 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -190,6 +190,170 @@ static void ggml_zendnn_compute_forward_mul_mat( } } +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +static void ggml_zendnn_compute_forward_mul_mat_id( + ggml_backend_zendnn_context * ctx, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; // expert weights + const ggml_tensor * src1 = dst->src[1]; // inputs + const ggml_tensor * ids = dst->src[2]; // expert ids + + GGML_TENSOR_BINARY_OP_LOCALS + + // exit for no tokens to process + if (ne2 == 0 || ne11 == 0) { + return; + } + + ggml_type const vec_dot_type = src0->type; + ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ne3 == 1); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_experts + + std::vector matrix_row_counts(n_as, 0); + std::vector> matrix_rows(n_as); + + int64_t max_rows = 0; + // group rows by expert (preprocessing step) + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *)((const char *)ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + matrix_rows[i02].push_back({id, iid1}); + matrix_row_counts[i02]++; + if (matrix_row_counts[i02] > max_rows) { + max_rows = matrix_row_counts[i02]; + } + } + } + + if (max_rows == 0) { + return; // no rows to process + } + + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + // size for converting src1 rows to vec_dot_type if needed + const size_t nbw1 = row_size; + const size_t nbw2 = nbw1 * ne11; + const size_t nbw3 = nbw2 * ne12; + const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0; + + // size for MoE gather/scatter buffers + const size_t wdata_cur_size = max_rows * row_size; + const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01); + + // allocate single buffer for all needs + const size_t total_size = src1_conv_size + wdata_cur_size + dst_cur_size; + if (ctx->work_size < total_size) { + ctx->work_data.reset(new char[total_size]); + ctx->work_size = total_size; + } + + // partition the buffer + char * work_data = ctx->work_data.get(); + char * wdata_cur = work_data + src1_conv_size; + char * dst_cur = wdata_cur + wdata_cur_size; + + if (src1->type != vec_dot_type) { + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static) + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13); + void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3; + from_float(src1_f32, src1_conv, ne10); + } + } + } + } + + const void * wdata = src1->type == vec_dot_type ? src1->data : work_data; + + // process each expert with gather -> gemm -> scatter pattern + for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const char * src0_cur = (const char *) src0->data + cur_a*nb02; + + // gather input rows for this expert + #pragma omp parallel for num_threads(ctx->n_threads) schedule(static) + for (int64_t ir1 = 0; ir1 < cne1; ++ir1) { + const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1]; + const int64_t id = row_mapping.i1; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; + + std::memcpy( + wdata_cur + ir1 * row_size, + (const char *) wdata + (i11 + i12*ne11) * row_size, + row_size + ); + } + + // batched gemm for all tokens in this expert + if (!ggml_zendnn_sgemm(ctx, + ne01, // m + cne1, // n + ne10, // k + src0_cur, + ne00, // lda + wdata_cur, + ne10, // ldb + dst_cur, + ne01, // ldc + src0->type, + vec_dot_type, + dst->type)) { + GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + } + + // scatter output rows to destination + #pragma omp parallel for num_threads(ctx->n_threads) schedule(static) + for (int64_t ir1 = 0; ir1 < cne1; ++ir1) { + const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1]; + const int64_t id = row_mapping.i1; + const int64_t i1 = id; + const int64_t i2 = row_mapping.i2; + + std::memcpy( + (char *) dst->data + i1*nb1 + i2*nb2, + dst_cur + ir1 * ggml_row_size(dst->type, ne01), + ggml_row_size(dst->type, ne01) + ); + } + } +} + // backend interface static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) { @@ -218,6 +382,9 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm case GGML_OP_MUL_MAT: ggml_zendnn_compute_forward_mul_mat(ctx, node); break; + case GGML_OP_MUL_MAT_ID: + ggml_zendnn_compute_forward_mul_mat_id(ctx, node); + break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -361,6 +528,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const return true; case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: { const ggml_tensor * weights = op->src[0]; const ggml_tensor * inputs = op->src[1]; @@ -374,6 +542,17 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) { return false; } + // MUL_MAT_ID performs best with a moderate number of experts due to its + // gather + batched matmul + scatter approach. Future versions will leverage + // ZenDNN's grouped_gemm for better scalability with larger expert counts: + // https://github.com/amd/ZenDNN/blob/main/docs/operator/lowoha_group_gemm_operator.md + if (op->op == GGML_OP_MUL_MAT_ID) { + const int64_t n_experts = weights->ne[2]; + const int64_t max_experts = 32; + if (n_experts > max_experts) { + return false; + } + } switch (weights->type) { case GGML_TYPE_F32: case GGML_TYPE_BF16: