add fast matmul iquants (llama/22504)
This commit is contained in:
parent
66392cf1a2
commit
d74c56862b
|
|
@ -1806,6 +1806,25 @@ class ggml_webgpu_shader_lib {
|
|||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
defines.push_back("SRC0_INNER_TYPE=u32");
|
||||
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
variant += std::string("_") + src0_name;
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1422,7 +1422,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
use_fast = is_vec;
|
||||
use_fast = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -740,3 +740,426 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q6_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ4_NL
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
|
||||
let pos = k_in_block % 16u;
|
||||
let nib_shift = (k_in_block / 16u) * 4u;
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u);
|
||||
let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu;
|
||||
|
||||
shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ4_NL
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ4_XS
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 136u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d_scales_h = load_u32_at_src0(block_byte_base);
|
||||
let d = bitcast<vec2<f16>>(d_scales_h).x;
|
||||
let scales_h = d_scales_h >> 16u;
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let pos = k_in_block % 32u;
|
||||
|
||||
let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
|
||||
let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu;
|
||||
let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u;
|
||||
let dl = d * f16(i32(ls_lo | ls_hi) - 32);
|
||||
|
||||
let iqs = ib * 16u + (pos % 16u);
|
||||
let nib_shift = (pos / 16u) * 4u;
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u);
|
||||
let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu;
|
||||
|
||||
shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ4_XS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ1_S
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 50u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let pos = k_in_block % 32u;
|
||||
let l = pos / 8u;
|
||||
let j = pos % 8u;
|
||||
|
||||
let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu;
|
||||
let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
|
||||
|
||||
let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u);
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
|
||||
|
||||
let gw = iq1_grid[(ig + j) / 16u];
|
||||
let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u;
|
||||
let gs = bitcast<i32>(g << 30u) >> 30u;
|
||||
|
||||
shmem[elem_idx] = f16(dl * (f32(gs) + delta));
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ1_S
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ1_M
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 56u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let scales0 = load_u32_at_src0(block_byte_base + 48u);
|
||||
let scales1 = load_u32_at_src0(block_byte_base + 52u);
|
||||
let scale_packed = ((scales0 >> 12u) & 0xFu) |
|
||||
((scales0 >> 24u) & 0x00F0u) |
|
||||
((scales1 >> 4u) & 0x0F00u) |
|
||||
((scales1 >> 16u) & 0xF000u);
|
||||
let d = f32(bitcast<vec2<f16>>(scale_packed).x);
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let pos = k_in_block % 32u;
|
||||
let l = pos / 8u;
|
||||
let j = pos % 8u;
|
||||
|
||||
let scales = select(scales0, scales1, ib >= 4u);
|
||||
let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu;
|
||||
let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u;
|
||||
let dl = d * f32(2u * s_pair + 1u);
|
||||
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u);
|
||||
let qh = qh_word >> (16u * (ib % 2u));
|
||||
let qh_nib = (qh >> (4u * l)) & 0xFu;
|
||||
|
||||
let qs_w = load_u32_at_src0(block_byte_base + ib * 4u);
|
||||
let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u);
|
||||
|
||||
let ig = idx * 8u;
|
||||
let gw = iq1_grid[(ig + j) / 16u];
|
||||
let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u;
|
||||
let gs = bitcast<i32>(g << 30u) >> 30u;
|
||||
|
||||
shmem[elem_idx] = f16(dl * (f32(gs) + delta));
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ1_M
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ2_XXS
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 66u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let entry_idx = k_in_block / 8u;
|
||||
let j = k_in_block % 8u;
|
||||
|
||||
let ib = entry_idx & ~3u;
|
||||
let l = entry_idx & 3u;
|
||||
|
||||
let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u);
|
||||
let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u);
|
||||
let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25;
|
||||
|
||||
let ig = get_byte(aux0, l) * 8u;
|
||||
let is = (aux1 >> (7u * l)) & 127u;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
|
||||
|
||||
let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(db * f32(g) * m);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ2_XXS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ2_XS
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 74u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let entry_idx = k_in_block / 8u;
|
||||
let j = k_in_block % 8u;
|
||||
|
||||
let ib = entry_idx & ~3u;
|
||||
let l = entry_idx & 3u;
|
||||
|
||||
let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u);
|
||||
let s = get_byte(scales_word, (ib % 16u) / 4u);
|
||||
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u);
|
||||
let dl = d * (0.5 + f32(s_nib)) * 0.25;
|
||||
|
||||
let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u);
|
||||
let qs_val = qs_word & 0xFFFFu;
|
||||
let ig = (qs_val & 511u) * 8u;
|
||||
let is = qs_val >> 9u;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
|
||||
|
||||
let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(dl * f32(g) * m);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ2_XS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ2_S
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 82u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let ib = k_in_block / 32u;
|
||||
let l = (k_in_block % 32u) / 8u;
|
||||
let j = k_in_block % 8u;
|
||||
|
||||
let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u);
|
||||
let s = get_byte(scales_word, ib % 4u);
|
||||
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u);
|
||||
let dl = d * (0.5 + f32(s_nib)) * 0.25;
|
||||
|
||||
let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u);
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u);
|
||||
let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u;
|
||||
let ig = (get_byte(qs_word, l) | qh_b) * 8u;
|
||||
|
||||
let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u);
|
||||
let signs = get_byte(signs_word, l);
|
||||
|
||||
let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(dl * f32(g) * m);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ2_S
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ3_XXS
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 98u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let ib_pair = k_in_block / 32u;
|
||||
let in_pair = k_in_block % 32u;
|
||||
let l = in_pair / 8u;
|
||||
let in_l = in_pair % 8u;
|
||||
let k2 = in_l / 4u;
|
||||
let j = in_l % 4u;
|
||||
|
||||
let ib = ib_pair * 2u;
|
||||
let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u;
|
||||
let sc_sign = load_u32_at_src0(sc_sign_off);
|
||||
let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5;
|
||||
let is = (sc_sign >> (7u * l)) & 127u;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
|
||||
|
||||
let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu;
|
||||
let ig_byte = get_byte(ig_word, k2);
|
||||
let g = get_byte(iq3xxs_grid[ig_byte], j);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(db * f32(g) * m);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ3_XXS
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ3_S
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 110u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_as_f32_at_src0(block_byte_base);
|
||||
|
||||
let ib = k_in_block / 64u;
|
||||
let rest = k_in_block % 64u;
|
||||
let k = rest / 32u;
|
||||
let in_k = rest % 32u;
|
||||
let l = in_k / 8u;
|
||||
let in_l = in_k % 8u;
|
||||
let k2 = in_l / 4u;
|
||||
let j = in_l % 4u;
|
||||
|
||||
let scales_word = load_u32_at_src0(block_byte_base + 106u);
|
||||
let s = get_byte(scales_word, ib);
|
||||
let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u);
|
||||
let dl = d * (1.0 + 2.0 * f32(s_nib));
|
||||
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u);
|
||||
let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k);
|
||||
|
||||
let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu;
|
||||
let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u);
|
||||
let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u);
|
||||
let ig = select(ig_lo, ig_hi, k2 != 0u);
|
||||
|
||||
let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u);
|
||||
let signs = get_byte(signs_word, l);
|
||||
|
||||
let g = get_byte(iq3s_grid[ig], j);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u);
|
||||
|
||||
shmem[elem_idx] = f16(dl * f32(g) * m);
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ3_S
|
||||
|
|
|
|||
Loading…
Reference in New Issue