From 640443521ffe040c623414b974c113e3ce35355e Mon Sep 17 00:00:00 2001 From: Acelogic Date: Wed, 11 Mar 2026 15:30:19 -0400 Subject: [PATCH] fix: allow flash attention in encoder when DTW is enabled Previously, enabling DTW token timestamps with flash attention caused DTW to be silently disabled entirely. DTW only needs the explicit cross-attention weights (KQ_soft_max) from the decoder, so flash attention can remain enabled for: - encoder self-attention - decoder self-attention Only the cross-attention path in both the encoder (KV storage) and decoder (KQ computation) needs to fall back to the non-flash path when DTW is active, since flash attention fuses the entire attention operation and never materializes KQ_soft_max. This allows DTW timestamps to work alongside flash attention with no encoder performance penalty. Fixes #3662 --- src/whisper.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index 796bccfb..41a29e86 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -2317,7 +2317,11 @@ static struct ggml_cgraph * whisper_build_graph_cross( struct ggml_tensor * k; struct ggml_tensor * v; - if (wctx.params.flash_attn) { + // Use non-flash layout for cross-attention KV when DTW is enabled, + // since DTW needs the explicit cross-attention weights (KQ_soft_max) + const bool flash_cross = wctx.params.flash_attn && !wctx.params.dtw_token_timestamps; + + if (flash_cross) { k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad)); @@ -2677,7 +2681,11 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), 0, 2, 1, 3); - if (wctx.params.flash_attn) { + // Use non-flash path for cross-attention when DTW is enabled, + // since DTW needs the explicit cross-attention weights (KQ_soft_max) + const bool flash_cross = wctx.params.flash_attn && !wctx.params.dtw_token_timestamps; + + if (flash_cross) { struct ggml_tensor * Kcross = ggml_view_3d(ctx0, wstate.kv_cross.k, n_state_head, n_audio_ctx_pad, n_head, @@ -3706,8 +3714,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ ggml_time_init(); if (params.flash_attn && params.dtw_token_timestamps) { - WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__); - params.dtw_token_timestamps = false; + WHISPER_LOG_INFO("%s: flash_attn with dtw - disabling flash attention for cross-attention only\n", __func__); } WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);