ggml : check cuda and metal argsort limits and add test (llama/16323)

* check cuda argsort limits and add test

* add metal check
This commit is contained in:
Sigbjørn Skjæret 2025-09-29 11:09:00 +02:00 committed by Georgi Gerganov
parent 7ce0a7bcd0
commit 112e10f2e4
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 6 additions and 2 deletions

View File

@ -3639,9 +3639,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width
return op->src[0]->ne[0] <= 1024;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM:

View File

@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width
return op->src[0]->ne[0] <= 1024;
case GGML_OP_ARANGE:
return true;
case GGML_OP_FLASH_ATTN_EXT: