refactor(dtw): use named constants and helper lambdas

- Replace magic numbers with DTW_* constants (documented values)
- Extract get_prev_end/get_next_start/get_text_len helpers
- Document phonetic reasoning for onset shift values
- Fix C++14 compatibility (remove structured bindings)
- No behavioral changes, same timestamp output
This commit is contained in:
obvirm 2026-01-01 00:15:47 +07:00
parent 543dabe25a
commit a0fb5a0981
1 changed files with 93 additions and 119 deletions

View File

@ -8948,164 +8948,138 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
}
}
// adjust timestamps
const int64_t min_dur = 5;
// DTW timestamp refinement constants (in centiseconds, 1 cs = 10ms)
// These values are tuned for natural speech at ~150 WPM
static const int64_t DTW_MIN_TOKEN_DUR = 5; // 50ms absolute minimum
static const int DTW_ONSET_VOWEL = 15; // 150ms for vowels/plosives (anticipate burst)
static const int DTW_ONSET_CONSONANT = 8; // 80ms for other consonants
static const int DTW_DUR_PER_CHAR = 2; // 20ms per character for min duration
static const int DTW_MAX_DUR_PER_CHAR = 15; // 150ms per character for max duration
static const int64_t DTW_MAX_DUR_BASE = 10; // 100ms base max duration
// vowels + plosives benefit from earlier onset to match perceived speech start
static const char * DTW_ONSET_PHONEMES = "aeiouywbcdgkpqt";
// helper: get previous token's end time
auto get_prev_end = [&](size_t seg_idx, int tok_idx) -> int64_t {
auto & seg = state->result_all[seg_idx];
for (int t2 = tok_idx - 1; t2 >= 0; --t2) {
if (seg.tokens[t2].id < whisper_token_eot(ctx)) {
return seg.tokens[t2].t_dtw;
}
}
if (seg_idx > 0 && !state->result_all[seg_idx - 1].tokens.empty()) {
return state->result_all[seg_idx - 1].tokens.back().t_dtw;
}
return 0;
};
// helper: get next token's start time
auto get_next_start = [&](size_t seg_idx, int tok_idx, int64_t fallback) -> int64_t {
auto & seg = state->result_all[seg_idx];
const int n = seg.tokens.size();
for (int t2 = tok_idx + 1; t2 < n; ++t2) {
if (seg.tokens[t2].id < whisper_token_eot(ctx)) {
return seg.tokens[t2].t_dtw;
}
}
if (seg_idx + 1 < state->result_all.size()) {
for (const auto & ntok : state->result_all[seg_idx + 1].tokens) {
if (ntok.id < whisper_token_eot(ctx)) {
return ntok.t_dtw;
}
}
}
return fallback;
};
// helper: get token text length (excluding leading space)
auto get_text_len = [&](whisper_token id) -> std::pair<const char*, int> {
const char * text = whisper_token_to_str(ctx, id);
int len = text ? (int)strlen(text) : 1;
if (len > 0 && text && text[0] == ' ') { text++; len--; }
if (len < 1) len = 1;
return {text, len};
};
// pass 1: onset shift + min duration adjustment
for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
auto & segment = state->result_all[i];
const int n_tokens = segment.tokens.size();
for (int t = 0; t < n_tokens; ++t) {
auto & tok = segment.tokens[t];
if (tok.id >= whisper_token_eot(ctx)) continue;
int len = 1;
const char * text = whisper_token_to_str(ctx, tok.id);
if (text) {
len = (int)strlen(text);
if (len > 0 && text[0] == ' ') {
text++;
len--;
}
}
auto text_pair = get_text_len(tok.id);
const char * text = text_pair.first;
int len = text_pair.second;
// onset shift
{
// onset shift: move start earlier for vowels/plosives
if (len > 0 && text) {
char c = tolower(text[0]);
int shift = 0;
if (len > 0 && text) {
char c = tolower(text[0]);
if (strchr("aeiouywbcdgkpqqt", c)) {
shift = 15;
} else if (c >= 'a' && c <= 'z') {
shift = 8;
}
if (strchr(DTW_ONSET_PHONEMES, c)) {
shift = DTW_ONSET_VOWEL;
} else if (c >= 'a' && c <= 'z') {
shift = DTW_ONSET_CONSONANT;
}
if (shift > 0) {
int64_t prev_end = 0;
if (t > 0) {
for (int t2 = t - 1; t2 >= 0; --t2) {
if (segment.tokens[t2].id < whisper_token_eot(ctx)) {
prev_end = segment.tokens[t2].t_dtw;
break;
}
}
} else if (i > 0 && !state->result_all[i-1].tokens.empty()) {
prev_end = state->result_all[i-1].tokens.back().t_dtw;
}
int64_t prev_end = get_prev_end(i, t);
if (tok.t_dtw - shift > prev_end + 1) {
tok.t_dtw -= shift;
}
}
}
// min duration
{
int64_t next_t_dtw = -1;
for (int t2 = t + 1; t2 < n_tokens; ++t2) {
if (segment.tokens[t2].id < whisper_token_eot(ctx)) {
next_t_dtw = segment.tokens[t2].t_dtw;
break;
}
}
// min duration: extend backward if too short
int64_t next_start = get_next_start(i, t, segment.t1);
int64_t duration = next_start - tok.t_dtw;
int64_t len_based_min = (int64_t)(len * DTW_DUR_PER_CHAR);
int64_t adaptive_min = (DTW_MIN_TOKEN_DUR > len_based_min) ? DTW_MIN_TOKEN_DUR : len_based_min;
if (next_t_dtw < 0) {
next_t_dtw = segment.t1;
}
int64_t duration = next_t_dtw - tok.t_dtw;
const int64_t adaptive_min = std::max((int64_t)5, (int64_t)(len * 2));
if (duration < adaptive_min && duration >= 0) {
int64_t needed = adaptive_min - duration;
int64_t prev_end = 0;
if (t > 0) {
for (int t2 = t - 1; t2 >= 0; --t2) {
if (segment.tokens[t2].id < whisper_token_eot(ctx)) {
prev_end = segment.tokens[t2].t_dtw;
break;
}
}
} else if (i > 0 && !state->result_all[i-1].tokens.empty()) {
prev_end = state->result_all[i-1].tokens.back().t_dtw;
}
int64_t new_start = tok.t_dtw - needed;
if (new_start > prev_end + 2) {
tok.t_dtw = new_start;
}
if (duration >= 0 && duration < adaptive_min) {
int64_t prev_end = get_prev_end(i, t);
int64_t new_start = tok.t_dtw - (adaptive_min - duration);
if (new_start > prev_end + 2) {
tok.t_dtw = new_start;
}
}
}
}
// propagate to t0/t1
// pass 2: propagate t_dtw to t0/t1 with max duration cap
for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
auto & segment = state->result_all[i];
const int n_tokens = segment.tokens.size();
for (int t = 0; t < n_tokens; ++t) {
auto & tok = segment.tokens[t];
if (tok.id >= whisper_token_eot(ctx)) continue;
tok.t0 = tok.t_dtw;
int64_t next_t_dtw = -1;
for (int t2 = t + 1; t2 < n_tokens; ++t2) {
if (segment.tokens[t2].id < whisper_token_eot(ctx)) {
next_t_dtw = segment.tokens[t2].t_dtw;
break;
}
}
auto text_pair2 = get_text_len(tok.id);
int len2 = text_pair2.second;
int64_t next_start = get_next_start(i, t, segment.t1);
int64_t len_based_max = (int64_t)(len2 * DTW_MAX_DUR_PER_CHAR);
int64_t max_dur = (DTW_MAX_DUR_BASE > len_based_max) ? DTW_MAX_DUR_BASE : len_based_max;
if (next_t_dtw < 0 && i + 1 < state->result_all.size()) {
for (const auto & ntok : state->result_all[i + 1].tokens) {
if (ntok.id < whisper_token_eot(ctx)) {
next_t_dtw = ntok.t_dtw;
break;
}
}
}
int64_t raw_t1 = (next_t_dtw >= 0) ? next_t_dtw : segment.t1;
// max duration
{
int len = 1;
const char * text = whisper_token_to_str(ctx, tok.id);
if (text) {
len = (int)strlen(text);
if (len > 0 && text[0] == ' ') len--;
if (len < 1) len = 1;
}
int64_t max_dur = std::max((int64_t)10, (int64_t)(len * 15));
if (raw_t1 < tok.t0 + min_dur) {
raw_t1 = tok.t0 + min_dur;
}
tok.t1 = (raw_t1 - tok.t0 > max_dur) ? tok.t0 + max_dur : raw_t1;
}
int64_t min_t1 = tok.t0 + DTW_MIN_TOKEN_DUR;
int64_t raw_t1 = (next_start > min_t1) ? next_start : min_t1;
tok.t1 = (raw_t1 - tok.t0 > max_dur) ? tok.t0 + max_dur : raw_t1;
}
// segment bounds
{
int64_t first_t0 = -1;
int64_t last_t1 = -1;
for (int t = 0; t < n_tokens; ++t) {
const auto & tok = segment.tokens[t];
if (tok.id >= whisper_token_eot(ctx)) continue;
if (first_t0 < 0) first_t0 = tok.t0;
last_t1 = tok.t1;
}
if (first_t0 >= 0) segment.t0 = first_t0;
if (last_t1 >= 0) segment.t1 = last_t1;
// sync segment boundaries with token bounds
int64_t first_t0 = -1, last_t1 = -1;
for (int t = 0; t < n_tokens; ++t) {
const auto & tok = segment.tokens[t];
if (tok.id >= whisper_token_eot(ctx)) continue;
if (first_t0 < 0) first_t0 = tok.t0;
last_t1 = tok.t1;
}
if (first_t0 >= 0) segment.t0 = first_t0;
if (last_t1 >= 0) segment.t1 = last_t1;
}
// Print DTW timestamps