ggml-webgpu: improve i-quants mul_mat performance and speed up prefill (llama/24530)
* Improve prefill speeds for i-quants * Fix #if defined() usage in preprocessor guards.
This commit is contained in:
parent
1216e0957b
commit
3e0b917514
|
|
@ -98,6 +98,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q1_0
|
||||
|
||||
// legacy-quants
|
||||
#if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4)
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
|
|
@ -124,7 +125,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
#if defined(INIT_SRC0_SHMEM_Q4_0)
|
||||
let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
|
||||
|
|
@ -134,7 +135,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
|
||||
}
|
||||
#elif INIT_SRC0_SHMEM_Q4_1
|
||||
#endif // INIT_SRC0_SHMEM_Q4_0
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_Q4_1)
|
||||
let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u;
|
||||
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
|
||||
let d = f16(dm[0]);
|
||||
|
|
@ -153,7 +156,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
#elif INIT_SRC0_SHMEM_Q5_0
|
||||
#endif // INIT_SRC0_SHMEM_Q4_1
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_Q5_0)
|
||||
let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u;
|
||||
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
|
|
@ -176,7 +181,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
#elif INIT_SRC0_SHMEM_Q5_1
|
||||
#endif // INIT_SRC0_SHMEM_Q5_0
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_Q5_1)
|
||||
let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u;
|
||||
|
||||
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
|
||||
|
|
@ -201,7 +208,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
#elif INIT_SRC0_SHMEM_Q8_0
|
||||
#endif // INIT_SRC0_SHMEM_Q5_1
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_Q8_0)
|
||||
let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
|
||||
|
|
@ -211,7 +220,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
|
||||
}
|
||||
#elif INIT_SRC0_SHMEM_Q8_1
|
||||
#endif // INIT_SRC0_SHMEM_Q8_0
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_Q8_1)
|
||||
let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u;
|
||||
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
|
||||
let d = f16(dm[0]);
|
||||
|
|
@ -227,7 +238,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val;
|
||||
}
|
||||
}
|
||||
#elif INIT_SRC0_SHMEM_MXFP4
|
||||
#endif // INIT_SRC0_SHMEM_Q8_1
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_MXFP4)
|
||||
let block_byte_base = src0_idx * 17u;
|
||||
let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u);
|
||||
let e = ldexp(1.0, i32(eu8) - 128);
|
||||
|
|
@ -244,11 +257,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // INIT_SRC0_SHMEM_MXFP4
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // legacy-quants
|
||||
|
||||
// k-quants
|
||||
#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K)
|
||||
|
|
@ -284,7 +297,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q2_K
|
||||
#if defined(INIT_SRC0_SHMEM_Q2_K)
|
||||
let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u;
|
||||
let scales_byte_base = block_byte_base;
|
||||
let qs_byte_base = block_byte_base + 16u;
|
||||
|
|
@ -314,7 +327,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let ml = dmin * f16(scale >> 4u);
|
||||
|
||||
store_shmem_kquants(qs_vec4 * dl - ml, elem_idx);
|
||||
#elif INIT_SRC0_SHMEM_Q3_K
|
||||
#endif // INIT_SRC0_SHMEM_Q2_K
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_Q3_K)
|
||||
let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u;
|
||||
let hmask_byte_base = block_byte_base + 0u;
|
||||
let qs_byte_base = block_byte_base + 32u;
|
||||
|
|
@ -355,7 +370,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0);
|
||||
|
||||
store_shmem_kquants(dl * q_vec4, elem_idx);
|
||||
#elif INIT_SRC0_SHMEM_Q4_K
|
||||
#endif // INIT_SRC0_SHMEM_Q3_K
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_Q4_K)
|
||||
let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u;
|
||||
let dm_byte_base = block_byte_base + 0u;
|
||||
let scale_byte_base = block_byte_base + 4u;
|
||||
|
|
@ -399,7 +416,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let ml = dmin * f16(mn);
|
||||
|
||||
store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx);
|
||||
#elif INIT_SRC0_SHMEM_Q5_K
|
||||
#endif // INIT_SRC0_SHMEM_Q4_K
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_Q5_K)
|
||||
let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u;
|
||||
let dm_byte_base = block_byte_base + 0u;
|
||||
let scale_byte_base = block_byte_base + 4u;
|
||||
|
|
@ -456,7 +475,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let ml = dmin * f16(mn);
|
||||
|
||||
store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4<f16>(ml, ml, ml, ml), elem_idx);
|
||||
#elif INIT_SRC0_SHMEM_Q6_K
|
||||
#endif // INIT_SRC0_SHMEM_Q5_K
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_Q6_K)
|
||||
let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u;
|
||||
let ql_byte_base = block_byte_base;
|
||||
let qh_byte_base = block_byte_base + 128u;
|
||||
|
|
@ -497,17 +518,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let scale = get_byte_i32(scale_word, scale_byte & 3u);
|
||||
|
||||
store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx);
|
||||
#endif
|
||||
#endif // INIT_SRC0_SHMEM_Q6_K
|
||||
}
|
||||
}
|
||||
#endif // k-quants
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ4_NL
|
||||
#if defined(INIT_SRC0_SHMEM_IQ4_NL)
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
const NQ = 4u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += NQ * TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
|
|
@ -519,408 +541,464 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let d_byte_base = block_byte_base + 0u;
|
||||
let qs_byte_base = block_byte_base + 2u;
|
||||
|
||||
let pos = k_in_block % 16u;
|
||||
let nib_shift = (k_in_block / 16u) * 4u;
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u);
|
||||
let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu;
|
||||
let d = load_f16_at_src0(d_byte_base);
|
||||
|
||||
shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]);
|
||||
let id_qtr = (k_in_block % 16u) / 4u;
|
||||
let shift_phase = k_in_block / 16u;
|
||||
|
||||
let qs_u32 = load_u32_at_src0(qs_byte_base + 4u * id_qtr);
|
||||
|
||||
shmem[elem_idx + 0u] = d * f16(kvalues_iq4nl[(qs_u32 >> ( 0u + 4u * shift_phase)) & 0xFu]);
|
||||
shmem[elem_idx + 1u] = d * f16(kvalues_iq4nl[(qs_u32 >> ( 8u + 4u * shift_phase)) & 0xFu]);
|
||||
shmem[elem_idx + 2u] = d * f16(kvalues_iq4nl[(qs_u32 >> (16u + 4u * shift_phase)) & 0xFu]);
|
||||
shmem[elem_idx + 3u] = d * f16(kvalues_iq4nl[(qs_u32 >> (24u + 4u * shift_phase)) & 0xFu]);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ4_NL
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ4_XS
|
||||
// i-quants (super block size: 256)
|
||||
#if defined(INIT_SRC0_SHMEM_IQ4_XS) || defined(INIT_SRC0_SHMEM_IQ1_S) || defined(INIT_SRC0_SHMEM_IQ1_M) || defined(INIT_SRC0_SHMEM_IQ2_XXS) \
|
||||
|| defined(INIT_SRC0_SHMEM_IQ2_XS) || defined(INIT_SRC0_SHMEM_IQ2_S) || defined(INIT_SRC0_SHMEM_IQ3_XXS) || defined(INIT_SRC0_SHMEM_IQ3_S)
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 136u;
|
||||
const NQ = 16u;
|
||||
|
||||
fn store_shmem_iquants(val: vec4<f16>, idx: u32) {
|
||||
shmem[idx] = val.x;
|
||||
shmem[idx + 1] = val.y;
|
||||
shmem[idx + 2] = val.z;
|
||||
shmem[idx + 3] = val.w;
|
||||
}
|
||||
|
||||
fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 {
|
||||
return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u);
|
||||
}
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_IQ1_M) || defined(INIT_SRC0_SHMEM_IQ1_S)
|
||||
fn create_iq_gw4(dl: f32, gw: u32, shift_base: u32, delta: f32) -> vec4<f16> {
|
||||
return vec4<f16>(
|
||||
f16(dl * (f32((bitcast<i32>(((gw >> (shift_base + 0u)) & 3u) << 30u) >> 30u)) + delta)),
|
||||
f16(dl * (f32((bitcast<i32>(((gw >> (shift_base + 2u)) & 3u) << 30u) >> 30u)) + delta)),
|
||||
f16(dl * (f32((bitcast<i32>(((gw >> (shift_base + 4u)) & 3u) << 30u) >> 30u)) + delta)),
|
||||
f16(dl * (f32((bitcast<i32>(((gw >> (shift_base + 6u)) & 3u) << 30u) >> 30u)) + delta)),
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_IQ4_XS)
|
||||
fn create_iq_gw4(dl: f16, qs_u32: u32, shift_phase: u32) -> vec4<f16> {
|
||||
return vec4<f16>(
|
||||
dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 0u)) & 0xFu]),
|
||||
dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 8u)) & 0xFu]),
|
||||
dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 16u)) & 0xFu]),
|
||||
dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 24u)) & 0xFu]),
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_IQ2_XXS)
|
||||
fn create_iq_gw4(ig: u32, grid_phase: u32) -> vec4<f32> {
|
||||
return vec4<f32>(
|
||||
f32(get_byte(iq2xxs_grid[(ig + grid_phase + 0u) / 4u], (ig + grid_phase + 0u) % 4u)),
|
||||
f32(get_byte(iq2xxs_grid[(ig + grid_phase + 1u) / 4u], (ig + grid_phase + 1u) % 4u)),
|
||||
f32(get_byte(iq2xxs_grid[(ig + grid_phase + 2u) / 4u], (ig + grid_phase + 2u) % 4u)),
|
||||
f32(get_byte(iq2xxs_grid[(ig + grid_phase + 3u) / 4u], (ig + grid_phase + 3u) % 4u)),
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_IQ2_XS)
|
||||
fn create_iq_gw4(ig: u32, grid_phase: u32) -> vec4<f32> {
|
||||
return vec4<f32>(
|
||||
f32(get_byte(iq2xs_grid[(ig + grid_phase + 0u) / 4u], (ig + grid_phase + 0u) % 4u)),
|
||||
f32(get_byte(iq2xs_grid[(ig + grid_phase + 1u) / 4u], (ig + grid_phase + 1u) % 4u)),
|
||||
f32(get_byte(iq2xs_grid[(ig + grid_phase + 2u) / 4u], (ig + grid_phase + 2u) % 4u)),
|
||||
f32(get_byte(iq2xs_grid[(ig + grid_phase + 3u) / 4u], (ig + grid_phase + 3u) % 4u)),
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_IQ2_S)
|
||||
fn create_iq_gw4(ig: u32, grid_phase: u32) -> vec4<f32> {
|
||||
return vec4<f32>(
|
||||
f32(get_byte(iq2s_grid[(ig + grid_phase + 0u) / 4u], (ig + grid_phase + 0u) % 4u)),
|
||||
f32(get_byte(iq2s_grid[(ig + grid_phase + 1u) / 4u], (ig + grid_phase + 1u) % 4u)),
|
||||
f32(get_byte(iq2s_grid[(ig + grid_phase + 2u) / 4u], (ig + grid_phase + 2u) % 4u)),
|
||||
f32(get_byte(iq2s_grid[(ig + grid_phase + 3u) / 4u], (ig + grid_phase + 3u) % 4u)),
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_IQ3_XXS)
|
||||
fn create_iq_gw4(ig: u32) -> vec4<f32> {
|
||||
return vec4<f32>(
|
||||
f32(get_byte(iq3xxs_grid[ig], 0)),
|
||||
f32(get_byte(iq3xxs_grid[ig], 1)),
|
||||
f32(get_byte(iq3xxs_grid[ig], 2)),
|
||||
f32(get_byte(iq3xxs_grid[ig], 3)),
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_IQ3_S)
|
||||
fn create_iq_gw4(ig: u32) -> vec4<f32> {
|
||||
return vec4<f32>(
|
||||
f32(get_byte(iq3s_grid[ig], 0)),
|
||||
f32(get_byte(iq3s_grid[ig], 1)),
|
||||
f32(get_byte(iq3s_grid[ig], 2)),
|
||||
f32(get_byte(iq3s_grid[ig], 3)),
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(INIT_SRC0_SHMEM_IQ2_XXS) || defined(INIT_SRC0_SHMEM_IQ2_XS) || defined(INIT_SRC0_SHMEM_IQ2_S) \
|
||||
|| defined(INIT_SRC0_SHMEM_IQ3_XXS) || defined(INIT_SRC0_SHMEM_IQ3_S)
|
||||
fn create_iq2_m4(signs: u32, mask_phase: u32) -> vec4<f32> {
|
||||
return vec4<f32>(
|
||||
select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 0) & signs) != 0u),
|
||||
select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 1) & signs) != 0u),
|
||||
select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 2) & signs) != 0u),
|
||||
select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 3) & signs) != 0u),
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += NQ * TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
let zero_vec4 = vec4<f16>(f16(0.0), f16(0.0), f16(0.0), f16(0.0));
|
||||
store_shmem_iquants(zero_vec4, elem_idx + 0u);
|
||||
store_shmem_iquants(zero_vec4, elem_idx + 4u);
|
||||
store_shmem_iquants(zero_vec4, elem_idx + 8u);
|
||||
store_shmem_iquants(zero_vec4, elem_idx + 12u);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 16 == 0;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
|
||||
let d_scales_h = load_u32_at_src0(block_byte_base);
|
||||
#if defined(INIT_SRC0_SHMEM_IQ4_XS)
|
||||
let block_byte_base = src0_idx * 136u; // BLOCK_SIZE_BYTES = 136u;
|
||||
let d_byte_base = block_byte_base + 0u;
|
||||
let scales_l_byte_base = block_byte_base + 4u;
|
||||
let qs_byte_base = block_byte_base + 8u;
|
||||
|
||||
let d_scales_h = load_u32_at_src0_aligned(d_byte_base);
|
||||
let d = bitcast<vec2<f16>>(d_scales_h).x;
|
||||
let scales_h = d_scales_h >> 16u;
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let pos = k_in_block % 32u;
|
||||
let sub_block = k_in_block / 32u;
|
||||
let phase = (k_in_block / NQ) % 2u;
|
||||
|
||||
let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
|
||||
let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu;
|
||||
let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u;
|
||||
let dl = d * f16(i32(ls_lo | ls_hi) - 32);
|
||||
let scales_l_u32 = load_u32_at_src0_aligned(scales_l_byte_base);
|
||||
let ls_lo = (get_byte(scales_l_u32, sub_block / 2u) >> (4u * (sub_block % 2u))) & 0xFu;
|
||||
let ls_hi = ((scales_h >> (2u * sub_block)) & 3u) << 4u;
|
||||
let dl = d * f16(i32(ls_lo | ls_hi) - 32);
|
||||
|
||||
let iqs = ib * 16u + (pos % 16u);
|
||||
let nib_shift = (pos / 16u) * 4u;
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u);
|
||||
let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu;
|
||||
let qs_0_3_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 0u);
|
||||
let qs_4_7_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 4u);
|
||||
let qs_8_11_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 8u);
|
||||
let qs_12_15_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 12u);
|
||||
|
||||
shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]);
|
||||
}
|
||||
}
|
||||
store_shmem_iquants(create_iq_gw4(dl, qs_0_3_u32, phase), elem_idx + 0u);
|
||||
store_shmem_iquants(create_iq_gw4(dl, qs_4_7_u32, phase), elem_idx + 4u);
|
||||
store_shmem_iquants(create_iq_gw4(dl, qs_8_11_u32, phase), elem_idx + 8u);
|
||||
store_shmem_iquants(create_iq_gw4(dl, qs_12_15_u32, phase), elem_idx + 12u);
|
||||
#endif // INIT_SRC0_SHMEM_IQ4_XS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ1_S
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 50u;
|
||||
#if defined(INIT_SRC0_SHMEM_IQ1_S)
|
||||
let block_byte_base = src0_idx * 50u; // BLOCK_SIZE_BYTES = 50u;
|
||||
let d_byte_base = block_byte_base + 0u;
|
||||
let qs_byte_base = block_byte_base + 2u;
|
||||
let qh_byte_base = block_byte_base + 34u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
let d = load_f16_as_f32_at_src0(d_byte_base);
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
let sub_block = k_in_block / 32u;
|
||||
let phase = (k_in_block / NQ) % 2u;
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
let qh_u16 = load_u32_at_src0(qh_byte_base + sub_block * 2u) & 0xFFFFu;
|
||||
let qs_u16 = load_u32_at_src0(qs_byte_base + sub_block * 4u + phase * 2u) & 0xFFFFu;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
let dl = d * (2.0 * f32((qh_u16 >> 12u) & 7u) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_u16 & 0x8000u) != 0u);
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let pos = k_in_block % 32u;
|
||||
let l = pos / 8u;
|
||||
let j = pos % 8u;
|
||||
let gp0_grid_id = ((qs_u16 & 0xFFu) | (((qh_u16 >> (phase * 6u)) & 7u) << 8u)) * 8u;
|
||||
let gp1_grid_id = (((qs_u16 >> 8) & 0xFFu) | (((qh_u16 >> (phase * 6u + 3u)) & 7u) << 8u)) * 8u;
|
||||
|
||||
let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu;
|
||||
let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
|
||||
let gp0_gw = iq1_grid[(gp0_grid_id) / 16u];
|
||||
let gp1_gw = iq1_grid[(gp1_grid_id) / 16u];
|
||||
|
||||
let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u);
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
|
||||
let gp0_shift_base = (gp0_grid_id % 16u) * 2u;
|
||||
let gp1_shift_base = (gp1_grid_id % 16u) * 2u;
|
||||
|
||||
let gw = iq1_grid[(ig + j) / 16u];
|
||||
let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u;
|
||||
let gs = bitcast<i32>(g << 30u) >> 30u;
|
||||
|
||||
shmem[elem_idx] = f16(dl * (f32(gs) + delta));
|
||||
}
|
||||
}
|
||||
store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 0u, delta), elem_idx + 0u);
|
||||
store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 8u, delta), elem_idx + 4u);
|
||||
store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 0u, delta), elem_idx + 8u);
|
||||
store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 8u, delta), elem_idx + 12u);
|
||||
#endif // INIT_SRC0_SHMEM_IQ1_S
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ1_M
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 56u;
|
||||
#if defined(INIT_SRC0_SHMEM_IQ1_M)
|
||||
let block_byte_base = src0_idx * 56u; // BLOCK_SIZE_BYTES = 56u;
|
||||
let qs_byte_base = block_byte_base + 0u;
|
||||
let qh_byte_base = block_byte_base + 32u;
|
||||
let scales_byte_base = block_byte_base + 48u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let scales0 = load_u32_at_src0(block_byte_base + 48u);
|
||||
let scales1 = load_u32_at_src0(block_byte_base + 52u);
|
||||
let scales0 = load_u32_at_src0_aligned(scales_byte_base);
|
||||
let scales1 = load_u32_at_src0_aligned(scales_byte_base + 4u);
|
||||
let scale_packed = ((scales0 >> 12u) & 0xFu) |
|
||||
((scales0 >> 24u) & 0x00F0u) |
|
||||
((scales1 >> 4u) & 0x0F00u) |
|
||||
((scales1 >> 16u) & 0xF000u);
|
||||
let d = f32(bitcast<vec2<f16>>(scale_packed).x);
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let pos = k_in_block % 32u;
|
||||
let l = pos / 8u;
|
||||
let j = pos % 8u;
|
||||
let sub_block = k_in_block / 32u;
|
||||
let phase = (k_in_block / NQ) % 2u;
|
||||
|
||||
let scales = select(scales0, scales1, ib >= 4u);
|
||||
let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu;
|
||||
let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u;
|
||||
let dl = d * f32(2u * s_pair + 1u);
|
||||
let scale_u32 = select(scales0, scales1, sub_block >= 4u);
|
||||
let scale_u3 = (scale_u32 >> (16u * ((sub_block / 2u) % 2u) + 6u * (sub_block % 2u) + 3u * phase)) & 0x7u;
|
||||
let dl = d * f32(2u * scale_u3 + 1u);
|
||||
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u);
|
||||
let qh = qh_word >> (16u * (ib % 2u));
|
||||
let qh_nib = (qh >> (4u * l)) & 0xFu;
|
||||
let qh_u8 = (load_u32_at_src0_aligned(qh_byte_base + 4u * (sub_block / 2u)) >> (16u * (sub_block % 2u) + 8u * phase)) & 0xFFu;
|
||||
let qs_u16 = (load_u32_at_src0_aligned(qs_byte_base + 4u * sub_block) >> (16u * phase)) & 0xFFFFu;
|
||||
|
||||
let qs_w = load_u32_at_src0(block_byte_base + ib * 4u);
|
||||
let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u);
|
||||
let gp0_grid_id = ((qs_u16 & 0xFFu) | ((qh_u8 & 7u) << 8u)) * 8u;
|
||||
let gp0_delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_u8 & 0x8u) != 0u);
|
||||
|
||||
let ig = idx * 8u;
|
||||
let gw = iq1_grid[(ig + j) / 16u];
|
||||
let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u;
|
||||
let gs = bitcast<i32>(g << 30u) >> 30u;
|
||||
let gp1_grid_id = (((qs_u16 >> 8u) & 0xFFu) | (((qh_u8 >> 4u) & 7u) << 8u)) * 8u;
|
||||
let gp1_delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_u8 & 0x80u) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(dl * (f32(gs) + delta));
|
||||
}
|
||||
}
|
||||
let gp0_gw = iq1_grid[(gp0_grid_id) / 16u];
|
||||
let gp1_gw = iq1_grid[(gp1_grid_id) / 16u];
|
||||
|
||||
let gp0_shift_base = (gp0_grid_id % 16u) * 2u;
|
||||
let gp1_shift_base = (gp1_grid_id % 16u) * 2u;
|
||||
|
||||
store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 0u, gp0_delta), elem_idx + 0u);
|
||||
store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 8u, gp0_delta), elem_idx + 4u);
|
||||
store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 0u, gp1_delta), elem_idx + 8u);
|
||||
store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 8u, gp1_delta), elem_idx + 12u);
|
||||
#endif // INIT_SRC0_SHMEM_IQ1_M
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ2_XXS
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 66u;
|
||||
#if defined(INIT_SRC0_SHMEM_IQ2_XXS)
|
||||
let block_byte_base = src0_idx * 66u; // BLOCK_SIZE_BYTES = 66u;
|
||||
let d_byte_base = block_byte_base + 0u;
|
||||
let qs_byte_base = block_byte_base + 2u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
let d = load_f16_as_f32_at_src0(d_byte_base);
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
let sub_block = k_in_block / 32u;
|
||||
let phase = (k_in_block / NQ) % 2u;
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let entry_idx = k_in_block / 8u;
|
||||
let j = k_in_block % 8u;
|
||||
|
||||
let ib = entry_idx & ~3u;
|
||||
let l = entry_idx & 3u;
|
||||
|
||||
let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u);
|
||||
let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u);
|
||||
let aux0 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 0u);
|
||||
let aux1 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u);
|
||||
let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25;
|
||||
|
||||
let ig = get_byte(aux0, l) * 8u;
|
||||
let is = (aux1 >> (7u * l)) & 127u;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
|
||||
let gp0_ig = get_byte(aux0, 2u * phase + 0u) * 8u;
|
||||
let gp1_ig = get_byte(aux0, 2u * phase + 1u) * 8u;
|
||||
|
||||
let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
|
||||
let gp0_is = (aux1 >> (14u * phase + 0u)) & 127u;
|
||||
let gp1_is = (aux1 >> (14u * phase + 7u)) & 127u;
|
||||
|
||||
shmem[elem_idx] = f16(db * f32(g) * m);
|
||||
}
|
||||
}
|
||||
let gp0_signs = get_byte(ksigns_iq2xs[gp0_is / 4u], gp0_is % 4u);
|
||||
let gp1_signs = get_byte(ksigns_iq2xs[gp1_is / 4u], gp1_is % 4u);
|
||||
|
||||
let m_0_3_val4 = create_iq2_m4(gp0_signs, 0);
|
||||
let m_4_7_val4 = create_iq2_m4(gp0_signs, 1);
|
||||
let m_8_11_val4 = create_iq2_m4(gp1_signs, 0);
|
||||
let m_12_15_val4 = create_iq2_m4(gp1_signs, 1);
|
||||
|
||||
let gw_0_3_val4 = create_iq_gw4(gp0_ig, 0);
|
||||
let gw_4_7_val4 = create_iq_gw4(gp0_ig, 4);
|
||||
let gw_8_11_val4 = create_iq_gw4(gp1_ig, 0);
|
||||
let gw_12_15_val4 = create_iq_gw4(gp1_ig, 4);
|
||||
|
||||
store_shmem_iquants(vec4<f16>(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u);
|
||||
#endif // INIT_SRC0_SHMEM_IQ2_XXS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ2_XS
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 74u;
|
||||
#if defined(INIT_SRC0_SHMEM_IQ2_XS)
|
||||
let block_byte_base = src0_idx * 74u; // BLOCK_SIZE_BYTES = 74u;
|
||||
let d_byte_base = block_byte_base + 0u;
|
||||
let qs_byte_base = block_byte_base + 2u;
|
||||
let scales_byte_base = block_byte_base + 66u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
let d = load_f16_as_f32_at_src0(d_byte_base);
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
let sub_block = k_in_block / 32u;
|
||||
let phase = (k_in_block / NQ) % 2u;
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
let scale = (load_byte_at_src0_aligned(scales_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu;
|
||||
let db = d * (0.5 + f32(scale)) * 0.25;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
let qs_u32 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u * phase);
|
||||
|
||||
let entry_idx = k_in_block / 8u;
|
||||
let j = k_in_block % 8u;
|
||||
let gp0_ig = (qs_u32 & 0x1FFu) * 8u;
|
||||
let gp1_ig = ((qs_u32 >> 16u) & 0x1FFu) * 8u;
|
||||
|
||||
let ib = entry_idx & ~3u;
|
||||
let l = entry_idx & 3u;
|
||||
let gp0_is = (qs_u32 >> 9u) & 0x7Fu;
|
||||
let gp1_is = (qs_u32 >> 25u) & 0x7Fu;
|
||||
|
||||
let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u);
|
||||
let s = get_byte(scales_word, (ib % 16u) / 4u);
|
||||
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u);
|
||||
let dl = d * (0.5 + f32(s_nib)) * 0.25;
|
||||
let gp0_signs = get_byte(ksigns_iq2xs[gp0_is / 4u], gp0_is % 4u);
|
||||
let gp1_signs = get_byte(ksigns_iq2xs[gp1_is / 4u], gp1_is % 4u);
|
||||
|
||||
let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u);
|
||||
let qs_val = qs_word & 0xFFFFu;
|
||||
let ig = (qs_val & 511u) * 8u;
|
||||
let is = qs_val >> 9u;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
|
||||
let m_0_3_val4 = create_iq2_m4(gp0_signs, 0);
|
||||
let m_4_7_val4 = create_iq2_m4(gp0_signs, 1);
|
||||
let m_8_11_val4 = create_iq2_m4(gp1_signs, 0);
|
||||
let m_12_15_val4 = create_iq2_m4(gp1_signs, 1);
|
||||
|
||||
let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
|
||||
let gw_0_3_val4 = create_iq_gw4(gp0_ig, 0);
|
||||
let gw_4_7_val4 = create_iq_gw4(gp0_ig, 4);
|
||||
let gw_8_11_val4 = create_iq_gw4(gp1_ig, 0);
|
||||
let gw_12_15_val4 = create_iq_gw4(gp1_ig, 4);
|
||||
|
||||
shmem[elem_idx] = f16(dl * f32(g) * m);
|
||||
}
|
||||
}
|
||||
store_shmem_iquants(vec4<f16>(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u);
|
||||
#endif // INIT_SRC0_SHMEM_IQ2_XS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ2_S
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 82u;
|
||||
#if defined(INIT_SRC0_SHMEM_IQ2_S)
|
||||
let block_byte_base = src0_idx * 82u; // BLOCK_SIZE_BYTES = 82u;
|
||||
let d_byte_base = block_byte_base + 0u;
|
||||
let qs_byte_base = block_byte_base + 2u;
|
||||
let qh_byte_base = block_byte_base + 66u;
|
||||
let scales_byte_base = block_byte_base + 74u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
let d = load_f16_as_f32_at_src0(d_byte_base);
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
let sub_block = k_in_block / 32u;
|
||||
let phase = (k_in_block / NQ) % 2u;
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
let scale = (load_byte_at_src0_aligned(scales_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu;
|
||||
let db = d * (0.5 + f32(scale)) * 0.25;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
let qs_u16 = load_u32_at_src0(qs_byte_base + 4u * sub_block + 2u * phase) & 0xFFFFu;
|
||||
let signs_u16 = load_u32_at_src0(qs_byte_base + 32u + 4u * sub_block + 2u * phase) & 0xFFFFu;
|
||||
let qh_u4 = (load_byte_at_src0_aligned(qh_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu;
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let l = (k_in_block % 32u) / 8u;
|
||||
let j = k_in_block % 8u;
|
||||
let gp0_ig = ((qs_u16 & 0xFFu) | ((qh_u4 & 0x3u) << 8u)) * 8u;
|
||||
let gp1_ig = (((qs_u16 >> 8u) & 0xFFu) | ((qh_u4 & 0xCu) << 6u)) * 8u;
|
||||
|
||||
let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u);
|
||||
let s = get_byte(scales_word, ib % 4u);
|
||||
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u);
|
||||
let dl = d * (0.5 + f32(s_nib)) * 0.25;
|
||||
let gp0_signs = get_byte(signs_u16, 0);
|
||||
let gp1_signs = get_byte(signs_u16, 1);
|
||||
|
||||
let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u);
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u);
|
||||
let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u;
|
||||
let ig = (get_byte(qs_word, l) | qh_b) * 8u;
|
||||
let m_0_3_val4 = create_iq2_m4(gp0_signs, 0);
|
||||
let m_4_7_val4 = create_iq2_m4(gp0_signs, 1);
|
||||
let m_8_11_val4 = create_iq2_m4(gp1_signs, 0);
|
||||
let m_12_15_val4 = create_iq2_m4(gp1_signs, 1);
|
||||
|
||||
let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u);
|
||||
let signs = get_byte(signs_word, l);
|
||||
let gw_0_3_val4 = create_iq_gw4(gp0_ig, 0);
|
||||
let gw_4_7_val4 = create_iq_gw4(gp0_ig, 4);
|
||||
let gw_8_11_val4 = create_iq_gw4(gp1_ig, 0);
|
||||
let gw_12_15_val4 = create_iq_gw4(gp1_ig, 4);
|
||||
|
||||
let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(dl * f32(g) * m);
|
||||
}
|
||||
}
|
||||
store_shmem_iquants(vec4<f16>(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u);
|
||||
#endif // INIT_SRC0_SHMEM_IQ2_S
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ3_XXS
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 98u;
|
||||
#if defined(INIT_SRC0_SHMEM_IQ3_XXS)
|
||||
let block_byte_base = src0_idx * 98u; // BLOCK_SIZE_BYTES = 98u;
|
||||
let d_byte_base = block_byte_base + 0u;
|
||||
let qs_byte_base = block_byte_base + 2u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
let d = load_f16_as_f32_at_src0(d_byte_base);
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
let sub_block = k_in_block / 32u;
|
||||
let phase = (k_in_block / NQ) % 2u;
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
let qs_u32 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u * phase);
|
||||
let sign_u32 = load_u32_at_src0(qs_byte_base + 64u + 4u * sub_block);
|
||||
let db = d * (0.5 + f32(sign_u32 >> 28u)) * 0.5;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
let ig_0_3 = get_byte(qs_u32, 0);
|
||||
let ig_4_7 = get_byte(qs_u32, 1);
|
||||
let ig_8_11 = get_byte(qs_u32, 2);
|
||||
let ig_12_15 = get_byte(qs_u32, 3);
|
||||
|
||||
let ib_pair = k_in_block / 32u;
|
||||
let in_pair = k_in_block % 32u;
|
||||
let l = in_pair / 8u;
|
||||
let in_l = in_pair % 8u;
|
||||
let k2 = in_l / 4u;
|
||||
let j = in_l % 4u;
|
||||
let gp0_is = (sign_u32 >> (14u * phase + 0u)) & 0x7Fu;
|
||||
let gp1_is = (sign_u32 >> (14u * phase + 7u)) & 0x7Fu;
|
||||
|
||||
let ib = ib_pair * 2u;
|
||||
let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u;
|
||||
let sc_sign = load_u32_at_src0(sc_sign_off);
|
||||
let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5;
|
||||
let is = (sc_sign >> (7u * l)) & 127u;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
|
||||
let gp0_signs = get_byte(ksigns_iq2xs[gp0_is / 4u], gp0_is % 4u);
|
||||
let gp1_signs = get_byte(ksigns_iq2xs[gp1_is / 4u], gp1_is % 4u);
|
||||
|
||||
let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu;
|
||||
let ig_byte = get_byte(ig_word, k2);
|
||||
let g = get_byte(iq3xxs_grid[ig_byte], j);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u);
|
||||
let m_0_3_val4 = create_iq2_m4(gp0_signs, 0);
|
||||
let m_4_7_val4 = create_iq2_m4(gp0_signs, 1);
|
||||
let m_8_11_val4 = create_iq2_m4(gp1_signs, 0);
|
||||
let m_12_15_val4 = create_iq2_m4(gp1_signs, 1);
|
||||
|
||||
shmem[elem_idx] = f16(db * f32(g) * m);
|
||||
}
|
||||
}
|
||||
let gw_0_3_val4 = create_iq_gw4(ig_0_3);
|
||||
let gw_4_7_val4 = create_iq_gw4(ig_4_7);
|
||||
let gw_8_11_val4 = create_iq_gw4(ig_8_11);
|
||||
let gw_12_15_val4 = create_iq_gw4(ig_12_15);
|
||||
|
||||
store_shmem_iquants(vec4<f16>(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u);
|
||||
#endif // INIT_SRC0_SHMEM_IQ3_XXS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ3_S
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 110u;
|
||||
#if defined(INIT_SRC0_SHMEM_IQ3_S)
|
||||
let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u;
|
||||
let d_byte_base = block_byte_base + 0u;
|
||||
let qs_byte_base = block_byte_base + 2u;
|
||||
let qh_byte_base = block_byte_base + 66u;
|
||||
let signs_byte_base = block_byte_base + 74u;
|
||||
let scales_byte_base = block_byte_base + 106u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
let d = load_f16_as_f32_at_src0(d_byte_base);
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
let sub_block = k_in_block / 32u;
|
||||
let phase = (k_in_block / NQ) % 2u;
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
let scale = (load_byte_at_src0_aligned(scales_byte_base + 1u * (sub_block / 2u)) >> (4u * (sub_block % 2u))) & 0xFu;
|
||||
let db = d * (1.0 + 2.0 * f32(scale));
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
let qs_u32 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u * phase);
|
||||
let qh_u4 = (load_byte_at_src0_aligned(qh_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu;
|
||||
let signs_u16 = (load_u32_at_src0(signs_byte_base + 4u * sub_block + 2u * phase)) & 0xFFFFu;
|
||||
|
||||
let ib = k_in_block / 64u;
|
||||
let rest = k_in_block % 64u;
|
||||
let k = rest / 32u;
|
||||
let in_k = rest % 32u;
|
||||
let l = in_k / 8u;
|
||||
let in_l = in_k % 8u;
|
||||
let k2 = in_l / 4u;
|
||||
let j = in_l % 4u;
|
||||
let ig_0_3 = ((qs_u32 >> 0u) & 0xFFu) | ((qh_u4 & 0x1u) << 8u);
|
||||
let ig_4_7 = ((qs_u32 >> 8u) & 0xFFu) | ((qh_u4 & 0x2u) << 7u);
|
||||
let ig_8_11 = ((qs_u32 >> 16u) & 0xFFu) | ((qh_u4 & 0x4u) << 6u);
|
||||
let ig_12_15 = ((qs_u32 >> 24u) & 0xFFu) | ((qh_u4 & 0x8u) << 5u);
|
||||
|
||||
let scales_word = load_u32_at_src0(block_byte_base + 106u);
|
||||
let s = get_byte(scales_word, ib);
|
||||
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u);
|
||||
let dl = d * (1.0 + 2.0 * f32(s_nib));
|
||||
let gp0_signs = get_byte(signs_u16, 0);
|
||||
let gp1_signs = get_byte(signs_u16, 1);
|
||||
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u);
|
||||
let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k);
|
||||
let m_0_3_val4 = create_iq2_m4(gp0_signs, 0);
|
||||
let m_4_7_val4 = create_iq2_m4(gp0_signs, 1);
|
||||
let m_8_11_val4 = create_iq2_m4(gp1_signs, 0);
|
||||
let m_12_15_val4 = create_iq2_m4(gp1_signs, 1);
|
||||
|
||||
let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu;
|
||||
let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u);
|
||||
let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u);
|
||||
let ig = select(ig_lo, ig_hi, k2 != 0u);
|
||||
let gw_0_3_val4 = create_iq_gw4(ig_0_3);
|
||||
let gw_4_7_val4 = create_iq_gw4(ig_4_7);
|
||||
let gw_8_11_val4 = create_iq_gw4(ig_8_11);
|
||||
let gw_12_15_val4 = create_iq_gw4(ig_12_15);
|
||||
|
||||
let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u);
|
||||
let signs = get_byte(signs_word, l);
|
||||
|
||||
let g = get_byte(iq3s_grid[ig], j);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(dl * f32(g) * m);
|
||||
store_shmem_iquants(vec4<f16>(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u);
|
||||
store_shmem_iquants(vec4<f16>(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u);
|
||||
#endif // INIT_SRC0_SHMEM_IQ3_S
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ3_S
|
||||
#endif // i-quants (super block size: 256)
|
||||
|
|
|
|||
Loading…
Reference in New Issue