From a5a8496d31ef1690ff2addc65b555916c3cf8895 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 08:49:06 +0300 Subject: [PATCH] ggml : remove obsoloete wgsl templates (ggml/0) --- .../ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl | 107 ------ .../ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl | 323 ---------------- .../ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl | 295 --------------- .../wgsl-shaders/soft_max.tmpl.wgsl | 345 ------------------ 4 files changed, 1070 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl deleted file mode 100644 index b5e93b81..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +++ /dev/null @@ -1,107 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f32" - } - }, - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "i32" - } - }, - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f32" - } - } -] - -#end(VARIANTS) - -#define(SHADER) -enable f16; - -@group(0) @binding(0) -var src: array<{{SRC_TYPE}}>; - -@group(0) @binding(1) -var dst: array<{{DST_TYPE}}>; - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shapes - src_ne0: u32, - src_ne1: u32, - src_ne2: u32, - - dst_ne0: u32, - dst_ne1: u32, - dst_ne2: u32 -}; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); - i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); - let i2 = i / (params.src_ne1 * params.src_ne0); - i = i % (params.src_ne1 * params.src_ne0); - let i1 = i / params.src_ne0; - let i0 = i % params.src_ne0; - - var j = gid.x; - let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - let j2 = j / (params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne1 * params.dst_ne0); - let j1 = j / params.dst_ne0; - let j0 = j % params.dst_ne0; - - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; - - let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + - j2 * params.stride_dst2 + j3 * params.stride_dst3; - - dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx])); -} -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl deleted file mode 100644 index 03fcd548..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +++ /dev/null @@ -1,323 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "reglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "geglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "swiglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_oai_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "swiglu_oai_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "geglu_erf_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_quick_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(REGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return max(a, 0) * b; -} -#enddecl(REGLU) - -#decl(GEGLU) -const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876; -const GELU_COEF_A: {{TYPE}} = 0.044715; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); - return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b; -} -#enddecl(GEGLU) - -#decl(SWIGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a / (1.0 + exp(-a)) * b; -} -#enddecl(SWIGLU) - -#decl(SWIGLU_OAI) -fn op(a: f32, b: f32) -> f32 { - let xi = min(a, params.limit); - let gi = max(min(b, params.limit), -params.limit); - var out_glu = xi / (1.0 + exp(-xi * params.alpha)); - out_glu = out_glu * (1.0 + gi); - return out_glu; -} -#enddecl(SWIGLU_OAI) - -#decl(GEGLU_ERF) -const p_erf: {{TYPE}} = 0.3275911; -const a1_erf: {{TYPE}} = 0.254829592; -const a2_erf: {{TYPE}} = -0.284496736; -const a3_erf: {{TYPE}} = 1.421413741; -const a4_erf: {{TYPE}} = -1.453152027; -const a5_erf: {{TYPE}} = 1.061405429; -const SQRT_2_INV: {{TYPE}} = 0.7071067811865476; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let a_div_sqr2 = a * SQRT_2_INV; - let sign_x = sign(a_div_sqr2); - let x = abs(a_div_sqr2); - let t = 1.0 / (1.0 + p_erf * x); - let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); - let erf_approx = sign_x * y; - return 0.5 * a * (1.0 + erf_approx) * b; -} -#enddecl(GEGLU_ERF) - -#decl(GEGLU_QUICK) -const GELU_QUICK_COEF: {{TYPE}} = -1.702; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; -} -#enddecl(GEGLU_QUICK) - -#decl(NO_SPLIT) -@group(0) @binding(1) -var dst: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(0, params.ne0, params.swapped != 0); - return src0[base + offset]; -} - -fn b_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(params.ne0, 0, params.swapped != 0); - return src0[base + offset]; -} -#enddecl(NO_SPLIT) - -#decl(SPLIT) -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - return src0[base]; -} - -fn b_value(base: u32) -> {{TYPE}} { - return src1[base]; -} -#enddecl(SPLIT) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_src11: u32, - stride_src12: u32, - stride_src13: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // shape of dst - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - swapped: u32, - alpha: f32, - limit: f32, -} - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; - let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; - let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; - - dst[i_dst] = op(a_value(i_a), b_value(i_b)); -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl deleted file mode 100644 index 84dc8dbf..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +++ /dev/null @@ -1,295 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f32_ff", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_ff_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f16_ff", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_ff_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(ROTATE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - dst[i_dst0] = {{TYPE}}(out0); - dst[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE) - -#decl(ROTATE_INPLACE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - src0[i_dst0] = {{TYPE}}(out0); - src0[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE_INPLACE) - -#decl(NO_FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return 1.0f; -} -#enddecl(NO_FF_FUNC) - -#decl(FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return src2[params.offset_src2 + i/2]; -} -#enddecl(FF_FUNC) - -#decl(NO_FF_BINDINGS) - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -#enddecl(NO_FF_BINDINGS) - -#decl(NO_FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var params: Params; - -#enddecl(NO_FF_BINDINGS_INPLACE) - -#decl(FF_BINDINGS) - -@group(0) @binding(2) -var src2: array; - -@group(0) @binding(3) -var dst: array<{{TYPE}}>; - -@group(0) @binding(4) -var params: Params; - -#enddecl(FF_BINDINGS) - -#decl(FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var src2: array; - -@group(0) @binding(3) -var params: Params; - -#enddecl(FF_BINDINGS_INPLACE) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_src2: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - n_threads: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - n_dims: u32, - mode: u32, - theta_scale: f32, - attn_factor: f32, - freq_scale: f32, - ext_factor: f32, - corr_dim0: f32, - corr_dim1: f32, - sections0: u32, - sections1: u32, - sections2: u32, - sections3: u32 -}; - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array; - -DECLS - -fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { - let y = (f32(i / 2) - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// returns vector of (cos_theta, sin_theta) -// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row -fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { - var mscale = params.attn_factor; - var theta = params.freq_scale * theta_extrap; - if (params.ext_factor != 0.0f) { - let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; - theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; - mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale); - } - return vec2(cos(theta) * mscale, sin(theta) * mscale); -} - -fn pair_base(i0: u32, div_2: bool) -> u32 { - if (div_2) { - return i0 / 2; - } else { - return i0; - } -} - -fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { - if (is_vision) { - return params.n_dims; - } else if (is_neox || is_mrope) { - return params.n_dims / 2; - } else { - return 1; - } -} - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - // two elements per thread - if (gid.x >= params.n_threads) { - return; - } - - let is_neox = bool(params.mode & 2); - let is_mrope = bool(params.mode & 8); - let is_imrope = params.mode == 40; - let is_vision = params.mode == 24; - - var i = gid.x * 2; // start index for this thread - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; - let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - - if (i0 >= params.n_dims && !is_vision) { - let i_src = i_src_row + i0; - let i_dst = i_dst_row + i0; - rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1])); - return; - } - - var theta_base_mult: u32 = 0; - var theta_scale_pwr: u32 = i0 / 2; - if (is_mrope) { - let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3; - let sec_w = params.sections1 + params.sections0; - let sec_e = params.sections2 + sec_w; - let sector = (i0 / 2) % sect_dims; - if (is_imrope) { - if (sector % 3 == 1 && sector < 3 * params.sections1) { - theta_base_mult = 1; - } else if (sector % 3 == 2 && sector < 3 * params.sections2) { - theta_base_mult = 2; - } else if (sector % 3 == 0 && sector < 3 * params.sections0) { - theta_base_mult = 0; - } else { - theta_base_mult = 3; - } - } else { - if (sector >= params.sections0 && sector < sec_w) { - theta_base_mult = 1; - if (is_vision) { - theta_scale_pwr = sector - params.sections0; - } - } else if (sector >= sec_w && sector < sec_e) { - theta_base_mult = 2; - if (is_vision) { - theta_scale_pwr = sector - sec_w; - } - } else if (sector >= sec_e) { - if (is_vision) { - theta_scale_pwr = sector - sec_e; - theta_scale_pwr = (i0 / 2) % sec_e; - } - theta_base_mult = 3; - } else if (is_vision) { - theta_scale_pwr = sector; - } - } - } - let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr)); - let thetas = rope_yarn(theta_base/freq_factor(i0), i0); - - let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision); - let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision); - - let x0 = f32(src0[i_src]); - let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); - rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl deleted file mode 100644 index c74dc4cc..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +++ /dev/null @@ -1,345 +0,0 @@ -#define(VARIANTS) -[ - { - "SHADER_NAME": "soft_max_f32", - "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_inplace", - "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink", - "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink_inplace", - "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - } -] -#end(VARIANTS) - -#define(DECLS) - -#decl(BASE_BINDINGS) -@group(0) @binding(1) -var dst: array; - -@group(0) @binding(2) -var params: Params; -#enddecl(BASE_BINDINGS) - -#decl(BASE_BINDINGS_INPLACE) -@group(0) @binding(1) -var params: Params; -#enddecl(BASE_BINDINGS_INPLACE) - -#decl(SINK_BINDINGS) -@group(0) @binding(1) -var sinks: array; - -@group(0) @binding(2) -var dst: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(SINK_BINDINGS) - -#decl(SINK_BINDINGS_INPLACE) -@group(0) @binding(1) -var sinks: array; - -@group(0) @binding(2) -var params: Params; -#enddecl(SINK_BINDINGS_INPLACE) - -#decl(MASK_BINDINGS) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var dst: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(MASK_BINDINGS) - -#decl(MASK_BINDINGS_INPLACE) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var params: Params; -#enddecl(MASK_BINDINGS_INPLACE) - -#decl(MASK_SINK_BINDINGS) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var sinks: array; - -@group(0) @binding(3) -var dst: array; - -@group(0) @binding(4) -var params: Params; -#enddecl(MASK_SINK_BINDINGS) - -#decl(MASK_SINK_BINDINGS_INPLACE) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var sinks: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(MASK_SINK_BINDINGS_INPLACE) - -#decl(NOT_INPLACE) -fn inter_value(i: u32) -> f32 { - return dst[i]; -} - -fn update(i: u32, val: f32) { - dst[i] = val; -} -#enddecl(NOT_INPLACE) - -#decl(INPLACE) -fn inter_value(i: u32) -> f32 { - return src[i]; -} - -fn update(i: u32, val: f32) { - src[i] = val; -} -#enddecl(INPLACE) - -#decl(NO_MASK) -fn mask_val(i: u32) -> f32 { - return 0.0; -} -#enddecl(NO_MASK) - -#decl(MASK) -fn mask_val(i: u32) -> f32 { - return f32(mask[i]); -} -#enddecl(MASK) - -#decl(NO_SINK) -fn lower_max_bound(i2: u32) -> f32 { - return -1e30; -} - -fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val; -} -#enddecl(NO_SINK) - -#decl(SINK) -fn lower_max_bound(i2: u32) -> f32 { - return sinks[params.offset_sinks + i2]; -} - -fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val + exp(sinks[params.offset_sinks + i2] - max_val); -} -#enddecl(SINK) - -#end(DECLS) - -#define(SHADER) -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_sinks: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_src11: u32, - stride_src12: u32, - stride_src13: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // shape of src0/dst - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - // shape of src1 - ne12: u32, - ne13: u32, - - scale: f32, - max_bias: f32, - n_head_log2: f32, - m0: f32, - m1: f32, -}; - -@group(0) @binding(0) -var src: array; - -DECLS - -const CACHE_SIZE: u32 = 16; - -override wg_size: u32; -var scratch: array; - -@compute @workgroup_size(wg_size) -fn main(@builtin(workgroup_id) wid: vec3, - @builtin(local_invocation_id) lid: vec3) { - - var i = wid.x; - let i3 = i / (params.ne2 * params.ne1); - i = i % (params.ne2 * params.ne1); - let i2 = i / params.ne1; - let i1 = i % params.ne1; - let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; - let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11; - let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - let elems = (params.ne0 + wg_size - 1) / wg_size; - - let head = f32(i2); - let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); - - var cache: array; - - var max_val = lower_max_bound(i2); - var col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col); - max_val = max(max_val, val); - if (col < CACHE_SIZE) { - cache[col] = val; - } - col += wg_size; - } - - scratch[lid.x] = max_val; - workgroupBarrier(); - var offset = wg_size / 2; - while (offset > 0) { - if (lid.x < offset) { - scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]); - } - offset = offset / 2; - workgroupBarrier(); - } - let row_max = scratch[0]; - workgroupBarrier(); - - var sum = 0.0f; - col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col), - cache[col], col < CACHE_SIZE); - let ex = exp(val - row_max); - sum += ex; - if (col < CACHE_SIZE) { - cache[col] = ex; - } else { - update(i_dst_row + col, ex); - } - col += wg_size; - } - - scratch[lid.x] = sum; - workgroupBarrier(); - offset = wg_size / 2; - while (offset > 0) { - if (lid.x < offset) { - scratch[lid.x] += scratch[lid.x + offset]; - } - offset = offset / 2; - workgroupBarrier(); - } - let row_sum = add_sinks(scratch[0], i2, row_max); - - let sum_recip = 1.0 / row_sum; - col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); - col += wg_size; - } -} -#end(SHADER)