ggml-cpu : use template for argsort (llama/17222)

This commit is contained in:
Diego Devesa 2025-11-13 00:59:05 -08:00 committed by Georgi Gerganov
parent 726912d1cb
commit 6a91780c3b
1 changed files with 22 additions and 8 deletions

View File

@ -7665,6 +7665,18 @@ void ggml_compute_forward_timestep_embedding(
// ggml_compute_forward_argsort
template<enum ggml_sort_order order>
struct argsort_cmp {
const float * data;
bool operator()(int32_t a, int32_t b) const {
if constexpr (order == GGML_SORT_ORDER_ASC) {
return data[a] < data[b];
} else {
return data[a] > data[b];
}
}
};
static void ggml_compute_forward_argsort_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
@ -7691,16 +7703,18 @@ static void ggml_compute_forward_argsort_f32(
dst_data[j] = j;
}
std::function<bool(int32_t, int32_t)> cmp;
// note: this might be causing memory allocations? ideally should be avoided if it's the case
switch (order) {
case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
default: GGML_ABORT("invalid sort order");
}
case GGML_SORT_ORDER_ASC:
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
break;
std::sort(dst_data, dst_data + ne0, cmp);
case GGML_SORT_ORDER_DESC:
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
break;
default:
GGML_ABORT("invalid sort order");
}
}
}