From 3b54460d2cbe48adcabffda378f43ac9e6336a08 Mon Sep 17 00:00:00 2001 From: Daniel Worthington-Bodart Date: Thu, 12 Feb 2026 12:59:39 +0000 Subject: [PATCH 1/2] Add cross-attention accessor functions for AlignAtt streaming Add whisper_decode_with_state_and_aheads() which saves alignment head cross-attention data during decode, and whisper_state_get_aheads_cross_qks() to read the resulting tensor from state. Co-Authored-By: Claude Opus 4.6 --- include/whisper.h | 21 +++++++++++++++++++++ src/whisper.cpp | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/include/whisper.h b/include/whisper.h index f4cc6bf7..3343b27b 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -339,6 +339,27 @@ extern "C" { int n_past, int n_threads); + // Same as whisper_decode_with_state, but saves alignment head cross-attention data. + // Requires context created with dtw_token_timestamps=true and flash_attn=false. + WHISPER_API int whisper_decode_with_state_and_aheads( + struct whisper_context * ctx, + struct whisper_state * state, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + + // Get cross-attention data from alignment heads after a decode call with aheads enabled. + // Returns pointer to float array of shape [n_tokens x n_audio_ctx x n_heads]. + // Copies data from GPU/backend to CPU on each call. + // Returns NULL if DTW is not enabled or no attention data is available. + // The pointer is valid until the next call to this function or whisper_free_state. + WHISPER_API const float * whisper_state_get_aheads_cross_qks( + struct whisper_state * state, + int * n_tokens, + int * n_audio_ctx, + int * n_heads); + // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success, no more than n_max_tokens diff --git a/src/whisper.cpp b/src/whisper.cpp index 796bccfb..0729f777 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -3954,6 +3954,38 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads); } +int whisper_decode_with_state_and_aheads(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { + whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0); + whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1); + + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) { + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + +const float * whisper_state_get_aheads_cross_qks(struct whisper_state * state, int * out_n_tokens, int * out_n_audio_ctx, int * out_n_heads) { + if (state->aheads_cross_QKs == nullptr) { + return nullptr; + } + + const int n_tokens = state->aheads_cross_QKs->ne[0]; + const int n_audio_ctx = state->aheads_cross_QKs->ne[1]; + const int n_heads = state->aheads_cross_QKs->ne[2]; + + auto & data = state->aheads_cross_QKs_data; + data.resize(n_tokens * n_audio_ctx * n_heads); + ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads); + + if (out_n_tokens) *out_n_tokens = n_tokens; + if (out_n_audio_ctx) *out_n_audio_ctx = n_audio_ctx; + if (out_n_heads) *out_n_heads = n_heads; + + return data.data(); +} + int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { const auto res = tokenize(ctx->vocab, text); From 9305251037974c142458b792b1b0cb9c86f517bb Mon Sep 17 00:00:00 2001 From: Daniel Worthington-Bodart Date: Fri, 13 Feb 2026 06:43:25 +0000 Subject: [PATCH 2/2] Fix VAD timing log to show per-call and cumulative time The log previously showed cumulative time labeled as just "vad time", which was misleading when called multiple times. Now shows both the per-call time and the cumulative total. Co-Authored-By: Claude Opus 4.6 --- src/whisper.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index 0729f777..a1bf4b0b 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -5189,8 +5189,9 @@ bool whisper_vad_detect_speech( //WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]); } - vctx->t_vad_us += ggml_time_us() - t_start_vad_us; - WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples); + const int64_t t_vad_this_us = ggml_time_us() - t_start_vad_us; + vctx->t_vad_us += t_vad_this_us; + WHISPER_LOG_INFO("%s: vad time = %.2f ms (cumulative %.2f ms) processing %d samples\n", __func__, 1e-3f * t_vad_this_us, 1e-3f * vctx->t_vad_us, n_samples); ggml_backend_sched_reset(sched);