diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index a526d8e58..1f5a91272 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -17,6 +17,7 @@ #include "common.hpp" #include "concat.hpp" #include "conv.hpp" +#include "conv3d.hpp" #include "convert.hpp" #include "count-equal.hpp" #include "cpy.hpp" diff --git a/ggml/src/ggml-sycl/conv3d.cpp b/ggml/src/ggml-sycl/conv3d.cpp new file mode 100644 index 000000000..2fa29f930 --- /dev/null +++ b/ggml/src/ggml-sycl/conv3d.cpp @@ -0,0 +1,218 @@ +#include "conv3d.hpp" + +static inline int64_t ggml_sycl_conv3d_calc_patch_total(const ggml_tensor * dst, int32_t n) { + return (int64_t) n * dst->ne[0] * dst->ne[1] * dst->ne[2]; +} + +static inline int64_t ggml_sycl_conv3d_calc_knl_n_total(const ggml_tensor * src0, int32_t c) { + return (int64_t) src0->ne[0] * src0->ne[1] * src0->ne[2] * c; +} + +static inline void ggml_sycl_conv3d_write_output( + const ggml_tensor * dst, + const float * src, float * dst_data, + int64_t patch_total, int64_t oc, + int64_t dst_w, int64_t dst_h, int64_t dst_d, + dpct::queue_ptr stream) { + const int64_t dst_nb0 = dst->nb[0]; + const int64_t dst_nb1 = dst->nb[1]; + const int64_t dst_nb2 = dst->nb[2]; + const int64_t dst_nb3 = dst->nb[3]; + const int64_t total = patch_total * oc; + const int64_t block_size = 256; + const int64_t num_work_items = ((total + block_size - 1) / block_size) * block_size; + + stream->parallel_for(sycl::range<1>(num_work_items), [=](sycl::id<1> id) { + const int64_t i = id[0]; + if (i >= total) { + return; + } + + const int64_t patch_idx = i / oc; + const int64_t out_ch = i % oc; + const int64_t p_in_batch = patch_idx % (dst_w * dst_h * dst_d); + const int64_t batch_idx = patch_idx / (dst_w * dst_h * dst_d); + const int64_t dst_z = p_in_batch / (dst_w * dst_h); + const int64_t dst_y = (p_in_batch % (dst_w * dst_h)) / dst_w; + const int64_t dst_x = p_in_batch % dst_w; + const int64_t ocn_idx = batch_idx * oc + out_ch; + + const int64_t dst_offset = dst_x * dst_nb0 + dst_y * dst_nb1 + dst_z * dst_nb2 + ocn_idx * dst_nb3; + // `src` is a column-major (m x n) GEMM output where m == patch_total, n == oc. + // GEMM stores element (row, col) at index `row + col*m`, so compute index accordingly. + const int64_t src_index = patch_idx + out_ch * patch_total; + const float value = src[src_index]; + *(float *)((char *)dst_data + dst_offset) = value; + }); +} + +void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const int32_t * opts = (const int32_t *) dst->op_params; + const int32_t s0 = opts[0]; + const int32_t s1 = opts[1]; + const int32_t s2 = opts[2]; + const int32_t p0 = opts[3]; + const int32_t p1 = opts[4]; + const int32_t p2 = opts[5]; + const int32_t d0 = opts[6]; + const int32_t d1 = opts[7]; + const int32_t d2 = opts[8]; + const int32_t c = opts[9]; + const int32_t n = opts[10]; + const int32_t oc = opts[11]; + + const int64_t knl_w = src0->ne[0]; + const int64_t knl_h = src0->ne[1]; + const int64_t knl_d = src0->ne[2]; + + const int64_t patch_total = ggml_sycl_conv3d_calc_patch_total(dst, n); + const int64_t knl_n_total = ggml_sycl_conv3d_calc_knl_n_total(src0, c); + + const size_t kernel_type_size = ggml_element_size(src0); + + ggml_sycl_pool_alloc gemm_output(ctx.pool()); + gemm_output.alloc((size_t) patch_total * oc); + + ggml_tensor dst_mat = {}; + dst_mat.type = GGML_TYPE_F32; + dst_mat.ne[0] = patch_total; + dst_mat.ne[1] = oc; + dst_mat.ne[2] = 1; + dst_mat.ne[3] = 1; + dst_mat.nb[0] = sizeof(float); + dst_mat.nb[1] = dst_mat.nb[0] * dst_mat.ne[0]; + dst_mat.nb[2] = dst_mat.nb[1]; + dst_mat.nb[3] = dst_mat.nb[2]; + dst_mat.data = gemm_output.get(); + dst_mat.buffer = dst->buffer; + dst_mat.extra = dst->extra; + + dpct::queue_ptr stream = ctx.stream(); + + // allocate packed arrays: A_packed (k x m), B_packed (k x n) + ggml_sycl_pool_alloc A_packed_alloc(ctx.pool()); + ggml_sycl_pool_alloc B_packed_alloc(ctx.pool()); + A_packed_alloc.alloc((size_t) knl_n_total * patch_total * sizeof(float)); + B_packed_alloc.alloc((size_t) knl_n_total * oc * sizeof(float)); + + float * A_packed = A_packed_alloc.get(); + float * B_packed = B_packed_alloc.get(); + + const int m = (int) patch_total; + const int n_gemm = (int) oc; + const int k = (int) knl_n_total; + + // Combined kernel: im2col -> pack A, and pack B simultaneously + const char * src1_base = (const char *) src1->data; + const int64_t src1_nb0 = src1->nb[0]; + const int64_t src1_nb1 = src1->nb[1]; + const int64_t src1_nb2 = src1->nb[2]; + const int64_t src1_nb3 = src1->nb[3]; + + // Compute correct strides for src0 as (knl_n_total, oc) matrix + const int64_t src0_packed_nb0 = kernel_type_size; + const int64_t src0_packed_nb1 = kernel_type_size * knl_n_total; + + const int64_t KW = knl_w; + const int64_t KH = knl_h; + const int64_t KD = knl_d; + const int64_t PW = dst->ne[0]; + const int64_t PH = dst->ne[1]; + const int64_t PD = dst->ne[2]; + + // Pack A (with inline im2col): for each (row, col) in k x m matrix + const int64_t A_total = (int64_t)k * m; + const int64_t A_block_size = 256; + const int64_t A_num_work = ((A_total + A_block_size - 1) / A_block_size) * A_block_size; + + stream->parallel_for(sycl::range<1>(A_num_work), [=](sycl::id<1> id) { + const int64_t t = id[0]; + if (t >= A_total) return; + + const int64_t row = t % k; + const int64_t col = t / k; + + // Inline im2col for this element + const int64_t k_index = row; + const int64_t patch_idx = col; + + const int64_t ic = k_index / (KD * KH * KW); + const int64_t rem = k_index - ic * (KD * KH * KW); + const int64_t kz = rem / (KH * KW); + const int64_t rem2 = rem - kz * (KH * KW); + const int64_t ky = rem2 / KW; + const int64_t kx = rem2 % KW; + + const int64_t p_in_batch = patch_idx % (PW * PH * PD); + const int64_t batch_idx = patch_idx / (PW * PH * PD); + const int64_t dst_z = p_in_batch / (PW * PH); + const int64_t dst_y = (p_in_batch % (PW * PH)) / PW; + const int64_t dst_x = p_in_batch % PW; + + const int64_t sx = dst_x * s0 + kx * d0 - p0; + const int64_t sy = dst_y * s1 + ky * d1 - p1; + const int64_t sz = dst_z * s2 + kz * d2 - p2; + + float val = 0.0f; + if (sx >= 0 && sx < src1->ne[0] && sy >= 0 && sy < src1->ne[1] && sz >= 0 && sz < src1->ne[2]) { + const int64_t channel_idx = batch_idx * c + ic; + const char * ptr = src1_base + sx * src1_nb0 + sy * src1_nb1 + sz * src1_nb2 + channel_idx * src1_nb3; + val = *(const float *) ptr; + } + A_packed[row + col * (int64_t)k] = val; + }); + + // Pack B: for each (row, col) in k x n_gemm matrix + const int64_t B_total = (int64_t)k * n_gemm; + const int64_t B_block_size = 256; + const int64_t B_num_work = ((B_total + B_block_size - 1) / B_block_size) * B_block_size; + + stream->parallel_for(sycl::range<1>(B_num_work), [=](sycl::id<1> id) { + const int64_t t = id[0]; + if (t >= B_total) return; + + const int64_t row = t % k; + const int64_t col = t / k; + const char * src_ptr = (const char *) src0->data + row * src0_packed_nb0 + col * src0_packed_nb1; + float v; + if (src0->type == GGML_TYPE_F32) { + v = *(const float *) src_ptr; + } else { + v = sycl::vec(*(const sycl::half *) src_ptr).convert()[0]; + } + B_packed[row + col * (int64_t)k] = v; + }); + + // GEMM: C = A^T * B where A is (k x m), B is (k x n), C is (m x n) + const float alpha = 1.0f; + const float beta = 0.0f; + const int lda = k; + const int ldb = k; + const int ldc = m; + + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, + m, n_gemm, k, + dpct::get_value(&alpha, *stream), + (const float *) A_packed, lda, + (const float *) B_packed, ldb, + dpct::get_value(&beta, *stream), + (float *) dst_mat.data, ldc))); + + const float * gemm_data = (const float *) dst_mat.data; + float * dst_data = (float *) dst->data; + + ggml_sycl_conv3d_write_output(dst, gemm_data, dst_data, patch_total, oc, + dst->ne[0], dst->ne[1], dst->ne[2], stream); +} diff --git a/ggml/src/ggml-sycl/conv3d.hpp b/ggml/src/ggml-sycl/conv3d.hpp new file mode 100644 index 000000000..5852f393f --- /dev/null +++ b/ggml/src/ggml-sycl/conv3d.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_CONV3D_HPP +#define GGML_SYCL_CONV3D_HPP + +#include "common.hpp" + +void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_CONV3D_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 057e0dccd..43c7e0a93 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4572,6 +4572,11 @@ static void ggml_sycl_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * d ggml_sycl_op_im2col_3d(ctx, dst); } +static void ggml_sycl_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_conv_3d(ctx, dst); +} + static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); @@ -4638,6 +4643,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_CONV_TRANSPOSE_1D: ggml_sycl_op_conv_transpose_1d(ctx, dst); break; + case GGML_OP_CONV_3D: + ggml_sycl_conv_3d(ctx, dst); + break; case GGML_OP_REPEAT: ggml_sycl_repeat(ctx, dst); break; @@ -5615,6 +5623,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_IM2COL_3D: case GGML_OP_UPSCALE: return true; + case GGML_OP_CONV_3D: + return op->type == GGML_TYPE_F32 && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->src[1]->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]); case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: