From 79f88a1104ef95aad161af3d6b912ccb794bccc4 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Tue, 16 Jun 2026 13:34:29 +0800 Subject: [PATCH] Support OP EXPM1, support all UT cases of FLOOR, TRUNC, ROUND (llama/24363) * support OP EXPM1, support all UT cases of FLOOR, TRUNC, ROUND * fix conflict * rebase, support new UT case of repeat, concat --- ggml/src/ggml-sycl/binbcast.cpp | 7 +++ ggml/src/ggml-sycl/concat.cpp | 22 ++++++++- ggml/src/ggml-sycl/element_wise.cpp | 76 ++++++++++------------------- ggml/src/ggml-sycl/element_wise.hpp | 2 + ggml/src/ggml-sycl/ggml-sycl.cpp | 10 ++-- 5 files changed, 60 insertions(+), 57 deletions(-) diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index 92dd18889..ad2e6ca35 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -287,6 +287,13 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); +#ifdef GGML_SYCL_HAS_BF16 + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { + op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const sycl::ext::oneapi::bfloat16 *) src1->data, + (sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, + ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), + ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); +#endif } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp index d16215bc9..93e00d65f 100644 --- a/ggml/src/ggml-sycl/concat.cpp +++ b/ggml/src/ggml-sycl/concat.cpp @@ -10,6 +10,8 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // +#include "ggml.h" + #include "concat.hpp" static inline size_t elem_size(ggml_type t) { @@ -192,11 +194,29 @@ void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { case GGML_TYPE_F32: concat_impl_sycl(ctx, dst); break; + case GGML_TYPE_F16: + concat_impl_sycl(ctx, dst); + break; +#ifdef GGML_SYCL_HAS_BF16 + case GGML_TYPE_BF16: + concat_impl_sycl(ctx, dst); + break; +#endif case GGML_TYPE_I32: concat_impl_sycl(ctx, dst); break; + case GGML_TYPE_I16: + concat_impl_sycl(ctx, dst); + break; + case GGML_TYPE_I64: + concat_impl_sycl(ctx, dst); + break; + case GGML_TYPE_I8: + concat_impl_sycl(ctx, dst); + break; default: - GGML_ASSERT(false && "ggml_sycl_op_concat: unsupported type"); + fprintf(stderr, "%s: unsupported types: dst: %s\n", __func__, ggml_type_name(dst->type)); + GGML_ASSERT(false); break; } } diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 249e80c82..aca68e58e 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -124,6 +124,11 @@ static __dpct_inline__ T op_exp(T x) { return sycl::exp(x); } +template +static __dpct_inline__ T op_expm1(T x) { + return sycl::expm1(x); +} + template static __dpct_inline__ T op_log(T x) { if (x <= static_cast(0)) { @@ -266,13 +271,6 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl: } } -template -static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { - SYCL_GLOBAL_ID_LOOP(k, item_ct1) { - dst[i] = op_floor(x[i]); - } -} - template static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { SYCL_GLOBAL_ID_LOOP(k, item_ct1) { @@ -280,20 +278,6 @@ static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl:: } } -template -static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { - SYCL_GLOBAL_ID_LOOP(k, item_ct1) { - dst[i] = op_round(x[i]); - } -} - -template -static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { - SYCL_GLOBAL_ID_LOOP(k, item_ct1) { - dst[i] = op_trunc(x[i]); - } -} - template static void clamp(const T * x, T * dst, const float min, const float max, const int k, const sycl::nd_item<1> &item_ct1) { @@ -605,6 +589,12 @@ static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor }); } +static inline void ggml_sycl_op_expm1(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { + return op_expm1(x); + }); +} + static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { @@ -728,16 +718,9 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens } static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, - [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { - const int num_blocks = ceil_div(k_elements, 256); - stream->parallel_for( - sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), - sycl::range<1>(256)), - [=](sycl::nd_item<1> item_ct1) { - unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1); - }); - }); + ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { + return op_floor(x); + }); } static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { @@ -747,29 +730,15 @@ static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tenso } static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, - [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { - const int num_blocks = ceil_div(k_elements, 256); - stream->parallel_for( - sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), - sycl::range<1>(256)), - [=](sycl::nd_item<1> item_ct1) { - unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1); - }); - }); + ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { + return op_round(x); + }); } static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, - [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { - const int num_blocks = ceil_div(k_elements, 256); - stream->parallel_for( - sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), - sycl::range<1>(256)), - [=](sycl::nd_item<1> item_ct1) { - unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1); - }); - }); + ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { + return op_trunc(x); + }); } static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { @@ -1018,6 +987,11 @@ void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_exp(ctx, dst); } +void ggml_sycl_expm1(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_expm1(ctx, dst); +} + void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_log(ctx, dst); diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 997132166..3bdc38596 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -59,6 +59,8 @@ void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_expm1(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index fb8665a02..15ee53f7f 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4489,6 +4489,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_sycl_exp(ctx, dst); break; + case GGML_UNARY_OP_EXPM1: + ggml_sycl_expm1(ctx, dst); + break; case GGML_UNARY_OP_SOFTPLUS: ggml_sycl_softplus(ctx, dst); break; @@ -5138,6 +5141,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_EXPM1: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_CEIL: @@ -5145,11 +5149,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: -#if defined (GGML_SYCL_F16) - return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type); -#else - return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); -#endif + return true; default: return false; }