From aa42b48312f28f31a84840d51bcc783380b00d03 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 9 Jun 2026 07:19:56 +0900 Subject: [PATCH] ggml-webgpu: Improve prefill speeds for k-quants + refactor matmul for Q4/Q5/Q8 and k-quants (llama/24225) * ggml-webgpu: Improve prefill speeds + refactor matmul for quants * Fixes for editroconfig checker --- .../wgsl-shaders/mul_mat_decls.tmpl | 794 ++++++------------ 1 file changed, 259 insertions(+), 535 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 72991504d..ed4a6b13b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -98,72 +98,50 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } #endif // INIT_SRC0_SHMEM_Q1_0 -#ifdef INIT_SRC0_SHMEM_Q4_0 +#if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4) const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 18u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; +#if defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) +const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q +#else const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +#endif const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; + let block_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; + let shmem_idx = block_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - let tile_m = blck_idx / BLOCKS_K; + let tile_m = block_idx / BLOCKS_K; let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; + let block_k = block_idx % BLOCKS_K; let global_block_k = k_outer / BLOCK_SIZE + block_k; if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + +#ifdef INIT_SRC0_SHMEM_Q4_0 + let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u; let d = load_f16_at_src0(block_byte_base); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } - } - } -} -#endif // INIT_SRC0_SHMEM_Q4_0 +#elif INIT_SRC0_SHMEM_Q4_1 + let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u; + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); -#ifdef INIT_SRC0_SHMEM_Q4_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 20u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); - - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); @@ -175,41 +153,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q4_1 - -#ifdef INIT_SRC0_SHMEM_Q5_0 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 22u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -// tile_k is defined as 32u, so blocks_k ends up being 1 always -override BLOCKS_K = TILE_K / BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; +#elif INIT_SRC0_SHMEM_Q5_0 + let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u; let d = load_f16_at_src0(block_byte_base); let qh_packed = load_u32_at_src0(block_byte_base + 2u); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); @@ -226,44 +176,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q5_0 +#elif INIT_SRC0_SHMEM_Q5_1 + let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u; -#ifdef INIT_SRC0_SHMEM_Q5_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 24u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K / BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); + let qh_packed = load_u32_at_src0_aligned(block_byte_base + 4u); -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); - let qh_packed = load_u32_at_src0(block_byte_base + 4u); - - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; - let q_packed = load_u32_at_src0(q_byte_offset); + let q_packed = load_u32_at_src0_aligned(q_byte_offset); for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); @@ -277,461 +201,306 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q5_1 - -#ifdef INIT_SRC0_SHMEM_Q8_0 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 34u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; +#elif INIT_SRC0_SHMEM_Q8_0 + let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u; let d = load_f16_at_src0(block_byte_base); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } - } - } -} -#endif // INIT_SRC0_SHMEM_Q8_0 +#elif INIT_SRC0_SHMEM_Q8_1 + let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u; + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); -#ifdef INIT_SRC0_SHMEM_Q8_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 36u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); - - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d + m; shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; } } +#elif INIT_SRC0_SHMEM_MXFP4 + let block_byte_base = src0_idx * 17u; + let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u); + let e = ldexp(1.0, i32(eu8) - 128); + + // load NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e; + let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo); + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); + } + } +#endif } } } -#endif // INIT_SRC0_SHMEM_Q8_1 +#endif + +// k-quants +#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K) +const BLOCK_SIZE = 256u; +const NQ = 4u; + +fn store_shmem_kquants(val: vec4, idx: u32) { + shmem[idx] = val.x; + shmem[idx + 1] = val.y; + shmem[idx + 2] = val.z; + shmem[idx + 3] = val.w; +} + +fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 { + return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u); +} + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * NQ) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + store_shmem_kquants(vec4(f16(0.0), f16(0.0), f16(0.0), f16(0.0)), elem_idx); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; #ifdef INIT_SRC0_SHMEM_Q2_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 84u; + let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u; + let scales_byte_base = block_byte_base; + let qs_byte_base = block_byte_base + 16u; + let dm_byte_base = block_byte_base + 80u; -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - // Use standard thread layout instead of lane/row_group - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; + let d_packed = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(d_packed[0]); + let dmin = f16(d_packed[1]); - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; + let chunk = k_in_block / 128u; + let pos_in_chunk = k_in_block % 32u; + let sub_block = k_in_block / 16u; + let shift_phase = (k_in_block % 128u) / 32u; - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } + // whole 2 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_vec4 = vec4( + f16((qs_word >> (2u * shift_phase + 0u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 8u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 16u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 24u)) & 0x3u), + ); - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; + let scale = load_byte_at_src0_aligned(scales_byte_base + sub_block); - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let dl = d * f16(scale & 0xFu); + let ml = dmin * f16(scale >> 4u); - let d = load_f16_at_src0(block_byte_base + 80u); - let dmin = load_f16_at_src0(block_byte_base + 82u); + store_shmem_kquants(qs_vec4 * dl - ml, elem_idx); +#elif INIT_SRC0_SHMEM_Q3_K + let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u; + let hmask_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 32u; + let scales_byte_base = block_byte_base + 96u; - // Decode the element at position k_in_block - let block_of_32 = k_in_block / 32u; - let pos_in_32 = k_in_block % 32u; + let d_all = load_f16_at_src0(block_byte_base + 108u); - let q_b_idx = (block_of_32 / 4u) * 32u; - let shift = (block_of_32 % 4u) * 2u; - let k = (pos_in_32 / 16u) * 16u; - let l = pos_in_32 % 16u; + let chunk = k_in_block / 128u; + let pos_in_chunk = k_in_block % 32u; + let sub_block = k_in_block / 16u; + let shift_phase = (k_in_block % 128u) / 32u; - let is = k_in_block / 16u; + let hmask_block = pos_in_chunk; + let hmask_shift_phase = k_in_block / 32u; - let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u)); - let sc = get_byte(sc_packed, is % 4u); + // low 2 bits (4 elems) + let q_lo2_word = load_u32_at_src0(qs_byte_base + 32u * chunk + 1u * hmask_block); + let q_lo2_vec4 = vec4( + f16((q_lo2_word >> (2u * shift_phase + 0u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 8u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 16u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 24u)) & 3u) + ); - let dl = d * f16(sc & 0xFu); - let ml = dmin * f16(sc >> 4u); + // high 1 bit (4 elems) + let q_hi1_word = load_u32_at_src0(hmask_byte_base + pos_in_chunk); + let q_hi1_vec4 = vec4( + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 0u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 8u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 16u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 24u)) & 1u) == 1u)) + ); - let q_idx = q_b_idx + k + l; - let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); - let q_byte = get_byte(q_packed, q_idx % 4u); - let qs_val = (q_byte >> shift) & 3u; + let q_vec4 = q_lo2_vec4 - q_hi1_vec4; - let q_val = f16(qs_val) * dl - ml; - shmem[elem_idx] = q_val; - } -} -#endif // INIT_SRC0_SHMEM_Q2_K + let scale_low4 = (load_byte_at_src0_aligned(scales_byte_base + (sub_block % 8u)) >> (4u * (sub_block / 8u))) & 0xFu; + let scale_hi2 = (load_byte_at_src0_aligned(scales_byte_base + 8u + (sub_block % 4u)) >> (2u * (sub_block / 4u))) & 3u; + let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0); -#ifdef INIT_SRC0_SHMEM_Q3_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 110u; + store_shmem_kquants(dl * q_vec4, elem_idx); +#elif INIT_SRC0_SHMEM_Q4_K + let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u; + let dm_byte_base = block_byte_base + 0u; + let scale_byte_base = block_byte_base + 4u; + let qs_byte_base = block_byte_base + 16u; -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; + let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(dm[0]); + let dmin = f16(dm[1]); - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; + let chunk = k_in_block / 64u; + let pos_in_chunk = (k_in_block % 64u) % 32u; + let sub_block = k_in_block / 32u; + let shift_phase = sub_block & 1u; - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base + 108u); - - // Load and unpack scales - let kmask1: u32 = 0x03030303u; - let kmask2: u32 = 0x0f0f0f0fu; - - var scale_vals: array; - for (var i: u32 = 0u; i < 4u; i++) { - scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i); - } - - var tmp: u32 = scale_vals[2]; - scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u); - scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u); - scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u); - scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u); - - // Load hmask and qs arrays - var hmask_vals: array; - for (var i: u32 = 0u; i < 8u; i++) { - hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i); - } - - var qs_vals: array; - for (var i: u32 = 0u; i < 16u; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i); - } - - let half = k_in_block / 128u; // 0 or 1 - let pos_in_half = k_in_block % 128u; // 0-127 - let shift_group = pos_in_half / 32u; // 0-3 - let pos_in_32 = pos_in_half % 32u; // 0-31 - let k_group = pos_in_32 / 16u; // 0 or 1 - let l = pos_in_32 % 16u; // 0-15 - - let q_b_idx = half * 32u; // 0 or 32 - let shift = shift_group * 2u; // 0, 2, 4, 6 - let k = k_group * 16u; // 0 or 16 - let is = k_in_block / 16u; // 0-15 - - // m increments every 32 elements across entire 256 element block - let m_shift = k_in_block / 32u; // 0-7 - let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128 - - let sc = get_byte(scale_vals[is / 4u], is % 4u); - let dl = d * (f16(sc) - 32.0); - - let q_idx = q_b_idx + k + l; - let hm_idx = k + l; - - let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u); - let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u); - - let hm = select(4.0, 0.0, (hmask_byte & m) != 0); - let qs_val = (q_byte >> shift) & 3u; - - let q_val = (f16(qs_val) - f16(hm)) * dl; - shmem[elem_idx] = q_val; - } -} - -#endif // INIT_SRC0_SHMEM_Q3_K - -#ifdef INIT_SRC0_SHMEM_Q4_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 144u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base); - let dmin = load_f16_at_src0(block_byte_base + 2u); - - // Map k_in_block to loop structure: - // Outer loop over 64-element groups (alternating q_b_idx) - // Inner loop over 2 shifts per group - let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx) - let pos_in_64 = k_in_block % 64u; // 0-63 - let shift_group = pos_in_64 / 32u; // 0 or 1 - let l = pos_in_64 % 32u; // 0-31 - - let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 - let shift = shift_group * 4u; // 0 or 4 - let is = k_in_block / 32u; // 0-7 + // whole 4 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_vec4 = vec4( + f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu) + ); var sc: u32; var mn: u32; - let scale_base = block_byte_base + 4u; - - if (is < 4u) { - let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); - let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - sc = sc_byte & 63u; - mn = min_byte & 63u; + if (sub_block < 4u) { + let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u); + let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - - sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); - mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); } let dl = d * f16(sc); let ml = dmin * f16(mn); - let q_idx = q_b_idx + l; - let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); + store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx); +#elif INIT_SRC0_SHMEM_Q5_K + let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u; + let dm_byte_base = block_byte_base + 0u; + let scale_byte_base = block_byte_base + 4u; + let qh_byte_base = block_byte_base + 16u; + let qs_byte_base = block_byte_base + 48u; - let q_byte = get_byte(q_packed, q_idx % 4u); - let qs_val = (q_byte >> shift) & 0xFu; + let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(dm[0]); + let dmin = f16(dm[1]); - let q_val = f16(qs_val) * dl - ml; - shmem[elem_idx] = q_val; - } -} -#endif // INIT_SRC0_SHMEM_Q4_K + let chunk = k_in_block / 64u; + let pos_in_chunk = (k_in_block % 64u) % 32u; + let sub_block = k_in_block / 32u; + let shift_phase = sub_block & 1u; -#ifdef INIT_SRC0_SHMEM_Q5_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 176u; + let qh_block = k_in_block % 32u; + let qh_shift_phase = sub_block; -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; + // low 4 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_lo4_vec4 = vec4( + f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu) + ); - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base); - let dmin = load_f16_at_src0(block_byte_base + 2u); - - - // The original loop processes elements in groups of 64 - // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4] - // But u increments EVERY 32 elements (after each l loop) - let group_of_64 = k_in_block / 64u; // 0-3 - let pos_in_64 = k_in_block % 64u; // 0-63 - let shift_group = pos_in_64 / 32u; // 0 or 1 - let l = pos_in_64 % 32u; // 0-31 - - let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 - let shift = shift_group * 4u; // 0 or 4 - let is = k_in_block / 32u; // 0-7 - - // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128) - let u_shift = k_in_block / 32u; // 0-7 - let u: u32 = 1u << u_shift; + // high 1 bit (4 elems) + let qh_word = load_u32_at_src0_aligned(qh_byte_base + qh_block); + let qh_vec4 = vec4( + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 0u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 8u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 16u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 24u)) & 1u) == 1u)) + ); var sc: u32; var mn: u32; - let scale_base = block_byte_base + 4u; - - if (is < 4u) { - let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); - let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - sc = sc_byte & 63u; - mn = min_byte & 63u; + if (sub_block < 4u) { + let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u); + let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - - sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); - mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); } let dl = d * f16(sc); let ml = dmin * f16(mn); - let q_idx = q_b_idx + l; - let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u)); + store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4(ml, ml, ml, ml), elem_idx); +#elif INIT_SRC0_SHMEM_Q6_K + let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u; + let ql_byte_base = block_byte_base; + let qh_byte_base = block_byte_base + 128u; + let scales_byte_base = block_byte_base + 192u; + let d_byte_base = block_byte_base + 208u; - let q_byte = get_byte(q_packed, q_idx % 4u); + let d = load_f16_at_src0(d_byte_base); - let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u)); + let chunk = k_in_block / 128u; + let ql_pos_in_chunk = (k_in_block % 128u) % 64u; + let qh_pos_in_chunk = (k_in_block % 128u) % 32u; + let sub_block = k_in_block / 16u; + let ql_shift_phase = (k_in_block % 128u) / 64u; + let qh_shift_phase = (k_in_block % 128u) / 32u; - let qh_byte = get_byte(qh_packed, l % 4u); + // low 4 bits (4 elems) + let ql_word = load_u32_at_src0(ql_byte_base + 64u * chunk + 1u * ql_pos_in_chunk); + let ql_lo4_vec4 = vec4( + (ql_word >> (4u * ql_shift_phase + 0u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 8u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 16u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 24u)) & 0xFu + ); - let qs_val = (q_byte >> shift) & 0xFu; - let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); + // hi 2 bits (4 elems) + let qh_word = load_u32_at_src0(qh_byte_base + 32u * chunk + 1u * qh_pos_in_chunk); + let qh_hi2_vec4 = vec4( + ((qh_word >> (2u * qh_shift_phase + 0u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 8u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 16u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 24u)) & 0x3u) << 4u, + ); - let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml; - shmem[elem_idx] = q_val; + let q_vec4 = vec4(qh_hi2_vec4 | ql_lo4_vec4) - vec4(32.0, 32.0, 32.0, 32.0); + + let scale_byte = scales_byte_base + 1u * sub_block; + let scale_word = load_u32_at_src0_aligned(scale_byte); + let scale = get_byte_i32(scale_word, scale_byte & 3u); + + store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx); +#endif } } - -#endif // INIT_SRC0_SHMEM_Q5_K - -#ifdef INIT_SRC0_SHMEM_Q6_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 210u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let half = k_in_block / 128u; - let pos_in_half = k_in_block % 128u; - let quarter = pos_in_half / 32u; - let l = pos_in_half % 32u; - - let ql_b_idx = half * 64u; - let qh_b_idx = half * 32u; - let sc_b_idx = half * 8u; - - // Load only ql13 word needed - let ql13_flat = ql_b_idx + l; - let ql13 = load_u32_at_src0(block_byte_base + ql13_flat); - let ql13_b = get_byte(ql13, 0u); - - // Load only ql24 word needed - let ql24_flat = ql_b_idx + l + 32u; - let ql24 = load_u32_at_src0(block_byte_base + ql24_flat); - let ql24_b = get_byte(ql24, 0u); - - // Load only qh word needed - let qh_flat = qh_b_idx + l; - let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat); - let qh_b = get_byte(qh, 0u); - - let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); - let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0); - let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0); - let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0); - - // Load only the scale word needed - let is = l / 16u; - let sc_idx = sc_b_idx + is + quarter * 2u; - let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx); - let sc_val = get_byte_i32(sc, 0u); - - let d = load_f16_at_src0(block_byte_base + 208u); - - var q_val: f16; - if (quarter == 0u) { - q_val = q1; - } else if (quarter == 1u) { - q_val = q2; - } else if (quarter == 2u) { - q_val = q3; - } else { - q_val = q4; - } - - shmem[elem_idx] = d * f16(sc_val) * q_val; - } -} -#endif // INIT_SRC0_SHMEM_Q6_K +#endif // k-quants #ifdef INIT_SRC0_SHMEM_IQ4_NL const BLOCK_SIZE = 32u; @@ -1155,48 +924,3 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif // INIT_SRC0_SHMEM_IQ3_S - -#ifdef INIT_SRC0_SHMEM_MXFP4 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 17u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights uses 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); - let e = ldexp(1.0, i32(eu8) - 128); - - // store NQ(16) weights - for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - - let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; - let q_packed = load_u32_at_src0(q_byte_offset); - - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e; - let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo); - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); - } - } - } - } -} -#endif // INIT_SRC0_SHMEM_MXFP4