Fixed VAD to work when using whisper_full_with_state

This commit is contained in:
Martin Destagnol 2025-09-17 10:23:49 -10:00
parent edea8a9c3c
commit b91fe0f390
1 changed files with 26 additions and 35 deletions

View File

@ -6652,8 +6652,8 @@ static bool whisper_vad(
if (vad_segments->data.size() > 0) {
state->has_vad_segments = true;
ctx->state->vad_segments.clear();
ctx->state->vad_segments.reserve(vad_segments->data.size());
state->vad_segments.clear();
state->vad_segments.reserve(vad_segments->data.size());
// Initialize the time mapping table
state->vad_mapping_table.clear();
@ -6749,7 +6749,7 @@ static bool whisper_vad(
WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
__func__, segment.orig_start/100.0, segment.orig_end/100.0, segment.vad_start/100.0, segment.vad_end/100.0);
ctx->state->vad_segments.push_back(segment);
state->vad_segments.push_back(segment);
// Copy this speech segment
memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
@ -6820,6 +6820,24 @@ int whisper_full_with_state(
}
}
std::vector<float> vad_samples;
if (params.vad)
{
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples))
{
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
return -1;
}
if (vad_samples.empty())
{
state->result_all.clear();
return 0;
}
samples = vad_samples.data();
n_samples = vad_samples.size();
}
// auto-detect language if not specified
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
@ -7720,25 +7738,11 @@ int whisper_full_with_state(
}
int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * samples,
int n_samples) {
std::vector<float> vad_samples;
if (params.vad) {
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
return -1;
}
if (vad_samples.empty()) {
ctx->state->result_all.clear();
return 0;
}
samples = vad_samples.data();
n_samples = vad_samples.size();
}
struct whisper_context *ctx,
struct whisper_full_params params,
const float *samples,
int n_samples)
{
return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
}
@ -7753,19 +7757,6 @@ int whisper_full_parallel(
return whisper_full(ctx, params, samples, n_samples);
}
std::vector<float> vad_samples;
if (params.vad) {
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
return -1;
}
if (vad_samples.empty()) {
return 0;
}
samples = vad_samples.data();
n_samples = vad_samples.size();
}
int ret = 0;
// prepare separate states for each thread