ggml-zendnn : adaptive fallback to CPU backend for small batch sizes (llama/22681)

* ggml-zendnn : add runtime env var GGML_ZENDNN_ADAPTIVE_FALLBACK to control adaptive fallback (default: enabled)

* ggml-zendnn : restore original fallback logic when adaptive fallback is disabled
This commit is contained in:
Sachin Sharma 2026-05-13 11:43:47 +05:30 committed by Georgi Gerganov
parent bcaf449826
commit 8b288f5d96
2 changed files with 25 additions and 6 deletions

View File

@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
ExternalProject_Add(
zendnn
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13
GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17
PREFIX ${ZENDNN_PREFIX}
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
BINARY_DIR ${ZENDNN_BUILD_DIR}

View File

@ -47,6 +47,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
params.dtypes.dst = ggml_to_zendnn_type<TC>();
params.num_threads = ctx->n_threads;
zendnnl::lowoha::matmul::matmul_batch_params_t batch_params;
zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct(
'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major)
n, // M: rows of B and C
@ -59,7 +60,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
0.0f, // beta
C, ldc, // output C[n,m]
true, // is_weights_const
{}, // batch_params
batch_params, // batch_params
params // params
);
@ -520,6 +521,12 @@ static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggm
GGML_UNUSED(max_tensor_size);
}
static bool ggml_zendnn_adaptive_fallback_enabled() {
static const bool enabled = std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK") == nullptr ||
std::atoi(std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK")) != 0;
return enabled;
}
static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
switch (op->op) {
case GGML_OP_NONE:
@ -538,12 +545,24 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
const int64_t ne10 = inputs->ne[0];
const int64_t ne0 = op->ne[0];
const int64_t ne1 = op->ne[1];
const int64_t min_batch = 1;
if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) ||
ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
return false;
if(!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs)) {
return false;
}
if (ggml_zendnn_adaptive_fallback_enabled()) {
const int64_t K = inputs->ne[0];
const int64_t N = (inputs->ne[1]*inputs->ne[2]*inputs->ne[3]);
const int64_t M = weights->ne[1];
if(K <= 256 || N <= 128 || M <= 96) {
return false;
}
}
else if (ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
return false;
}
// MUL_MAT_ID performs best with a moderate number of experts due to its
// gather + batched matmul + scatter approach. Future versions will leverage
// ZenDNN's grouped_gemm for better scalability with larger expert counts: