add to support pool_1d, move pool_1d/2d code to pool.cpp/hpp (llama/24584)

* add to support pool_1d, move pool_1d/2d code to pool.cpp/hpp

* update ops.md
This commit is contained in:
Neo Zhang 2026-06-15 15:01:07 +08:00 committed by Georgi Gerganov
parent 5832e734d4
commit 3cb087c42a
4 changed files with 218 additions and 102 deletions

View File

@ -70,6 +70,7 @@
#include "ggml-sycl/diag.hpp"
#include "ggml-sycl/solve_tri.hpp"
#include "ggml-sycl/gated_delta_net.hpp"
#include "ggml-sycl/pool.hpp"
static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0;
@ -1940,69 +1941,6 @@ static void scale_f32(const float * x, float * dst, const float scale, const flo
}
template <typename Ti, typename To>
static void pool2d_nchw_kernel(
const int ih, const int iw, const int oh, const int ow,
const int kh, const int kw, const int sh, const int sw,
const int ph, const int pw, const int parallel_elements,
const Ti* src, To* dst, const enum ggml_op_pool op,
const sycl::nd_item<3> &item_ct1) {
int idx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (idx >= parallel_elements) {
return;
}
const int I_HW = ih * iw;
const int O_HW = oh * ow;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / ow;
const int cur_ow = idx % O_HW % ow;
const Ti* i_ptr = src + nc * I_HW;
To* o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * sh - ph;
const int bh = sycl::max(0, start_h);
const int eh = sycl::min(ih, start_h + kh);
const int start_w = cur_ow * sw - pw;
const int bw = sycl::max(0, start_w);
const int ew = sycl::min(iw, start_w + kw);
To res = 0;
switch (op) {
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
default:
res = (To) sycl::nan(uint32_t(0));
break;
}
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
#if DPCT_COMPATIBILITY_TEMP >= 350
/*
DPCT1098:106: The '*' expression is used instead of the __ldg
call. These two expressions do not provide the exact same
functionality. Check the generated code for potential precision
and/or performance issues.
*/
Ti cur = *(i_ptr + i * iw + j);
#else
Ti cur = i_ptr[i * iw + j];
#endif
switch (op) {
case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
default:
res = (To) sycl::nan(uint32_t(0));
break;
}
}
}
o_ptr[cur_oh * ow + cur_ow] = res;
}
static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
float *dst, const int ncols_x,
const int nrows_x,
@ -2551,45 +2489,6 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
const int k0 = opts[1];
const int k1 = opts[2];
const int s0 = opts[3];
const int s1 = opts[4];
const int p0 = opts[5];
const int p1 = opts[6];
const int64_t IH = dst->src[0]->ne[1];
const int64_t IW = dst->src[0]->ne[0];
const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];
const int parallel_elements = N * OC * OH * OW;
const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
sycl::range<3> block_nums(1, 1, num_blocks);
main_stream->parallel_for(
sycl::nd_range<3>(block_nums *
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,
parallel_elements, src0_dd, dst_dd, op,
item_ct1);
});
}
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -4428,6 +4327,11 @@ static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
ggml_sycl_op_pool2d(ctx, dst);
}
static void ggml_sycl_pool1d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_pool1d(ctx, dst);
}
static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_im2col(ctx, dst);
@ -4741,6 +4645,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_POOL_2D:
ggml_sycl_pool2d(ctx, dst);
break;
case GGML_OP_POOL_1D:
ggml_sycl_pool1d(ctx, dst);
break;
case GGML_OP_SUM:
ggml_sycl_sum(ctx, dst);
break;
@ -5495,6 +5402,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
k > 0 && k <= 32;
}
case GGML_OP_POOL_2D:
case GGML_OP_POOL_1D:
return true;
case GGML_OP_ACC:
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);

185
ggml/src/ggml-sycl/pool.cpp Normal file
View File

