vulkan/cuda: fix topk_moe with exp_probs_b (llama/18071)

I updated test_topk_moe to more closely match llm_graph_context::build_moe_ffn
and added coverage for exp_probs_b and some other missing combinations. This
exposed a bug in both CUDA and Vulkan backends where they were assuming the
input to argsort and the input to get_rows are the same. I'd like to optimize
this graph in another change, but for now just get it functional.

CUDA also had a bug where it got n_experts from the wrong place, leading to
GGML_ASSERT failures in some of the new tests.
This commit is contained in:
Jeff Bolz 2025-12-21 03:27:34 -06:00 committed by Georgi Gerganov
parent ad6ee3865d
commit f407c5e562
4 changed files with 55 additions and 6 deletions

View File

@ -3076,8 +3076,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
@ -3085,7 +3088,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
@ -3094,8 +3101,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}

View File

@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
}
}
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert) {
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
return false;
}
float scale = 1.0f;
float max_bias = 0.0f;
@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return false;
}
const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;

View File

@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

View File

@ -12941,24 +12941,43 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
const ggml_tensor * softmax;
const ggml_tensor * weights;
const ggml_tensor * get_rows;
const ggml_tensor * argsort;
switch (mode) {
case TOPK_MOE_EARLY_SOFTMAX_NORM:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 9];
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break;
case TOPK_MOE_EARLY_SOFTMAX:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 4];
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break;
case TOPK_MOE_LATE_SOFTMAX:
softmax = cgraph->nodes[node_idx + 4];
weights = cgraph->nodes[node_idx + 5];
get_rows = cgraph->nodes[node_idx + 2];
argsort = cgraph->nodes[node_idx + 0];
break;
default:
return false;
}
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
return false;
}
const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0];