Merge 640443521f into fc674574ca
This commit is contained in:
commit
0e3b3116ac
|
|
@ -2317,7 +2317,11 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|||
struct ggml_tensor * k;
|
||||
struct ggml_tensor * v;
|
||||
|
||||
if (wctx.params.flash_attn) {
|
||||
// Use non-flash layout for cross-attention KV when DTW is enabled,
|
||||
// since DTW needs the explicit cross-attention weights (KQ_soft_max)
|
||||
const bool flash_cross = wctx.params.flash_attn && !wctx.params.dtw_token_timestamps;
|
||||
|
||||
if (flash_cross) {
|
||||
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
||||
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
|
||||
|
||||
|
|
@ -2677,7 +2681,11 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
||||
0, 2, 1, 3);
|
||||
|
||||
if (wctx.params.flash_attn) {
|
||||
// Use non-flash path for cross-attention when DTW is enabled,
|
||||
// since DTW needs the explicit cross-attention weights (KQ_soft_max)
|
||||
const bool flash_cross = wctx.params.flash_attn && !wctx.params.dtw_token_timestamps;
|
||||
|
||||
if (flash_cross) {
|
||||
struct ggml_tensor * Kcross =
|
||||
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
||||
n_state_head, n_audio_ctx_pad, n_head,
|
||||
|
|
@ -3706,8 +3714,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
|
|||
ggml_time_init();
|
||||
|
||||
if (params.flash_attn && params.dtw_token_timestamps) {
|
||||
WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
|
||||
params.dtw_token_timestamps = false;
|
||||
WHISPER_LOG_INFO("%s: flash_attn with dtw - disabling flash attention for cross-attention only\n", __func__);
|
||||
}
|
||||
|
||||
WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
|
||||
|
|
|
|||
Loading…
Reference in New Issue