Merge 75dcf461d7 into fc674574ca
This commit is contained in:
commit
7436cc41d9
|
|
@ -144,8 +144,8 @@ extern "C" {
|
|||
|
||||
// [EXPERIMENTAL] Token-level timestamps with DTW
|
||||
// do not use if you haven't computed token-level timestamps with dtw
|
||||
// Roughly corresponds to the moment in audio in which the token was output
|
||||
int64_t t_dtw;
|
||||
int64_t t_dtw; // start time from DTW alignment
|
||||
int64_t t_dtw_end; // end time from DTW alignment
|
||||
|
||||
float vlen; // voice length of the token
|
||||
} whisper_token_data;
|
||||
|
|
|
|||
|
|
@ -6449,7 +6449,7 @@ static whisper_token_data whisper_sample_token(
|
|||
const whisper_decoder & decoder,
|
||||
bool best) {
|
||||
whisper_token_data result = {
|
||||
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f,
|
||||
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, -1, 0.0f,
|
||||
};
|
||||
|
||||
const auto & vocab = ctx.vocab;
|
||||
|
|
@ -6567,7 +6567,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|||
const auto id = dist(decoder.rng);
|
||||
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
|
||||
|
||||
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, });
|
||||
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, -1, 0.0f, });
|
||||
|
||||
if (result[i].id >= vocab.token_beg) {
|
||||
result[i].tid = result[i].id;
|
||||
|
|
@ -6800,6 +6800,11 @@ int whisper_full_with_state(
|
|||
|
||||
result_all.clear();
|
||||
|
||||
if (ctx->params.flash_attn && params.strategy == WHISPER_SAMPLING_GREEDY) {
|
||||
// Avoid spurious timestamp anchors with flash attention in greedy decoding.
|
||||
params.thold_ptsum = std::max(params.thold_ptsum, 0.02f);
|
||||
}
|
||||
|
||||
if (n_samples > 0) {
|
||||
// compute log mel spectrogram
|
||||
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
||||
|
|
@ -8920,14 +8925,23 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|||
|
||||
ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
|
||||
|
||||
// Place timestamps on segments
|
||||
// Place timestamps on segments - capture both start AND end times from DTW path
|
||||
int32_t last_v = 0;
|
||||
int32_t last_time_index = 0;
|
||||
auto seg_i = state->result_all.begin() + i_segment;
|
||||
auto tok_i = seg_i->tokens.begin();
|
||||
whisper_token_data * prev_tok = nullptr;
|
||||
|
||||
for (int i = 0; i < alignment->ne[1]; ++i) {
|
||||
int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0);
|
||||
int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
|
||||
|
||||
if (v != last_v) {
|
||||
int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
|
||||
// End time for PREVIOUS token = last audio frame before text_index changed
|
||||
if (prev_tok != nullptr) {
|
||||
prev_tok->t_dtw_end = (last_time_index * 2) + seek;
|
||||
}
|
||||
|
||||
int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
|
||||
last_v = v;
|
||||
|
||||
|
|
@ -8940,13 +8954,22 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|||
}
|
||||
}
|
||||
|
||||
// Start time for current token
|
||||
tok_i->t_dtw = timestamp;
|
||||
prev_tok = &(*tok_i);
|
||||
|
||||
++tok_i;
|
||||
if (tok_i == seg_i->tokens.end()) {
|
||||
++seg_i;
|
||||
tok_i = seg_i->tokens.begin();
|
||||
}
|
||||
}
|
||||
last_time_index = time_index;
|
||||
}
|
||||
|
||||
// End time for the last token
|
||||
if (prev_tok != nullptr) {
|
||||
prev_tok->t_dtw_end = (last_time_index * 2) + seek;
|
||||
}
|
||||
|
||||
// Print DTW timestamps
|
||||
|
|
|
|||
Loading…
Reference in New Issue