ggml-webgpu: add q4_0/q8_0 SET_ROWS (llama/23760)

* Add q8_0 and q4_0 set_rows

* Add fast(er) quantization set_rows path

* formatting/naming

* a little more naming

* Remove unused constant

* Don't override other override

* Avoid bitcast

* Narrow relaxation
This commit is contained in:
Reese Levine 2026-05-29 14:14:11 -07:00 committed by Georgi Gerganov
parent f7aad4ed7e
commit acd91d2c38
4 changed files with 289 additions and 41 deletions

View File

@ -84,16 +84,16 @@ struct ggml_webgpu_shader_lib_context {
ggml_tensor * src5;
ggml_tensor * dst;
uint32_t max_wg_size;
size_t wg_mem_limit_bytes = 0;
bool supports_subgroups = false;
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0;
uint32_t sg_mat_n = 0;
uint32_t sg_mat_k = 0;
uint32_t min_subgroup_size = 0;
uint32_t max_subgroup_size = 0;
bool supports_dot_product = false;
uint32_t max_wg_size;
size_t wg_mem_limit_bytes = 0;
bool supports_subgroups = false;
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0;
uint32_t sg_mat_n = 0;
uint32_t sg_mat_k = 0;
uint32_t min_subgroup_size = 0;
uint32_t max_subgroup_size = 0;
bool supports_dot_product = false;
std::string vendor;
};
@ -166,9 +166,11 @@ struct ggml_webgpu_set_rows_pipeline_key {
int dst_type;
int vec4;
int i64_idx;
int pair_blocks;
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx &&
pair_blocks == other.pair_blocks;
}
};
@ -178,6 +180,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.vec4);
ggml_webgpu_hash_combine(seed, key.i64_idx);
ggml_webgpu_hash_combine(seed, key.pair_blocks);
return seed;
}
};
@ -185,6 +188,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash {
struct ggml_webgpu_set_rows_shader_decisions {
bool vec4;
bool i64_idx;
bool pair_blocks;
uint32_t wg_size;
};
@ -772,31 +776,30 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
(uint32_t) ggml_blck_size(K->type);
const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 &&
context.src2->ne[0] % kv_vec_head_align == 0;
const uint32_t kv_vec_head_align =
K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type);
const bool kv_vec_head_dims_aligned =
context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0;
// Compile with enough invocations to cover the largest reported subgroup.
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) &&
kv_vec_head_dims_aligned && kv_vec_type_supported &&
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && kv_vec_head_dims_aligned &&
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
(context.src2->type == K->type);
const bool tile_can_dispatch_all_q_rows =
context.max_subgroup_size > 0 &&
context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size;
const bool use_subgroup_matrix =
context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0;
const bool use_subgroup_matrix = context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
context.src0->ne[0] % context.sg_mat_k == 0 &&
context.src2->ne[0] % context.sg_mat_n == 0;
const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 &&
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
tile_can_dispatch_all_q_rows && !use_vec;
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {
return decisions;
@ -1131,9 +1134,9 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
flash_attn_blk_pipelines;
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
mul_mat_vec_pipelines; // fast mat-vec (n==1)
mul_mat_vec_pipelines; // fast mat-vec (n==1)
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash>
quantize_q8_pipelines;
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
@ -1264,10 +1267,13 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_set_rows_pipeline_key key = {};
key.dst_type = context.dst->type;
key.vec4 = context.src0->ne[0] % 4 == 0;
key.i64_idx = context.src1->type == GGML_TYPE_I64;
const bool quantized = ggml_is_quantized(context.dst->type);
ggml_webgpu_set_rows_pipeline_key key = {};
key.dst_type = context.dst->type;
key.vec4 =
(context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0;
key.i64_idx = context.src1->type == GGML_TYPE_I64;
key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0);
auto it = set_rows_pipelines.find(key);
if (it != set_rows_pipelines.end()) {
@ -1286,6 +1292,14 @@ class ggml_webgpu_shader_lib {
defines.push_back("DST_F16");
variant += "_dstf16";
break;
case GGML_TYPE_Q8_0:
defines.push_back("DST_Q8_0");
variant += "_dstq8_0";
break;
case GGML_TYPE_Q4_0:
defines.push_back("DST_Q4_0");
variant += "_dstq4_0";
break;
default:
GGML_ABORT("Unsupported dst type for set_rows shader");
}
@ -1298,13 +1312,19 @@ class ggml_webgpu_shader_lib {
defines.push_back("I64_IDX");
variant += "_i64idx";
}
if (key.pair_blocks) {
defines.push_back("PAIR_BLOCKS");
variant += "_pair_blocks";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_set_rows, defines);
auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows;
auto processed = preprocessor.preprocess(shader_source, defines);
auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
decisions->vec4 = key.vec4;
decisions->i64_idx = key.i64_idx;
decisions->pair_blocks = key.pair_blocks;
decisions->wg_size = context.max_wg_size;
set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
set_rows_pipelines[key].context = decisions;
@ -1660,7 +1680,7 @@ class ggml_webgpu_shader_lib {
key.type = context.dst->type;
key.d_state = (int) context.src0->ne[0];
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
ggml_webgpu_tensor_overlap(context.src1, context.src5);
ggml_webgpu_tensor_overlap(context.src1, context.src5);
auto it = ssm_scan_pipelines.find(key);
if (it != ssm_scan_pipelines.end()) {
@ -1819,7 +1839,7 @@ class ggml_webgpu_shader_lib {
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.use_mmvq =
key.use_mmvq =
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
auto it = mul_mat_vec_pipelines.find(key);

View File

@ -1331,7 +1331,11 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_set_rows(webgpu_context & ct
}
uint32_t threads;
if (decisions->vec4) {
if (ggml_is_quantized(dst->type)) {
const uint32_t blocks_per_row = src->ne[0] / ggml_blck_size(dst->type);
threads =
(src->ne[1] * src->ne[2] * src->ne[3]) * (decisions->pair_blocks ? (blocks_per_row / 2) : blocks_per_row);
} else if (decisions->vec4) {
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
} else {
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
@ -4046,8 +4050,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32);
break;
case GGML_OP_SET_ROWS:
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_Q8_0 ||
op->type == GGML_TYPE_Q4_0) &&
src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
break;
case GGML_OP_GET_ROWS:
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {

View File

@ -71,7 +71,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
return;
}
// getting the row from gid
let elems_per_row = params.ne0 / VEC_SIZE;
var i = gid.x / elems_per_row;
@ -104,6 +103,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
let col_idx = (gid.x % elems_per_row);
dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
let col_idx = gid.x % elems_per_row;
dst[i_dst_row / VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row / VEC_SIZE + col_idx]);
}

View File

@ -0,0 +1,224 @@
#ifdef DST_Q8_0
#define BLOCK_SIZE 32u
#define BLOCK_BYTES 34u
#define QS_WORDS 8u
#elif defined(DST_Q4_0)
#define BLOCK_SIZE 32u
#define BLOCK_BYTES 18u
#define QS_WORDS 4u
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> idx: array<u32>;
@group(0) @binding(2)
#ifdef PAIR_BLOCKS
var<storage, read_write> dst: array<u32>;
#else
var<storage, read_write> dst: array<atomic<u32>>;
#endif
#ifdef I64_IDX
@group(0) @binding(3)
var<storage, read_write> error: atomic<u32>;
#define PARAMS_BINDING 4
#else
#define PARAMS_BINDING 3
#endif
struct Params {
offset_src: u32, // in elements
offset_idx: u32, // in elements
offset_dst: u32, // in blocks
// Strides (in elements / blocks)
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_idx0: u32,
stride_idx1: u32,
stride_idx2: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Shape of src
ne0: u32,
n_rows: u32,
ne2: u32,
ne3: u32,
// Shape of idx
idx1: u32,
idx2: u32,
};
@group(0) @binding(PARAMS_BINDING)
var<uniform> params: Params;
// if the quantization type is unaligned and there are an odd number of blocks per row, we need to store atomically
#ifndef PAIR_BLOCKS
fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) {
loop {
let old = atomicLoad(&dst[word_idx]);
let merged = (old & ~mask) | (bits & mask);
let result = atomicCompareExchangeWeak(&dst[word_idx], old, merged);
if (result.exchanged) {
return;
}
}
}
#else
fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) {
let old = dst[word_idx];
dst[word_idx] = (old & ~mask) | (bits & mask);
}
#endif
fn store_u16(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) {
let total_byte_offset = block_byte_offset + byte_offset;
let word_idx = dst_word_idx + total_byte_offset / 4u;
let shift = (total_byte_offset & 2u) * 8u;
let mask = 0xFFFFu << shift;
merge_store_dst_word(word_idx, mask, (value & 0xFFFFu) << shift);
}
fn store_u32(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) {
let total_byte_offset = block_byte_offset + byte_offset;
let word_idx = dst_word_idx + total_byte_offset / 4u;
let shift = (total_byte_offset & 3u) * 8u;
if (shift == 0u) {
#ifdef PAIR_BLOCKS
dst[word_idx] = value;
#else
atomicStore(&dst[word_idx], value);
#endif
return;
}
let lo_mask = 0xFFFFFFFFu << shift;
let hi_mask = (1u << shift) - 1u;
merge_store_dst_word(word_idx, lo_mask, value << shift);
merge_store_dst_word(word_idx + 1u, hi_mask, value >> (32u - shift));
}
fn quantize_block_params(src_block: u32) -> vec2<f32> {
#ifdef DST_Q8_0
var amax = 0.0;
for (var j: u32 = 0u; j < BLOCK_SIZE; j++) {
amax = max(amax, abs(src[src_block + j]));
}
let d = amax / 127.0;
let id = select(0.0, 1.0 / d, d > 0.0);
return vec2(d, id);
#elif defined(DST_Q4_0)
var amax = 0.0;
var max_val = 0.0;
for (var j: u32 = 0u; j < BLOCK_SIZE; j++) {
let v = src[src_block + j];
let av = abs(v);
if (amax < av) {
amax = av;
max_val = v;
}
}
let d = max_val / -8.0;
let id = select(0.0, 1.0 / d, d != 0.0);
return vec2(d, id);
#endif
}
fn quantize_block_word(src_block: u32, j: u32, id: f32) -> u32 {
#ifdef DST_Q8_0
let base = src_block + j * 4u;
return (u32(i32(round(src[base + 0u] * id)) & 0xFF) << 0u) |
(u32(i32(round(src[base + 1u] * id)) & 0xFF) << 8u) |
(u32(i32(round(src[base + 2u] * id)) & 0xFF) << 16u) |
(u32(i32(round(src[base + 3u] * id)) & 0xFF) << 24u);
#elif defined(DST_Q4_0)
var packed_q = 0u;
for (var k: u32 = 0u; k < 4u; k++) {
let x0 = src[src_block + j * 4u + k] * id;
let x1 = src[src_block + 16u + j * 4u + k] * id;
let q0 = u32(clamp(i32(x0 + 8.5), 0, 15));
let q1 = u32(clamp(i32(x1 + 8.5), 0, 15));
packed_q |= (q0 & 0xFu) << (8u * k);
packed_q |= (q1 & 0xFu) << (8u * k + 4u);
}
return packed_q;
#endif
}
fn quantize_block(src_block: u32, dst_word_idx: u32, block_byte_offset: u32) {
let params = quantize_block_params(src_block);
let d = params.x;
let id = params.y;
let packed_d = pack2x16float(vec2(d, 0.0)) & 0xFFFFu;
store_u16(dst_word_idx, block_byte_offset, 0u, packed_d);
for (var j: u32 = 0u; j < QS_WORDS; j++) {
store_u32(dst_word_idx, block_byte_offset, 2u + j * 4u, quantize_block_word(src_block, j, id));
}
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let blocks_per_row = params.ne0 / BLOCK_SIZE;
#ifdef PAIR_BLOCKS
let blocks_per_invocation = 2u;
#else
let blocks_per_invocation = 1u;
#endif
let invocations_per_row = blocks_per_row / blocks_per_invocation;
let total_invocations = params.ne3 * params.ne2 * params.n_rows * invocations_per_row;
if (gid.x >= total_invocations) {
return;
}
var i = gid.x / invocations_per_row;
let block_in_row = (gid.x % invocations_per_row) * blocks_per_invocation;
let i_src3 = i / (params.ne2 * params.n_rows);
i = i % (params.ne2 * params.n_rows);
let i_src2 = i / params.n_rows;
let i_src1 = i % params.n_rows;
let i_idx2 = i_src3 % params.idx2;
let i_idx1 = i_src2 % params.idx1;
let i_idx0 = i_src1;
#ifdef I64_IDX
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2u;
let idx_val = idx[idx_high];
let idx_low_val = idx[idx_high + 1u];
if (idx_low_val != 0u) {
atomicStore(&error, 1u);
return;
}
#else
let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
let idx_val = idx[idx_i];
#endif
let dst_row_blocks = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
let src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
let src_block = src_row + block_in_row * BLOCK_SIZE;
let dst_block_byte = (dst_row_blocks + block_in_row) * BLOCK_BYTES;
let dst_word_idx = dst_block_byte / 4u;
#ifdef PAIR_BLOCKS
quantize_block(src_block, dst_word_idx, 0u);
quantize_block(src_block + BLOCK_SIZE, dst_word_idx, BLOCK_BYTES);
#else
quantize_block(src_block, dst_word_idx, dst_block_byte & 3u);
#endif
}