whisper : clean-up ggml_mul_mat_pad
This commit is contained in:
parent
2b4160af29
commit
0d5e4cdc36
|
|
@ -77,7 +77,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
|
||||||
|
|
||||||
// same as ggml_graph_compute but uses Metal
|
// same as ggml_graph_compute but uses Metal
|
||||||
// creates gf->n_threads command buffers in parallel
|
// creates gf->n_threads command buffers in parallel
|
||||||
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool concurrent);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -610,14 +610,15 @@ void ggml_metal_graph_find_concurrency(
|
||||||
|
|
||||||
void ggml_metal_graph_compute(
|
void ggml_metal_graph_compute(
|
||||||
struct ggml_metal_context * ctx,
|
struct ggml_metal_context * ctx,
|
||||||
struct ggml_cgraph * gf) {
|
struct ggml_cgraph * gf,
|
||||||
|
bool concurrent) {
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
|
|
||||||
// if there is ctx->concur_list, dispatch concurrently
|
// if there is ctx->concur_list, dispatch concurrently
|
||||||
// else fallback to serial dispatch
|
// else fallback to serial dispatch
|
||||||
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
||||||
|
|
||||||
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
|
const bool has_concur = concurrent && ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
|
||||||
|
|
||||||
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
||||||
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
||||||
|
|
@ -927,7 +928,7 @@ void ggml_metal_graph_compute(
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
||||||
//} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
//} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||||
} else if (false) {
|
} else if (false) {
|
||||||
// TODO: with the ggml_cont(ctx0, Q), this kernel is no longer useful
|
// TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
||||||
nrows = ne11;
|
nrows = ne11;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
69
whisper.cpp
69
whisper.cpp
|
|
@ -141,20 +141,20 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
|
||||||
//
|
//
|
||||||
// Z = X @ Y
|
// Z = X @ Y
|
||||||
//
|
//
|
||||||
// with two matrix multiplications:
|
// with the sum of two matrix multiplications:
|
||||||
//
|
//
|
||||||
// Z = [X_0; X_1] @ [Y_0; Y_1]
|
// Z = (X_0 @ Y_0) + (X_1 @ Y_1)
|
||||||
//
|
//
|
||||||
// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
|
// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
|
||||||
// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
|
// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
|
||||||
// general-purpose kernels
|
// general-purpose kernels
|
||||||
//
|
//
|
||||||
static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) {
|
static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) {
|
||||||
//#if !defined(GGML_USE_METAL)
|
// use padding only if dimension 0 is at least 8 times larger than the padding
|
||||||
// return ggml_mul_mat(ctx, x, y);
|
// else we won't get much benefit from the optimization
|
||||||
//#endif
|
const int n_pad_req = 8;
|
||||||
|
|
||||||
if (x->ne[0] % pad == 0 || x->ne[0] / pad < 2) {
|
if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
|
||||||
return ggml_mul_mat(ctx, x, y);
|
return ggml_mul_mat(ctx, x, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -169,6 +169,11 @@ static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct g
|
||||||
ggml_mul_mat(ctx, x_1, y_1));
|
ggml_mul_mat(ctx, x_1, y_1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: check if other platforms can benefit from this optimization
|
||||||
|
#if defined(GGML_USE_METAL)
|
||||||
|
#define ggml_mul_mat ggml_mul_mat_pad
|
||||||
|
#endif
|
||||||
|
|
||||||
// available whisper models
|
// available whisper models
|
||||||
enum e_model {
|
enum e_model {
|
||||||
MODEL_UNKNOWN,
|
MODEL_UNKNOWN,
|
||||||
|
|
@ -1659,7 +1664,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||||
|
|
||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0,
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_q_w,
|
layer.attn_q_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -1668,13 +1673,13 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||||
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||||
|
|
||||||
// note: no bias for Key
|
// note: no bias for Key
|
||||||
struct ggml_tensor * Kcur = ggml_mul_mat_pad(ctx0,
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_k_w,
|
layer.attn_k_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
//Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
//Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||||
|
|
||||||
struct ggml_tensor * Vcur = ggml_mul_mat_pad(ctx0,
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_v_w,
|
layer.attn_v_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -1723,7 +1728,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, K, Q);
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||||
|
|
||||||
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
|
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
|
||||||
|
|
||||||
|
|
@ -1739,7 +1744,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
||||||
);
|
);
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||||
#endif
|
#endif
|
||||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
|
|
@ -1750,7 +1755,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||||
|
|
||||||
// projection
|
// projection
|
||||||
{
|
{
|
||||||
cur = ggml_mul_mat_pad(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_ln_1_w,
|
layer.attn_ln_1_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -1780,7 +1785,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||||
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
||||||
#else
|
#else
|
||||||
// fully connected
|
// fully connected
|
||||||
cur = ggml_mul_mat_pad(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.mlp_0_w,
|
layer.mlp_0_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -1790,7 +1795,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
|
|
||||||
// projection
|
// projection
|
||||||
cur = ggml_mul_mat_pad(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.mlp_1_w,
|
layer.mlp_1_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -1868,13 +1873,13 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
||||||
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
||||||
auto & layer = model.layers_decoder[il];
|
auto & layer = model.layers_decoder[il];
|
||||||
|
|
||||||
struct ggml_tensor* Kcross = ggml_mul_mat_pad(ctx0,
|
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
||||||
layer.cross_attn_k_w,
|
layer.cross_attn_k_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
Kcross = ggml_scale(ctx0, Kcross, Kscale);
|
Kcross = ggml_scale(ctx0, Kcross, Kscale);
|
||||||
|
|
||||||
struct ggml_tensor* Vcross = ggml_mul_mat_pad(ctx0,
|
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
||||||
layer.cross_attn_v_w,
|
layer.cross_attn_v_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -1948,7 +1953,7 @@ static bool whisper_encode_internal(
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (wstate.ctx_metal) {
|
if (wstate.ctx_metal) {
|
||||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
ggml_metal_graph_compute(wstate.ctx_metal, gf, false);
|
||||||
} else {
|
} else {
|
||||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||||
}
|
}
|
||||||
|
|
@ -1970,7 +1975,7 @@ static bool whisper_encode_internal(
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (wstate.ctx_metal) {
|
if (wstate.ctx_metal) {
|
||||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
ggml_metal_graph_compute(wstate.ctx_metal, gf, false);
|
||||||
} else {
|
} else {
|
||||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||||
}
|
}
|
||||||
|
|
@ -2071,7 +2076,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
|
|
||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0,
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_q_w,
|
layer.attn_q_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -2082,7 +2087,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
||||||
|
|
||||||
// note: no bias for Key
|
// note: no bias for Key
|
||||||
struct ggml_tensor * Kcur = ggml_mul_mat_pad(ctx0,
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_k_w,
|
layer.attn_k_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -2090,7 +2095,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
|
|
||||||
// store key and value to memory
|
// store key and value to memory
|
||||||
{
|
{
|
||||||
struct ggml_tensor * Vcur = ggml_mul_mat_pad(ctx0,
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_v_w,
|
layer.attn_v_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -2124,7 +2129,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, K, Q);
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||||
|
|
||||||
//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
|
//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
|
||||||
|
|
||||||
|
|
@ -2139,7 +2144,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
|
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
|
||||||
il*n_ctx*ggml_element_size(kv_self.v)*n_state);
|
il*n_ctx*ggml_element_size(kv_self.v)*n_state);
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||||
|
|
||||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
|
|
@ -2150,7 +2155,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
|
|
||||||
// projection
|
// projection
|
||||||
{
|
{
|
||||||
cur = ggml_mul_mat_pad(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_ln_1_w,
|
layer.attn_ln_1_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -2176,7 +2181,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
|
|
||||||
// cross-attention
|
// cross-attention
|
||||||
{
|
{
|
||||||
struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0,
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
||||||
layer.cross_attn_q_w,
|
layer.cross_attn_q_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -2219,7 +2224,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, Kcross, Q);
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
|
||||||
|
|
||||||
//struct ggml_tensor * KQ_scaled =
|
//struct ggml_tensor * KQ_scaled =
|
||||||
// ggml_scale(ctx0,
|
// ggml_scale(ctx0,
|
||||||
|
|
@ -2232,7 +2237,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
|
|
||||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||||
|
|
||||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
|
|
@ -2244,7 +2249,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
|
|
||||||
// projection
|
// projection
|
||||||
{
|
{
|
||||||
cur = ggml_mul_mat_pad(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.cross_attn_ln_1_w,
|
layer.cross_attn_ln_1_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -2273,7 +2278,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
}
|
}
|
||||||
|
|
||||||
// fully connected
|
// fully connected
|
||||||
cur = ggml_mul_mat_pad(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.mlp_0_w,
|
layer.mlp_0_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -2285,7 +2290,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
|
|
||||||
// projection
|
// projection
|
||||||
cur = ggml_mul_mat_pad(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.mlp_1_w,
|
layer.mlp_1_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
|
|
@ -2315,7 +2320,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
// might be useful in the future
|
// might be useful in the future
|
||||||
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
||||||
|
|
||||||
struct ggml_tensor * logits = ggml_mul_mat_pad(ctx0, model.d_te, cur);
|
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, logits);
|
ggml_build_forward_expand(gf, logits);
|
||||||
|
|
||||||
|
|
@ -2368,7 +2373,7 @@ static bool whisper_decode_internal(
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (wstate.ctx_metal) {
|
if (wstate.ctx_metal) {
|
||||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
ggml_metal_graph_compute(wstate.ctx_metal, gf, false);
|
||||||
} else {
|
} else {
|
||||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue