This commit is contained in:
Deep khurana 2026-04-23 18:09:09 +00:00 committed by GitHub
commit 1a5676eaee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 131 additions and 0 deletions

View File

@ -625,6 +625,14 @@ extern "C" {
int n_samples,
int n_processors);
WHISPER_API int whisper_full_batch_parallel(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * const * batches,
const int * size_per_batch,
int n_batches,
int n_processors);
// Number of generated text segments
// A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx);

View File

@ -7893,6 +7893,129 @@ int whisper_full_parallel(
return ret;
}
int whisper_full_batch_parallel(
struct whisper_context *ctx,
struct whisper_full_params params,
const float *const *batches,
const int *size_per_batch,
int n_batches,
int n_processors)
{
int ret = 0;
n_processors = std::min(n_processors, n_batches);
if (n_batches > n_processors)
{
throw std::runtime_error("batch size must be equal to number of processors");
}
// prepare separate states for each thread
std::vector<whisper_state *> states;
std::vector<std::vector<float>> batches_vector;
batches_vector.reserve(n_batches);
for (int i = 0; i < n_batches; ++i)
{
int batch_size = size_per_batch[i];
batches_vector.emplace_back(batches[i], batches[i] + batch_size);
}
// the calling thread will process the first chunk
// while the other threads will process the remaining chunks
const int n_parallel_processes = n_processors - 1;
std::vector<std::thread> workers(n_parallel_processes);
for (int i = 0; i < n_parallel_processes; ++i)
{
if (i + 1 > n_batches - 1)
{
// break when batch not exist for parallel process
break;
}
const float *samples = batches_vector[i + 1].data();
const int n_samples = batches_vector[i + 1].size();
// create a new state for each thread
states.push_back(whisper_init_state(ctx));
auto params_cur = params;
params_cur.offset_ms = 0;
params_cur.print_progress = false;
params_cur.print_realtime = false;
params_cur.new_segment_callback = nullptr;
params_cur.new_segment_callback_user_data = nullptr;
params_cur.progress_callback = nullptr;
params_cur.progress_callback_user_data = nullptr;
workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples, n_samples);
}
{
auto params_cur = params;
// We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
params_cur.print_realtime = false;
const float *samples = batches_vector[0].data();
const int n_samples = batches_vector[0].size();
// Run the first transformation using default state but only for the first chunk.
ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, n_samples);
}
for (int i = 0; i < n_parallel_processes; ++i)
{
workers[i].join();
}
// combine results into result_state->result_all from all other states
for (int i = 0; i < n_processors - 1; ++i)
{
auto &results_i = states[i]->result_all;
for (auto &result : results_i)
{
// make sure that segments are not overlapping
if (!ctx->state->result_all.empty())
{
result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
}
ctx->state->result_all.push_back(std::move(result));
// call the new_segment_callback for each segment
if (params.new_segment_callback)
{
params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data);
}
}
ctx->state->t_mel_us += states[i]->t_mel_us;
ctx->state->t_sample_us += states[i]->t_sample_us;
ctx->state->t_encode_us += states[i]->t_encode_us;
ctx->state->t_decode_us += states[i]->t_decode_us;
ctx->state->t_batchd_us += states[i]->t_batchd_us;
ctx->state->t_prompt_us += states[i]->t_prompt_us;
ctx->state->n_sample += states[i]->n_sample;
ctx->state->n_encode += states[i]->n_encode;
ctx->state->n_decode += states[i]->n_decode;
ctx->state->n_batchd += states[i]->n_batchd;
ctx->state->n_prompt += states[i]->n_prompt;
whisper_free_state(states[i]);
}
// average the timings
ctx->state->t_mel_us /= n_processors;
ctx->state->t_sample_us /= n_processors;
ctx->state->t_encode_us /= n_processors;
ctx->state->t_decode_us /= n_processors;
return ret;
}
int whisper_full_n_segments_from_state(struct whisper_state * state) {
return state->result_all.size();
}