fix: allow flash attention in encoder when DTW is enabled

Previously, enabling DTW token timestamps with flash attention
caused DTW to be silently disabled entirely. DTW only needs the
explicit cross-attention weights (KQ_soft_max) from the decoder,
so flash attention can remain enabled for:
- encoder self-attention
- decoder self-attention

Only the cross-attention path in both the encoder (KV storage) and
decoder (KQ computation) needs to fall back to the non-flash path
when DTW is active, since flash attention fuses the entire attention
operation and never materializes KQ_soft_max.

This allows DTW timestamps to work alongside flash attention with
no encoder performance penalty.

Fixes #3662
This commit is contained in:
Acelogic 2026-03-11 15:30:19 -04:00
parent 30c5194c96
commit 640443521f
1 changed files with 11 additions and 4 deletions

View File

@ -2317,7 +2317,11 @@ static struct ggml_cgraph * whisper_build_graph_cross(
struct ggml_tensor * k;
struct ggml_tensor * v;
if (wctx.params.flash_attn) {
// Use non-flash layout for cross-attention KV when DTW is enabled,
// since DTW needs the explicit cross-attention weights (KQ_soft_max)
const bool flash_cross = wctx.params.flash_attn && !wctx.params.dtw_token_timestamps;
if (flash_cross) {
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
@ -2677,7 +2681,11 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
0, 2, 1, 3);
if (wctx.params.flash_attn) {
// Use non-flash path for cross-attention when DTW is enabled,
// since DTW needs the explicit cross-attention weights (KQ_soft_max)
const bool flash_cross = wctx.params.flash_attn && !wctx.params.dtw_token_timestamps;
if (flash_cross) {
struct ggml_tensor * Kcross =
ggml_view_3d(ctx0, wstate.kv_cross.k,
n_state_head, n_audio_ctx_pad, n_head,
@ -3706,8 +3714,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
ggml_time_init();
if (params.flash_attn && params.dtw_token_timestamps) {
WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
params.dtw_token_timestamps = false;
WHISPER_LOG_INFO("%s: flash_attn with dtw - disabling flash attention for cross-attention only\n", __func__);
}
WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);