ggml-webgpu : extend GDN for K>1 (llama/23299)

This commit is contained in:
Reese Levine 2026-05-18 23:45:41 -07:00 committed by Georgi Gerganov
parent 28edd0cb36
commit 6090f39f36
2 changed files with 22 additions and 4 deletions

View File

@ -1234,6 +1234,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
const uint32_t h = (uint32_t) src2->ne[1];
const uint32_t n_tokens = (uint32_t) src2->ne[2];
const uint32_t n_seqs = (uint32_t) src2->ne[3];
const uint32_t K = (uint32_t) src5->ne[1];
const float scale = 1.0f / sqrtf((float) s_v);
uint32_t scale_u32;
memcpy(&scale_u32, &scale, sizeof(scale_u32));
@ -1258,6 +1259,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
(uint32_t) src0->ne[1],
(uint32_t) (src2->ne[3] / src0->ne[3]),
K,
scale_u32,
};

View File

@ -39,6 +39,7 @@ struct Params {
neq1: u32,
rq3: u32,
K: u32,
scale: f32,
};
@ -62,11 +63,14 @@ fn main(
let iq3 = seq_id / params.rq3;
let state_size = S_V * S_V;
let state_base = (seq_id * params.h + head_id) * state_size;
let state_in_base = (seq_id * params.K * params.h + head_id) * state_size;
let state_out_base = (seq_id * params.h + head_id) * state_size;
let state_size_per_snap = state_size * params.h * params.n_seqs;
let shift = i32(params.n_tokens) - i32(params.K);
var state: array<f32, S_V>;
for (var i = 0u; i < S_V; i++) {
state[i] = src_state[state_base + col * S_V + i];
state[i] = src_state[state_in_base + col * S_V + i];
}
var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V;
@ -123,10 +127,22 @@ fn main(
dst[attn_off + col] = attn_col * params.scale;
attn_off += S_V * params.h;
if (params.K > 1u) {
let target_slot = i32(t) - shift;
if (target_slot >= 0 && target_slot < i32(params.K)) {
let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base;
for (var i = 0u; i < S_V; i++) {
dst[slot_base + col * S_V + i] = state[i];
}
}
}
workgroupBarrier();
}
for (var i = 0u; i < S_V; i++) {
dst[params.s_off + state_base + col * S_V + i] = state[i];
if (params.K == 1u) {
for (var i = 0u; i < S_V; i++) {
dst[params.s_off + state_out_base + col * S_V + i] = state[i];
}
}
}