ggml-webgpu: add q4_0/q8_0 SET_ROWS (llama/23760)
* Add q8_0 and q4_0 set_rows * Add fast(er) quantization set_rows path * formatting/naming * a little more naming * Remove unused constant * Don't override other override * Avoid bitcast * Narrow relaxation
This commit is contained in:
parent
f7aad4ed7e
commit
acd91d2c38
|
|
@ -84,16 +84,16 @@ struct ggml_webgpu_shader_lib_context {
|
|||
ggml_tensor * src5;
|
||||
ggml_tensor * dst;
|
||||
|
||||
uint32_t max_wg_size;
|
||||
size_t wg_mem_limit_bytes = 0;
|
||||
bool supports_subgroups = false;
|
||||
bool supports_subgroup_matrix = false;
|
||||
uint32_t sg_mat_m = 0;
|
||||
uint32_t sg_mat_n = 0;
|
||||
uint32_t sg_mat_k = 0;
|
||||
uint32_t min_subgroup_size = 0;
|
||||
uint32_t max_subgroup_size = 0;
|
||||
bool supports_dot_product = false;
|
||||
uint32_t max_wg_size;
|
||||
size_t wg_mem_limit_bytes = 0;
|
||||
bool supports_subgroups = false;
|
||||
bool supports_subgroup_matrix = false;
|
||||
uint32_t sg_mat_m = 0;
|
||||
uint32_t sg_mat_n = 0;
|
||||
uint32_t sg_mat_k = 0;
|
||||
uint32_t min_subgroup_size = 0;
|
||||
uint32_t max_subgroup_size = 0;
|
||||
bool supports_dot_product = false;
|
||||
std::string vendor;
|
||||
};
|
||||
|
||||
|
|
@ -166,9 +166,11 @@ struct ggml_webgpu_set_rows_pipeline_key {
|
|||
int dst_type;
|
||||
int vec4;
|
||||
int i64_idx;
|
||||
int pair_blocks;
|
||||
|
||||
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
|
||||
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
|
||||
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx &&
|
||||
pair_blocks == other.pair_blocks;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -178,6 +180,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash {
|
|||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vec4);
|
||||
ggml_webgpu_hash_combine(seed, key.i64_idx);
|
||||
ggml_webgpu_hash_combine(seed, key.pair_blocks);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
|
@ -185,6 +188,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash {
|
|||
struct ggml_webgpu_set_rows_shader_decisions {
|
||||
bool vec4;
|
||||
bool i64_idx;
|
||||
bool pair_blocks;
|
||||
uint32_t wg_size;
|
||||
};
|
||||
|
||||
|
|
@ -772,31 +776,30 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
|
||||
(uint32_t) ggml_blck_size(K->type);
|
||||
const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 &&
|
||||
context.src2->ne[0] % kv_vec_head_align == 0;
|
||||
const uint32_t kv_vec_head_align =
|
||||
K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type);
|
||||
const bool kv_vec_head_dims_aligned =
|
||||
context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0;
|
||||
// Compile with enough invocations to cover the largest reported subgroup.
|
||||
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) &&
|
||||
kv_vec_head_dims_aligned && kv_vec_type_supported &&
|
||||
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
|
||||
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && kv_vec_head_dims_aligned &&
|
||||
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
|
||||
(context.src2->type == K->type);
|
||||
const bool tile_can_dispatch_all_q_rows =
|
||||
context.max_subgroup_size > 0 &&
|
||||
context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size;
|
||||
const bool use_subgroup_matrix =
|
||||
context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
|
||||
context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0;
|
||||
const bool use_subgroup_matrix = context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
|
||||
context.src0->ne[0] % context.sg_mat_k == 0 &&
|
||||
context.src2->ne[0] % context.sg_mat_n == 0;
|
||||
const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 &&
|
||||
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
|
||||
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
tile_can_dispatch_all_q_rows && !use_vec;
|
||||
|
||||
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
|
||||
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
|
||||
use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
|
||||
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
|
||||
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
|
||||
use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {
|
||||
return decisions;
|
||||
|
|
@ -1131,9 +1134,9 @@ class ggml_webgpu_shader_lib {
|
|||
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
|
||||
flash_attn_blk_pipelines;
|
||||
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
|
||||
mul_mat_vec_pipelines; // fast mat-vec (n==1)
|
||||
mul_mat_vec_pipelines; // fast mat-vec (n==1)
|
||||
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
|
||||
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
|
||||
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
|
||||
std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash>
|
||||
quantize_q8_pipelines;
|
||||
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
|
||||
|
|
@ -1264,10 +1267,13 @@ class ggml_webgpu_shader_lib {
|
|||
}
|
||||
|
||||
webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_set_rows_pipeline_key key = {};
|
||||
key.dst_type = context.dst->type;
|
||||
key.vec4 = context.src0->ne[0] % 4 == 0;
|
||||
key.i64_idx = context.src1->type == GGML_TYPE_I64;
|
||||
const bool quantized = ggml_is_quantized(context.dst->type);
|
||||
ggml_webgpu_set_rows_pipeline_key key = {};
|
||||
key.dst_type = context.dst->type;
|
||||
key.vec4 =
|
||||
(context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0;
|
||||
key.i64_idx = context.src1->type == GGML_TYPE_I64;
|
||||
key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0);
|
||||
|
||||
auto it = set_rows_pipelines.find(key);
|
||||
if (it != set_rows_pipelines.end()) {
|
||||
|
|
@ -1286,6 +1292,14 @@ class ggml_webgpu_shader_lib {
|
|||
defines.push_back("DST_F16");
|
||||
variant += "_dstf16";
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
defines.push_back("DST_Q8_0");
|
||||
variant += "_dstq8_0";
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
defines.push_back("DST_Q4_0");
|
||||
variant += "_dstq4_0";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported dst type for set_rows shader");
|
||||
}
|
||||
|
|
@ -1298,13 +1312,19 @@ class ggml_webgpu_shader_lib {
|
|||
defines.push_back("I64_IDX");
|
||||
variant += "_i64idx";
|
||||
}
|
||||
if (key.pair_blocks) {
|
||||
defines.push_back("PAIR_BLOCKS");
|
||||
variant += "_pair_blocks";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_set_rows, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
|
||||
const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows;
|
||||
auto processed = preprocessor.preprocess(shader_source, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
|
||||
decisions->vec4 = key.vec4;
|
||||
decisions->i64_idx = key.i64_idx;
|
||||
decisions->pair_blocks = key.pair_blocks;
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
set_rows_pipelines[key].context = decisions;
|
||||
|
|
@ -1660,7 +1680,7 @@ class ggml_webgpu_shader_lib {
|
|||
key.type = context.dst->type;
|
||||
key.d_state = (int) context.src0->ne[0];
|
||||
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
|
||||
auto it = ssm_scan_pipelines.find(key);
|
||||
if (it != ssm_scan_pipelines.end()) {
|
||||
|
|
@ -1819,7 +1839,7 @@ class ggml_webgpu_shader_lib {
|
|||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.use_mmvq =
|
||||
key.use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
||||
|
||||
auto it = mul_mat_vec_pipelines.find(key);
|
||||
|
|
|
|||
|
|
@ -1331,7 +1331,11 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_set_rows(webgpu_context & ct
|
|||
}
|
||||
|
||||
uint32_t threads;
|
||||
if (decisions->vec4) {
|
||||
if (ggml_is_quantized(dst->type)) {
|
||||
const uint32_t blocks_per_row = src->ne[0] / ggml_blck_size(dst->type);
|
||||
threads =
|
||||
(src->ne[1] * src->ne[2] * src->ne[3]) * (decisions->pair_blocks ? (blocks_per_row / 2) : blocks_per_row);
|
||||
} else if (decisions->vec4) {
|
||||
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
|
||||
} else {
|
||||
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
|
||||
|
|
@ -4046,8 +4050,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
|
||||
(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
|
||||
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_Q8_0 ||
|
||||
op->type == GGML_TYPE_Q4_0) &&
|
||||
src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {
|
||||
|
|
|
|||
|
|
@ -71,7 +71,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
return;
|
||||
}
|
||||
|
||||
// getting the row from gid
|
||||
let elems_per_row = params.ne0 / VEC_SIZE;
|
||||
var i = gid.x / elems_per_row;
|
||||
|
||||
|
|
@ -104,6 +103,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
|
||||
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
|
||||
|
||||
let col_idx = (gid.x % elems_per_row);
|
||||
dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
|
||||
let col_idx = gid.x % elems_per_row;
|
||||
dst[i_dst_row / VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row / VEC_SIZE + col_idx]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,224 @@
|
|||
#ifdef DST_Q8_0
|
||||
#define BLOCK_SIZE 32u
|
||||
#define BLOCK_BYTES 34u
|
||||
#define QS_WORDS 8u
|
||||
#elif defined(DST_Q4_0)
|
||||
#define BLOCK_SIZE 32u
|
||||
#define BLOCK_BYTES 18u
|
||||
#define QS_WORDS 4u
|
||||
#endif
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> idx: array<u32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
#ifdef PAIR_BLOCKS
|
||||
var<storage, read_write> dst: array<u32>;
|
||||
#else
|
||||
var<storage, read_write> dst: array<atomic<u32>>;
|
||||
#endif
|
||||
|
||||
#ifdef I64_IDX
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> error: atomic<u32>;
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
#define PARAMS_BINDING 3
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
offset_idx: u32, // in elements
|
||||
offset_dst: u32, // in blocks
|
||||
|
||||
// Strides (in elements / blocks)
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
stride_idx0: u32,
|
||||
stride_idx1: u32,
|
||||
stride_idx2: u32,
|
||||
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// Shape of src
|
||||
ne0: u32,
|
||||
n_rows: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
|
||||
// Shape of idx
|
||||
idx1: u32,
|
||||
idx2: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(PARAMS_BINDING)
|
||||
var<uniform> params: Params;
|
||||
|
||||
// if the quantization type is unaligned and there are an odd number of blocks per row, we need to store atomically
|
||||
#ifndef PAIR_BLOCKS
|
||||
fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) {
|
||||
loop {
|
||||
let old = atomicLoad(&dst[word_idx]);
|
||||
let merged = (old & ~mask) | (bits & mask);
|
||||
let result = atomicCompareExchangeWeak(&dst[word_idx], old, merged);
|
||||
if (result.exchanged) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) {
|
||||
let old = dst[word_idx];
|
||||
dst[word_idx] = (old & ~mask) | (bits & mask);
|
||||
}
|
||||
#endif
|
||||
|
||||
fn store_u16(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) {
|
||||
let total_byte_offset = block_byte_offset + byte_offset;
|
||||
let word_idx = dst_word_idx + total_byte_offset / 4u;
|
||||
let shift = (total_byte_offset & 2u) * 8u;
|
||||
let mask = 0xFFFFu << shift;
|
||||
merge_store_dst_word(word_idx, mask, (value & 0xFFFFu) << shift);
|
||||
}
|
||||
|
||||
fn store_u32(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) {
|
||||
let total_byte_offset = block_byte_offset + byte_offset;
|
||||
let word_idx = dst_word_idx + total_byte_offset / 4u;
|
||||
let shift = (total_byte_offset & 3u) * 8u;
|
||||
|
||||
if (shift == 0u) {
|
||||
#ifdef PAIR_BLOCKS
|
||||
dst[word_idx] = value;
|
||||
#else
|
||||
atomicStore(&dst[word_idx], value);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
let lo_mask = 0xFFFFFFFFu << shift;
|
||||
let hi_mask = (1u << shift) - 1u;
|
||||
merge_store_dst_word(word_idx, lo_mask, value << shift);
|
||||
merge_store_dst_word(word_idx + 1u, hi_mask, value >> (32u - shift));
|
||||
}
|
||||
|
||||
fn quantize_block_params(src_block: u32) -> vec2<f32> {
|
||||
#ifdef DST_Q8_0
|
||||
var amax = 0.0;
|
||||
for (var j: u32 = 0u; j < BLOCK_SIZE; j++) {
|
||||
amax = max(amax, abs(src[src_block + j]));
|
||||
}
|
||||
|
||||
let d = amax / 127.0;
|
||||
let id = select(0.0, 1.0 / d, d > 0.0);
|
||||
return vec2(d, id);
|
||||
#elif defined(DST_Q4_0)
|
||||
var amax = 0.0;
|
||||
var max_val = 0.0;
|
||||
for (var j: u32 = 0u; j < BLOCK_SIZE; j++) {
|
||||
let v = src[src_block + j];
|
||||
let av = abs(v);
|
||||
if (amax < av) {
|
||||
amax = av;
|
||||
max_val = v;
|
||||
}
|
||||
}
|
||||
|
||||
let d = max_val / -8.0;
|
||||
let id = select(0.0, 1.0 / d, d != 0.0);
|
||||
return vec2(d, id);
|
||||
#endif
|
||||
}
|
||||
|
||||
fn quantize_block_word(src_block: u32, j: u32, id: f32) -> u32 {
|
||||
#ifdef DST_Q8_0
|
||||
let base = src_block + j * 4u;
|
||||
return (u32(i32(round(src[base + 0u] * id)) & 0xFF) << 0u) |
|
||||
(u32(i32(round(src[base + 1u] * id)) & 0xFF) << 8u) |
|
||||
(u32(i32(round(src[base + 2u] * id)) & 0xFF) << 16u) |
|
||||
(u32(i32(round(src[base + 3u] * id)) & 0xFF) << 24u);
|
||||
#elif defined(DST_Q4_0)
|
||||
var packed_q = 0u;
|
||||
for (var k: u32 = 0u; k < 4u; k++) {
|
||||
let x0 = src[src_block + j * 4u + k] * id;
|
||||
let x1 = src[src_block + 16u + j * 4u + k] * id;
|
||||
let q0 = u32(clamp(i32(x0 + 8.5), 0, 15));
|
||||
let q1 = u32(clamp(i32(x1 + 8.5), 0, 15));
|
||||
packed_q |= (q0 & 0xFu) << (8u * k);
|
||||
packed_q |= (q1 & 0xFu) << (8u * k + 4u);
|
||||
}
|
||||
return packed_q;
|
||||
#endif
|
||||
}
|
||||
|
||||
fn quantize_block(src_block: u32, dst_word_idx: u32, block_byte_offset: u32) {
|
||||
let params = quantize_block_params(src_block);
|
||||
let d = params.x;
|
||||
let id = params.y;
|
||||
let packed_d = pack2x16float(vec2(d, 0.0)) & 0xFFFFu;
|
||||
store_u16(dst_word_idx, block_byte_offset, 0u, packed_d);
|
||||
|
||||
for (var j: u32 = 0u; j < QS_WORDS; j++) {
|
||||
store_u32(dst_word_idx, block_byte_offset, 2u + j * 4u, quantize_block_word(src_block, j, id));
|
||||
}
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let blocks_per_row = params.ne0 / BLOCK_SIZE;
|
||||
#ifdef PAIR_BLOCKS
|
||||
let blocks_per_invocation = 2u;
|
||||
#else
|
||||
let blocks_per_invocation = 1u;
|
||||
#endif
|
||||
let invocations_per_row = blocks_per_row / blocks_per_invocation;
|
||||
let total_invocations = params.ne3 * params.ne2 * params.n_rows * invocations_per_row;
|
||||
if (gid.x >= total_invocations) {
|
||||
return;
|
||||
}
|
||||
|
||||
var i = gid.x / invocations_per_row;
|
||||
let block_in_row = (gid.x % invocations_per_row) * blocks_per_invocation;
|
||||
|
||||
let i_src3 = i / (params.ne2 * params.n_rows);
|
||||
i = i % (params.ne2 * params.n_rows);
|
||||
let i_src2 = i / params.n_rows;
|
||||
let i_src1 = i % params.n_rows;
|
||||
|
||||
let i_idx2 = i_src3 % params.idx2;
|
||||
let i_idx1 = i_src2 % params.idx1;
|
||||
let i_idx0 = i_src1;
|
||||
|
||||
#ifdef I64_IDX
|
||||
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2u;
|
||||
let idx_val = idx[idx_high];
|
||||
let idx_low_val = idx[idx_high + 1u];
|
||||
|
||||
if (idx_low_val != 0u) {
|
||||
atomicStore(&error, 1u);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
|
||||
let idx_val = idx[idx_i];
|
||||
#endif
|
||||
|
||||
let dst_row_blocks = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
|
||||
let src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
|
||||
let src_block = src_row + block_in_row * BLOCK_SIZE;
|
||||
let dst_block_byte = (dst_row_blocks + block_in_row) * BLOCK_BYTES;
|
||||
|
||||
let dst_word_idx = dst_block_byte / 4u;
|
||||
#ifdef PAIR_BLOCKS
|
||||
quantize_block(src_block, dst_word_idx, 0u);
|
||||
quantize_block(src_block + BLOCK_SIZE, dst_word_idx, BLOCK_BYTES);
|
||||
#else
|
||||
quantize_block(src_block, dst_word_idx, dst_block_byte & 3u);
|
||||
#endif
|
||||
}
|
||||
Loading…
Reference in New Issue