From 1a1900f90c165a078d66b4958db46fe82d14ff27 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Wed, 10 Jun 2026 23:21:16 +0530 Subject: [PATCH] Remove padding and multiple D2D copies for MTP (llama/24086) * Make ggml_gated_delta_net take only the initial recurrent state (D, 1, n_seqs) and passes the snapshot count K as an op parameter instead of inferring it from state->ne[1]. Remove the padding hack and copy all emitted snapshots into the recurrent cache with a single strided ggml_cpy * Make GDN changes in all backends. Address review comments. * Fix CI build errors --- ggml/include/ggml.h | 17 +++++++---- ggml/src/ggml-backend-meta.cpp | 4 +-- ggml/src/ggml-cpu/ggml-cpu.c | 2 +- ggml/src/ggml-cpu/ops.cpp | 17 +++++------ ggml/src/ggml-cuda/gated_delta_net.cu | 16 +++++----- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 5 ++-- .../ggml-hexagon/htp/gated-delta-net-ops.c | 29 ++++++++++--------- ggml/src/ggml-metal/ggml-metal-device.cpp | 4 +-- ggml/src/ggml-metal/ggml-metal.metal | 11 ++++--- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- .../ggml-opencl/kernels/gated_delta_net.cl | 8 +++-- ggml/src/ggml-sycl/gated_delta_net.cpp | 15 +++++----- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 ++--- .../vulkan-shaders/gated_delta_net.comp | 11 ++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- .../wgsl-shaders/gated_delta_net.wgsl | 7 +++-- ggml/src/ggml.c | 16 ++++++---- 17 files changed, 93 insertions(+), 81 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 374934aac..d6807b6dd 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2553,10 +2553,16 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 // - // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): - // K == 1: output carries the final state only. - // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) - // per-token snapshots into the trailing slots + // tensor shapes (S_k == S_v, H_v % H_k == 0): + // q, k : [S_k, H_k, n_tokens, n_seqs] + // v : [S_v, H_v, n_tokens, n_seqs] + // g : [1, H_v, n_tokens, n_seqs] (scalar gate) or [S_v, H_v, n_tokens, n_seqs] (KDA) + // beta : [1, H_v, n_tokens, n_seqs] + // state : [S_v, S_v, H_v, n_seqs] -- initial recurrent state s0 + // + // the output packs the attention scores [S_v, H_v, n_tokens, n_seqs] followed by K state + // snapshots, most-recent first (slot 0 = final state, slot s = state s tokens back). K == 1 + // keeps only the final state; when n_tokens < K only slots 0..n_tokens-1 are written. GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, @@ -2564,7 +2570,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state); + struct ggml_tensor * state, + int64_t K); // custom operators diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 8c44c3e44..0a36f0990 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -776,8 +776,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, - // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + // state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2, + // so a head-aligned split on the input cache lands on axis 2 here. GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; }; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index af7827aec..eb8341c9a 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2948,7 +2948,7 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t K = ggml_get_op_params_i32(node, 0); const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); cur = per_thread * sizeof(float) * n_tasks; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 86842e554..74611dce7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10624,11 +10624,11 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int64_t K = src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int64_t K = ggml_get_op_params_i32(dst, 0); GGML_ASSERT(K >= 1); - // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) - const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + // per-seq stride in floats (seq s starts at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[3] / sizeof(float); const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; @@ -10644,9 +10644,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int64_t shift = n_tokens - K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. const float * state_in_base = (const float *)src_state->data; @@ -10674,7 +10673,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( : state_out_base + (iv3 * H + iv1) * S_v * S_v; // copy input state into the working buffer and operate in-place - // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + // state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride. const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); @@ -10727,7 +10726,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( attn_data += S_v * H; // advance to next token if (K > 1) { - const int64_t target_slot = t - shift; + const int64_t target_slot = n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state_o = state_out_base + target_slot * state_size_per_snap + (iv3 * H + iv1) * S_v * S_v; diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 7cfda6523..a547360eb 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -39,9 +39,9 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v. // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. - const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; state += state_out_offset; curr_state += state_in_offset + col * S_v; @@ -143,12 +143,10 @@ gated_delta_net_cuda(const float * q, attn_data += S_v * H; if constexpr (keep_rs_t) { - // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots - // are written; earlier slots are left untouched (caller-owned). - const int shift = (int) n_tokens - K; - + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output - const int target_slot = t - shift; + const int target_slot = (int) n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll @@ -286,8 +284,8 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = (int) src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int K = ggml_get_op_params_i32(dst, 0); const bool keep_rs = K > 1; if (kda) { diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index d550841a2..49bd7e433 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2538,7 +2538,7 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses const int64_t H = v->ne[1]; const int64_t n_tokens = v->ne[2]; const int64_t n_seqs = v->ne[3]; - const int64_t K = state->ne[1]; + const int64_t K = ggml_get_op_params_i32(op, 0); if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) { return false; @@ -2551,7 +2551,8 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { return false; } - if (ggml_nelements(state) != S_v * S_v * H * n_seqs * K) { + // state holds s0 only [S_v, S_v, H, n_seqs]; K is op param 0. + if (ggml_nelements(state) != S_v * S_v * H * n_seqs) { return false; } if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c index 3b092d744..35518e611 100644 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -584,7 +584,7 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; - const uint32_t K = state->ne[1]; + const uint32_t K = octx->op_params[0]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -618,9 +618,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); - const uint64_t state_seq_stride = state->nb[2] / sizeof(float); + const uint64_t state_seq_stride = state->nb[3] / sizeof(float); const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; - const int64_t shift = (int64_t) n_tokens - (int64_t) K; uint32_t ir_prefetch = ith; int spad_idx = 0; @@ -630,7 +629,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; - float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v; // Push dummy write-back dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), @@ -661,7 +661,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = fastdiv(iv3, &fd_rq3); const uint32_t ik3 = fastdiv(iv3, &fd_rk3); - float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; @@ -792,7 +793,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo } if (K > 1) { - const int64_t target_slot = (int64_t) t - shift; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + const int64_t target_slot = (int64_t) n_tokens - 1 - (int64_t) t; if (target_slot >= 0 && target_slot < (int64_t) K) { float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; if (curr_state_o != s_out) { @@ -844,7 +846,6 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t S_v = v->ne[0]; const uint32_t H = v->ne[1]; const uint32_t n_seqs = v->ne[3]; - const uint32_t K = state->ne[1]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -878,8 +879,7 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); - const uint64_t state_seq_stride = state->nb[2] / sizeof(float); - const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + const uint64_t state_seq_stride = state->nb[3] / sizeof(float); uint32_t ir_prefetch = ith; int spad_idx = 0; @@ -889,7 +889,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; - float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v; // Push dummy write-back dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), @@ -920,7 +921,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = fastdiv(iv3, &fd_rq3); const uint32_t ik3 = fastdiv(iv3, &fd_rk3); - float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; @@ -1097,7 +1099,7 @@ int op_gated_delta_net(struct htp_ops_context * octx) { const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; - const uint32_t K = state->ne[1]; + const uint32_t K = octx->op_params[0]; if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) { return HTP_STATUS_NO_SUPPORT; @@ -1110,7 +1112,8 @@ int op_gated_delta_net(struct htp_ops_context * octx) { (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { return HTP_STATUS_NO_SUPPORT; } - if (state->ne[0] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { + // state holds s0 only: [S_v, S_v, H, n_seqs] + if (state->ne[0] != S_v || state->ne[1] != S_v || state->ne[2] != H || state->ne[3] != n_seqs) { return HTP_STATUS_NO_SUPPORT; } if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index ce847dd8b..4f4f073cb 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,8 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G - // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = op->src[5]->ne[1]; + // state is src[5], 4D [S_v, S_v, H_v, n_seqs] (s0 only); K is op param 0. + const int K = ggml_get_op_params_i32(op, 0); const int nsg = op->src[2]->ne[0]/32; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2bd310d94..0aea68455 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2599,9 +2599,9 @@ kernel void kernel_gated_delta_net_impl( const float scale = 1.0f / sqrt((float)S_v); - // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + // input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2620,9 +2620,8 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int shift = (int)args.ne22 - (int)K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned. // output state base offset: after attention scores const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; @@ -2680,7 +2679,7 @@ kernel void kernel_gated_delta_net_impl( g_ptr += args.ne21*G; if (K > 1) { - const int target_slot = (int)t - shift; + const int target_slot = (int)args.ne22 - 1 - (int)t; if (target_slot >= 0 && target_slot < (int)K) { device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; FOR_UNROLL (short j = 0; j < NSG; j++) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 2a41215fd..d30579b94 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -17750,7 +17750,7 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, ggml_tensor * dst) { const cl_uint H_v = (cl_uint) src_v->ne[1]; const cl_uint n_tokens = (cl_uint) src_v->ne[2]; const cl_uint n_seqs = (cl_uint) src_v->ne[3]; - const cl_uint K = (cl_uint) src_state->ne[1]; + const cl_uint K = (cl_uint) ggml_get_op_params_i32(dst, 0); int si; switch (S_v) { diff --git a/ggml/src/ggml-opencl/kernels/gated_delta_net.cl b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl index d11192f58..319c98295 100644 --- a/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +++ b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl @@ -123,7 +123,8 @@ kernel void kernel_gated_delta_net( const uint iq3 = seq_id / rq3; // seq index for Q and K const uint state_size = S_V * S_V; - const uint state_base = (seq_id * K * H_v + head_id) * state_size; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + const uint state_base = (seq_id * H_v + head_id) * state_size; const uint q_off_base = iq3 * sq3 + iq1 * sq1; const uint v_off_base = seq_id * sv3 + head_id * sv1; const uint gb_off_base = seq_id * sb3 + head_id * sb1; @@ -143,7 +144,8 @@ kernel void kernel_gated_delta_net( } } - const int shift = (int)n_tokens - (int)K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. uint attn_off = (seq_id * n_tokens * H_v + head_id) * S_V; for (uint t = 0; t < n_tokens; t++) { @@ -219,7 +221,7 @@ kernel void kernel_gated_delta_net( attn_off += S_V * H_v; if (K > 1u) { - const int target_slot = (int)t - shift; + const int target_slot = (int)n_tokens - 1 - (int)t; if (target_slot >= 0 && target_slot < (int)K) { #pragma unroll for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index 9c2449aba..239e00bd7 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -44,9 +44,9 @@ void gated_delta_net_sycl(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // input state holds s0 only [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v. // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. - const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output state += state_out_offset; @@ -63,9 +63,8 @@ void gated_delta_net_sycl(const float * q, s_shard[r] = curr_state[i]; } - // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots - // are written; earlier slots are left untouched (caller-owned). - const int shift = (int) n_tokens - K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -144,7 +143,7 @@ void gated_delta_net_sycl(const float * q, // Write state back to global memory if constexpr (keep_rs_t) { - const int target_slot = t - shift; + const int target_slot = (int) n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll @@ -315,8 +314,8 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dpct::queue_ptr stream = ctx.stream(); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = (int) src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int K = ggml_get_op_params_i32(dst, 0); const bool keep_rs = K > 1; if (kda) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 22405f234..387826b6d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -11528,7 +11528,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; const ggml_tensor * src_beta = dst->src[4]; - const ggml_tensor * src_state = dst->src[5]; GGML_ASSERT(dst->buffer != nullptr); @@ -11537,8 +11536,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const uint32_t K = (uint32_t)src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const uint32_t K = (uint32_t)ggml_get_op_params_i32(dst, 0); const uint32_t s_off = S_v * H * n_tokens * n_seqs; @@ -17954,7 +17953,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * src_clone[4], src_clone[5], src_clone[6]); } else if (tensor->op == GGML_OP_GATED_DELTA_NET) { tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1], - src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + src_clone[2], src_clone[3], src_clone[4], src_clone[5], + ggml_get_op_params_i32(tensor, 0)); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = tensor->src[0]->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 33c3202db..0e384330b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -102,8 +102,8 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. - const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + const uint state_in_base = (seq_id * H + head_id) * state_size; // output state layout per slot: same per-(seq,head) offset as the single-slot case. const uint state_out_base = (seq_id * H + head_id) * state_size; const uint state_size_per_snap = state_size * H * n_seqs; @@ -113,9 +113,8 @@ void main() { s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int shift = int(n_tokens) - int(K); + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned. uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; @@ -172,7 +171,7 @@ void main() { attn_off += S_V * H; if (K > 1u) { - const int target_slot = int(t) - shift; + const int target_slot = int(n_tokens) - 1 - int(t); if (target_slot >= 0 && target_slot < int(K)) { const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 538e587bb..0b605fa86 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1245,7 +1245,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, const uint32_t h = (uint32_t) src2->ne[1]; const uint32_t n_tokens = (uint32_t) src2->ne[2]; const uint32_t n_seqs = (uint32_t) src2->ne[3]; - const uint32_t K = (uint32_t) src5->ne[1]; + const uint32_t K = (uint32_t) ggml_get_op_params_i32(dst, 0); const float scale = 1.0f / sqrtf((float) s_v); uint32_t scale_u32; memcpy(&scale_u32, &scale, sizeof(scale_u32)); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl index d68520f82..7d7b34755 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -63,10 +63,10 @@ fn main( let iq3 = seq_id / params.rq3; let state_size = S_V * S_V; - let state_in_base = (seq_id * params.K * params.h + head_id) * state_size; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + let state_in_base = (seq_id * params.h + head_id) * state_size; let state_out_base = (seq_id * params.h + head_id) * state_size; let state_size_per_snap = state_size * params.h * params.n_seqs; - let shift = i32(params.n_tokens) - i32(params.K); var state: array; for (var i = 0u; i < S_V; i++) { @@ -128,7 +128,8 @@ fn main( attn_off += S_V * params.h; if (params.K > 1u) { - let target_slot = i32(t) - shift; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + let target_slot = i32(params.n_tokens) - 1 - i32(t); if (target_slot >= 0 && target_slot < i32(params.K)) { let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base; for (var i = 0u; i < S_V; i++) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 18a5ebd2a..b43016c87 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6223,7 +6223,8 @@ struct ggml_tensor * ggml_gated_delta_net( struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state) { + struct ggml_tensor * state, + int64_t K) { GGML_ASSERT(ggml_is_contiguous_rows(q)); GGML_ASSERT(ggml_is_contiguous_rows(k)); GGML_ASSERT(ggml_is_contiguous_rows(v)); @@ -6247,15 +6248,18 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. - GGML_ASSERT(state->ne[0] == S_v * S_v * H); - GGML_ASSERT(state->ne[2] == n_seqs); - GGML_ASSERT(state->ne[3] == 1); - const int64_t K = state->ne[1]; + // state holds the initial state s0 only: [S_v, S_v, H, n_seqs]. K (snapshot slot count) is an op param. + GGML_ASSERT(state->ne[0] == S_v); + GGML_ASSERT(state->ne[1] == S_v); + GGML_ASSERT(state->ne[2] == H); + GGML_ASSERT(state->ne[3] == n_seqs); + GGML_ASSERT(K >= 1); const int64_t state_rows = K * S_v * n_seqs; const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + ggml_set_op_params_i32(result, 0, (int32_t) K); + result->op = GGML_OP_GATED_DELTA_NET; result->src[0] = q; result->src[1] = k;