llama : separate compute buffer reserve from fattn check (llama/15696)

Exposes ggml_backend_sched_split_graph() to allow splitting the graph without allocating compute buffers and uses it to split the graph for the automatic Flash Attention check.
This commit is contained in:
Diego Devesa 2025-08-31 06:49:03 -07:00 committed by Georgi Gerganov
parent db7ecfb61d
commit b11c972b88
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 6 additions and 1 deletions

View File

@ -307,6 +307,9 @@ extern "C" {
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
// Split graph without allocating it
GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
// Allocate and compute graph on the backend scheduler
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);

View File

@ -902,7 +902,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
}
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
// reset splits
sched->n_splits = 0;
sched->n_graph_inputs = 0;
@ -1687,6 +1687,8 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
GGML_ASSERT(sched);
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
ggml_backend_sched_reset(sched);
ggml_backend_sched_synchronize(sched);
ggml_backend_sched_split_graph(sched, measure_graph);