Merge 116669655c into fc674574ca
This commit is contained in:
commit
767473186d
|
|
@ -4,6 +4,7 @@
|
|||
#include "convert.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
#define FATTN_KQ_STRIDE 256
|
||||
|
|
@ -645,7 +646,7 @@ static __global__ void flash_attn_mask_to_KV_max(
|
|||
#pragma unroll
|
||||
for (int j = 0; j < ncols1; ++j) {
|
||||
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
|
||||
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
|
||||
all_inf = all_inf && int(std::isinf(tmp.x)) && int(std::isinf(tmp.y));
|
||||
}
|
||||
|
||||
all_inf = warp_reduce_all(all_inf);
|
||||
|
|
|
|||
Loading…
Reference in New Issue