diff --git a/whisper.cpp b/whisper.cpp index 3083fd83..edd97c7c 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -4521,11 +4521,12 @@ static const std::vector non_speech_tokens = { // process the logits for the selected decoder // - applies logit filters // - computes logprobs and probs +// TODO: optimize static void whisper_process_logits( struct whisper_context & ctx, struct whisper_state & state, - const struct whisper_full_params params, struct whisper_decoder & decoder, + const struct whisper_full_params params, float temperature) { const auto & vocab = ctx.vocab; const auto & tokens_cur = decoder.sequence.tokens; @@ -5297,7 +5298,7 @@ int whisper_full_with_state( state->decoders[0].i_batch = prompt.size() - 1; - whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur); + whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur); for (int j = 1; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; @@ -5322,56 +5323,66 @@ int whisper_full_with_state( } } + // sampling + // TODO: avoid memory allocations, optimize, avoid threads? { std::atomic j_cur(0); + auto process = [&]() { + while (true) { + const int j = j_cur.fetch_add(1); + + if (j >= n_decoders_cur) { + break; + } + + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur < 1e-6f) { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + } else { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto & token : tokens_new) { + bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, }); + bc_per_dec[j].back().sequence.tokens.push_back(token); + bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog; + } + } break; + }; + } + }; + const int n_threads = std::min(params.n_threads, n_decoders_cur); - std::vector threads(n_threads); + if (n_threads == 1) { + process(); + } else { + std::vector threads(n_threads - 1); - for (int t = 0; t < n_threads; ++t) { - threads[t] = std::thread([&]() { - while (true) { - const int j = j_cur.fetch_add(1); + for (int t = 0; t < n_threads - 1; ++t) { + threads[t] = std::thread(process); + } - if (j >= n_decoders_cur) { - break; - } + process(); - auto & decoder = state->decoders[j]; - - if (decoder.completed || decoder.failed) { - continue; - } - - switch (params.strategy) { - case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: - { - if (t_cur < 1e-6f) { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); - } else { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); - } - - decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; - } break; - case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: - { - const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); - - for (const auto & token : tokens_new) { - bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, }); - bc_per_dec[j].back().sequence.tokens.push_back(token); - bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog; - } - } break; - }; - } - }); - } - - for (auto & t : threads) { - t.join(); + for (int t = 0; t < n_threads - 1; ++t) { + threads[t].join(); + } } } @@ -5577,13 +5588,10 @@ int whisper_full_with_state( const int64_t t_start_sample_us = ggml_time_us(); + // TODO: avoid memory allocations, optimize, avoid threads? { std::atomic j_cur(0); - const int n_threads = std::min(params.n_threads, n_decoders_cur); - - std::vector threads(n_threads); - auto process = [&]() { while (true) { const int j = j_cur.fetch_add(1); @@ -5598,18 +5606,26 @@ int whisper_full_with_state( continue; } - whisper_process_logits(*ctx, *state, params, decoder, t_cur); + whisper_process_logits(*ctx, *state, decoder, params, t_cur); } }; - for (int t = 0; t < n_threads - 1; ++t) { - threads[t] = std::thread(process); - } + const int n_threads = std::min(params.n_threads, n_decoders_cur); - process(); + if (n_threads == 1) { + process(); + } else { + std::vector threads(n_threads - 1); - for (int t = 0; t < n_threads - 1; ++t) { - threads[t].join(); + for (int t = 0; t < n_threads - 1; ++t) { + threads[t] = std::thread(process); + } + + process(); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t].join(); + } } }