ggml-webgpu: add layer norm ops (llama/22406)

* shader(norm): add layer norm ops

* shader(norm): stablize floating point computation with Kahan summation and handle mixed types

* shader(norm): remove the non-contiguous strides

* shader(norm): use the original implementation rather than the kahan summation
This commit is contained in:
Chen Yuan 2026-05-03 23:52:53 -04:00 committed by Georgi Gerganov
parent 3bcac0a0c7
commit d1d0dc2348
3 changed files with 107 additions and 24 deletions

View File

@ -228,11 +228,13 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
/** Row Norm **/
struct ggml_webgpu_row_norm_pipeline_key {
ggml_op op;
bool inplace;
ggml_op op;
ggml_type src_type;
ggml_type dst_type;
bool inplace;
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
return op == other.op && inplace == other.inplace;
return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace;
}
};
@ -240,6 +242,8 @@ struct ggml_webgpu_row_norm_pipeline_key_hash {
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.src_type);
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
@ -1097,6 +1101,8 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_row_norm_pipeline_key key = {};
key.op = context.dst->op;
key.src_type = context.src0->type;
key.dst_type = context.dst->type;
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
auto it = row_norm_pipelines.find(key);
@ -1111,6 +1117,10 @@ class ggml_webgpu_shader_lib {
defines.push_back("RMS_NORM");
variant = "rms_norm";
break;
case GGML_OP_NORM:
defines.push_back("NORM");
variant = "norm";
break;
case GGML_OP_L2_NORM:
defines.push_back("L2_NORM");
variant = "l2_norm";
@ -1124,6 +1134,22 @@ class ggml_webgpu_shader_lib {
variant += "_inplace";
}
if (key.src_type == GGML_TYPE_F32) {
defines.push_back("SRC_F32");
variant += "_src_f32";
} else if (key.src_type == GGML_TYPE_F16) {
defines.push_back("SRC_F16");
variant += "_src_f16";
}
if (key.dst_type == GGML_TYPE_F32) {
defines.push_back("DST_F32");
variant += "_dst_f32";
} else if (key.dst_type == GGML_TYPE_F16) {
defines.push_back("DST_F16");
variant += "_dst_f16";
}
const uint32_t row_norm_wg_size = 128u;
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));

View File

@ -2927,6 +2927,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
} else {
return ggml_webgpu_row_norm(ctx, src0, node);
}
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
return ggml_webgpu_row_norm(ctx, src0, node);
case GGML_OP_ROPE:
@ -4071,6 +4072,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
}
case GGML_OP_RMS_NORM:
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;

View File

@ -1,20 +1,17 @@
#ifdef INPLACE
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
src[dst_offset] = scale * src[src_offset];
}
#if defined(SRC_F16) || defined(DST_F16)
enable f16;
#endif
@group(0) @binding(1)
var<uniform> params: Params;
#ifdef SRC_F16
#define SRC_TYPE f16
#else
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
dst[dst_offset] = scale * src[src_offset];
}
#define SRC_TYPE f32
#endif
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#ifdef DST_F16
#define DST_TYPE f16
#else
#define DST_TYPE f32
#endif
struct Params {
@ -40,9 +37,20 @@ struct Params {
};
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
var<storage, read_write> src: array<SRC_TYPE>;
var<workgroup> scratch: array<f32, WG_SIZE>;
#ifdef INPLACE
@group(0) @binding(1)
var<uniform> params: Params;
#else
@group(0) @binding(1)
var<storage, read_write> dst: array<DST_TYPE>;
@group(0) @binding(2)
var<uniform> params: Params;
#endif
var<workgroup> scratch: array<f32, WG_SIZE * 2u>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@ -65,34 +73,81 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
if (col >= params.ne0) {
break;
}
sum += pow(src[i_src_row + col], 2.0);
let v = f32(src[i_src_row + col]);
#ifdef NORM
sum += v;
#else
sum += v * v;
#endif
col += WG_SIZE;
}
scratch[lid.x] = sum;
workgroupBarrier();
var offset: u32 = WG_SIZE / 2;
var offset: u32 = WG_SIZE / 2u;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
}
offset = offset / 2;
offset /= 2u;
workgroupBarrier();
}
sum = scratch[0];
#ifdef RMS_NORM
#ifdef NORM
let mean = sum / f32(params.ne0);
var sq_sum = 0.0f;
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
let v = f32(src[i_src_row + col]);
let d = v - mean;
sq_sum += d * d;
col += WG_SIZE;
}
workgroupBarrier();
scratch[lid.x] = sq_sum;
workgroupBarrier();
offset = WG_SIZE / 2u;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
}
offset /= 2u;
workgroupBarrier();
}
let variance = scratch[0] / f32(params.ne0);
let scale = 1.0 / sqrt(variance + params.eps);
#elif defined(RMS_NORM)
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
#elif defined(L2_NORM)
let scale = 1.0/max(sqrt(sum), params.eps);
#endif
#ifdef NORM
let mean_val = mean;
#else
let mean_val = 0.0f;
#endif
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
update(i_src_row + col, i_dst_row + col, scale);
let i_src = i_src_row + col;
let i_dst = i_dst_row + col;
let v = src[i_src];
#ifdef INPLACE
src[i_dst] = scale * (v - mean_val);
#else
dst[i_dst] = scale * (v - mean_val);
#endif
col += WG_SIZE;
}
}