From 0d5e4cdc36d8e80fcb9858300eac034fdddc7b60 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 Sep 2023 17:28:13 +0300 Subject: [PATCH] whisper : clean-up ggml_mul_mat_pad --- ggml-metal.h | 2 +- ggml-metal.m | 7 +++--- whisper.cpp | 69 ++++++++++++++++++++++++++++------------------------ 3 files changed, 42 insertions(+), 36 deletions(-) diff --git a/ggml-metal.h b/ggml-metal.h index fca28d37..4e36cc12 100644 --- a/ggml-metal.h +++ b/ggml-metal.h @@ -77,7 +77,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx); // same as ggml_graph_compute but uses Metal // creates gf->n_threads command buffers in parallel -void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); +void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool concurrent); #ifdef __cplusplus } diff --git a/ggml-metal.m b/ggml-metal.m index 059da6ee..7ec31c21 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -610,14 +610,15 @@ void ggml_metal_graph_find_concurrency( void ggml_metal_graph_compute( struct ggml_metal_context * ctx, - struct ggml_cgraph * gf) { + struct ggml_cgraph * gf, + bool concurrent) { @autoreleasepool { // if there is ctx->concur_list, dispatch concurrently // else fallback to serial dispatch MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; + const bool has_concur = concurrent && ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; @@ -927,7 +928,7 @@ void ggml_metal_graph_compute( [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; //} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { } else if (false) { - // TODO: with the ggml_cont(ctx0, Q), this kernel is no longer useful + // TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4]; nrows = ne11; } else { diff --git a/whisper.cpp b/whisper.cpp index a8b64c1e..a79689d0 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -141,20 +141,20 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * // // Z = X @ Y // -// with two matrix multiplications: +// with the sum of two matrix multiplications: // -// Z = [X_0; X_1] @ [Y_0; Y_1] +// Z = (X_0 @ Y_0) + (X_1 @ Y_1) // // here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad" // and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more // general-purpose kernels // static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) { -//#if !defined(GGML_USE_METAL) -// return ggml_mul_mat(ctx, x, y); -//#endif + // use padding only if dimension 0 is at least 8 times larger than the padding + // else we won't get much benefit from the optimization + const int n_pad_req = 8; - if (x->ne[0] % pad == 0 || x->ne[0] / pad < 2) { + if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) { return ggml_mul_mat(ctx, x, y); } @@ -169,6 +169,11 @@ static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct g ggml_mul_mat(ctx, x_1, y_1)); } +// TODO: check if other platforms can benefit from this optimization +#if defined(GGML_USE_METAL) +#define ggml_mul_mat ggml_mul_mat_pad +#endif + // available whisper models enum e_model { MODEL_UNKNOWN, @@ -1659,7 +1664,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0, + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); @@ -1668,13 +1673,13 @@ static struct ggml_cgraph * whisper_build_graph_encoder( //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat_pad(ctx0, + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - struct ggml_tensor * Vcur = ggml_mul_mat_pad(ctx0, + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.attn_v_w, cur); @@ -1723,7 +1728,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( 0, 2, 1, 3); // K * Q - struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, K, Q); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale); @@ -1739,7 +1744,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head) ); - struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); #endif struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -1750,7 +1755,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // projection { - cur = ggml_mul_mat_pad(ctx0, + cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); @@ -1780,7 +1785,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else // fully connected - cur = ggml_mul_mat_pad(ctx0, + cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); @@ -1790,7 +1795,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( cur = ggml_gelu(ctx0, cur); // projection - cur = ggml_mul_mat_pad(ctx0, + cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); @@ -1868,13 +1873,13 @@ static struct ggml_cgraph * whisper_build_graph_cross( for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; - struct ggml_tensor* Kcross = ggml_mul_mat_pad(ctx0, + struct ggml_tensor* Kcross = ggml_mul_mat(ctx0, layer.cross_attn_k_w, cur); Kcross = ggml_scale(ctx0, Kcross, Kscale); - struct ggml_tensor* Vcross = ggml_mul_mat_pad(ctx0, + struct ggml_tensor* Vcross = ggml_mul_mat(ctx0, layer.cross_attn_v_w, cur); @@ -1948,7 +1953,7 @@ static bool whisper_encode_internal( #ifdef GGML_USE_METAL if (wstate.ctx_metal) { ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - ggml_metal_graph_compute(wstate.ctx_metal, gf); + ggml_metal_graph_compute(wstate.ctx_metal, gf, false); } else { ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); } @@ -1970,7 +1975,7 @@ static bool whisper_encode_internal( #ifdef GGML_USE_METAL if (wstate.ctx_metal) { ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - ggml_metal_graph_compute(wstate.ctx_metal, gf); + ggml_metal_graph_compute(wstate.ctx_metal, gf, false); } else { ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); } @@ -2071,7 +2076,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0, + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); @@ -2082,7 +2087,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( Qcur = ggml_scale(ctx0, Qcur, KQscale); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat_pad(ctx0, + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); @@ -2090,7 +2095,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // store key and value to memory { - struct ggml_tensor * Vcur = ggml_mul_mat_pad(ctx0, + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.attn_v_w, cur); @@ -2124,7 +2129,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_element_size(kv_self.k)*n_state*n_ctx*il); // K * Q - struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, K, Q); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); @@ -2139,7 +2144,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, il*n_ctx*ggml_element_size(kv_self.v)*n_state); - struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -2150,7 +2155,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // projection { - cur = ggml_mul_mat_pad(ctx0, + cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); @@ -2176,7 +2181,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // cross-attention { - struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0, + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.cross_attn_q_w, cur); @@ -2219,7 +2224,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( 0, 2, 1, 3); // K * Q - struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, Kcross, Q); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); //struct ggml_tensor * KQ_scaled = // ggml_scale(ctx0, @@ -2232,7 +2237,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); - struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -2244,7 +2249,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // projection { - cur = ggml_mul_mat_pad(ctx0, + cur = ggml_mul_mat(ctx0, layer.cross_attn_ln_1_w, cur); @@ -2273,7 +2278,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( } // fully connected - cur = ggml_mul_mat_pad(ctx0, + cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); @@ -2285,7 +2290,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur = ggml_gelu(ctx0, cur); // projection - cur = ggml_mul_mat_pad(ctx0, + cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); @@ -2315,7 +2320,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // might be useful in the future cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); - struct ggml_tensor * logits = ggml_mul_mat_pad(ctx0, model.d_te, cur); + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); ggml_build_forward_expand(gf, logits); @@ -2368,7 +2373,7 @@ static bool whisper_decode_internal( #ifdef GGML_USE_METAL if (wstate.ctx_metal) { ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - ggml_metal_graph_compute(wstate.ctx_metal, gf); + ggml_metal_graph_compute(wstate.ctx_metal, gf, false); } else { ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); }