Fix: added <cmath>

```
fattn-common.cuh: In function 'void flash_attn_mask_to_KV_max(const half2*, int*, int, int, int)':
fattn-common.cuh:599:38: error: there are no arguments to 'isinf' that depend on a template parameter, so a declaration of 'isinf' must be available [-Wtemplate-body]
  599 |             all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
      |                                      ^~~~~
fattn-common.cuh:599:38: note: (if you use '-fpermissive', G++ will accept your code, but allowing the use of an undeclared name is deprecated)
fattn-common.cuh:599:59: error: there are no arguments to 'isinf' that depend on a template parameter, so a declaration of 'isinf' must be available [-Wtemplate-body]
  599 |             all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
      |                                                           ^~~~~
```
This commit is contained in:
Jamaika1 2026-01-02 12:51:07 +01:00 committed by GitHub
parent e9898ddfb9
commit 116669655c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -4,6 +4,7 @@
#include "convert.cuh"
#include "vecdotq.cuh"
#include <cmath>
#include <cstdint>
#define FATTN_KQ_STRIDE 256
@ -595,7 +596,7 @@ static __global__ void flash_attn_mask_to_KV_max(
#pragma unroll
for (int j = 0; j < ncols1; ++j) {
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
all_inf = all_inf && int(std::isinf(tmp.x)) && int(std::isinf(tmp.y));
}
all_inf = warp_reduce_all(all_inf);