ggml-webgpu: add layer norm ops (llama/22406)
* shader(norm): add layer norm ops * shader(norm): stablize floating point computation with Kahan summation and handle mixed types * shader(norm): remove the non-contiguous strides * shader(norm): use the original implementation rather than the kahan summation
This commit is contained in:
parent
3bcac0a0c7
commit
d1d0dc2348
|
|
@ -228,11 +228,13 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
|
|||
/** Row Norm **/
|
||||
|
||||
struct ggml_webgpu_row_norm_pipeline_key {
|
||||
ggml_op op;
|
||||
bool inplace;
|
||||
ggml_op op;
|
||||
ggml_type src_type;
|
||||
ggml_type dst_type;
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
|
||||
return op == other.op && inplace == other.inplace;
|
||||
return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -240,6 +242,8 @@ struct ggml_webgpu_row_norm_pipeline_key_hash {
|
|||
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.src_type);
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
|
|
@ -1097,6 +1101,8 @@ class ggml_webgpu_shader_lib {
|
|||
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_row_norm_pipeline_key key = {};
|
||||
key.op = context.dst->op;
|
||||
key.src_type = context.src0->type;
|
||||
key.dst_type = context.dst->type;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
|
||||
auto it = row_norm_pipelines.find(key);
|
||||
|
|
@ -1111,6 +1117,10 @@ class ggml_webgpu_shader_lib {
|
|||
defines.push_back("RMS_NORM");
|
||||
variant = "rms_norm";
|
||||
break;
|
||||
case GGML_OP_NORM:
|
||||
defines.push_back("NORM");
|
||||
variant = "norm";
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
defines.push_back("L2_NORM");
|
||||
variant = "l2_norm";
|
||||
|
|
@ -1124,6 +1134,22 @@ class ggml_webgpu_shader_lib {
|
|||
variant += "_inplace";
|
||||
}
|
||||
|
||||
if (key.src_type == GGML_TYPE_F32) {
|
||||
defines.push_back("SRC_F32");
|
||||
variant += "_src_f32";
|
||||
} else if (key.src_type == GGML_TYPE_F16) {
|
||||
defines.push_back("SRC_F16");
|
||||
variant += "_src_f16";
|
||||
}
|
||||
|
||||
if (key.dst_type == GGML_TYPE_F32) {
|
||||
defines.push_back("DST_F32");
|
||||
variant += "_dst_f32";
|
||||
} else if (key.dst_type == GGML_TYPE_F16) {
|
||||
defines.push_back("DST_F16");
|
||||
variant += "_dst_f16";
|
||||
}
|
||||
|
||||
const uint32_t row_norm_wg_size = 128u;
|
||||
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
|
|
|
|||
|
|
@ -2927,6 +2927,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
|||
} else {
|
||||
return ggml_webgpu_row_norm(ctx, src0, node);
|
||||
}
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_webgpu_row_norm(ctx, src0, node);
|
||||
case GGML_OP_ROPE:
|
||||
|
|
@ -4071,6 +4072,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
break;
|
||||
}
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -1,20 +1,17 @@
|
|||
#ifdef INPLACE
|
||||
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||
src[dst_offset] = scale * src[src_offset];
|
||||
}
|
||||
#if defined(SRC_F16) || defined(DST_F16)
|
||||
enable f16;
|
||||
#endif
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
#ifdef SRC_F16
|
||||
#define SRC_TYPE f16
|
||||
#else
|
||||
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||
dst[dst_offset] = scale * src[src_offset];
|
||||
}
|
||||
#define SRC_TYPE f32
|
||||
#endif
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#ifdef DST_F16
|
||||
#define DST_TYPE f16
|
||||
#else
|
||||
#define DST_TYPE f32
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
|
|
@ -40,9 +37,20 @@ struct Params {
|
|||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
var<storage, read_write> src: array<SRC_TYPE>;
|
||||
|
||||
var<workgroup> scratch: array<f32, WG_SIZE>;
|
||||
#ifdef INPLACE
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
#else
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<DST_TYPE>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#endif
|
||||
|
||||
var<workgroup> scratch: array<f32, WG_SIZE * 2u>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
|
|
@ -65,34 +73,81 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
sum += pow(src[i_src_row + col], 2.0);
|
||||
let v = f32(src[i_src_row + col]);
|
||||
#ifdef NORM
|
||||
sum += v;
|
||||
#else
|
||||
sum += v * v;
|
||||
#endif
|
||||
col += WG_SIZE;
|
||||
}
|
||||
|
||||
scratch[lid.x] = sum;
|
||||
workgroupBarrier();
|
||||
var offset: u32 = WG_SIZE / 2;
|
||||
|
||||
var offset: u32 = WG_SIZE / 2u;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] += scratch[lid.x + offset];
|
||||
}
|
||||
offset = offset / 2;
|
||||
offset /= 2u;
|
||||
workgroupBarrier();
|
||||
}
|
||||
sum = scratch[0];
|
||||
|
||||
#ifdef RMS_NORM
|
||||
#ifdef NORM
|
||||
let mean = sum / f32(params.ne0);
|
||||
var sq_sum = 0.0f;
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
let v = f32(src[i_src_row + col]);
|
||||
let d = v - mean;
|
||||
sq_sum += d * d;
|
||||
col += WG_SIZE;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
scratch[lid.x] = sq_sum;
|
||||
workgroupBarrier();
|
||||
offset = WG_SIZE / 2u;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] += scratch[lid.x + offset];
|
||||
}
|
||||
offset /= 2u;
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let variance = scratch[0] / f32(params.ne0);
|
||||
let scale = 1.0 / sqrt(variance + params.eps);
|
||||
#elif defined(RMS_NORM)
|
||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
|
||||
#elif defined(L2_NORM)
|
||||
let scale = 1.0/max(sqrt(sum), params.eps);
|
||||
#endif
|
||||
|
||||
#ifdef NORM
|
||||
let mean_val = mean;
|
||||
#else
|
||||
let mean_val = 0.0f;
|
||||
#endif
|
||||
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
update(i_src_row + col, i_dst_row + col, scale);
|
||||
let i_src = i_src_row + col;
|
||||
let i_dst = i_dst_row + col;
|
||||
let v = src[i_src];
|
||||
#ifdef INPLACE
|
||||
src[i_dst] = scale * (v - mean_val);
|
||||
#else
|
||||
dst[i_dst] = scale * (v - mean_val);
|
||||
#endif
|
||||
col += WG_SIZE;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue