ggml-webgpu: Update the `RMS_NORM` preprocessor and add `L2_NORM` (llama/20665)

* Update the preprocessor of RMS_NORM and add L2_NORM.

* Fix the name of rms_norm to row_norm.
This commit is contained in:
Masashi Yoshimura 2026-03-28 11:47:59 +02:00 committed by Georgi Gerganov
parent 12015a2174
commit 3d004fbf0a
3 changed files with 171 additions and 16 deletions

View File

@ -151,6 +151,26 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
}
};
/** Row Norm **/
struct ggml_webgpu_row_norm_pipeline_key {
ggml_op op;
bool inplace;
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
return op == other.op && inplace == other.inplace;
}
};
struct ggml_webgpu_row_norm_pipeline_key_hash {
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
/** Pad **/
struct ggml_webgpu_pad_pipeline_key {
bool circular;
@ -438,6 +458,8 @@ class ggml_webgpu_shader_lib {
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
row_norm_pipelines; // op/inplace
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
get_rows_pipelines; // src_type, vectorized
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
@ -482,6 +504,44 @@ class ggml_webgpu_shader_lib {
return sum_rows_pipelines[1];
}
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_row_norm_pipeline_key key = {
.op = context.dst->op,
.inplace = context.inplace,
};
auto it = row_norm_pipelines.find(key);
if (it != row_norm_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant;
switch (key.op) {
case GGML_OP_RMS_NORM:
defines.push_back("OP_RMS_NORM");
variant = "rms_norm";
break;
case GGML_OP_L2_NORM:
defines.push_back("OP_L2_NORM");
variant = "l2_norm";
break;
default:
GGML_ABORT("Unsupported op for row_norm shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
return row_norm_pipelines[key];
}
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
bool vec4 = context.src0->ne[0] % 4 == 0;

View File

@ -366,7 +366,6 @@ struct webgpu_context_struct {
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
@ -1598,8 +1597,8 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
int inplace = ggml_webgpu_tensor_equal(src, dst);
static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool inplace = ggml_webgpu_tensor_equal(src, dst);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
@ -1630,8 +1629,15 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
entries, ggml_nrows(src));
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.dst = dst,
.max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE,
.inplace = inplace,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src));
}
static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
@ -2192,7 +2198,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_REPEAT:
return ggml_webgpu_repeat(ctx, src0, node);
case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node);
case GGML_OP_L2_NORM:
return ggml_webgpu_row_norm(ctx, src0, node);
case GGML_OP_ROPE:
return ggml_webgpu_rope(ctx, src0, src1, src2, node);
case GGML_OP_GLU:
@ -2616,15 +2623,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
webgpu_ctx->rms_norm_pipelines[0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
@ -2909,7 +2907,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
@ -3120,6 +3117,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
}
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_ROPE:

View File

@ -0,0 +1,97 @@
#ifdef INPLACE
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
src[dst_offset] = scale * src[src_offset];
}
@group(0) @binding(1)
var<uniform> params: Params;
#else
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
dst[dst_offset] = scale * src[src_offset];
}
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#endif
struct Params {
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements)
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Shape of src/dst
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
eps: f32
};
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
var<workgroup> scratch: array<f32, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
// one thread per row
var i = wid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
var sum = 0.0f;
var col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
sum += pow(src[i_src_row + col], 2.0);
col += WG_SIZE;
}
scratch[lid.x] = sum;
workgroupBarrier();
var offset: u32 = WG_SIZE / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
}
offset = offset / 2;
workgroupBarrier();
}
sum = scratch[0];
#ifdef OP_RMS_NORM
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
#elif OP_L2_NORM
let scale = 1.0/max(sqrt(sum), params.eps);
#endif
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
update(i_src_row + col, i_dst_row + col, scale);
col += WG_SIZE;
}
}