diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3357a0d99..41566d41a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2541,6 +2541,11 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 + // + // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): + // K == 1: output carries the final state only. + // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) + // per-token snapshots into the trailing slots GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index c0ffd9a04..df0f405ed 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -753,7 +753,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2); + // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, + // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; }; @@ -2140,4 +2142,3 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; return backend_ctx->backend_configs[index].backend; } - diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7b05edf6b..cd5c61a81 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2943,7 +2943,9 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - cur = S_v * sizeof(float) * n_tasks; + const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); + cur = per_thread * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6bc8dc150..7485ba4fc 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10513,19 +10513,30 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // scratch layout per thread: [delta(S_v)] - const int64_t scratch_per_thread = S_v; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int64_t K = src_state->ne[1]; + GGML_ASSERT(K >= 1); + // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; - float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32; + float * state_work = K > 1 ? (delta + S_v) : nullptr; // output layout: [attn_scores | new_states] - // attn_scores: S_v * H * n_tokens * n_seqs floats - // new_states: S_v * S_v * H * n_seqs floats - const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K)) + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H * n_seqs; float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int64_t shift = n_tokens - K; + const float * state_in_base = (const float *)src_state->data; //const int64_t rq1 = nev1 / neq1; @@ -10545,10 +10556,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t iq3 = iv3 / rq3; const int64_t ik3 = iv3 / rk3; - float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; + // For K=1, write directly to the single output slot to avoid an extra memcpy at the end. + // For K>1, work in scratch and copy out per-token when the slot is in range. + float * s_out = (K > 1) + ? state_work + : state_out_base + (iv3 * H + iv1) * S_v * S_v; - // copy input state into output buffer and operate in-place - const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; + // copy input state into the working buffer and operate in-place + // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); // attn output pointer for first token of this (head, seq) @@ -10598,6 +10614,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } attn_data += S_v * H; // advance to next token + + if (K > 1) { + const int64_t target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state_o = state_out_base + target_slot * state_size_per_snap + + (iv3 * H + iv1) * S_v * S_v; + memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float)); + } + } } } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 6b44bec73..b4c9845e7 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,6 +1,6 @@ #include "gated_delta_net.cuh" -template +template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) gated_delta_net_cuda(const float * q, const float * k, @@ -23,7 +23,8 @@ gated_delta_net_cuda(const float * q, int64_t sb3, const uint3 neqk1_magic, const uint3 rq3_magic, - float scale) { + float scale, + int K) { const uint32_t h_idx = blockIdx.x; const uint32_t sequence = blockIdx.y; // each warp owns one column, using warp-level primitives to reduce across rows @@ -37,9 +38,13 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset + col * S_v; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; @@ -54,6 +59,10 @@ gated_delta_net_cuda(const float * q, s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -135,17 +144,30 @@ gated_delta_net_cuda(const float * q, } attn_data += S_v * H; + + if constexpr (keep_rs_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } } - // Write state back to global memory (transposed layout) + if constexpr (!keep_rs_t) { #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, @@ -155,7 +177,7 @@ static void launch_gated_delta_net( int64_t sv1, int64_t sv2, int64_t sv3, int64_t sb1, int64_t sb2, int64_t sb3, int64_t neqk1, int64_t rq3, - float scale, cudaStream_t stream) { + float scale, int K, cudaStream_t stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const int num_warps = 4; @@ -169,29 +191,29 @@ static void launch_gated_delta_net( switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA><<>>( + gated_delta_net_cuda<16, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA><<>>( + gated_delta_net_cuda<32, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA><<>>( + gated_delta_net_cuda<64, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA><<>>( + gated_delta_net_cuda<128, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } default: @@ -261,13 +283,29 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_rs = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index f0147af84..e288a27f9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G + // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = op->src[5]->ne[1]; const int nsg = op->src[2]->ne[0]/32; @@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( GGML_ASSERT(ne20 % 32 == 0); snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); - snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); + snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2d45de8cc..f6ffb2b3a 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32( constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; +constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]]; #if 1 template @@ -2548,21 +2549,24 @@ kernel void kernel_gated_delta_net_impl( uint3 ntg[[threads_per_threadgroup]]) { #define S_v FC_gated_delta_net_ne20 #define G FC_gated_delta_net_ne30 +#define K FC_gated_delta_net_K const uint tx = tpitg.x; const uint ty = tpitg.y; - const uint i23 = tgpig.z; // B - const uint i21 = tgpig.y; // H - const uint i20 = tgpig.x*NSG + ty; + const uint i23 = tgpig.z; // B (n_seqs) + const uint i21 = tgpig.y; // H (head) + const uint i20 = tgpig.x*NSG + ty; // row within S_v const uint i01 = i21 % args.ne01; const uint i11 = i21 % args.ne11; const float scale = 1.0f / sqrt((float)S_v); + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2580,6 +2584,17 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = (int)args.ne22 - (int)K; + + // output state base offset: after attention scores + const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; + // output state per-slot size: S_v * S_v * H * n_seqs + const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23; + // per-(seq,head) offset within a slot + const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + for (short t = 0; t < args.ne22; t++) { float s_k = 0.0f; @@ -2627,17 +2642,30 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; + + if (K > 1u) { + const int target_slot = (int)t - shift; + if (target_slot >= 0 && target_slot < (int)K) { + device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } + } + } } - device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; - - FOR_UNROLL (short j = 0; j < NSG; j++) { - const short is = tx*NSG + j; - dst_state[is] = ls[j]; + if (K == 1u) { + device float * dst_state = (device float *) (dst) + attn_size + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } } #undef S_v #undef G +#undef K } typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8c4cf9ef1..d29a4bab2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1506,6 +1506,7 @@ struct vk_op_gated_delta_net_push_constants { uint32_t sb1, sb2, sb3; uint32_t neq1, rq3; float scale; + uint32_t K; }; struct vk_op_ssm_scan_push_constants { @@ -10767,6 +10768,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; GGML_ASSERT(dst->buffer != nullptr); @@ -10775,6 +10777,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const uint32_t K = (uint32_t)src_state->ne[1]; + const uint32_t s_off = S_v * H * n_tokens * n_seqs; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); @@ -10808,7 +10813,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s sv1, sv2, sv3, sb1, sb2, sb3, neq1, rq3, - scale + scale, + K }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 5e9f8308c..33c3202db 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -31,6 +31,7 @@ layout(push_constant) uniform Parameters { uint sb1, sb2, sb3; uint neq1, rq3; float scale; + uint K; }; layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; }; @@ -101,13 +102,21 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - const uint state_base = (seq_id * H + head_id) * state_size; + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // output state layout per slot: same per-(seq,head) offset as the single-slot case. + const uint state_out_base = (seq_id * H + head_id) * state_size; + const uint state_size_per_snap = state_size * H * n_seqs; FLOAT_TYPE s_shard[ROWS_PER_LANE]; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); + s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = int(n_tokens) - int(K); + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; for (uint t = 0; t < n_tokens; t++) { @@ -161,9 +170,21 @@ void main() { } attn_off += S_V * H; + + if (K > 1u) { + const int target_slot = int(t) - shift; + if (target_slot >= 0 && target_slot < int(K)) { + const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } + } + } } - [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + if (K == 1u) { + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_out_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 191cf2fa1..476c30797 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6210,11 +6210,13 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); - - // concat output and new_state into a single tensor - // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs - const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. + GGML_ASSERT(state->ne[0] == S_v * S_v * H); + GGML_ASSERT(state->ne[2] == n_seqs); + GGML_ASSERT(state->ne[3] == 1); + const int64_t K = state->ne[1]; + const int64_t state_rows = K * S_v * n_seqs; + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_GATED_DELTA_NET;