From d1d0dc2348f6b294598725ef8cd3d40652fb674d Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Sun, 3 May 2026 23:52:53 -0400 Subject: [PATCH] ggml-webgpu: add layer norm ops (llama/22406) * shader(norm): add layer norm ops * shader(norm): stablize floating point computation with Kahan summation and handle mixed types * shader(norm): remove the non-contiguous strides * shader(norm): use the original implementation rather than the kahan summation --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 32 +++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 + .../ggml-webgpu/wgsl-shaders/row_norm.wgsl | 97 +++++++++++++++---- 3 files changed, 107 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index cff93b8d1..c6dc2c211 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -228,11 +228,13 @@ struct ggml_webgpu_get_rows_pipeline_key_hash { /** Row Norm **/ struct ggml_webgpu_row_norm_pipeline_key { - ggml_op op; - bool inplace; + ggml_op op; + ggml_type src_type; + ggml_type dst_type; + bool inplace; bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const { - return op == other.op && inplace == other.inplace; + return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace; } }; @@ -240,6 +242,8 @@ struct ggml_webgpu_row_norm_pipeline_key_hash { size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.inplace); return seed; } @@ -1097,6 +1101,8 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_row_norm_pipeline_key key = {}; key.op = context.dst->op; + key.src_type = context.src0->type; + key.dst_type = context.dst->type; key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = row_norm_pipelines.find(key); @@ -1111,6 +1117,10 @@ class ggml_webgpu_shader_lib { defines.push_back("RMS_NORM"); variant = "rms_norm"; break; + case GGML_OP_NORM: + defines.push_back("NORM"); + variant = "norm"; + break; case GGML_OP_L2_NORM: defines.push_back("L2_NORM"); variant = "l2_norm"; @@ -1124,6 +1134,22 @@ class ggml_webgpu_shader_lib { variant += "_inplace"; } + if (key.src_type == GGML_TYPE_F32) { + defines.push_back("SRC_F32"); + variant += "_src_f32"; + } else if (key.src_type == GGML_TYPE_F16) { + defines.push_back("SRC_F16"); + variant += "_src_f16"; + } + + if (key.dst_type == GGML_TYPE_F32) { + defines.push_back("DST_F32"); + variant += "_dst_f32"; + } else if (key.dst_type == GGML_TYPE_F16) { + defines.push_back("DST_F16"); + variant += "_dst_f16"; + } + const uint32_t row_norm_wg_size = 128u; uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index cab0aead1..12f60a990 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2927,6 +2927,7 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, } else { return ggml_webgpu_row_norm(ctx, src0, node); } + case GGML_OP_NORM: case GGML_OP_L2_NORM: return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -4071,6 +4072,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_RMS_NORM: + case GGML_OP_NORM: case GGML_OP_L2_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl index bd8d32bde..5eaf5e7bb 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl @@ -1,20 +1,17 @@ -#ifdef INPLACE -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - src[dst_offset] = scale * src[src_offset]; -} +#if defined(SRC_F16) || defined(DST_F16) +enable f16; +#endif -@group(0) @binding(1) -var params: Params; +#ifdef SRC_F16 +#define SRC_TYPE f16 #else -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - dst[dst_offset] = scale * src[src_offset]; -} +#define SRC_TYPE f32 +#endif -@group(0) @binding(1) -var dst: array; - -@group(0) @binding(2) -var params: Params; +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 #endif struct Params { @@ -40,9 +37,20 @@ struct Params { }; @group(0) @binding(0) -var src: array; +var src: array; -var scratch: array; +#ifdef INPLACE +@group(0) @binding(1) +var params: Params; +#else +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; +#endif + +var scratch: array; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wid: vec3, @@ -65,34 +73,81 @@ fn main(@builtin(workgroup_id) wid: vec3, if (col >= params.ne0) { break; } - sum += pow(src[i_src_row + col], 2.0); + let v = f32(src[i_src_row + col]); +#ifdef NORM + sum += v; +#else + sum += v * v; +#endif col += WG_SIZE; } scratch[lid.x] = sum; workgroupBarrier(); - var offset: u32 = WG_SIZE / 2; + + var offset: u32 = WG_SIZE / 2u; while (offset > 0) { if (lid.x < offset) { scratch[lid.x] += scratch[lid.x + offset]; } - offset = offset / 2; + offset /= 2u; workgroupBarrier(); } sum = scratch[0]; -#ifdef RMS_NORM +#ifdef NORM + let mean = sum / f32(params.ne0); + var sq_sum = 0.0f; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let v = f32(src[i_src_row + col]); + let d = v - mean; + sq_sum += d * d; + col += WG_SIZE; + } + + workgroupBarrier(); + scratch[lid.x] = sq_sum; + workgroupBarrier(); + offset = WG_SIZE / 2u; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset /= 2u; + workgroupBarrier(); + } + + let variance = scratch[0] / f32(params.ne0); + let scale = 1.0 / sqrt(variance + params.eps); +#elif defined(RMS_NORM) let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); #elif defined(L2_NORM) let scale = 1.0/max(sqrt(sum), params.eps); #endif +#ifdef NORM + let mean_val = mean; +#else + let mean_val = 0.0f; +#endif + col = lid.x; for (var j: u32 = 0; j < elems; j++) { if (col >= params.ne0) { break; } - update(i_src_row + col, i_dst_row + col, scale); + let i_src = i_src_row + col; + let i_dst = i_dst_row + col; + let v = src[i_src]; +#ifdef INPLACE + src[i_dst] = scale * (v - mean_val); +#else + dst[i_dst] = scale * (v - mean_val); +#endif col += WG_SIZE; } }