diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 3f01e858d..43e22c5e5 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -61,6 +61,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 4, 64, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 4, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); @@ -1561,6 +1566,10 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } + if (DKQ == 192 && ncols2 != 8 && ncols2 != 16) { + NO_DEVICE_CODE; + return; + } #ifdef VOLTA_MMA_AVAILABLE if (ncols1*ncols2 < 32) { NO_DEVICE_CODE; diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index d60634cc0..c8281497d 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -34,6 +34,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); } break; + case 192: { + GGML_ASSERT(V->ne[0] == 128); + ggml_cuda_flash_attn_ext_tile_case<192, 128>(ctx, dst); + } break; case 256: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 585f2c228..7b0a5e5cf 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -62,6 +62,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) @@ -124,6 +130,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) @@ -193,6 +205,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) @@ -264,6 +282,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 6, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 128, 6, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 5, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) @@ -1250,7 +1274,20 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DKQ <= 512 && DKQ != 320) { + if constexpr (DKQ == 192) { + // MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn) + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("flash-attn tile (192/128): expected GQA ratio multiple of 8"); + } + + if constexpr (DKQ <= 512 && DKQ != 320 && DKQ != 192) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1303,6 +1340,7 @@ extern DECL_FATTN_TILE_CASE( 80, 80); extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(192, 128); extern DECL_FATTN_TILE_CASE(256, 256); extern DECL_FATTN_TILE_CASE(320, 256); extern DECL_FATTN_TILE_CASE(512, 512); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 8256591b2..e045b04f7 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -139,6 +139,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 128); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); break; + case 192: { + // MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn) + GGML_ASSERT(V->ne[0] == 128); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 16>(ctx, dst); + } else { + GGML_ASSERT(gqa_ratio % 8 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 8>(ctx, dst); + } + } break; case 256: GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); @@ -368,6 +384,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } break; + case 192: + if (V->ne[0] != 128 || !gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + if (gqa_ratio % 8 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 320: if (V->ne[0] != 256 || !gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; @@ -425,7 +449,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + // 192 satisfies % 64 == 0 but has no vec instance (DKQ != DV); force it onto the MMA path. + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && Q->ne[0] != 192 && K->ne[1] % FATTN_KQ_STRIDE == 0; // If Turing tensor cores are available, use them: if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { @@ -454,7 +479,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { int gqa_ratio_eff = 1; - const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8; while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { gqa_ratio_eff *= 2; } @@ -468,7 +493,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 512 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -501,7 +526,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); // MMA vs tile crossover benchmarked on MI300X @ d32768: // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu index fb26abeb0..b2661b931 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 1, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index 22d383173..6ae77bec8 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 1, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu index f011a208c..fd41e71b1 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 2, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 84b674cd0..9f4bef11a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 2, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu index f5fd0e236..cc41fa52f 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 4, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index 5906398db..859bea5c5 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 4, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 4bc60d62f..c975ce6b9 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 8, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu new file mode 100644 index 000000000..b571cca0d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(192, 128); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 5e9a1cb2e..af05a9eff 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,10 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 192, 256, 320, 512, 576] + +# DKQ -> DV override for asymmetric head dims. +HEAD_SIZES_V_OVERRIDE = {576: 512, 320: 256, 192: 128} TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] @@ -62,7 +65,7 @@ for filename in glob("*.cu"): os.remove(filename) for head_size_kq in HEAD_SIZES_KQ: - head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) + head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq) with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) @@ -85,15 +88,17 @@ for ncols in [8, 16, 32, 64]: if head_size_kq == 72: continue # Skip compilation of unused ncols2 values for niche head sizes: + if head_size_kq == 192 and ncols2 not in (8, 16): # MiMo-V2.5 + continue if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4 continue if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4 continue if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash continue - if head_size_kq not in (320, 576) and ncols2 in (16, 32): + if head_size_kq not in (192, 320, 576) and ncols2 in (16, 32): continue - head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) + head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq) f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: