revert to using global_invocation_id for cpy shader (llama/23955)

This commit is contained in:
Masashi Yoshimura 2026-06-02 08:59:06 +09:00 committed by Georgi Gerganov
parent e728bae159
commit db2a39507c
1 changed files with 4 additions and 5 deletions

View File

@ -50,13 +50,13 @@ var<uniform> params: Params;
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(global_invocation_index) gindex: u32,
@builtin(global_invocation_id) gid: vec3<u32>,
) {
if (gindex >= params.ne) {
if (gid.x >= params.ne) {
return;
}
var i = gindex;
var i = gid.x;
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
let i2 = i / (params.src_ne1 * params.src_ne0);
@ -64,7 +64,7 @@ fn main(
let i1 = i / params.src_ne0;
let i0 = i % params.src_ne0;
var j = gindex;
var j = gid.x;
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
let j2 = j / (params.dst_ne1 * params.dst_ne0);
@ -80,4 +80,3 @@ fn main(
dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx]));
}