feat: add t_dtw_end for accurate DTW word timestamps

Add t_dtw_end field to whisper_token_data to capture the end time
from DTW alignment path, not just the start. This enables accurate
word-level timestamps that can detect pauses/silence between words.

Previously only t_dtw (start) was stored, causing silence to be
absorbed into token durations.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Kirill K 2026-01-18 22:15:00 -08:00
parent 4979e04f5d
commit 75dcf461d7
2 changed files with 29 additions and 6 deletions

View File

@ -144,8 +144,8 @@ extern "C" {
// [EXPERIMENTAL] Token-level timestamps with DTW
// do not use if you haven't computed token-level timestamps with dtw
// Roughly corresponds to the moment in audio in which the token was output
int64_t t_dtw;
int64_t t_dtw; // start time from DTW alignment
int64_t t_dtw_end; // end time from DTW alignment
float vlen; // voice length of the token
} whisper_token_data;

View File

@ -6423,7 +6423,7 @@ static whisper_token_data whisper_sample_token(
const whisper_decoder & decoder,
bool best) {
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f,
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, -1, 0.0f,
};
const auto & vocab = ctx.vocab;
@ -6541,7 +6541,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
const auto id = dist(decoder.rng);
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, });
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, -1, 0.0f, });
if (result[i].id >= vocab.token_beg) {
result[i].tid = result[i].id;
@ -6796,6 +6796,11 @@ int whisper_full_with_state(
result_all.clear();
if (ctx->params.flash_attn && params.strategy == WHISPER_SAMPLING_GREEDY) {
// Avoid spurious timestamp anchors with flash attention in greedy decoding.
params.thold_ptsum = std::max(params.thold_ptsum, 0.02f);
}
if (n_samples > 0) {
// compute log mel spectrogram
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
@ -8916,14 +8921,23 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
// Place timestamps on segments
// Place timestamps on segments - capture both start AND end times from DTW path
int32_t last_v = 0;
int32_t last_time_index = 0;
auto seg_i = state->result_all.begin() + i_segment;
auto tok_i = seg_i->tokens.begin();
whisper_token_data * prev_tok = nullptr;
for (int i = 0; i < alignment->ne[1]; ++i) {
int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0);
int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
if (v != last_v) {
int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
// End time for PREVIOUS token = last audio frame before text_index changed
if (prev_tok != nullptr) {
prev_tok->t_dtw_end = (last_time_index * 2) + seek;
}
int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
last_v = v;
@ -8936,13 +8950,22 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
}
}
// Start time for current token
tok_i->t_dtw = timestamp;
prev_tok = &(*tok_i);
++tok_i;
if (tok_i == seg_i->tokens.end()) {
++seg_i;
tok_i = seg_i->tokens.begin();
}
}
last_time_index = time_index;
}
// End time for the last token
if (prev_tok != nullptr) {
prev_tok->t_dtw_end = (last_time_index * 2) + seek;
}
// Print DTW timestamps