kleidiai : dynamic chunck-based scheduling for hybrid execution (llama/23819)

This commit is contained in:
Charles Xu 2026-06-05 09:11:47 +02:00 committed by Georgi Gerganov
parent 4fa1e0687e
commit facb02c4c3
1 changed files with 139 additions and 129 deletions

View File

@ -38,6 +38,7 @@
#include "kleidiai.h"
#include "ggml-cpu.h"
#include "ggml-cpu-impl.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include "ggml-threading.h"
@ -61,7 +62,8 @@ struct ggml_kleidiai_context {
ggml_kleidiai_kernels * kernels_q8;
int sme_thread_cap; // <= 0 means “SME disabled/unknown”;
int thread_hint; // <= 0 means “no hint”
} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 };
int chunk_multiplier;
} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1, 4 };
static const char* cpu_feature_to_string(cpu_feature f) {
if (f == CPU_FEATURE_NONE) {
@ -186,8 +188,9 @@ static void init_kleidiai_context(void) {
if (!initialized) {
initialized = true;
const char *env_sme = getenv("GGML_KLEIDIAI_SME");
const char *env_threads = getenv("GGML_TOTAL_THREADS");
const char *env_sme = getenv("GGML_KLEIDIAI_SME");
const char *env_threads = getenv("GGML_TOTAL_THREADS");
const char *env_chunk_mult = getenv("GGML_KLEIDIAI_CHUNK_MULTIPLIER");
const bool cpu_has_sme = ggml_cpu_has_sme();
size_t detected_smcus = 0;
@ -204,6 +207,14 @@ static void init_kleidiai_context(void) {
}
}
if (env_chunk_mult) {
bool ok = false;
int multiplier = parse_uint_env(env_chunk_mult, "GGML_KLEIDIAI_CHUNK_MULTIPLIER", &ok);
if (ok && multiplier > 0) {
ctx.chunk_multiplier = multiplier;
}
}
// SME policy:
// - If CPU doesn't support SME: SME always off.
// - Else:
@ -296,6 +307,50 @@ static inline size_t align_up(size_t value, size_t alignment) {
return remainder == 0 ? value : value + (alignment - remainder);
}
static inline size_t gcd_size(size_t a, size_t b) {
while (b != 0) {
const size_t t = a % b;
a = b;
b = t;
}
return a;
}
static inline bool lcm_size(size_t a, size_t b, size_t & result) {
if (a == 0 || b == 0) {
result = 0;
return false;
}
const size_t g = gcd_size(a, b);
const size_t q = a / g;
if (q > SIZE_MAX / b) {
return false;
}
result = q * b;
return true;
}
static inline size_t ceil_div_size(size_t a, size_t b) {
return b == 0 ? 0 : (a + b - 1) / b;
}
struct kleidiai_block_args {
size_t lhs_bl;
size_t rhs_bl;
size_t pack_bl;
};
static inline kleidiai_block_args kleidiai_get_block_args(ggml_type rhs_type) {
switch (rhs_type) {
case GGML_TYPE_Q4_0:
return { QK4_0, QK4_0, QK4_0 };
case GGML_TYPE_Q8_0:
return { 0, 0, QK8_0 };
default:
return { 0, 0, 0 };
}
}
static inline bool kleidiai_pack_fallback_allowed() {
if (ctx.sme_thread_cap <= 0) {
return false;
@ -746,8 +801,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
size_t n_step;
size_t lhs_packed_size;
size_t lhs_offset;
size_t n_offset;
size_t n_cols;
size_t lhs_bl;
size_t rhs_bl;
size_t pack_bl;
size_t lhs_packed_offset0;
int assigned_threads;
int thread_begin;
int thread_end;
@ -772,6 +829,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
continue;
}
const kleidiai_block_args block_args = kleidiai_get_block_args(kernels->rhs_type);
runtime[runtime_count] = {
slot,
kernels,
@ -784,7 +843,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
kinfo->get_n_step(),
0,
0,
0,
block_args.lhs_bl,
block_args.rhs_bl,
block_args.pack_bl,
0,
0,
0,
@ -795,45 +856,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
}
if (runtime_count == 0) {
ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst);
if (!fallback) {
return false;
}
kernel_info * kinfo = is_gemv ? &fallback->gemv : &fallback->gemm;
lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info;
rhs_packing_info * rinfo = &fallback->rhs_info;
if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex ||
!kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset ||
!rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) {
return false;
}
kernel_chain[0] = fallback;
runtime[0] = {
0,
fallback,
kinfo,
linfo,
kinfo->get_mr(),
kinfo->get_nr(),
kinfo->get_kr(),
kinfo->get_sr(),
kinfo->get_n_step(),
0,
0,
0,
0,
0,
0,
0,
nullptr
};
size_t rhs_size_fallback = 0;
const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback);
if (!rhs_base) {
rhs_base = static_cast<const uint8_t *>(src0->data);
}
runtime[0].rhs_base = rhs_base;
runtime_count = 1;
GGML_LOG_WARN("kleidiai: no runtime kernel slot available for supported op %s\n", dst->name);
return false;
}
const int nth_total = params->nth > 0 ? params->nth : 1;
@ -846,6 +870,13 @@ class tensor_traits : public ggml::cpu::tensor_traits {
break;
}
}
int non_sme_slot = -1;
for (int i = 0; i < runtime_count; ++i) {
if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) != CPU_FEATURE_SME) {
non_sme_slot = i;
break;
}
}
const int sme_cap_limit = ctx.sme_thread_cap;
const bool use_hybrid = sme_cap_limit > 0 &&
@ -864,12 +895,15 @@ class tensor_traits : public ggml::cpu::tensor_traits {
if (!hybrid_enabled) {
int chosen_slot = 0;
if (too_small_for_hybrid && sme_slot != -1) {
chosen_slot = sme_slot;
chosen_slot = nth_total > sme_cap_limit && non_sme_slot != -1 ? non_sme_slot : sme_slot;
} else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {
chosen_slot = 1;
}
if (chosen_slot != 0 && chosen_slot < runtime_count) {
runtime[0] = runtime[chosen_slot];
runtime[0].assigned_threads = 0;
runtime[0].thread_begin = 0;
runtime[0].thread_end = 0;
}
runtime_count = runtime_count > 0 ? 1 : 0;
@ -896,6 +930,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
int fallback_count = 0;
// The current hybrid chain is bounded to SME + one non-SME fallback slot.
GGML_ASSERT(GGML_KLEIDIAI_MAX_KERNEL_SLOTS == 2);
for (int i = 0; i < runtime_count; ++i) {
if (i == sme_slot) {
continue;
@ -952,73 +988,67 @@ class tensor_traits : public ggml::cpu::tensor_traits {
size_t cursor = 0;
for (int i = 0; i < runtime_count; ++i) {
const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type;
const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr);
runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, runtime[i].pack_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr);
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
runtime[i].lhs_offset = cursor;
runtime[i].lhs_packed_offset0 = runtime[i].lhs_info->get_packed_offset_ex(0, k, runtime[i].lhs_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr);
cursor += runtime[i].lhs_packed_size;
}
GGML_ASSERT(cursor <= params->wsize);
uint8_t * scratch = static_cast<uint8_t *>(params->wdata);
size_t assigned_cols = 0;
uint64_t weighted_total = 0;
if (runtime_count > 1 && sme_slot != -1) {
for (int i = 0; i < runtime_count; ++i) {
const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
weighted_total += (uint64_t)runtime[i].assigned_threads * weight;
}
}
size_t common_step = 1;
for (int i = 0; i < runtime_count; ++i) {
runtime[i].n_offset = assigned_cols;
if (runtime[i].assigned_threads == 0) {
runtime[i].n_cols = 0;
continue;
}
const size_t remaining_cols = n - assigned_cols;
if (remaining_cols == 0) {
runtime[i].n_cols = 0;
continue;
size_t next_step = 0;
if (!lcm_size(common_step, runtime[i].n_step ? runtime[i].n_step : 1, next_step)) {
return false;
}
const size_t step = runtime[i].n_step ? runtime[i].n_step : 1;
size_t target = 0;
if (weighted_total > 0) {
const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total);
} else {
target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total);
}
target = std::min(target, remaining_cols);
size_t aligned = round_down(target, step);
if (aligned == 0 && remaining_cols >= step) {
aligned = step;
}
runtime[i].n_cols = aligned;
assigned_cols += aligned;
common_step = next_step;
}
GGML_ASSERT(common_step > 0);
if (assigned_cols < n) {
for (int i = runtime_count - 1; i >= 0; --i) {
if (runtime[i].assigned_threads > 0) {
runtime[i].n_cols += n - assigned_cols;
break;
}
}
const bool disable_chunking = ggml_is_numa();
const size_t chunk_multiplier = std::max(1, ctx.chunk_multiplier);
const size_t chunk_divisor = (nth_total == 1 || disable_chunking) ? (size_t)nth_total : (size_t)nth_total * chunk_multiplier;
size_t chunk_cols = align_up(std::max<size_t>(1, ceil_div_size(n, chunk_divisor)), common_step);
if (chunk_cols == 0) {
chunk_cols = common_step;
}
// If common_step is larger than n, the loop below runs one valid tail chunk
// with cols == n.
const size_t nchunk_size = std::max<size_t>(1, ceil_div_size(n, chunk_cols));
GGML_ASSERT(nchunk_size <= (size_t)INT_MAX);
const int nchunk = (int)nchunk_size;
const size_t dst_stride = dst->nb[1];
auto run_chunk = [&](runtime_slot & slot, size_t global_start, size_t cols, uint8_t * dst_batch_base) {
const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot.rhs_bl);
const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride);
const uint8_t * lhs_ptr = scratch + slot.lhs_offset + slot.lhs_packed_offset0;
const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
slot.kernel->run_kernel_ex(m, cols, k, slot.rhs_bl,
lhs_ptr,
rhs_ptr,
dst_ptr,
dst_stride,
sizeof(float),
-FLT_MAX,
FLT_MAX);
};
for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {
const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
if (runtime[local_slot].assigned_threads > 0) {
runtime_slot & slot = runtime[local_slot];
const ggml_type slot_rhs_type = slot.kernels->rhs_type;
const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);
int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;
max_threads = std::max<int64_t>(1, max_threads);
@ -1031,8 +1061,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;
const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr);
const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr);
const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;
int64_t remaining = m_count;
@ -1049,7 +1079,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
void * dst_ptr = lhs_packed + dst_off;
slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
slot.lhs_info->pack_func_ex(take, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
cur += take;
remaining -= take;
@ -1057,49 +1087,29 @@ class tensor_traits : public ggml::cpu::tensor_traits {
}
}
if (ith_total == 0) {
ggml_threadpool_chunk_set(params->threadpool, nth_total);
}
// Publishes both LHS packing and the initialized dynamic chunk queue.
ggml_barrier(params->threadpool);
runtime_slot & slot = runtime[local_slot];
if (slot.n_cols > 0 && slot.assigned_threads > 0) {
int64_t active_threads = slot.assigned_threads;
const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads;
if (max_threads > 0) {
active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads));
int current_chunk = ith_total;
while (current_chunk < nchunk) {
const size_t global_start = (size_t)current_chunk * chunk_cols;
if (global_start >= n) {
break;
}
active_threads = std::max<int64_t>(1, active_threads);
if (local_ith < active_threads) {
const size_t step = slot.n_step ? slot.n_step : 1;
const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step);
const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0;
const size_t local_start = (size_t)local_ith * chunk0;
const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0;
if (cols > 0) {
const ggml_type slot_rhs_type = slot.kernels->rhs_type;
const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
const size_t global_start = slot.n_offset + local_start;
const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg);
const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride);
const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset;
const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg,
lhs_ptr,
rhs_ptr,
dst_ptr,
dst_stride,
sizeof(float),
-FLT_MAX,
FLT_MAX);
}
const size_t cols = std::min(chunk_cols, n - global_start);
if (cols > 0) {
// KleidiAI GEMM/GEMV kernels accept arbitrary final tail widths;
// only non-tail chunks are guaranteed to be n_step-aligned.
run_chunk(slot, global_start, cols, dst_batch_base);
}
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
if (batch_idx != ne12 - 1) {