Add DTW token timestamps
This commit is contained in:
parent
6114e69213
commit
543dabe25a
169
src/whisper.cpp
169
src/whisper.cpp
|
|
@ -7695,7 +7695,7 @@ int whisper_full_with_state(
|
|||
whisper_exp_compute_token_level_timestamps(
|
||||
*ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
||||
|
||||
if (params.max_len > 0) {
|
||||
if (params.max_len > 0 && !ctx->params.dtw_token_timestamps) {
|
||||
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
||||
}
|
||||
}
|
||||
|
|
@ -7708,15 +7708,14 @@ int whisper_full_with_state(
|
|||
// FIXME: will timestamp offsets be correct?
|
||||
// [EXPERIMENTAL] Token-level timestamps with DTW
|
||||
{
|
||||
const int n_segments = state->result_all.size() - n_segments_before;
|
||||
int n_segments = state->result_all.size() - n_segments_before;
|
||||
if (ctx->params.dtw_token_timestamps && n_segments) {
|
||||
const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
|
||||
whisper_exp_compute_token_level_timestamps_dtw(
|
||||
ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
|
||||
|
||||
if (params.new_segment_callback) {
|
||||
for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
|
||||
params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
|
||||
}
|
||||
params.new_segment_callback(ctx, state, n_segments, params.new_segment_callback_user_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -8949,6 +8948,166 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|||
}
|
||||
}
|
||||
|
||||
// adjust timestamps
|
||||
const int64_t min_dur = 5;
|
||||
|
||||
for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
|
||||
auto & segment = state->result_all[i];
|
||||
const int n_tokens = segment.tokens.size();
|
||||
|
||||
for (int t = 0; t < n_tokens; ++t) {
|
||||
auto & tok = segment.tokens[t];
|
||||
|
||||
if (tok.id >= whisper_token_eot(ctx)) continue;
|
||||
|
||||
int len = 1;
|
||||
const char * text = whisper_token_to_str(ctx, tok.id);
|
||||
if (text) {
|
||||
len = (int)strlen(text);
|
||||
if (len > 0 && text[0] == ' ') {
|
||||
text++;
|
||||
len--;
|
||||
}
|
||||
}
|
||||
|
||||
// onset shift
|
||||
{
|
||||
int shift = 0;
|
||||
if (len > 0 && text) {
|
||||
char c = tolower(text[0]);
|
||||
if (strchr("aeiouywbcdgkpqqt", c)) {
|
||||
shift = 15;
|
||||
} else if (c >= 'a' && c <= 'z') {
|
||||
shift = 8;
|
||||
}
|
||||
}
|
||||
|
||||
if (shift > 0) {
|
||||
int64_t prev_end = 0;
|
||||
if (t > 0) {
|
||||
for (int t2 = t - 1; t2 >= 0; --t2) {
|
||||
if (segment.tokens[t2].id < whisper_token_eot(ctx)) {
|
||||
prev_end = segment.tokens[t2].t_dtw;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (i > 0 && !state->result_all[i-1].tokens.empty()) {
|
||||
prev_end = state->result_all[i-1].tokens.back().t_dtw;
|
||||
}
|
||||
|
||||
if (tok.t_dtw - shift > prev_end + 1) {
|
||||
tok.t_dtw -= shift;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// min duration
|
||||
{
|
||||
int64_t next_t_dtw = -1;
|
||||
for (int t2 = t + 1; t2 < n_tokens; ++t2) {
|
||||
if (segment.tokens[t2].id < whisper_token_eot(ctx)) {
|
||||
next_t_dtw = segment.tokens[t2].t_dtw;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (next_t_dtw < 0) {
|
||||
next_t_dtw = segment.t1;
|
||||
}
|
||||
|
||||
int64_t duration = next_t_dtw - tok.t_dtw;
|
||||
const int64_t adaptive_min = std::max((int64_t)5, (int64_t)(len * 2));
|
||||
|
||||
if (duration < adaptive_min && duration >= 0) {
|
||||
int64_t needed = adaptive_min - duration;
|
||||
|
||||
int64_t prev_end = 0;
|
||||
if (t > 0) {
|
||||
for (int t2 = t - 1; t2 >= 0; --t2) {
|
||||
if (segment.tokens[t2].id < whisper_token_eot(ctx)) {
|
||||
prev_end = segment.tokens[t2].t_dtw;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (i > 0 && !state->result_all[i-1].tokens.empty()) {
|
||||
prev_end = state->result_all[i-1].tokens.back().t_dtw;
|
||||
}
|
||||
|
||||
int64_t new_start = tok.t_dtw - needed;
|
||||
if (new_start > prev_end + 2) {
|
||||
tok.t_dtw = new_start;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// propagate to t0/t1
|
||||
for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
|
||||
auto & segment = state->result_all[i];
|
||||
const int n_tokens = segment.tokens.size();
|
||||
|
||||
for (int t = 0; t < n_tokens; ++t) {
|
||||
auto & tok = segment.tokens[t];
|
||||
|
||||
if (tok.id >= whisper_token_eot(ctx)) continue;
|
||||
|
||||
tok.t0 = tok.t_dtw;
|
||||
|
||||
int64_t next_t_dtw = -1;
|
||||
for (int t2 = t + 1; t2 < n_tokens; ++t2) {
|
||||
if (segment.tokens[t2].id < whisper_token_eot(ctx)) {
|
||||
next_t_dtw = segment.tokens[t2].t_dtw;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (next_t_dtw < 0 && i + 1 < state->result_all.size()) {
|
||||
for (const auto & ntok : state->result_all[i + 1].tokens) {
|
||||
if (ntok.id < whisper_token_eot(ctx)) {
|
||||
next_t_dtw = ntok.t_dtw;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t raw_t1 = (next_t_dtw >= 0) ? next_t_dtw : segment.t1;
|
||||
|
||||
// max duration
|
||||
{
|
||||
int len = 1;
|
||||
const char * text = whisper_token_to_str(ctx, tok.id);
|
||||
if (text) {
|
||||
len = (int)strlen(text);
|
||||
if (len > 0 && text[0] == ' ') len--;
|
||||
if (len < 1) len = 1;
|
||||
}
|
||||
|
||||
int64_t max_dur = std::max((int64_t)10, (int64_t)(len * 15));
|
||||
|
||||
if (raw_t1 < tok.t0 + min_dur) {
|
||||
raw_t1 = tok.t0 + min_dur;
|
||||
}
|
||||
|
||||
tok.t1 = (raw_t1 - tok.t0 > max_dur) ? tok.t0 + max_dur : raw_t1;
|
||||
}
|
||||
}
|
||||
|
||||
// segment bounds
|
||||
{
|
||||
int64_t first_t0 = -1;
|
||||
int64_t last_t1 = -1;
|
||||
for (int t = 0; t < n_tokens; ++t) {
|
||||
const auto & tok = segment.tokens[t];
|
||||
if (tok.id >= whisper_token_eot(ctx)) continue;
|
||||
if (first_t0 < 0) first_t0 = tok.t0;
|
||||
last_t1 = tok.t1;
|
||||
}
|
||||
if (first_t0 >= 0) segment.t0 = first_t0;
|
||||
if (last_t1 >= 0) segment.t1 = last_t1;
|
||||
}
|
||||
}
|
||||
|
||||
// Print DTW timestamps
|
||||
/*for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
|
||||
auto & segment = state->result_all[i];
|
||||
|
|
|
|||
Loading…
Reference in New Issue