ggml-webgpu: improve i-quants mul_mat performance and speed up prefill (llama/24530)

* Improve prefill speeds for i-quants

* Fix #if defined() usage in preprocessor guards.
This commit is contained in:
Masashi Yoshimura 2026-06-15 10:15:30 +09:00 committed by Georgi Gerganov
parent 1216e0957b
commit 3e0b917514
1 changed files with 392 additions and 314 deletions

View File

@ -98,6 +98,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
#endif // INIT_SRC0_SHMEM_Q1_0
// legacy-quants
#if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4)
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
@ -124,7 +125,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
#ifdef INIT_SRC0_SHMEM_Q4_0
#if defined(INIT_SRC0_SHMEM_Q4_0)
let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u;
let d = load_f16_at_src0(block_byte_base);
@ -134,7 +135,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let q_packed = load_u32_at_src0(q_byte_offset);
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
}
#elif INIT_SRC0_SHMEM_Q4_1
#endif // INIT_SRC0_SHMEM_Q4_0
#if defined(INIT_SRC0_SHMEM_Q4_1)
let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u;
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
let d = f16(dm[0]);
@ -153,7 +156,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
#elif INIT_SRC0_SHMEM_Q5_0
#endif // INIT_SRC0_SHMEM_Q4_1
#if defined(INIT_SRC0_SHMEM_Q5_0)
let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u;
let d = load_f16_at_src0(block_byte_base);
@ -176,7 +181,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
#elif INIT_SRC0_SHMEM_Q5_1
#endif // INIT_SRC0_SHMEM_Q5_0
#if defined(INIT_SRC0_SHMEM_Q5_1)
let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u;
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
@ -201,7 +208,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
#elif INIT_SRC0_SHMEM_Q8_0
#endif // INIT_SRC0_SHMEM_Q5_1
#if defined(INIT_SRC0_SHMEM_Q8_0)
let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u;
let d = load_f16_at_src0(block_byte_base);
@ -211,7 +220,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let q_packed = load_u32_at_src0(q_byte_offset);
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
}
#elif INIT_SRC0_SHMEM_Q8_1
#endif // INIT_SRC0_SHMEM_Q8_0
#if defined(INIT_SRC0_SHMEM_Q8_1)
let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u;
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
let d = f16(dm[0]);
@ -227,7 +238,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val;
}
}
#elif INIT_SRC0_SHMEM_MXFP4
#endif // INIT_SRC0_SHMEM_Q8_1
#if defined(INIT_SRC0_SHMEM_MXFP4)
let block_byte_base = src0_idx * 17u;
let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u);
let e = ldexp(1.0, i32(eu8) - 128);
@ -244,11 +257,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
}
}
#endif
#endif // INIT_SRC0_SHMEM_MXFP4
}
}
}
#endif
#endif // legacy-quants
// k-quants
#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K)
@ -284,7 +297,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
#ifdef INIT_SRC0_SHMEM_Q2_K
#if defined(INIT_SRC0_SHMEM_Q2_K)
let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u;
let scales_byte_base = block_byte_base;
let qs_byte_base = block_byte_base + 16u;
@ -314,7 +327,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(scale >> 4u);
store_shmem_kquants(qs_vec4 * dl - ml, elem_idx);
#elif INIT_SRC0_SHMEM_Q3_K
#endif // INIT_SRC0_SHMEM_Q2_K
#if defined(INIT_SRC0_SHMEM_Q3_K)
let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u;
let hmask_byte_base = block_byte_base + 0u;
let qs_byte_base = block_byte_base + 32u;
@ -355,7 +370,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0);
store_shmem_kquants(dl * q_vec4, elem_idx);
#elif INIT_SRC0_SHMEM_Q4_K
#endif // INIT_SRC0_SHMEM_Q3_K
#if defined(INIT_SRC0_SHMEM_Q4_K)
let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u;
let dm_byte_base = block_byte_base + 0u;
let scale_byte_base = block_byte_base + 4u;
@ -399,7 +416,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx);
#elif INIT_SRC0_SHMEM_Q5_K
#endif // INIT_SRC0_SHMEM_Q4_K
#if defined(INIT_SRC0_SHMEM_Q5_K)
let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u;
let dm_byte_base = block_byte_base + 0u;
let scale_byte_base = block_byte_base + 4u;
@ -456,7 +475,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4<f16>(ml, ml, ml, ml), elem_idx);
#elif INIT_SRC0_SHMEM_Q6_K
#endif // INIT_SRC0_SHMEM_Q5_K
#if defined(INIT_SRC0_SHMEM_Q6_K)
let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u;
let ql_byte_base = block_byte_base;
let qh_byte_base = block_byte_base + 128u;
@ -497,17 +518,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let scale = get_byte_i32(scale_word, scale_byte & 3u);
store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx);
#endif
#endif // INIT_SRC0_SHMEM_Q6_K
}
}
#endif // k-quants
#ifdef INIT_SRC0_SHMEM_IQ4_NL
#if defined(INIT_SRC0_SHMEM_IQ4_NL)
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 18u;
const NQ = 4u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += NQ * TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
@ -519,408 +541,464 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let d_byte_base = block_byte_base + 0u;
let qs_byte_base = block_byte_base + 2u;
let pos = k_in_block % 16u;
let nib_shift = (k_in_block / 16u) * 4u;
let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u);
let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu;
let d = load_f16_at_src0(d_byte_base);
shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]);
let id_qtr = (k_in_block % 16u) / 4u;
let shift_phase = k_in_block / 16u;
let qs_u32 = load_u32_at_src0(qs_byte_base + 4u * id_qtr);
shmem[elem_idx + 0u] = d * f16(kvalues_iq4nl[(qs_u32 >> ( 0u + 4u * shift_phase)) & 0xFu]);
shmem[elem_idx + 1u] = d * f16(kvalues_iq4nl[(qs_u32 >> ( 8u + 4u * shift_phase)) & 0xFu]);
shmem[elem_idx + 2u] = d * f16(kvalues_iq4nl[(qs_u32 >> (16u + 4u * shift_phase)) & 0xFu]);
shmem[elem_idx + 3u] = d * f16(kvalues_iq4nl[(qs_u32 >> (24u + 4u * shift_phase)) & 0xFu]);
}
}
#endif // INIT_SRC0_SHMEM_IQ4_NL
#ifdef INIT_SRC0_SHMEM_IQ4_XS
// i-quants (super block size: 256)
#if defined(INIT_SRC0_SHMEM_IQ4_XS) || defined(INIT_SRC0_SHMEM_IQ1_S) || defined(INIT_SRC0_SHMEM_IQ1_M) || defined(INIT_SRC0_SHMEM_IQ2_XXS) \
|| defined(INIT_SRC0_SHMEM_IQ2_XS) || defined(INIT_SRC0_SHMEM_IQ2_S) || defined(INIT_SRC0_SHMEM_IQ3_XXS) || defined(INIT_SRC0_SHMEM_IQ3_S)
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 136u;
const NQ = 16u;
fn store_shmem_iquants(val: vec4<f16>, idx: u32) {
shmem[idx] = val.x;
shmem[idx + 1] = val.y;
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 {
return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u);
}
#if defined(INIT_SRC0_SHMEM_IQ1_M) || defined(INIT_SRC0_SHMEM_IQ1_S)
fn create_iq_gw4(dl: f32, gw: u32, shift_base: u32, delta: f32) -> vec4<f16> {
return vec4<f16>(
f16(dl * (f32((bitcast<i32>(((gw >> (shift_base + 0u)) & 3u) << 30u) >> 30u)) + delta)),
f16(dl * (f32((bitcast<i32>(((gw >> (shift_base + 2u)) & 3u) << 30u) >> 30u)) + delta)),
f16(dl * (f32((bitcast<i32>(((gw >> (shift_base + 4u)) & 3u) << 30u) >> 30u)) + delta)),
f16(dl * (f32((bitcast<i32>(((gw >> (shift_base + 6u)) & 3u) << 30u) >> 30u)) + delta)),
);
}
#endif
#if defined(INIT_SRC0_SHMEM_IQ4_XS)
fn create_iq_gw4(dl: f16, qs_u32: u32, shift_phase: u32) -> vec4<f16> {
return vec4<f16>(
dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 0u)) & 0xFu]),
dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 8u)) & 0xFu]),
dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 16u)) & 0xFu]),
dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 24u)) & 0xFu]),
);
}
#endif
#if defined(INIT_SRC0_SHMEM_IQ2_XXS)
fn create_iq_gw4(ig: u32, grid_phase: u32) -> vec4<f32> {
return vec4<f32>(
f32(get_byte(iq2xxs_grid[(ig + grid_phase + 0u) / 4u], (ig + grid_phase + 0u) % 4u)),
f32(get_byte(iq2xxs_grid[(ig + grid_phase + 1u) / 4u], (ig + grid_phase + 1u) % 4u)),
f32(get_byte(iq2xxs_grid[(ig + grid_phase + 2u) / 4u], (ig + grid_phase + 2u) % 4u)),
f32(get_byte(iq2xxs_grid[(ig + grid_phase + 3u) / 4u], (ig + grid_phase + 3u) % 4u)),
);
}
#endif
#if defined(INIT_SRC0_SHMEM_IQ2_XS)
fn create_iq_gw4(ig: u32, grid_phase: u32) -> vec4<f32> {
return vec4<f32>(
f32(get_byte(iq2xs_grid[(ig + grid_phase + 0u) / 4u], (ig + grid_phase + 0u) % 4u)),
f32(get_byte(iq2xs_grid[(ig + grid_phase + 1u) / 4u], (ig + grid_phase + 1u) % 4u)),
f32(get_byte(iq2xs_grid[(ig + grid_phase + 2u) / 4u], (ig + grid_phase + 2u) % 4u)),
f32(get_byte(iq2xs_grid[(ig + grid_phase + 3u) / 4u], (ig + grid_phase + 3u) % 4u)),
);
}
#endif
#if defined(INIT_SRC0_SHMEM_IQ2_S)
fn create_iq_gw4(ig: u32, grid_phase: u32) -> vec4<f32> {
return vec4<f32>(
f32(get_byte(iq2s_grid[(ig + grid_phase + 0u) / 4u], (ig + grid_phase + 0u) % 4u)),
f32(get_byte(iq2s_grid[(ig + grid_phase + 1u) / 4u], (ig + grid_phase + 1u) % 4u)),
f32(get_byte(iq2s_grid[(ig + grid_phase + 2u) / 4u], (ig + grid_phase + 2u) % 4u)),
f32(get_byte(iq2s_grid[(ig + grid_phase + 3u) / 4u], (ig + grid_phase + 3u) % 4u)),
);
}
#endif
#if defined(INIT_SRC0_SHMEM_IQ3_XXS)
fn create_iq_gw4(ig: u32) -> vec4<f32> {
return vec4<f32>(
f32(get_byte(iq3xxs_grid[ig], 0)),
f32(get_byte(iq3xxs_grid[ig], 1)),
f32(get_byte(iq3xxs_grid[ig], 2)),
f32(get_byte(iq3xxs_grid[ig], 3)),
);
}
#endif
#if defined(INIT_SRC0_SHMEM_IQ3_S)
fn create_iq_gw4(ig: u32) -> vec4<f32> {
return vec4<f32>(
f32(get_byte(iq3s_grid[ig], 0)),
f32(get_byte(iq3s_grid[ig], 1)),
f32(get_byte(iq3s_grid[ig], 2)),
f32(get_byte(iq3s_grid[ig], 3)),
);
}
#endif
#if defined(INIT_SRC0_SHMEM_IQ2_XXS) || defined(INIT_SRC0_SHMEM_IQ2_XS) || defined(INIT_SRC0_SHMEM_IQ2_S) \
|| defined(INIT_SRC0_SHMEM_IQ3_XXS) || defined(INIT_SRC0_SHMEM_IQ3_S)
fn create_iq2_m4(signs: u32, mask_phase: u32) -> vec4<f32> {
return vec4<f32>(
select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 0) & signs) != 0u),
select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 1) & signs) != 0u),
select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 2) & signs) != 0u),
select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 3) & signs) != 0u),
);
}
#endif
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += NQ * TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
let zero_vec4 = vec4<f16>(f16(0.0), f16(0.0), f16(0.0), f16(0.0));
store_shmem_iquants(zero_vec4, elem_idx + 0u);
store_shmem_iquants(zero_vec4, elem_idx + 4u);
store_shmem_iquants(zero_vec4, elem_idx + 8u);
store_shmem_iquants(zero_vec4, elem_idx + 12u);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 16 == 0;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let d_scales_h = load_u32_at_src0(block_byte_base);
#if defined(INIT_SRC0_SHMEM_IQ4_XS)
let block_byte_base = src0_idx * 136u; // BLOCK_SIZE_BYTES = 136u;
let d_byte_base = block_byte_base + 0u;
let scales_l_byte_base = block_byte_base + 4u;
let qs_byte_base = block_byte_base + 8u;
let d_scales_h = load_u32_at_src0_aligned(d_byte_base);
let d = bitcast<vec2<f16>>(d_scales_h).x;
let scales_h = d_scales_h >> 16u;
let ib = k_in_block / 32u;
let pos = k_in_block % 32u;
let sub_block = k_in_block / 32u;
let phase = (k_in_block / NQ) % 2u;
let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu;
let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u;
let dl = d * f16(i32(ls_lo | ls_hi) - 32);
let scales_l_u32 = load_u32_at_src0_aligned(scales_l_byte_base);
let ls_lo = (get_byte(scales_l_u32, sub_block / 2u) >> (4u * (sub_block % 2u))) & 0xFu;
let ls_hi = ((scales_h >> (2u * sub_block)) & 3u) << 4u;
let dl = d * f16(i32(ls_lo | ls_hi) - 32);
let iqs = ib * 16u + (pos % 16u);
let nib_shift = (pos / 16u) * 4u;
let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u);
let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu;
let qs_0_3_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 0u);
let qs_4_7_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 4u);
let qs_8_11_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 8u);
let qs_12_15_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 12u);
shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]);
}
}
store_shmem_iquants(create_iq_gw4(dl, qs_0_3_u32, phase), elem_idx + 0u);
store_shmem_iquants(create_iq_gw4(dl, qs_4_7_u32, phase), elem_idx + 4u);
store_shmem_iquants(create_iq_gw4(dl, qs_8_11_u32, phase), elem_idx + 8u);
store_shmem_iquants(create_iq_gw4(dl, qs_12_15_u32, phase), elem_idx + 12u);
#endif // INIT_SRC0_SHMEM_IQ4_XS
#ifdef INIT_SRC0_SHMEM_IQ1_S
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 50u;
#if defined(INIT_SRC0_SHMEM_IQ1_S)
let block_byte_base = src0_idx * 50u; // BLOCK_SIZE_BYTES = 50u;
let d_byte_base = block_byte_base + 0u;
let qs_byte_base = block_byte_base + 2u;
let qh_byte_base = block_byte_base + 34u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let d = load_f16_as_f32_at_src0(d_byte_base);
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let sub_block = k_in_block / 32u;
let phase = (k_in_block / NQ) % 2u;
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let qh_u16 = load_u32_at_src0(qh_byte_base + sub_block * 2u) & 0xFFFFu;
let qs_u16 = load_u32_at_src0(qs_byte_base + sub_block * 4u + phase * 2u) & 0xFFFFu;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_as_f32_at_src0(block_byte_base);
let dl = d * (2.0 * f32((qh_u16 >> 12u) & 7u) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_u16 & 0x8000u) != 0u);
let ib = k_in_block / 32u;
let pos = k_in_block % 32u;
let l = pos / 8u;
let j = pos % 8u;
let gp0_grid_id = ((qs_u16 & 0xFFu) | (((qh_u16 >> (phase * 6u)) & 7u) << 8u)) * 8u;
let gp1_grid_id = (((qs_u16 >> 8) & 0xFFu) | (((qh_u16 >> (phase * 6u + 3u)) & 7u) << 8u)) * 8u;
let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu;
let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
let gp0_gw = iq1_grid[(gp0_grid_id) / 16u];
let gp1_gw = iq1_grid[(gp1_grid_id) / 16u];
let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u);
let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
let gp0_shift_base = (gp0_grid_id % 16u) * 2u;
let gp1_shift_base = (gp1_grid_id % 16u) * 2u;
let gw = iq1_grid[(ig + j) / 16u];
let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u;
let gs = bitcast<i32>(g << 30u) >> 30u;
shmem[elem_idx] = f16(dl * (f32(gs) + delta));
}
}
store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 0u, delta), elem_idx + 0u);
store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 8u, delta), elem_idx + 4u);
store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 0u, delta), elem_idx + 8u);
store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 8u, delta), elem_idx + 12u);
#endif // INIT_SRC0_SHMEM_IQ1_S
#ifdef INIT_SRC0_SHMEM_IQ1_M
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 56u;
#if defined(INIT_SRC0_SHMEM_IQ1_M)
let block_byte_base = src0_idx * 56u; // BLOCK_SIZE_BYTES = 56u;
let qs_byte_base = block_byte_base + 0u;
let qh_byte_base = block_byte_base + 32u;
let scales_byte_base = block_byte_base + 48u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let scales0 = load_u32_at_src0(block_byte_base + 48u);
let scales1 = load_u32_at_src0(block_byte_base + 52u);
let scales0 = load_u32_at_src0_aligned(scales_byte_base);
let scales1 = load_u32_at_src0_aligned(scales_byte_base + 4u);
let scale_packed = ((scales0 >> 12u) & 0xFu) |
((scales0 >> 24u) & 0x00F0u) |
((scales1 >> 4u) & 0x0F00u) |
((scales1 >> 16u) & 0xF000u);
let d = f32(bitcast<vec2<f16>>(scale_packed).x);
let ib = k_in_block / 32u;
let pos = k_in_block % 32u;
let l = pos / 8u;
let j = pos % 8u;
let sub_block = k_in_block / 32u;
let phase = (k_in_block / NQ) % 2u;
let scales = select(scales0, scales1, ib >= 4u);
let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu;
let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u;
let dl = d * f32(2u * s_pair + 1u);
let scale_u32 = select(scales0, scales1, sub_block >= 4u);
let scale_u3 = (scale_u32 >> (16u * ((sub_block / 2u) % 2u) + 6u * (sub_block % 2u) + 3u * phase)) & 0x7u;
let dl = d * f32(2u * scale_u3 + 1u);
let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u);
let qh = qh_word >> (16u * (ib % 2u));
let qh_nib = (qh >> (4u * l)) & 0xFu;
let qh_u8 = (load_u32_at_src0_aligned(qh_byte_base + 4u * (sub_block / 2u)) >> (16u * (sub_block % 2u) + 8u * phase)) & 0xFFu;
let qs_u16 = (load_u32_at_src0_aligned(qs_byte_base + 4u * sub_block) >> (16u * phase)) & 0xFFFFu;
let qs_w = load_u32_at_src0(block_byte_base + ib * 4u);
let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u);
let gp0_grid_id = ((qs_u16 & 0xFFu) | ((qh_u8 & 7u) << 8u)) * 8u;
let gp0_delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_u8 & 0x8u) != 0u);
let ig = idx * 8u;
let gw = iq1_grid[(ig + j) / 16u];
let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u;
let gs = bitcast<i32>(g << 30u) >> 30u;
let gp1_grid_id = (((qs_u16 >> 8u) & 0xFFu) | (((qh_u8 >> 4u) & 7u) << 8u)) * 8u;
let gp1_delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_u8 & 0x80u) != 0u);
shmem[elem_idx] = f16(dl * (f32(gs) + delta));
}
}
let gp0_gw = iq1_grid[(gp0_grid_id) / 16u];
let gp1_gw = iq1_grid[(gp1_grid_id) / 16u];
let gp0_shift_base = (gp0_grid_id % 16u) * 2u;
let gp1_shift_base = (gp1_grid_id % 16u) * 2u;
store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 0u, gp0_delta), elem_idx + 0u);
store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 8u, gp0_delta), elem_idx + 4u);
store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 0u, gp1_delta), elem_idx + 8u);
store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 8u, gp1_delta), elem_idx + 12u);
#endif // INIT_SRC0_SHMEM_IQ1_M
#ifdef INIT_SRC0_SHMEM_IQ2_XXS
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 66u;
#if defined(INIT_SRC0_SHMEM_IQ2_XXS)
let block_byte_base = src0_idx * 66u; // BLOCK_SIZE_BYTES = 66u;
let d_byte_base = block_byte_base + 0u;
let qs_byte_base = block_byte_base + 2u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let d = load_f16_as_f32_at_src0(d_byte_base);
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let sub_block = k_in_block / 32u;
let phase = (k_in_block / NQ) % 2u;
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_as_f32_at_src0(block_byte_base);
let entry_idx = k_in_block / 8u;
let j = k_in_block % 8u;
let ib = entry_idx & ~3u;
let l = entry_idx & 3u;
let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u);
let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u);
let aux0 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 0u);
let aux1 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u);
let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25;
let ig = get_byte(aux0, l) * 8u;
let is = (aux1 >> (7u * l)) & 127u;
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
let gp0_ig = get_byte(aux0, 2u * phase + 0u) * 8u;
let gp1_ig = get_byte(aux0, 2u * phase + 1u) * 8u;
let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u);
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
let gp0_is = (aux1 >> (14u * phase + 0u)) & 127u;
let gp1_is = (aux1 >> (14u * phase + 7u)) & 127u;
shmem[elem_idx] = f16(db * f32(g) * m);
}
}
let gp0_signs = get_byte(ksigns_iq2xs[gp0_is / 4u], gp0_is % 4u);
let gp1_signs = get_byte(ksigns_iq2xs[gp1_is / 4u], gp1_is % 4u);
let m_0_3_val4 = create_iq2_m4(gp0_signs, 0);
let m_4_7_val4 = create_iq2_m4(gp0_signs, 1);
let m_8_11_val4 = create_iq2_m4(gp1_signs, 0);
let m_12_15_val4 = create_iq2_m4(gp1_signs, 1);
let gw_0_3_val4 = create_iq_gw4(gp0_ig, 0);
let gw_4_7_val4 = create_iq_gw4(gp0_ig, 4);
let gw_8_11_val4 = create_iq_gw4(gp1_ig, 0);
let gw_12_15_val4 = create_iq_gw4(gp1_ig, 4);
store_shmem_iquants(vec4<f16>(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u);
store_shmem_iquants(vec4<f16>(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u);
store_shmem_iquants(vec4<f16>(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u);
store_shmem_iquants(vec4<f16>(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u);
#endif // INIT_SRC0_SHMEM_IQ2_XXS
#ifdef INIT_SRC0_SHMEM_IQ2_XS
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 74u;
#if defined(INIT_SRC0_SHMEM_IQ2_XS)
let block_byte_base = src0_idx * 74u; // BLOCK_SIZE_BYTES = 74u;
let d_byte_base = block_byte_base + 0u;
let qs_byte_base = block_byte_base + 2u;
let scales_byte_base = block_byte_base + 66u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let d = load_f16_as_f32_at_src0(d_byte_base);
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let sub_block = k_in_block / 32u;
let phase = (k_in_block / NQ) % 2u;
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let scale = (load_byte_at_src0_aligned(scales_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu;
let db = d * (0.5 + f32(scale)) * 0.25;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_as_f32_at_src0(block_byte_base);
let qs_u32 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u * phase);
let entry_idx = k_in_block / 8u;
let j = k_in_block % 8u;
let gp0_ig = (qs_u32 & 0x1FFu) * 8u;
let gp1_ig = ((qs_u32 >> 16u) & 0x1FFu) * 8u;
let ib = entry_idx & ~3u;
let l = entry_idx & 3u;
let gp0_is = (qs_u32 >> 9u) & 0x7Fu;
let gp1_is = (qs_u32 >> 25u) & 0x7Fu;
let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u);
let s = get_byte(scales_word, (ib % 16u) / 4u);
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u);
let dl = d * (0.5 + f32(s_nib)) * 0.25;
let gp0_signs = get_byte(ksigns_iq2xs[gp0_is / 4u], gp0_is % 4u);
let gp1_signs = get_byte(ksigns_iq2xs[gp1_is / 4u], gp1_is % 4u);
let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u);
let qs_val = qs_word & 0xFFFFu;
let ig = (qs_val & 511u) * 8u;
let is = qs_val >> 9u;
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
let m_0_3_val4 = create_iq2_m4(gp0_signs, 0);
let m_4_7_val4 = create_iq2_m4(gp0_signs, 1);
let m_8_11_val4 = create_iq2_m4(gp1_signs, 0);
let m_12_15_val4 = create_iq2_m4(gp1_signs, 1);
let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u);
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
let gw_0_3_val4 = create_iq_gw4(gp0_ig, 0);
let gw_4_7_val4 = create_iq_gw4(gp0_ig, 4);
let gw_8_11_val4 = create_iq_gw4(gp1_ig, 0);
let gw_12_15_val4 = create_iq_gw4(gp1_ig, 4);
shmem[elem_idx] = f16(dl * f32(g) * m);
}
}
store_shmem_iquants(vec4<f16>(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u);
store_shmem_iquants(vec4<f16>(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u);
store_shmem_iquants(vec4<f16>(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u);
store_shmem_iquants(vec4<f16>(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u);
#endif // INIT_SRC0_SHMEM_IQ2_XS
#ifdef INIT_SRC0_SHMEM_IQ2_S
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 82u;
#if defined(INIT_SRC0_SHMEM_IQ2_S)
let block_byte_base = src0_idx * 82u; // BLOCK_SIZE_BYTES = 82u;
let d_byte_base = block_byte_base + 0u;
let qs_byte_base = block_byte_base + 2u;
let qh_byte_base = block_byte_base + 66u;
let scales_byte_base = block_byte_base + 74u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let d = load_f16_as_f32_at_src0(d_byte_base);
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let sub_block = k_in_block / 32u;
let phase = (k_in_block / NQ) % 2u;
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let scale = (load_byte_at_src0_aligned(scales_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu;
let db = d * (0.5 + f32(scale)) * 0.25;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_as_f32_at_src0(block_byte_base);
let qs_u16 = load_u32_at_src0(qs_byte_base + 4u * sub_block + 2u * phase) & 0xFFFFu;
let signs_u16 = load_u32_at_src0(qs_byte_base + 32u + 4u * sub_block + 2u * phase) & 0xFFFFu;
let qh_u4 = (load_byte_at_src0_aligned(qh_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu;
let ib = k_in_block / 32u;
let l = (k_in_block % 32u) / 8u;
let j = k_in_block % 8u;
let gp0_ig = ((qs_u16 & 0xFFu) | ((qh_u4 & 0x3u) << 8u)) * 8u;
let gp1_ig = (((qs_u16 >> 8u) & 0xFFu) | ((qh_u4 & 0xCu) << 6u)) * 8u;
let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u);
let s = get_byte(scales_word, ib % 4u);
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u);
let dl = d * (0.5 + f32(s_nib)) * 0.25;
let gp0_signs = get_byte(signs_u16, 0);
let gp1_signs = get_byte(signs_u16, 1);
let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u);
let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u);
let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u;
let ig = (get_byte(qs_word, l) | qh_b) * 8u;
let m_0_3_val4 = create_iq2_m4(gp0_signs, 0);
let m_4_7_val4 = create_iq2_m4(gp0_signs, 1);
let m_8_11_val4 = create_iq2_m4(gp1_signs, 0);
let m_12_15_val4 = create_iq2_m4(gp1_signs, 1);
let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u);
let signs = get_byte(signs_word, l);
let gw_0_3_val4 = create_iq_gw4(gp0_ig, 0);
let gw_4_7_val4 = create_iq_gw4(gp0_ig, 4);
let gw_8_11_val4 = create_iq_gw4(gp1_ig, 0);
let gw_12_15_val4 = create_iq_gw4(gp1_ig, 4);
let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u);
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
shmem[elem_idx] = f16(dl * f32(g) * m);
}
}
store_shmem_iquants(vec4<f16>(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u);
store_shmem_iquants(vec4<f16>(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u);
store_shmem_iquants(vec4<f16>(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u);
store_shmem_iquants(vec4<f16>(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u);
#endif // INIT_SRC0_SHMEM_IQ2_S
#ifdef INIT_SRC0_SHMEM_IQ3_XXS
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 98u;
#if defined(INIT_SRC0_SHMEM_IQ3_XXS)
let block_byte_base = src0_idx * 98u; // BLOCK_SIZE_BYTES = 98u;
let d_byte_base = block_byte_base + 0u;
let qs_byte_base = block_byte_base + 2u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let d = load_f16_as_f32_at_src0(d_byte_base);
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let sub_block = k_in_block / 32u;
let phase = (k_in_block / NQ) % 2u;
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let qs_u32 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u * phase);
let sign_u32 = load_u32_at_src0(qs_byte_base + 64u + 4u * sub_block);
let db = d * (0.5 + f32(sign_u32 >> 28u)) * 0.5;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_as_f32_at_src0(block_byte_base);
let ig_0_3 = get_byte(qs_u32, 0);
let ig_4_7 = get_byte(qs_u32, 1);
let ig_8_11 = get_byte(qs_u32, 2);
let ig_12_15 = get_byte(qs_u32, 3);
let ib_pair = k_in_block / 32u;
let in_pair = k_in_block % 32u;
let l = in_pair / 8u;
let in_l = in_pair % 8u;
let k2 = in_l / 4u;
let j = in_l % 4u;
let gp0_is = (sign_u32 >> (14u * phase + 0u)) & 0x7Fu;
let gp1_is = (sign_u32 >> (14u * phase + 7u)) & 0x7Fu;
let ib = ib_pair * 2u;
let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u;
let sc_sign = load_u32_at_src0(sc_sign_off);
let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5;
let is = (sc_sign >> (7u * l)) & 127u;
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
let gp0_signs = get_byte(ksigns_iq2xs[gp0_is / 4u], gp0_is % 4u);
let gp1_signs = get_byte(ksigns_iq2xs[gp1_is / 4u], gp1_is % 4u);
let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu;
let ig_byte = get_byte(ig_word, k2);
let g = get_byte(iq3xxs_grid[ig_byte], j);
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u);
let m_0_3_val4 = create_iq2_m4(gp0_signs, 0);
let m_4_7_val4 = create_iq2_m4(gp0_signs, 1);
let m_8_11_val4 = create_iq2_m4(gp1_signs, 0);
let m_12_15_val4 = create_iq2_m4(gp1_signs, 1);
shmem[elem_idx] = f16(db * f32(g) * m);
}
}
let gw_0_3_val4 = create_iq_gw4(ig_0_3);
let gw_4_7_val4 = create_iq_gw4(ig_4_7);
let gw_8_11_val4 = create_iq_gw4(ig_8_11);
let gw_12_15_val4 = create_iq_gw4(ig_12_15);
store_shmem_iquants(vec4<f16>(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u);
store_shmem_iquants(vec4<f16>(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u);
store_shmem_iquants(vec4<f16>(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u);
store_shmem_iquants(vec4<f16>(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u);
#endif // INIT_SRC0_SHMEM_IQ3_XXS
#ifdef INIT_SRC0_SHMEM_IQ3_S
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 110u;
#if defined(INIT_SRC0_SHMEM_IQ3_S)
let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u;
let d_byte_base = block_byte_base + 0u;
let qs_byte_base = block_byte_base + 2u;
let qh_byte_base = block_byte_base + 66u;
let signs_byte_base = block_byte_base + 74u;
let scales_byte_base = block_byte_base + 106u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let d = load_f16_as_f32_at_src0(d_byte_base);
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let sub_block = k_in_block / 32u;
let phase = (k_in_block / NQ) % 2u;
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let scale = (load_byte_at_src0_aligned(scales_byte_base + 1u * (sub_block / 2u)) >> (4u * (sub_block % 2u))) & 0xFu;
let db = d * (1.0 + 2.0 * f32(scale));
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_as_f32_at_src0(block_byte_base);
let qs_u32 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u * phase);
let qh_u4 = (load_byte_at_src0_aligned(qh_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu;
let signs_u16 = (load_u32_at_src0(signs_byte_base + 4u * sub_block + 2u * phase)) & 0xFFFFu;
let ib = k_in_block / 64u;
let rest = k_in_block % 64u;
let k = rest / 32u;
let in_k = rest % 32u;
let l = in_k / 8u;
let in_l = in_k % 8u;
let k2 = in_l / 4u;
let j = in_l % 4u;
let ig_0_3 = ((qs_u32 >> 0u) & 0xFFu) | ((qh_u4 & 0x1u) << 8u);
let ig_4_7 = ((qs_u32 >> 8u) & 0xFFu) | ((qh_u4 & 0x2u) << 7u);
let ig_8_11 = ((qs_u32 >> 16u) & 0xFFu) | ((qh_u4 & 0x4u) << 6u);
let ig_12_15 = ((qs_u32 >> 24u) & 0xFFu) | ((qh_u4 & 0x8u) << 5u);
let scales_word = load_u32_at_src0(block_byte_base + 106u);
let s = get_byte(scales_word, ib);
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u);
let dl = d * (1.0 + 2.0 * f32(s_nib));
let gp0_signs = get_byte(signs_u16, 0);
let gp1_signs = get_byte(signs_u16, 1);
let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u);
let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k);
let m_0_3_val4 = create_iq2_m4(gp0_signs, 0);
let m_4_7_val4 = create_iq2_m4(gp0_signs, 1);
let m_8_11_val4 = create_iq2_m4(gp1_signs, 0);
let m_12_15_val4 = create_iq2_m4(gp1_signs, 1);
let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu;
let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u);
let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u);
let ig = select(ig_lo, ig_hi, k2 != 0u);
let gw_0_3_val4 = create_iq_gw4(ig_0_3);
let gw_4_7_val4 = create_iq_gw4(ig_4_7);
let gw_8_11_val4 = create_iq_gw4(ig_8_11);
let gw_12_15_val4 = create_iq_gw4(ig_12_15);
let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u);
let signs = get_byte(signs_word, l);
let g = get_byte(iq3s_grid[ig], j);
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u);
shmem[elem_idx] = f16(dl * f32(g) * m);
store_shmem_iquants(vec4<f16>(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u);
store_shmem_iquants(vec4<f16>(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u);
store_shmem_iquants(vec4<f16>(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u);
store_shmem_iquants(vec4<f16>(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u);
#endif // INIT_SRC0_SHMEM_IQ3_S
}
}
#endif // INIT_SRC0_SHMEM_IQ3_S
#endif // i-quants (super block size: 256)