This commit is contained in:
Daniel Worthington-Bodart 2026-04-20 10:16:42 +00:00 committed by GitHub
commit 4067ba6b3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 2 deletions

View File

@ -339,6 +339,27 @@ extern "C" {
int n_past,
int n_threads);
// Same as whisper_decode_with_state, but saves alignment head cross-attention data.
// Requires context created with dtw_token_timestamps=true and flash_attn=false.
WHISPER_API int whisper_decode_with_state_and_aheads(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token * tokens,
int n_tokens,
int n_past,
int n_threads);
// Get cross-attention data from alignment heads after a decode call with aheads enabled.
// Returns pointer to float array of shape [n_tokens x n_audio_ctx x n_heads].
// Copies data from GPU/backend to CPU on each call.
// Returns NULL if DTW is not enabled or no attention data is available.
// The pointer is valid until the next call to this function or whisper_free_state.
WHISPER_API const float * whisper_state_get_aheads_cross_qks(
struct whisper_state * state,
int * n_tokens,
int * n_audio_ctx,
int * n_heads);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens

View File

@ -3954,6 +3954,38 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads);
}
int whisper_decode_with_state_and_aheads(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) {
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}
return 0;
}
const float * whisper_state_get_aheads_cross_qks(struct whisper_state * state, int * out_n_tokens, int * out_n_audio_ctx, int * out_n_heads) {
if (state->aheads_cross_QKs == nullptr) {
return nullptr;
}
const int n_tokens = state->aheads_cross_QKs->ne[0];
const int n_audio_ctx = state->aheads_cross_QKs->ne[1];
const int n_heads = state->aheads_cross_QKs->ne[2];
auto & data = state->aheads_cross_QKs_data;
data.resize(n_tokens * n_audio_ctx * n_heads);
ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads);
if (out_n_tokens) *out_n_tokens = n_tokens;
if (out_n_audio_ctx) *out_n_audio_ctx = n_audio_ctx;
if (out_n_heads) *out_n_heads = n_heads;
return data.data();
}
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
const auto res = tokenize(ctx->vocab, text);
@ -5158,8 +5190,9 @@ bool whisper_vad_detect_speech_no_reset(
//WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
}
vctx->t_vad_us += ggml_time_us() - t_start_vad_us;
WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
const int64_t t_vad_this_us = ggml_time_us() - t_start_vad_us;
vctx->t_vad_us += t_vad_this_us;
WHISPER_LOG_INFO("%s: vad time = %.2f ms (cumulative %.2f ms) processing %d samples\n", __func__, 1e-3f * t_vad_this_us, 1e-3f * vctx->t_vad_us, n_samples);
ggml_backend_sched_reset(sched);