Add conv_3d (llama/24691)
* add conv_3d * optimize * update ops.md * restore test script * rm unused code * rm copyright notes
This commit is contained in:
parent
ae09736e83
commit
01b1c3ada7
|
|
@ -17,6 +17,7 @@
|
|||
#include "common.hpp"
|
||||
#include "concat.hpp"
|
||||
#include "conv.hpp"
|
||||
#include "conv3d.hpp"
|
||||
#include "convert.hpp"
|
||||
#include "count-equal.hpp"
|
||||
#include "cpy.hpp"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,218 @@
|
|||
#include "conv3d.hpp"
|
||||
|
||||
static inline int64_t ggml_sycl_conv3d_calc_patch_total(const ggml_tensor * dst, int32_t n) {
|
||||
return (int64_t) n * dst->ne[0] * dst->ne[1] * dst->ne[2];
|
||||
}
|
||||
|
||||
static inline int64_t ggml_sycl_conv3d_calc_knl_n_total(const ggml_tensor * src0, int32_t c) {
|
||||
return (int64_t) src0->ne[0] * src0->ne[1] * src0->ne[2] * c;
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_conv3d_write_output(
|
||||
const ggml_tensor * dst,
|
||||
const float * src, float * dst_data,
|
||||
int64_t patch_total, int64_t oc,
|
||||
int64_t dst_w, int64_t dst_h, int64_t dst_d,
|
||||
dpct::queue_ptr stream) {
|
||||
const int64_t dst_nb0 = dst->nb[0];
|
||||
const int64_t dst_nb1 = dst->nb[1];
|
||||
const int64_t dst_nb2 = dst->nb[2];
|
||||
const int64_t dst_nb3 = dst->nb[3];
|
||||
const int64_t total = patch_total * oc;
|
||||
const int64_t block_size = 256;
|
||||
const int64_t num_work_items = ((total + block_size - 1) / block_size) * block_size;
|
||||
|
||||
stream->parallel_for(sycl::range<1>(num_work_items), [=](sycl::id<1> id) {
|
||||
const int64_t i = id[0];
|
||||
if (i >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t patch_idx = i / oc;
|
||||
const int64_t out_ch = i % oc;
|
||||
const int64_t p_in_batch = patch_idx % (dst_w * dst_h * dst_d);
|
||||
const int64_t batch_idx = patch_idx / (dst_w * dst_h * dst_d);
|
||||
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
||||
const int64_t dst_y = (p_in_batch % (dst_w * dst_h)) / dst_w;
|
||||
const int64_t dst_x = p_in_batch % dst_w;
|
||||
const int64_t ocn_idx = batch_idx * oc + out_ch;
|
||||
|
||||
const int64_t dst_offset = dst_x * dst_nb0 + dst_y * dst_nb1 + dst_z * dst_nb2 + ocn_idx * dst_nb3;
|
||||
// `src` is a column-major (m x n) GEMM output where m == patch_total, n == oc.
|
||||
// GEMM stores element (row, col) at index `row + col*m`, so compute index accordingly.
|
||||
const int64_t src_index = patch_idx + out_ch * patch_total;
|
||||
const float value = src[src_index];
|
||||
*(float *)((char *)dst_data + dst_offset) = value;
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
const int32_t * opts = (const int32_t *) dst->op_params;
|
||||
const int32_t s0 = opts[0];
|
||||
const int32_t s1 = opts[1];
|
||||
const int32_t s2 = opts[2];
|
||||
const int32_t p0 = opts[3];
|
||||
const int32_t p1 = opts[4];
|
||||
const int32_t p2 = opts[5];
|
||||
const int32_t d0 = opts[6];
|
||||
const int32_t d1 = opts[7];
|
||||
const int32_t d2 = opts[8];
|
||||
const int32_t c = opts[9];
|
||||
const int32_t n = opts[10];
|
||||
const int32_t oc = opts[11];
|
||||
|
||||
const int64_t knl_w = src0->ne[0];
|
||||
const int64_t knl_h = src0->ne[1];
|
||||
const int64_t knl_d = src0->ne[2];
|
||||
|
||||
const int64_t patch_total = ggml_sycl_conv3d_calc_patch_total(dst, n);
|
||||
const int64_t knl_n_total = ggml_sycl_conv3d_calc_knl_n_total(src0, c);
|
||||
|
||||
const size_t kernel_type_size = ggml_element_size(src0);
|
||||
|
||||
ggml_sycl_pool_alloc<float> gemm_output(ctx.pool());
|
||||
gemm_output.alloc((size_t) patch_total * oc);
|
||||
|
||||
ggml_tensor dst_mat = {};
|
||||
dst_mat.type = GGML_TYPE_F32;
|
||||
dst_mat.ne[0] = patch_total;
|
||||
dst_mat.ne[1] = oc;
|
||||
dst_mat.ne[2] = 1;
|
||||
dst_mat.ne[3] = 1;
|
||||
dst_mat.nb[0] = sizeof(float);
|
||||
dst_mat.nb[1] = dst_mat.nb[0] * dst_mat.ne[0];
|
||||
dst_mat.nb[2] = dst_mat.nb[1];
|
||||
dst_mat.nb[3] = dst_mat.nb[2];
|
||||
dst_mat.data = gemm_output.get();
|
||||
dst_mat.buffer = dst->buffer;
|
||||
dst_mat.extra = dst->extra;
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// allocate packed arrays: A_packed (k x m), B_packed (k x n)
|
||||
ggml_sycl_pool_alloc<float> A_packed_alloc(ctx.pool());
|
||||
ggml_sycl_pool_alloc<float> B_packed_alloc(ctx.pool());
|
||||
A_packed_alloc.alloc((size_t) knl_n_total * patch_total * sizeof(float));
|
||||
B_packed_alloc.alloc((size_t) knl_n_total * oc * sizeof(float));
|
||||
|
||||
float * A_packed = A_packed_alloc.get();
|
||||
float * B_packed = B_packed_alloc.get();
|
||||
|
||||
const int m = (int) patch_total;
|
||||
const int n_gemm = (int) oc;
|
||||
const int k = (int) knl_n_total;
|
||||
|
||||
// Combined kernel: im2col -> pack A, and pack B simultaneously
|
||||
const char * src1_base = (const char *) src1->data;
|
||||
const int64_t src1_nb0 = src1->nb[0];
|
||||
const int64_t src1_nb1 = src1->nb[1];
|
||||
const int64_t src1_nb2 = src1->nb[2];
|
||||
const int64_t src1_nb3 = src1->nb[3];
|
||||
|
||||
// Compute correct strides for src0 as (knl_n_total, oc) matrix
|
||||
const int64_t src0_packed_nb0 = kernel_type_size;
|
||||
const int64_t src0_packed_nb1 = kernel_type_size * knl_n_total;
|
||||
|
||||
const int64_t KW = knl_w;
|
||||
const int64_t KH = knl_h;
|
||||
const int64_t KD = knl_d;
|
||||
const int64_t PW = dst->ne[0];
|
||||
const int64_t PH = dst->ne[1];
|
||||
const int64_t PD = dst->ne[2];
|
||||
|
||||
// Pack A (with inline im2col): for each (row, col) in k x m matrix
|
||||
const int64_t A_total = (int64_t)k * m;
|
||||
const int64_t A_block_size = 256;
|
||||
const int64_t A_num_work = ((A_total + A_block_size - 1) / A_block_size) * A_block_size;
|
||||
|
||||
stream->parallel_for(sycl::range<1>(A_num_work), [=](sycl::id<1> id) {
|
||||
const int64_t t = id[0];
|
||||
if (t >= A_total) return;
|
||||
|
||||
const int64_t row = t % k;
|
||||
const int64_t col = t / k;
|
||||
|
||||
// Inline im2col for this element
|
||||
const int64_t k_index = row;
|
||||
const int64_t patch_idx = col;
|
||||
|
||||
const int64_t ic = k_index / (KD * KH * KW);
|
||||
const int64_t rem = k_index - ic * (KD * KH * KW);
|
||||
const int64_t kz = rem / (KH * KW);
|
||||
const int64_t rem2 = rem - kz * (KH * KW);
|
||||
const int64_t ky = rem2 / KW;
|
||||
const int64_t kx = rem2 % KW;
|
||||
|
||||
const int64_t p_in_batch = patch_idx % (PW * PH * PD);
|
||||
const int64_t batch_idx = patch_idx / (PW * PH * PD);
|
||||
const int64_t dst_z = p_in_batch / (PW * PH);
|
||||
const int64_t dst_y = (p_in_batch % (PW * PH)) / PW;
|
||||
const int64_t dst_x = p_in_batch % PW;
|
||||
|
||||
const int64_t sx = dst_x * s0 + kx * d0 - p0;
|
||||
const int64_t sy = dst_y * s1 + ky * d1 - p1;
|
||||
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
||||
|
||||
float val = 0.0f;
|
||||
if (sx >= 0 && sx < src1->ne[0] && sy >= 0 && sy < src1->ne[1] && sz >= 0 && sz < src1->ne[2]) {
|
||||
const int64_t channel_idx = batch_idx * c + ic;
|
||||
const char * ptr = src1_base + sx * src1_nb0 + sy * src1_nb1 + sz * src1_nb2 + channel_idx * src1_nb3;
|
||||
val = *(const float *) ptr;
|
||||
}
|
||||
A_packed[row + col * (int64_t)k] = val;
|
||||
});
|
||||
|
||||
// Pack B: for each (row, col) in k x n_gemm matrix
|
||||
const int64_t B_total = (int64_t)k * n_gemm;
|
||||
const int64_t B_block_size = 256;
|
||||
const int64_t B_num_work = ((B_total + B_block_size - 1) / B_block_size) * B_block_size;
|
||||
|
||||
stream->parallel_for(sycl::range<1>(B_num_work), [=](sycl::id<1> id) {
|
||||
const int64_t t = id[0];
|
||||
if (t >= B_total) return;
|
||||
|
||||
const int64_t row = t % k;
|
||||
const int64_t col = t / k;
|
||||
const char * src_ptr = (const char *) src0->data + row * src0_packed_nb0 + col * src0_packed_nb1;
|
||||
float v;
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
v = *(const float *) src_ptr;
|
||||
} else {
|
||||
v = sycl::vec<sycl::half, 1>(*(const sycl::half *) src_ptr).convert<float, sycl::rounding_mode::automatic>()[0];
|
||||
}
|
||||
B_packed[row + col * (int64_t)k] = v;
|
||||
});
|
||||
|
||||
// GEMM: C = A^T * B where A is (k x m), B is (k x n), C is (m x n)
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
const int lda = k;
|
||||
const int ldb = k;
|
||||
const int ldc = m;
|
||||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
||||
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans,
|
||||
m, n_gemm, k,
|
||||
dpct::get_value(&alpha, *stream),
|
||||
(const float *) A_packed, lda,
|
||||
(const float *) B_packed, ldb,
|
||||
dpct::get_value(&beta, *stream),
|
||||
(float *) dst_mat.data, ldc)));
|
||||
|
||||
const float * gemm_data = (const float *) dst_mat.data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
ggml_sycl_conv3d_write_output(dst, gemm_data, dst_data, patch_total, oc,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], stream);
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#ifndef GGML_SYCL_CONV3D_HPP
|
||||
#define GGML_SYCL_CONV3D_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_CONV3D_HPP
|
||||
|
|
@ -4572,6 +4572,11 @@ static void ggml_sycl_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * d
|
|||
ggml_sycl_op_im2col_3d(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_conv_3d(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
|
|
@ -4638,6 +4643,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
ggml_sycl_op_conv_transpose_1d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_3D:
|
||||
ggml_sycl_conv_3d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_REPEAT:
|
||||
ggml_sycl_repeat(ctx, dst);
|
||||
break;
|
||||
|
|
@ -5615,6 +5623,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_IM2COL_3D:
|
||||
case GGML_OP_UPSCALE:
|
||||
return true;
|
||||
case GGML_OP_CONV_3D:
|
||||
return op->type == GGML_TYPE_F32 &&
|
||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op->src[1]);
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
|
|
|
|||
Loading…
Reference in New Issue