diff --git a/whisper.cpp b/whisper.cpp index 1054a28b..2d164180 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -445,7 +445,7 @@ static void whisper_batch_free(struct whisper_batch batch) { if (batch.logits) free(batch.logits); } -static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past) { +static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) { batch.n_tokens = n_tokens; for (int i = 0; i < n_tokens; ++i) { if (tokens) { @@ -453,7 +453,7 @@ static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token } batch.pos [i] = n_past + i; batch.n_seq_id[i] = 1; - batch.seq_id [i][0] = 0; + batch.seq_id [i][0] = seq_id; batch.logits [i] = 0; } batch.logits[n_tokens - 1] = 1; @@ -654,11 +654,11 @@ struct whisper_partial_utf8 { }; struct whisper_grammar { - /*const*/ std::vector> rules; - std::vector> stacks; + /*const*/ std::vector> rules; + std::vector> stacks; // buffer for partially generated UTF-8 sequence from accepted tokens - whisper_partial_utf8 partial_utf8; + whisper_partial_utf8 partial_utf8; }; struct whisper_grammar_candidate { @@ -682,9 +682,6 @@ struct whisper_sequence { // TAGS: WHISPER_DECODER_INIT struct whisper_decoder { - // each decoder keeps its own KV-cache - whisper_kv_cache kv_self; - // the currently generated sequence of tokens whisper_sequence sequence; @@ -701,8 +698,6 @@ struct whisper_decoder { std::vector probs; std::vector logits; std::vector logprobs; - - std::vector tokens_tmp; // used for whisper_decode calls }; // replace std::pair by using customized pair struct (reason: std::pair is very slow) @@ -717,12 +712,6 @@ struct whisper_pair { whisper_pair() : first(A()), second(B()) {} }; -// beam-search helpers -struct kv_buf { - std::vector k; - std::vector v; -}; - // ggml_allocr wrapper for whisper usage struct whisper_allocr { ggml_allocr * alloc = nullptr; @@ -787,18 +776,19 @@ struct whisper_state { int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures + // unified self-attention KV cache for all decoders + whisper_kv_cache kv_self; + // cross-attention KV cache for the decoders // shared between all decoders whisper_kv_cache kv_cross; + whisper_mel mel; whisper_batch batch; whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; - // buffer for swapping KV caches between decoders during beam-search - std::vector kv_swap_bufs; - ggml_backend_t backend = nullptr; // ggml-alloc: @@ -1046,7 +1036,7 @@ static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) } } - return 0; + return 1; } static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) { @@ -1057,6 +1047,36 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) { cache.head = 0; } +static void whisper_kv_cache_seq_rm( + struct whisper_kv_cache & cache, + whisper_seq_id seq_id, + whisper_pos p0, + whisper_pos p1) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + if (seq_id < 0) { + cache.cells[i].seq_id.clear(); + } else if (cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cache.cells[i].seq_id.empty()) { + cache.cells[i].pos = -1; + if (new_head == cache.size) new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size) cache.head = new_head; +} + static void whisper_kv_cache_seq_cp( struct whisper_kv_cache & cache, whisper_seq_id seq_id_src, @@ -2197,13 +2217,12 @@ static bool whisper_encode_internal( static struct ggml_cgraph * whisper_build_graph_decoder( whisper_context & wctx, whisper_state & wstate, - whisper_decoder & decoder, const whisper_batch & batch) { const auto & model = wctx.model; const auto & hparams = model.hparams; // TODO: move to wstate - auto & kv_self = decoder.kv_self; + auto & kv_self = wstate.kv_self; WHISPER_ASSERT(!!kv_self.ctx); @@ -2374,7 +2393,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( n_kv, n_state/n_head, n_head, n_ctx*ggml_element_size(kv_self.v), n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, - il*n_ctx*ggml_element_size(kv_self.v)*n_state); + n_ctx*ggml_element_size(kv_self.v)*n_state*il); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); @@ -2574,7 +2593,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder( static bool whisper_decode_internal( whisper_context & wctx, whisper_state & wstate, - whisper_decoder & decoder, const whisper_batch & batch, const int n_threads, whisper_abort_callback abort_callback, @@ -2590,13 +2608,15 @@ static bool whisper_decode_internal( struct ggml_tensor * logits; - auto & kv_self = decoder.kv_self; + auto & kv_self = wstate.kv_self; if (!whisper_kv_cache_find_slot(kv_self, batch)) { return 1; } - kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); + kv_self.n = whisper_kv_cache_cell_max(kv_self); + //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); + //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); // decoder { @@ -2604,7 +2624,7 @@ static bool whisper_decode_internal( ggml_allocr_reset(alloc); - ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, batch); + ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch); ggml_allocr_alloc_graph(alloc, gf); @@ -3054,14 +3074,14 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->backend = whisper_backend_init(ctx->params); - if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) { + if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state; return nullptr; } { - const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v); + const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v); WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } @@ -3147,9 +3167,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { const int n_tokens = hparams.n_text_ctx; const int n_past = 0; - whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past); + whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0); - return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], state->batch); + return whisper_build_graph_decoder(*ctx, *state, state->batch); }); WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); @@ -3386,12 +3406,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa void whisper_free_state(struct whisper_state * state) { if (state) { + kv_cache_free(state->kv_self); kv_cache_free(state->kv_cross); - for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { - kv_cache_free(state->decoders[i].kv_self); - } - #ifdef WHISPER_USE_COREML if (state->ctx_coreml != nullptr) { whisper_coreml_free(state->ctx_coreml); @@ -3534,11 +3551,9 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { } int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - const int selected_decoder_id = 0; + whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0); - whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past); - - if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], state->batch, n_threads, nullptr, nullptr)) { + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3547,17 +3562,14 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state } int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - // TODO: add selected_decoder_id to state - const int selected_decoder_id = 0; - if (ctx->state == nullptr) { WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__); return false; } - whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past); + whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0); - if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], ctx->state->batch, n_threads, nullptr, nullptr)) { + if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -4178,8 +4190,7 @@ static std::vector whisper_grammar_reject_candidates_ if (*tok.code_points == 0) { // reached end of full codepoints in token, reject iff it ended in a partial sequence // that cannot satisfy this position in grammar - if (tok.partial_utf8.n_remain != 0 && - !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { rejects.push_back(tok); } } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) { @@ -5006,125 +5017,6 @@ static void whisper_sequence_score( } } -static bool whisper_kv_swap_fast( - std::vector & view, - whisper_decoder src[], - std::vector & kv_swap_bufs, - const int & n_decoders) { - WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders); - - // (decoder->buffer->decoder or decoder->buffer + decoder->decoder) - std::set two_copy; // decoder indices require two copies to safely modify KV caches - - // (buffer->decoder or decoder->decoder) - std::set one_copy; // decoder indices require one copy to safely modify KV caches - - // (decoder<->decoder) - std::set p_swap_set; // decoder indices able to swap KV-cache pointers - std::vector> p_swap_vec; - p_swap_vec.reserve(n_decoders); - - // see https://github.com/ggerganov/whisper.cpp/wiki - for (int i = 0; i < n_decoders; i++) { - // zero-copy (no modification) - if (i == view[i] || view[i] < 0) { - continue; - } - - bool is_one_copy = true; - // since we modify data sequentially, we only consider decoder indices after current index - for (int j = i + 1; j < n_decoders; j++) { - if (i == view[j]) { - // detect symmetric diagram - if (j == view[i]) { - p_swap_set.insert(i); - p_swap_set.insert(j); - p_swap_vec.emplace_back(i, j); - } else { - two_copy.insert(i); - is_one_copy = false; - } - break; - } - } - if (is_one_copy) { - one_copy.insert(i); - } - } - - kv_swap_bufs.resize(n_decoders); - - for (int i = 0; i < n_decoders; i++) { - kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k)); - kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v)); - } - - for (auto & i : two_copy) { - // make a copy of KV caches - WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i); - //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size()); - //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size()); - ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size()); - ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size()); - } - - // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first - for (auto & i : two_copy) { - // skip the decoder indices that require pointer swapping - if (p_swap_set.find(i) != p_swap_set.end()) { - continue; - } - - if (two_copy.find(view[i]) != two_copy.end()) { - // modify KV caches of decoder using data from kv_swap_bufs - WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); - //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); - ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size()); - ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size()); - } else { - // modify KV caches of decoder using data from correspond decoder KV caches directly - WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); - //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); - ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k); - ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v); - } - } - - // then modify one-copy decoder KV caches - for (auto & i : one_copy) { - // skip the decoder indices that require pointer swapping - if (p_swap_set.find(i) != p_swap_set.end()) { - continue; - } - - if (two_copy.find(view[i]) != two_copy.end()) { - // modify KV caches of decoder using data from kv_swap_bufs - WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); - //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); - ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size()); - ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size()); - } else { - // modify KV caches of decoder using data from correspond decoder KV caches directly - WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); - //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); - ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k); - ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v); - } - } - - // swap the pointers - for (auto & i : p_swap_vec) { - WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second); - std::swap(src[i.first].kv_self, src[i.second].kv_self); - } - - return true; -} - int whisper_full_with_state( struct whisper_context * ctx, struct whisper_state * state, @@ -5218,21 +5110,11 @@ int whisper_full_with_state( for (int j = 1; j < n_decoders; j++) { auto & decoder = state->decoders[j]; - if (decoder.kv_self.ctx == nullptr) { - decoder.kv_self = state->decoders[0].kv_self; - if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) { - WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); - return -4; - } + decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); - WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); - - decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); - - decoder.probs.resize (ctx->vocab.n_vocab); - decoder.logits.resize (ctx->vocab.n_vocab); - decoder.logprobs.resize(ctx->vocab.n_vocab); - } + decoder.probs.resize (ctx->vocab.n_vocab); + decoder.logits.resize (ctx->vocab.n_vocab); + decoder.logprobs.resize(ctx->vocab.n_vocab); } // the accumulated text context so far @@ -5309,6 +5191,7 @@ int whisper_full_with_state( bool has_ts; whisper_sequence sequence; + whisper_grammar grammar; }; std::vector beam_candidates; @@ -5378,8 +5261,6 @@ int whisper_full_with_state( for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; - decoder.kv_self.n = 0; - decoder.sequence.tokens.clear(); decoder.sequence.result_len = 0; decoder.sequence.sum_logprobs_all = 0.0; @@ -5395,15 +5276,14 @@ int whisper_full_with_state( decoder.has_ts = false; if (params.grammar_rules != nullptr) { - decoder.grammar = whisper_grammar_init( - params.grammar_rules, params.n_grammar_rules, params.i_start_rule); + decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule); } else { decoder.grammar = {}; } } // init prompt and kv cache for the current iteration - // run whisper_decoder() only for decoder 0 and copy the results for the other decoders + // TODO: do not recompute the prompt if it is the same as previous time { prompt.clear(); @@ -5425,11 +5305,11 @@ int whisper_full_with_state( } WHISPER_PRINT_DEBUG("\n\n"); - whisper_kv_cache_clear(state->decoders[0].kv_self); + whisper_kv_cache_clear(state->kv_self); - whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0); + whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0); - if (!whisper_decode_internal(*ctx, *state, state->decoders[0], state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } @@ -5439,18 +5319,10 @@ int whisper_full_with_state( whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur); - state->decoders[0].kv_self.n += prompt.size(); - for (int j = 1; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; - // TODO: fix CUDA - //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); - //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); - ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k); - ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v); - - decoder.kv_self.n += prompt.size(); + whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1); memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); @@ -5492,7 +5364,7 @@ int whisper_full_with_state( const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size); for (const auto & token : tokens_new) { - beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); + beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, }); beam_candidates.back().sequence.tokens.push_back(token); beam_candidates.back().sequence.sum_logprobs_all += token.plog; @@ -5531,17 +5403,30 @@ int whisper_full_with_state( ++cur_c; } - decoder.sequence = cur.sequence; decoder.seek_delta = cur.seek_delta; decoder.has_ts = cur.has_ts; + decoder.sequence = cur.sequence; + decoder.grammar = cur.grammar; decoder_idx[j] = cur.decoder_idx; + + whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1); + WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); } - // update KV caches - whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1); + whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1); + whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1); + } } // update the decoder state @@ -5657,14 +5542,14 @@ int whisper_full_with_state( continue; } - decoder.tokens_tmp.resize(1); - decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta); - //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); + // TODO: use batch + const int n_past = prompt.size() + i; - whisper_batch_prep_legacy(state->batch, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n); + whisper_batch_prep_legacy(state->batch, &decoder.sequence.tokens.back().id, 1, n_past, j); - if (!whisper_decode_internal(*ctx, *state, decoder, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -8; } @@ -5674,8 +5559,6 @@ int whisper_full_with_state( whisper_process_logits(*ctx, *state, params, decoder, t_cur); - ++decoder.kv_self.n; - state->t_sample_us += ggml_time_us() - t_start_sample_us; } }