ggml-zendnn : fixed naming of matmul function (llama/20964)
* ggml-zendnn: fixed naming of matmul function * ggml-zendnn: fixed naming of mul_mat_id function * ggml-zendnn: fixed print in mul_mat_id --------- Co-authored-by: plotnikov.v10 <plotnikov.v10@wb.ru>
This commit is contained in:
parent
a0efd13f0f
commit
6a249cd640
|
|
@ -88,7 +88,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k,
|
static bool ggml_zendnn_gemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k,
|
||||||
const void * A, int64_t lda, const void * B, int64_t ldb, void * C,
|
const void * A, int64_t lda, const void * B, int64_t ldb, void * C,
|
||||||
int64_t ldc, int Atype, int Btype, int Ctype) {
|
int64_t ldc, int Atype, int Btype, int Ctype) {
|
||||||
|
|
||||||
|
|
@ -200,7 +200,7 @@ static void ggml_zendnn_compute_forward_mul_mat(
|
||||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||||
const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data;
|
const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data;
|
||||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||||
if (!ggml_zendnn_sgemm(ctx,
|
if (!ggml_zendnn_gemm(ctx,
|
||||||
ne01, // m
|
ne01, // m
|
||||||
ne11, // n
|
ne11, // n
|
||||||
ne10, // k
|
ne10, // k
|
||||||
|
|
@ -213,7 +213,7 @@ static void ggml_zendnn_compute_forward_mul_mat(
|
||||||
src0->type,
|
src0->type,
|
||||||
src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
|
src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
|
||||||
dst->type))
|
dst->type))
|
||||||
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
|
GGML_ABORT("%s: ZenDNN gemm failed\n", __func__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -355,7 +355,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
|
||||||
}
|
}
|
||||||
|
|
||||||
// batched gemm for all tokens in this expert
|
// batched gemm for all tokens in this expert
|
||||||
if (!ggml_zendnn_sgemm(ctx,
|
if (!ggml_zendnn_gemm(ctx,
|
||||||
ne01, // m
|
ne01, // m
|
||||||
cne1, // n
|
cne1, // n
|
||||||
ne10, // k
|
ne10, // k
|
||||||
|
|
@ -368,7 +368,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
|
||||||
src0->type,
|
src0->type,
|
||||||
src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
|
src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
|
||||||
dst->type)) {
|
dst->type)) {
|
||||||
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
|
GGML_ABORT("%s: ZenDNN gemm failed\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
// scatter output rows to destination
|
// scatter output rows to destination
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue