CUDA: fix negative KV_max values in FA (llama/15321)

This commit is contained in:
Johannes Gäßler 2025-08-14 23:21:24 +02:00 committed by Georgi Gerganov
parent 0e15332255
commit ba32f5df0a
1 changed files with 5 additions and 1 deletions

View File

@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max(
all_inf = warp_reduce_all(all_inf);
if (!all_inf) {
KV_max_sj += FATTN_KQ_STRIDE;
break;
}
}
// If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
// If the break was triggered it's the lower edge of the tile with the first non-masked values.
// In either case, walk back the decrementation by FATTN_KQ_STRIDE.
KV_max_sj += FATTN_KQ_STRIDE;
if (threadIdx.x != 0) {
return;
}