whisper : clean-up ggml_mul_mat_pad
This commit is contained in:
parent
2b4160af29
commit
0d5e4cdc36
|
|
@ -77,7 +77,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
|
|||
|
||||
// same as ggml_graph_compute but uses Metal
|
||||
// creates gf->n_threads command buffers in parallel
|
||||
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
||||
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool concurrent);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -610,14 +610,15 @@ void ggml_metal_graph_find_concurrency(
|
|||
|
||||
void ggml_metal_graph_compute(
|
||||
struct ggml_metal_context * ctx,
|
||||
struct ggml_cgraph * gf) {
|
||||
struct ggml_cgraph * gf,
|
||||
bool concurrent) {
|
||||
@autoreleasepool {
|
||||
|
||||
// if there is ctx->concur_list, dispatch concurrently
|
||||
// else fallback to serial dispatch
|
||||
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
||||
|
||||
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
|
||||
const bool has_concur = concurrent && ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
|
||||
|
||||
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
||||
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
||||
|
|
@ -927,7 +928,7 @@ void ggml_metal_graph_compute(
|
|||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
||||
//} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
} else if (false) {
|
||||
// TODO: with the ggml_cont(ctx0, Q), this kernel is no longer useful
|
||||
// TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
||||
nrows = ne11;
|
||||
} else {
|
||||
|
|
|
|||
69
whisper.cpp
69
whisper.cpp
|
|
@ -141,20 +141,20 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
|
|||
//
|
||||
// Z = X @ Y
|
||||
//
|
||||
// with two matrix multiplications:
|
||||
// with the sum of two matrix multiplications:
|
||||
//
|
||||
// Z = [X_0; X_1] @ [Y_0; Y_1]
|
||||
// Z = (X_0 @ Y_0) + (X_1 @ Y_1)
|
||||
//
|
||||
// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
|
||||
// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
|
||||
// general-purpose kernels
|
||||
//
|
||||
static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) {
|
||||
//#if !defined(GGML_USE_METAL)
|
||||
// return ggml_mul_mat(ctx, x, y);
|
||||
//#endif
|
||||
// use padding only if dimension 0 is at least 8 times larger than the padding
|
||||
// else we won't get much benefit from the optimization
|
||||
const int n_pad_req = 8;
|
||||
|
||||
if (x->ne[0] % pad == 0 || x->ne[0] / pad < 2) {
|
||||
if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
|
||||
return ggml_mul_mat(ctx, x, y);
|
||||
}
|
||||
|
||||
|
|
@ -169,6 +169,11 @@ static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct g
|
|||
ggml_mul_mat(ctx, x_1, y_1));
|
||||
}
|
||||
|
||||
// TODO: check if other platforms can benefit from this optimization
|
||||
#if defined(GGML_USE_METAL)
|
||||
#define ggml_mul_mat ggml_mul_mat_pad
|
||||
#endif
|
||||
|
||||
// available whisper models
|
||||
enum e_model {
|
||||
MODEL_UNKNOWN,
|
||||
|
|
@ -1659,7 +1664,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0,
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
||||
layer.attn_q_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -1668,13 +1673,13 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
// note: no bias for Key
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat_pad(ctx0,
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
||||
layer.attn_k_w,
|
||||
cur);
|
||||
|
||||
//Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat_pad(ctx0,
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
||||
layer.attn_v_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -1723,7 +1728,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, K, Q);
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
|
||||
|
||||
|
|
@ -1739,7 +1744,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max);
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
#endif
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
|
|
@ -1750,7 +1755,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||
|
||||
// projection
|
||||
{
|
||||
cur = ggml_mul_mat_pad(ctx0,
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.attn_ln_1_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -1780,7 +1785,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
||||
#else
|
||||
// fully connected
|
||||
cur = ggml_mul_mat_pad(ctx0,
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.mlp_0_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -1790,7 +1795,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
// projection
|
||||
cur = ggml_mul_mat_pad(ctx0,
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.mlp_1_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -1868,13 +1873,13 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|||
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
||||
auto & layer = model.layers_decoder[il];
|
||||
|
||||
struct ggml_tensor* Kcross = ggml_mul_mat_pad(ctx0,
|
||||
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_k_w,
|
||||
cur);
|
||||
|
||||
Kcross = ggml_scale(ctx0, Kcross, Kscale);
|
||||
|
||||
struct ggml_tensor* Vcross = ggml_mul_mat_pad(ctx0,
|
||||
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_v_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -1948,7 +1953,7 @@ static bool whisper_encode_internal(
|
|||
#ifdef GGML_USE_METAL
|
||||
if (wstate.ctx_metal) {
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf, false);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
}
|
||||
|
|
@ -1970,7 +1975,7 @@ static bool whisper_encode_internal(
|
|||
#ifdef GGML_USE_METAL
|
||||
if (wstate.ctx_metal) {
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf, false);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
}
|
||||
|
|
@ -2071,7 +2076,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0,
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
||||
layer.attn_q_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -2082,7 +2087,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
||||
|
||||
// note: no bias for Key
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat_pad(ctx0,
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
||||
layer.attn_k_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -2090,7 +2095,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
|
||||
// store key and value to memory
|
||||
{
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat_pad(ctx0,
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
||||
layer.attn_v_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -2124,7 +2129,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
||||
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, K, Q);
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
|
||||
|
||||
|
|
@ -2139,7 +2144,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
|
||||
il*n_ctx*ggml_element_size(kv_self.v)*n_state);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max);
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
|
|
@ -2150,7 +2155,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
|
||||
// projection
|
||||
{
|
||||
cur = ggml_mul_mat_pad(ctx0,
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.attn_ln_1_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -2176,7 +2181,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
|
||||
// cross-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0,
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_q_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -2219,7 +2224,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, Kcross, Q);
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
|
||||
|
||||
//struct ggml_tensor * KQ_scaled =
|
||||
// ggml_scale(ctx0,
|
||||
|
|
@ -2232,7 +2237,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max);
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
|
|
@ -2244,7 +2249,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
|
||||
// projection
|
||||
{
|
||||
cur = ggml_mul_mat_pad(ctx0,
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_ln_1_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -2273,7 +2278,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
}
|
||||
|
||||
// fully connected
|
||||
cur = ggml_mul_mat_pad(ctx0,
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.mlp_0_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -2285,7 +2290,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
// projection
|
||||
cur = ggml_mul_mat_pad(ctx0,
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.mlp_1_w,
|
||||
cur);
|
||||
|
||||
|
|
@ -2315,7 +2320,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
// might be useful in the future
|
||||
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
||||
|
||||
struct ggml_tensor * logits = ggml_mul_mat_pad(ctx0, model.d_te, cur);
|
||||
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
||||
|
||||
ggml_build_forward_expand(gf, logits);
|
||||
|
||||
|
|
@ -2368,7 +2373,7 @@ static bool whisper_decode_internal(
|
|||
#ifdef GGML_USE_METAL
|
||||
if (wstate.ctx_metal) {
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf, false);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue