diff --git a/include/whisper.h b/include/whisper.h index f4cc6bf7..b9bf454c 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -144,8 +144,8 @@ extern "C" { // [EXPERIMENTAL] Token-level timestamps with DTW // do not use if you haven't computed token-level timestamps with dtw - // Roughly corresponds to the moment in audio in which the token was output - int64_t t_dtw; + int64_t t_dtw; // start time from DTW alignment + int64_t t_dtw_end; // end time from DTW alignment float vlen; // voice length of the token } whisper_token_data; diff --git a/src/whisper.cpp b/src/whisper.cpp index 33e556c4..18337b1f 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -6423,7 +6423,7 @@ static whisper_token_data whisper_sample_token( const whisper_decoder & decoder, bool best) { whisper_token_data result = { - 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f, + 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, -1, 0.0f, }; const auto & vocab = ctx.vocab; @@ -6541,7 +6541,7 @@ static std::vector whisper_sample_token_topk( const auto id = dist(decoder.rng); //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); - result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, }); + result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, -1, 0.0f, }); if (result[i].id >= vocab.token_beg) { result[i].tid = result[i].id; @@ -6796,6 +6796,11 @@ int whisper_full_with_state( result_all.clear(); + if (ctx->params.flash_attn && params.strategy == WHISPER_SAMPLING_GREEDY) { + // Avoid spurious timestamp anchors with flash attention in greedy decoding. + params.thold_ptsum = std::max(params.thold_ptsum, 0.02f); + } + if (n_samples > 0) { // compute log mel spectrogram if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { @@ -8916,14 +8921,23 @@ static void whisper_exp_compute_token_level_timestamps_dtw( ggml_tensor * alignment = dtw_and_backtrace(gctx, w); - // Place timestamps on segments + // Place timestamps on segments - capture both start AND end times from DTW path int32_t last_v = 0; + int32_t last_time_index = 0; auto seg_i = state->result_all.begin() + i_segment; auto tok_i = seg_i->tokens.begin(); + whisper_token_data * prev_tok = nullptr; + for (int i = 0; i < alignment->ne[1]; ++i) { int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0); + int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0); + if (v != last_v) { - int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0); + // End time for PREVIOUS token = last audio frame before text_index changed + if (prev_tok != nullptr) { + prev_tok->t_dtw_end = (last_time_index * 2) + seek; + } + int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio last_v = v; @@ -8936,13 +8950,22 @@ static void whisper_exp_compute_token_level_timestamps_dtw( } } + // Start time for current token tok_i->t_dtw = timestamp; + prev_tok = &(*tok_i); + ++tok_i; if (tok_i == seg_i->tokens.end()) { ++seg_i; tok_i = seg_i->tokens.begin(); } } + last_time_index = time_index; + } + + // End time for the last token + if (prev_tok != nullptr) { + prev_tok->t_dtw_end = (last_time_index * 2) + seek; } // Print DTW timestamps