ggml-webgpu : extend GDN for K>1 (llama/23299)
This commit is contained in:
parent
28edd0cb36
commit
6090f39f36
|
|
@ -1234,6 +1234,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
|
|||
const uint32_t h = (uint32_t) src2->ne[1];
|
||||
const uint32_t n_tokens = (uint32_t) src2->ne[2];
|
||||
const uint32_t n_seqs = (uint32_t) src2->ne[3];
|
||||
const uint32_t K = (uint32_t) src5->ne[1];
|
||||
const float scale = 1.0f / sqrtf((float) s_v);
|
||||
uint32_t scale_u32;
|
||||
memcpy(&scale_u32, &scale, sizeof(scale_u32));
|
||||
|
|
@ -1258,6 +1259,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
|
|||
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) (src2->ne[3] / src0->ne[3]),
|
||||
K,
|
||||
scale_u32,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ struct Params {
|
|||
|
||||
neq1: u32,
|
||||
rq3: u32,
|
||||
K: u32,
|
||||
scale: f32,
|
||||
};
|
||||
|
||||
|
|
@ -62,11 +63,14 @@ fn main(
|
|||
let iq3 = seq_id / params.rq3;
|
||||
|
||||
let state_size = S_V * S_V;
|
||||
let state_base = (seq_id * params.h + head_id) * state_size;
|
||||
let state_in_base = (seq_id * params.K * params.h + head_id) * state_size;
|
||||
let state_out_base = (seq_id * params.h + head_id) * state_size;
|
||||
let state_size_per_snap = state_size * params.h * params.n_seqs;
|
||||
let shift = i32(params.n_tokens) - i32(params.K);
|
||||
|
||||
var state: array<f32, S_V>;
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
state[i] = src_state[state_base + col * S_V + i];
|
||||
state[i] = src_state[state_in_base + col * S_V + i];
|
||||
}
|
||||
|
||||
var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V;
|
||||
|
|
@ -123,10 +127,22 @@ fn main(
|
|||
dst[attn_off + col] = attn_col * params.scale;
|
||||
attn_off += S_V * params.h;
|
||||
|
||||
if (params.K > 1u) {
|
||||
let target_slot = i32(t) - shift;
|
||||
if (target_slot >= 0 && target_slot < i32(params.K)) {
|
||||
let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base;
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
dst[slot_base + col * S_V + i] = state[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
dst[params.s_off + state_base + col * S_V + i] = state[i];
|
||||
if (params.K == 1u) {
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
dst[params.s_off + state_out_base + col * S_V + i] = state[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue