ggml-zendnn : adaptive fallback to CPU backend for small batch sizes (llama/22681)
* ggml-zendnn : add runtime env var GGML_ZENDNN_ADAPTIVE_FALLBACK to control adaptive fallback (default: enabled) * ggml-zendnn : restore original fallback logic when adaptive fallback is disabled
This commit is contained in:
parent
bcaf449826
commit
8b288f5d96
|
|
@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
|
|||
ExternalProject_Add(
|
||||
zendnn
|
||||
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
|
||||
GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13
|
||||
GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17
|
||||
PREFIX ${ZENDNN_PREFIX}
|
||||
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
|
||||
BINARY_DIR ${ZENDNN_BUILD_DIR}
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
|
|||
params.dtypes.dst = ggml_to_zendnn_type<TC>();
|
||||
params.num_threads = ctx->n_threads;
|
||||
|
||||
zendnnl::lowoha::matmul::matmul_batch_params_t batch_params;
|
||||
zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct(
|
||||
'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major)
|
||||
n, // M: rows of B and C
|
||||
|
|
@ -59,7 +60,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
|
|||
0.0f, // beta
|
||||
C, ldc, // output C[n,m]
|
||||
true, // is_weights_const
|
||||
{}, // batch_params
|
||||
batch_params, // batch_params
|
||||
params // params
|
||||
);
|
||||
|
||||
|
|
@ -520,6 +521,12 @@ static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggm
|
|||
GGML_UNUSED(max_tensor_size);
|
||||
}
|
||||
|
||||
static bool ggml_zendnn_adaptive_fallback_enabled() {
|
||||
static const bool enabled = std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK") == nullptr ||
|
||||
std::atoi(std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK")) != 0;
|
||||
return enabled;
|
||||
}
|
||||
|
||||
static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||
switch (op->op) {
|
||||
case GGML_OP_NONE:
|
||||
|
|
@ -538,12 +545,24 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
|
|||
const int64_t ne10 = inputs->ne[0];
|
||||
const int64_t ne0 = op->ne[0];
|
||||
const int64_t ne1 = op->ne[1];
|
||||
|
||||
const int64_t min_batch = 1;
|
||||
if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) ||
|
||||
ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
|
||||
return false;
|
||||
|
||||
if(!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ggml_zendnn_adaptive_fallback_enabled()) {
|
||||
const int64_t K = inputs->ne[0];
|
||||
const int64_t N = (inputs->ne[1]*inputs->ne[2]*inputs->ne[3]);
|
||||
const int64_t M = weights->ne[1];
|
||||
if(K <= 256 || N <= 128 || M <= 96) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// MUL_MAT_ID performs best with a moderate number of experts due to its
|
||||
// gather + batched matmul + scatter approach. Future versions will leverage
|
||||
// ZenDNN's grouped_gemm for better scalability with larger expert counts:
|
||||
|
|
|
|||
Loading…
Reference in New Issue