@ -0,0 +1,185 @@
//
// MIT license
// Copyright (C) 2026 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#include "pool.hpp"
#include <float.h>
template <typename Ti, typename To>
static void pool2d_nchw_kernel(
const int ih, const int iw, const int oh, const int ow,
const int kh, const int kw, const int sh, const int sw,
const int ph, const int pw, const int parallel_elements,
const Ti* src, To* dst, const enum ggml_op_pool op,
const sycl::nd_item<3> &item_ct1) {
int idx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (idx >= parallel_elements) {
return;
}
const int I_HW = ih * iw;
const int O_HW = oh * ow;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / ow;
const int cur_ow = idx % O_HW % ow;
const Ti* i_ptr = src + nc * I_HW;
To* o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * sh - ph;
const int bh = sycl::max(0, start_h);
const int eh = sycl::min(ih, start_h + kh);
const int start_w = cur_ow * sw - pw;
const int bw = sycl::max(0, start_w);
const int ew = sycl::min(iw, start_w + kw);
To res = 0;
switch (op) {
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
default:
res = (To) sycl::nan(uint32_t(0));
break;
}
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
Ti cur = i_ptr[i * iw + j];
switch (op) {
case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
default:
res = (To) sycl::nan(uint32_t(0));
break;
}
}
}
o_ptr[cur_oh * ow + cur_ow] = res;
}
template <typename Ti, typename To>
static void pool1d_ncw_kernel(
const int iw, const int ow,
const int k, const int s,
const int p, const int parallel_elements,
const Ti * src, To * dst, const enum ggml_op_pool op,
const sycl::nd_item<3> & item_ct1) {
int idx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (idx >= parallel_elements) {
return;
}
const int nc = idx / ow;
const int cur_ow = idx % ow;
const Ti * i_ptr = src + nc * iw;
To * o_ptr = dst + nc * ow;
const int start = cur_ow * s - p;
const int b = sycl::max(0, start);
const int e = sycl::min(iw, start + k);
To res = 0;
switch (op) {
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
default:
res = (To) sycl::nan(uint32_t(0));
break;
}
for (int j = b; j < e; j += 1) {
Ti cur = i_ptr[j];
switch (op) {
case GGML_OP_POOL_AVG: res += cur; break;
case GGML_OP_POOL_MAX: res = sycl::max(res, (To) cur); break;
default:
res = (To) sycl::nan(uint32_t(0));
break;
}
}
const int count = e - b;
if (op == GGML_OP_POOL_AVG) {
res = (count > 0) ? (res / count) : (To) 0;
}
o_ptr[cur_ow] = res;
}
void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
const int k0 = opts[1];
const int k1 = opts[2];
const int s0 = opts[3];
const int s1 = opts[4];
const int p0 = opts[5];
const int p1 = opts[6];
const int64_t IH = dst->src[0]->ne[1];
const int64_t IW = dst->src[0]->ne[0];
const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];
const int parallel_elements = N * OC * OH * OW;
const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
sycl::range<3> block_nums(1, 1, num_blocks);
main_stream->parallel_for(
sycl::nd_range<3>(block_nums *
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,
parallel_elements, src0_dd, dst_dd, op,
item_ct1);
});
}
void ggml_sycl_op_pool1d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
const int k0 = opts[1];
const int s0 = opts[2];
const int p0 = opts[3];
const int64_t IW = dst->src[0]->ne[0];
const int64_t OW = dst->ne[0];
const int64_t NC = dst->ne[3] * dst->ne[2] * dst->ne[1];
const int parallel_elements = NC * OW;
const int num_blocks = (parallel_elements + SYCL_POOL1D_BLOCK_SIZE - 1) / SYCL_POOL1D_BLOCK_SIZE;
sycl::range<3> block_nums(1, 1, num_blocks);
main_stream->parallel_for(
sycl::nd_range<3>(block_nums *
sycl::range<3>(1, 1, SYCL_POOL1D_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_POOL1D_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
pool1d_ncw_kernel(IW, OW, k0, s0, p0,
parallel_elements, src0_dd, dst_dd, op,
item_ct1);
});
}

View File

@ -0,0 +1,22 @@
//
// MIT license
// Copyright (C) 2026 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_POOL_HPP
#define GGML_SYCL_POOL_HPP
#include "common.hpp"
#include "presets.hpp"
void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_op_pool1d(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_POOL_HPP

View File

@ -46,6 +46,7 @@
#define SYCL_PAD_BLOCK_SIZE 256
#define SYCL_ACC_BLOCK_SIZE 256
#define SYCL_IM2COL_BLOCK_SIZE 256
#define SYCL_POOL1D_BLOCK_SIZE 256
#define SYCL_POOL2D_BLOCK_SIZE 256
#define SYCL_ARGMAX_BLOCK_SIZE 256
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256