ggml : extend bin bcast for permuted src1 (llama/19484)
* tests : extend bin bcast for permuted src1 * cont : extend bin support * cont : s0 is always 1 * tests : simplify
This commit is contained in:
parent
de949fb1db
commit
3504358056
|
|
@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
|
|||
GGML_ASSERT(nb00 == sizeof(src0_t));
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
|
||||
|
||||
if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
}
|
||||
const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);
|
||||
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
vDSP_fn_t vDSP_op = nullptr;
|
||||
|
|
@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
|
|||
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
|
||||
if (is_src1_contiguous) {
|
||||
if (is_src1_contiguous_rows) {
|
||||
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
||||
const int64_t nr0 = ne00 / ne10;
|
||||
|
||||
|
|
|
|||
|
|
@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
|||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*int s0, */ const int s1,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
/*int s00,*/ const int s01,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
/*int s10,*/ const int s11,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
|
|
@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
|||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
|
||||
const uint32_t i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
|
|
@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
|||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*int s0, */ const int s1,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
/*int s00,*/ const int s01,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
/*int s10,*/ const int s11,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
|
|
@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
|||
|
||||
const int i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
|
|
@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
|||
cnb[3] *= cne[3];
|
||||
};
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (nr[i] != 1) {
|
||||
break;
|
||||
|
|
@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
|||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
//size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
|
@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
|||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s00 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
|
||||
|
|
@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
|||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
|
||||
ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
|
||||
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13);
|
||||
}
|
||||
} else {
|
||||
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
|
||||
if constexpr (sizeof...(I) > 0) {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00 ,s01, s02, s03,
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue