This commit is contained in:
Jamaika1 2026-04-20 13:45:08 +00:00 committed by GitHub
commit 767473186d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -4,6 +4,7 @@
#include "convert.cuh"
#include "vecdotq.cuh"
#include <cmath>
#include <cstdint>
#define FATTN_KQ_STRIDE 256
@ -645,7 +646,7 @@ static __global__ void flash_attn_mask_to_KV_max(
#pragma unroll
for (int j = 0; j < ncols1; ++j) {
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
all_inf = all_inf && int(std::isinf(tmp.x)) && int(std::isinf(tmp.y));
}
all_inf = warp_reduce_all(all_inf);