Merge c323ca2def into fc674574ca
This commit is contained in:
commit
1a5676eaee
|
|
@ -625,6 +625,14 @@ extern "C" {
|
|||
int n_samples,
|
||||
int n_processors);
|
||||
|
||||
WHISPER_API int whisper_full_batch_parallel(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_full_params params,
|
||||
const float * const * batches,
|
||||
const int * size_per_batch,
|
||||
int n_batches,
|
||||
int n_processors);
|
||||
|
||||
// Number of generated text segments
|
||||
// A segment can be a few words, a sentence, or even a paragraph.
|
||||
WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx);
|
||||
|
|
|
|||
123
src/whisper.cpp
123
src/whisper.cpp
|
|
@ -7893,6 +7893,129 @@ int whisper_full_parallel(
|
|||
return ret;
|
||||
}
|
||||
|
||||
|
||||
int whisper_full_batch_parallel(
|
||||
struct whisper_context *ctx,
|
||||
struct whisper_full_params params,
|
||||
const float *const *batches,
|
||||
const int *size_per_batch,
|
||||
int n_batches,
|
||||
int n_processors)
|
||||
{
|
||||
int ret = 0;
|
||||
n_processors = std::min(n_processors, n_batches);
|
||||
if (n_batches > n_processors)
|
||||
{
|
||||
throw std::runtime_error("batch size must be equal to number of processors");
|
||||
}
|
||||
// prepare separate states for each thread
|
||||
std::vector<whisper_state *> states;
|
||||
std::vector<std::vector<float>> batches_vector;
|
||||
batches_vector.reserve(n_batches);
|
||||
for (int i = 0; i < n_batches; ++i)
|
||||
{
|
||||
int batch_size = size_per_batch[i];
|
||||
batches_vector.emplace_back(batches[i], batches[i] + batch_size);
|
||||
}
|
||||
|
||||
// the calling thread will process the first chunk
|
||||
// while the other threads will process the remaining chunks
|
||||
const int n_parallel_processes = n_processors - 1;
|
||||
std::vector<std::thread> workers(n_parallel_processes);
|
||||
for (int i = 0; i < n_parallel_processes; ++i)
|
||||
{
|
||||
if (i + 1 > n_batches - 1)
|
||||
{
|
||||
// break when batch not exist for parallel process
|
||||
break;
|
||||
}
|
||||
const float *samples = batches_vector[i + 1].data();
|
||||
const int n_samples = batches_vector[i + 1].size();
|
||||
// create a new state for each thread
|
||||
states.push_back(whisper_init_state(ctx));
|
||||
|
||||
auto params_cur = params;
|
||||
|
||||
params_cur.offset_ms = 0;
|
||||
params_cur.print_progress = false;
|
||||
params_cur.print_realtime = false;
|
||||
|
||||
params_cur.new_segment_callback = nullptr;
|
||||
params_cur.new_segment_callback_user_data = nullptr;
|
||||
|
||||
params_cur.progress_callback = nullptr;
|
||||
params_cur.progress_callback_user_data = nullptr;
|
||||
|
||||
workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples, n_samples);
|
||||
}
|
||||
|
||||
{
|
||||
auto params_cur = params;
|
||||
|
||||
// We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
|
||||
params_cur.print_realtime = false;
|
||||
|
||||
const float *samples = batches_vector[0].data();
|
||||
const int n_samples = batches_vector[0].size();
|
||||
|
||||
// Run the first transformation using default state but only for the first chunk.
|
||||
ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, n_samples);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_parallel_processes; ++i)
|
||||
{
|
||||
workers[i].join();
|
||||
}
|
||||
|
||||
// combine results into result_state->result_all from all other states
|
||||
for (int i = 0; i < n_processors - 1; ++i)
|
||||
{
|
||||
auto &results_i = states[i]->result_all;
|
||||
|
||||
for (auto &result : results_i)
|
||||
{
|
||||
|
||||
// make sure that segments are not overlapping
|
||||
if (!ctx->state->result_all.empty())
|
||||
{
|
||||
result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
|
||||
}
|
||||
|
||||
ctx->state->result_all.push_back(std::move(result));
|
||||
|
||||
// call the new_segment_callback for each segment
|
||||
if (params.new_segment_callback)
|
||||
{
|
||||
params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data);
|
||||
}
|
||||
}
|
||||
|
||||
ctx->state->t_mel_us += states[i]->t_mel_us;
|
||||
|
||||
ctx->state->t_sample_us += states[i]->t_sample_us;
|
||||
ctx->state->t_encode_us += states[i]->t_encode_us;
|
||||
ctx->state->t_decode_us += states[i]->t_decode_us;
|
||||
ctx->state->t_batchd_us += states[i]->t_batchd_us;
|
||||
ctx->state->t_prompt_us += states[i]->t_prompt_us;
|
||||
|
||||
ctx->state->n_sample += states[i]->n_sample;
|
||||
ctx->state->n_encode += states[i]->n_encode;
|
||||
ctx->state->n_decode += states[i]->n_decode;
|
||||
ctx->state->n_batchd += states[i]->n_batchd;
|
||||
ctx->state->n_prompt += states[i]->n_prompt;
|
||||
|
||||
whisper_free_state(states[i]);
|
||||
}
|
||||
|
||||
// average the timings
|
||||
ctx->state->t_mel_us /= n_processors;
|
||||
ctx->state->t_sample_us /= n_processors;
|
||||
ctx->state->t_encode_us /= n_processors;
|
||||
ctx->state->t_decode_us /= n_processors;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
int whisper_full_n_segments_from_state(struct whisper_state * state) {
|
||||
return state->result_all.size();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue