diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index 20caaa99..549842a2 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -29,7 +29,7 @@ if (WHISPER_SDL2) llama-model-saver.cpp llama-model.cpp llama-quant.cpp - llama-sampling.cpp + llama-sampler.cpp llama-vocab.cpp unicode.cpp unicode-data.cpp diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index a54bc195..bd78f1e5 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -117,9 +117,11 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, - { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_STEP35, "step35" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, + { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -161,6 +163,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, { LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" }, + { LLM_KV_SWIGLU_CLAMP_EXP, "%s.swiglu_clamp_exp" }, + { LLM_KV_SWIGLU_CLAMP_SHEXP, "%s.swiglu_clamp_shexp" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, @@ -219,21 +223,21 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, - { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, - { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, + { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -246,6 +250,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_KDA_HEAD_DIM, "%s.kda.head_dim" }, + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, { LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" }, @@ -371,6 +377,15 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_CONV1D_Q, "blk.%d.ssm_conv1d_q" }, + { LLM_TENSOR_SSM_CONV1D_K, "blk.%d.ssm_conv1d_k" }, + { LLM_TENSOR_SSM_CONV1D_V, "blk.%d.ssm_conv1d_v" }, + { LLM_TENSOR_SSM_F_A, "blk.%d.ssm_f_a" }, + { LLM_TENSOR_SSM_F_B, "blk.%d.ssm_f_b" }, + { LLM_TENSOR_SSM_BETA, "blk.%d.ssm_beta" }, + { LLM_TENSOR_SSM_G_A, "blk.%d.ssm_g_a" }, + { LLM_TENSOR_SSM_G_B, "blk.%d.ssm_g_b" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, @@ -2267,6 +2282,35 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP_EXPS, LLM_TENSOR_FFN_EXP_PROBS_B, }; + case LLM_ARCH_STEP35: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; case LLM_ARCH_GPTJ: case LLM_ARCH_UNKNOWN: return { @@ -2289,6 +2333,54 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, }; + case LLM_ARCH_KIMI_LINEAR: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + // Dense FFN (layer 0 only) + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + // MoE FFN (layers 1+) + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + // Shared experts + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + // KDA (using SSM_ enum prefix, keeping GGUF names for backward compat) + LLM_TENSOR_SSM_CONV1D_Q, + LLM_TENSOR_SSM_CONV1D_K, + LLM_TENSOR_SSM_CONV1D_V, + LLM_TENSOR_SSM_F_A, + LLM_TENSOR_SSM_F_B, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_G_A, + LLM_TENSOR_SSM_G_B, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_NORM, + // MLA + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_KV_A_NORM, + }; default: GGML_ABORT("unknown architecture for tensor mapping"); } @@ -2392,6 +2484,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // Kimi KDA - Conv tensors are 4D [d_conv, 1, d_inner, 1], reshaped to 2D at runtime + {LLM_TENSOR_SSM_CONV1D_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_F_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_F_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_BETA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2573,6 +2674,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_KIMI_LINEAR: return true; default: return false; diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 270d28b1..e8263369 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -122,8 +122,10 @@ enum llm_arch { LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, LLM_ARCH_MIMO2, + LLM_ARCH_STEP35, LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, + LLM_ARCH_KIMI_LINEAR, LLM_ARCH_UNKNOWN, }; @@ -165,6 +167,8 @@ enum llm_kv { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, + LLM_KV_SWIGLU_CLAMP_EXP, + LLM_KV_SWIGLU_CLAMP_SHEXP, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, @@ -250,6 +254,8 @@ enum llm_kv { LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, + LLM_KV_KDA_HEAD_DIM, + LLM_KV_WKV_HEAD_SIZE, LLM_KV_TOKENIZER_MODEL, @@ -398,6 +404,15 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next + // Kimi Linear KDA (using SSM_ prefix for consistency) + LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight + LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight + LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight + LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A + LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B + LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient + LLM_TENSOR_SSM_G_A, // kimi: output gate projection A + LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 3c7e0afd..c415a998 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -233,7 +233,7 @@ int32_t llm_chat_apply_template( llm_chat_template tmpl, const std::vector & chat, std::string & dest, bool add_ass) { - // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + // Taken from the research: https://github.com/ggml-org/llama.cpp/issues/5527 std::stringstream ss; if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { // chatml template diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 10b306a8..a6df893a 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -317,6 +317,7 @@ llama_context::llama_context( auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { // ignore CPU backend + // TODO: should we ignore ACCEL types too? continue; } auto * dev = ggml_backend_get_device(backend.get()); @@ -1026,11 +1027,7 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { llama_sampler_chain_n(sampler) > 0; if (sampler && can_offload) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); - auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); - if (host_buft) { - buft = host_buft; - } + auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); sampler->iface->backend_init(sampler, buft); @@ -2016,7 +2013,7 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } uint32_t res = std::max(1024u, 8u*model.n_tensors()); diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp index 64ea2fd0..2d55070c 100644 --- a/examples/talk-llama/llama-grammar.cpp +++ b/examples/talk-llama/llama-grammar.cpp @@ -2,7 +2,7 @@ #include "llama-impl.h" #include "llama-vocab.h" -#include "llama-sampling.h" +#include "llama-sampler.h" #include #include diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 16d42c4a..bba747d3 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -13,6 +13,8 @@ #include #include #include +#include +#include #include void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { @@ -533,6 +535,50 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { return res; } +// TODO: Hybrid input classes are a bit redundant. +// Instead of creating a hybrid input, the graph can simply create 2 separate inputs. +// Refactoring is required in the future. +void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) { + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { const auto * attn_ctx = mctx->get_attn(); @@ -970,6 +1016,26 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate && type_gate == LLM_FFN_PAR) { + // Step35: HF clamps gate (after SiLU) and up before multiplication + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_shexp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_silu_clamped", il); + + tmp = ggml_clamp(ctx0, tmp, -limit, limit); + cb(tmp, "ffn_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, tmp); + cb(cur, "ffn_swiglu_limited", il); + type_gate = LLM_FFN_SEQ; + break; + } + } + cur = ggml_swiglu_split(ctx0, cur, tmp); cb(cur, "ffn_swiglu", il); type_gate = LLM_FFN_SEQ; @@ -1272,6 +1338,25 @@ ggml_tensor * llm_graph_context::build_moe_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate_exps) { + // Step35: per-layer clamp for routed experts + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_exp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_moe_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_moe_silu_clamped", il); + + up = ggml_clamp(ctx0, up, -limit, limit); + cb(up, "ffn_moe_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, up); + cb(cur, "ffn_moe_swiglu_limited", il); + break; + } + } + cur = ggml_swiglu_split(ctx0, cur, up); cb(cur, "ffn_moe_swiglu", il); } else { @@ -2268,6 +2353,17 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); + auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp)); +} + llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const { const auto * mctx_cur = static_cast(mctx); @@ -2419,6 +2515,9 @@ void llm_graph_context::build_sampling() const { return; } + std::array outs; + outs[0] = res->t_logits; + auto inp_sampling = std::make_unique(samplers); res->add_input(std::move(inp_sampling)); @@ -2439,14 +2538,14 @@ void llm_graph_context::build_sampling() const { // add a dummy row of logits // this trick makes the graph static, regardless of which samplers are activated // this is important in order to minimize graph reallocations - // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550) ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); for (const auto & [seq_id, sampler] : samplers) { const auto it = seq_to_logit_row.find(seq_id); // inactive samplers always work on the first row - const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0; + const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0; + const int i_out = it != seq_to_logit_row.end() ? 1 : 0; ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); ggml_format_name(logits_seq, "logits_seq_%d", seq_id); @@ -2463,22 +2562,26 @@ void llm_graph_context::build_sampling() const { if (data.sampled != nullptr) { res->t_sampled[seq_id] = data.sampled; - ggml_build_forward_expand(gf, data.sampled); + outs[1] = data.sampled; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.probs != nullptr) { res->t_sampled_probs[seq_id] = data.probs; - ggml_build_forward_expand(gf, data.probs); + outs[1] = data.probs; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.logits != nullptr) { res->t_sampled_logits[seq_id] = data.logits; - ggml_build_forward_expand(gf, data.logits); + outs[1] = data.logits; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.candidates != nullptr) { res->t_candidates[seq_id] = data.candidates; - ggml_build_forward_expand(gf, data.candidates); + outs[1] = data.candidates; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } } diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 4090d811..1d69ff1a 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -433,6 +433,34 @@ public: const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_mem_hybrid_k : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_k( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_context * mctx; +}; + class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i { public: llm_graph_input_mem_hybrid_iswa( @@ -960,6 +988,7 @@ struct llm_graph_context { // llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; + llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const; llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const; diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 392f9160..756dda1a 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -139,6 +139,13 @@ uint32_t llama_hparams::n_embd_r() const { return n_embd * (n_shortconv_l_cache - 1); } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim + const uint32_t d_inner = n_head() * n_embd_head_kda; // 32 * 128 = 4096 + return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner; + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size @@ -151,6 +158,13 @@ uint32_t llama_hparams::n_embd_s() const { return n_embd * wkv_head_size; } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Full recurrent state: head_dim * head_dim * n_head + // h tensor shape for delta attention: [head_dim, head_dim, n_head] + return n_embd_head_kda * n_embd_head_kda * n_head(); // 128 * 128 * 32 = 524288 + } + // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index caed0ec1..6c695bdb 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -137,6 +137,9 @@ struct llama_hparams { uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; + // for Kimi Linear KDA + uint32_t n_embd_head_kda = 0; + // for hybrid state space models std::array recurrent_layer_arr; @@ -195,7 +198,7 @@ struct llama_hparams { uint32_t n_deepstack_layers = 0; // needed by encoder-decoder models (e.g. T5, FLAN-T5) - // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + // ref: https://github.com/ggml-org/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; uint32_t dec_n_layer = 0; @@ -203,6 +206,11 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + + // Step35: optional per-layer clamps for (Swi)GLU + std::array swiglu_clamp_exp; // clamping for expert FFN + std::array swiglu_clamp_shexp; // shared expert + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) // dense_first means whether the pattern is start with a dense layer // note that if n_pattern == 0, all layers are SWA diff --git a/examples/talk-llama/llama-kv-cache-iswa.cpp b/examples/talk-llama/llama-kv-cache-iswa.cpp index 3a34102a..26e2cb42 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-iswa.cpp @@ -218,7 +218,9 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, } bool llama_kv_cache_iswa::get_can_shift() const { - return kv_base->get_size() == kv_swa->get_size(); + return kv_base->get_can_shift() && + kv_swa->get_can_shift() && + kv_base->get_size() == kv_swa->get_size(); } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index f3c9b49f..cb702b2a 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -974,6 +974,10 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & } bool llama_kv_cache::get_can_shift() const { + // Step35 uses per-layer RoPE dims; K-shift assumes a single global n_rot. + if (model.arch == LLM_ARCH_STEP35) { + return false; + } return true; } @@ -1772,8 +1776,6 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t io.write(&v_trans, sizeof(v_trans)); io.write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (const auto & layer : layers) { @@ -1791,7 +1793,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); io.write(&k_size_row, sizeof(k_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Read each range of cells of k_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * k_size_row; @@ -1818,7 +1820,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); io.write(&v_size_row, sizeof(v_size_row)); - // Read each range of cells of v_size length each into tmp_buf and write out + // Read each range of cells of v_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * v_size_row; @@ -1852,7 +1854,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Read each range of cells of v_size_el length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * kv_size) * v_size_el; diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 812bf253..f0038036 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -785,23 +785,21 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: io.write(&s_trans, sizeof(s_trans)); io.write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell + // Iterate and write all the R tensors first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (r_l[il] == nullptr) continue; - // Write key type + // Write R tensor type const int32_t r_type_i = (int32_t)r_l[il]->type; io.write(&r_type_i, sizeof(r_type_i)); - // Write row size of key + // Write row size of R tensor const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Write each range of cells of r_size_row length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -814,15 +812,15 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (s_l[il] == nullptr) continue; - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); - // Write row size of value + // Write row size of S tensor const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Read each range of cells of s_size length each into tmp_buf and write out + // Write each range of S tensor rows for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -830,7 +828,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: } } } else { - // When v is transposed, we also need the element size and get the element ranges from each row + // When S tensor is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) @@ -838,7 +836,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_s = hparams.n_embd_s(); - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); @@ -851,7 +849,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 72490a89..674d06c8 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -125,10 +125,12 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; + case LLM_TYPE_48B_A3B: return "48B.A3B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_196B_A11B: return "196B.A11B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; @@ -559,6 +561,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f); + std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -2450,6 +2454,66 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_KIMI_LINEAR: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + + // MLA qk_rope_head_dim (for reference) + // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 + + // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) + // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + } + + // MoE parameters - Kimi uses moe_intermediate_size = 1024 + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_STEP35: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + // MoE + SWA parameters + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Step35 uses sigmoid gating by default (if not set in GGUF) + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); + + switch (hparams.n_layer) { + case 45: type = LLM_TYPE_196B_A11B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -6752,6 +6816,141 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); } } break; + case LLM_ARCH_KIMI_LINEAR: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Check for KDA specific tensors to determine layer type or if it's a mixed model + // Assuming KDA layer if KDA tensors are present + + // KDA uses head_dim = 128 (from linear_attn_config.head_dim) + const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; + const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; + const int64_t ssm_d_conv = hparams.ssm_d_conv; + + // Try loading KDA specific tensors (using SSM_ prefix) + // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) + // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_q_conv) { + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED); + } + + if (layer.ssm_q_conv) { + // KDA Layer - Conv1d weights may be 3D or 4D + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_k_conv) { + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_v_conv) { + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); + } + + // q, k, v projections + // Python: q_proj, k_proj, v_proj + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_v_kda * n_head}, 0); + + // KDA specific projections + // f_a_proj, f_b_proj + layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim + layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size + + // b_proj (beta mixing coefficient) + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); + + // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_a) { + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + } + + // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); + + // g_a_proj, g_b_proj (output gate) + layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); + layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); + + // o_norm (reusing SSM_NORM) + layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated + + // o_proj + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); + + } else { + // MLA Layer - use MLA-specific head dimensions + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (layer.attn_q_a_norm) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) + // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 + const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); + // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED); + if (!layer.wkv_b) { // MLA KV cache enabled + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // MoE intermediate size (different from dense FFN) + const int64_t n_ff_exp = hparams.n_ff_exp; + + // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE + // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE + if (i < (int) hparams.n_layer_dense_lead) { + // Dense FFN layer - use normal n_ff + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared experts use moe_intermediate_size * num_shared_experts + // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 + // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] + const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } + } + } break; case LLM_ARCH_COGVLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6940,6 +7139,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_STEP35: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor + // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. + uint32_t n_rot_max = 0; + for (int i = 0; i < n_layer; ++i) { + n_rot_max = std::max(n_rot_max, hparams.n_rot); + } + if (n_rot_max == 0) { + n_rot_max = n_rot; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + // optional rope factors (llama3) / longrope tensors + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_l}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + + // head-wise attention gate (Step35 self_attn.g_proj) + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // dense MLP (leading dense blocks) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + // shared expert MLP + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + } + } break; case LLM_ARCH_MAINCODER: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -8086,6 +8351,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_KIMI_LINEAR: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_STEP35: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -8235,6 +8508,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: + case LLM_ARCH_KIMI_LINEAR: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -8330,6 +8604,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_AFMOE: case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_MIMO2: + case LLM_ARCH_STEP35: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index d1de16e3..7b580043 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -118,10 +118,12 @@ enum llm_type { LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, + LLM_TYPE_48B_A3B, // Kimi Linear LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_196B_A11B, // Step3.5-Flash LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big @@ -411,6 +413,18 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; + // Kimi Linear KDA (using ssm_ prefix for consistency) + // Note: ssm_dt_b already exists above (mamba bias), reused for Kimi dt_bias + struct ggml_tensor * ssm_q_conv = nullptr; + struct ggml_tensor * ssm_k_conv = nullptr; + struct ggml_tensor * ssm_v_conv = nullptr; + struct ggml_tensor * ssm_f_a = nullptr; + struct ggml_tensor * ssm_f_b = nullptr; + struct ggml_tensor * ssm_beta = nullptr; + struct ggml_tensor * ssm_g_a = nullptr; + struct ggml_tensor * ssm_g_b = nullptr; + struct ggml_tensor * ssm_o_norm = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 776222cb..a7891647 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -787,9 +787,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); - // do not quantize Mamba's small yet 2D weights + // do not quantize Mamba /Kimi's small conv1d weights // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ssm_conv1d.weight") == std::string::npos; + quantize &= name.find("ssm_conv1d") == std::string::npos; quantize &= name.find("shortconv.conv.weight") == std::string::npos; // do not quantize RWKV's small yet 2D weights diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampler.cpp similarity index 98% rename from examples/talk-llama/llama-sampling.cpp rename to examples/talk-llama/llama-sampler.cpp index 5dde5130..9bbc5dbd 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampler.cpp @@ -1,4 +1,4 @@ -#include "llama-sampling.h" +#include "llama-sampler.h" #include "llama-impl.h" #include "llama-vocab.h" @@ -1025,11 +1025,7 @@ struct llama_sampler_dist : public llama_sampler_backend { std::mt19937 rng; - // backend input - struct ggml_tensor * inp_uniform; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; + ggml_tensor * inp_uniform; }; static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { @@ -1138,37 +1134,10 @@ static bool llama_sampler_dist_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_dist *) smpl->ctx; - // allocate inputs - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - // Create the uniform random scalar input tensor. This will be set by - // llama_sampler_dist_backend_set_input after this graph is built. - sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); - ggml_set_name (sctx->inp_uniform, "uniform"); - ggml_set_input(sctx->inp_uniform); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - } - const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); - if (!res) { - sctx->inp_ctx.reset(nullptr); - sctx->inp_buf.reset(nullptr); - } - return res; } @@ -1178,8 +1147,13 @@ static void llama_sampler_dist_backend_apply( struct ggml_cgraph * gf, struct llama_sampler_data * data) { GGML_UNUSED(gf); + auto * sctx = (llama_sampler_dist *) smpl->ctx; + sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); ggml_set_name(probs, "dist_probs"); @@ -1226,6 +1200,7 @@ static void llama_sampler_dist_backend_apply( static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); // We sample in double precision and cast to float to match rnd numbers of @@ -1262,8 +1237,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { /* .seed_cur = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), /* .inp_uniform = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } @@ -3461,9 +3434,6 @@ struct llama_sampler_logit_bias : public llama_sampler_backend { struct ggml_tensor * inp_logit_bias; struct ggml_tensor * inp_logit_idxs; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; }; static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { @@ -3526,6 +3496,16 @@ static void llama_sampler_logit_bias_backend_apply( return; } + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); @@ -3562,6 +3542,8 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm static bool llama_sampler_logit_bias_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; sctx->init(true); @@ -3570,29 +3552,6 @@ static bool llama_sampler_logit_bias_backend_init( return true; } - ggml_init_params params = { - /*.mem_size =*/ 2*ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - const size_t n = sctx->logit_bias.size(); - - sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n); - ggml_set_name(sctx->inp_logit_bias, "logit_bias"); - ggml_set_input(sctx->inp_logit_bias); - - sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n); - ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); - ggml_set_input(sctx->inp_logit_idxs); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - return true; } @@ -3628,8 +3587,6 @@ struct llama_sampler * llama_sampler_init_logit_bias( /* .to_search = */ {}, /* .inp_logit_bias = */ nullptr, /* .inp_logit_idxs = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } diff --git a/examples/talk-llama/llama-sampling.h b/examples/talk-llama/llama-sampler.h similarity index 92% rename from examples/talk-llama/llama-sampling.h rename to examples/talk-llama/llama-sampler.h index 6a963c0b..b9bfc20d 100644 --- a/examples/talk-llama/llama-sampling.h +++ b/examples/talk-llama/llama-sampler.h @@ -1,7 +1,5 @@ #pragma once -// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? - #include "llama.h" #include diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index a23950d0..6d6bdfa0 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -90,7 +90,7 @@ static_assert(std::is_trivially_copyable::value, "llm_symbol is not // // SPM tokenizer // original implementation: -// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 +// https://github.com/ggml-org/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 // struct llm_bigram_spm { @@ -285,7 +285,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { // original regex from tokenizer.json //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", - // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 + // adapted: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2080233989 "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; @@ -1752,26 +1752,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + // Kimi-K2 uses custom tokenization without traditional BPE merges + const bool is_kimi_k2 = (tokenizer_pre == "kimi-k2"); + if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } - - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - - std::string first; - std::string second; - - const size_t pos = word.find(' ', 1); - - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); + if (!is_kimi_k2) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); } + // Kimi-K2 doesn't need merges, skip + LLAMA_LOG_INFO("%s: Kimi-K2 tokenizer detected, skipping BPE merges\n", __func__); + } else { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - bpe_ranks.emplace(std::make_pair(first, second), i); + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_ranks.emplace(std::make_pair(first, second), i); + } } // default special tokens @@ -2226,6 +2233,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end_of_text|>" // granite || t.first == "" || t.first == "_" + || t.first == "[EOT]" // Kimi-K2 || t.first == "<|end▁of▁sentence|>" // DeepSeek || t.first == "" // smoldocling ) { @@ -2262,6 +2270,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
                         || t.first == "<|code_prefix|>" // GLM-4.5
+                        || t.first == "<|prefix|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_pre_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2282,6 +2291,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
                         || t.first == "<|code_suffix|>" // GLM-4.5
+                        || t.first == "<|suffix|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_suf_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2302,6 +2312,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
                         || t.first == "<|code_middle|>" // GLM-4.5
+                        || t.first == "<|middle|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_mid_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2319,6 +2330,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == ""   // Granite
                         || t.first == ""
+                        || t.first == "[PAD]" // Kimi-K2
                         ) {
                     special_fim_pad_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2390,7 +2402,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
 
         // maintain a list of tokens that cause end-of-generation
         // this is currently determined based on the token text, which is obviously not ideal
-        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        // ref: https://github.com/ggml-org/llama.cpp/issues/9606
         special_eog_ids.clear();
 
         if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
@@ -2421,6 +2433,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|eom_id|>"
                     || t.first == ""
                     || t.first == "_"
+                    || t.first == "[EOT]" // Kimi-K2
+                    || t.first == "[EOS]" // Kimi-K2
                     || t.first == "<|end_of_text|>"
                     || t.first == "" // smoldocling
                ) {
@@ -3079,7 +3093,7 @@ std::vector llama_vocab::impl::tokenize(
 }
 
 int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
-    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
+    // ref: https://github.com/ggml-org/llama.cpp/pull/7587#discussion_r1620983843
     static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
     const llama_token_attr attr = token_get_attr(token);
     if (!special && (attr & attr_special)) {
diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp
index 297dca51..987f4499 100644
--- a/examples/talk-llama/models/deepseek2.cpp
+++ b/examples/talk-llama/models/deepseek2.cpp
@@ -14,7 +14,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
     // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
-    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
+    // See https://github.com/ggml-org/llama.cpp/discussions/7416 for detailed explanation.
     // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
 
     // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor
diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp
new file mode 100644
index 00000000..0f037d1a
--- /dev/null
+++ b/examples/talk-llama/models/kimi-linear.cpp
@@ -0,0 +1,772 @@
+#include "models.h"
+#include "ggml.h"
+
+#define CHUNK_SIZE 64
+
+// Causal Conv1d function for Q,K,V
+// When qkv is 0, it is Q, 1 is K, 2 is V
+static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) {
+    const int64_t d_inner = head_dim * n_head;
+    const int64_t conv_state_size = (d_conv - 1) * d_inner;
+    const int64_t n_embd_r_total = 3 * conv_state_size;  // Q + K + V
+
+    // conv_state_all is [n_embd_r_total, n_seqs], split into Q, K, V
+    // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs]
+    // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V
+    // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size
+    // View Q conv state: offset 0, size conv_state_size per seq
+    // conv_state_all is [n_embd_r_total, n_seqs] with memory layout:
+    //   state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V
+    // We want [d_conv-1, d_inner, n_seqs] view:
+    //   nb1 = (d_conv-1) * element_size (stride between channels)
+    //   nb2 = n_embd_r_total * element_size (stride between seqs)
+    ggml_tensor * conv_state_x = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs,
+        (d_conv - 1) * ggml_element_size(conv_state_all),  // nb1: stride between channels
+        n_embd_r_total * ggml_element_size(conv_state_all),  // nb2: stride between seqs
+        qkv * conv_state_size * ggml_element_size(conv_state_all));
+
+// Causal Conv1d function for Q,K,V
+// When qkv is 0, it is Q, 1 is K, 2 is V
+    // Step 1: Q, K, V projections -> [d_inner, n_tokens]
+    ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x);
+
+    // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs);
+
+    // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs}
+    ggml_tensor * conv_x = ggml_concat(ctx0, conv_state_x, ggml_transpose(ctx0, x_3d), 0);
+
+    // Save last (d_conv-1) columns back to Q conv state
+    ggml_tensor * last_conv_x = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
+        conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]);
+    ggml_build_forward_expand(gf,
+        ggml_cpy(ctx0, last_conv_x,
+            ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs,
+                (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all))));
+    // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner]
+    // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv]
+    // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step]
+    // ggml_ssm_conv computes: c[conv_step + channel * d_conv]
+    // GGUF layout: [d_conv, 1, d_inner] or [d_conv, 1, d_inner, 1] -> reshape to [d_conv, d_inner]
+    // Reshape conv weight from [d_conv, 1, d_inner, 1] to [d_conv, d_inner] for ggml_ssm_conv
+    ggml_tensor * conv_weight = ggml_reshape_2d(ctx0, conv_w, d_conv, d_inner);
+
+    // Apply conv1d
+    // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * Xcur = ggml_ssm_conv(ctx0, conv_x, conv_weight);
+    // Reshape to 2D for bias add: {d_inner, n_tokens}
+    Xcur = ggml_reshape_2d(ctx0, Xcur, d_inner, n_tokens);
+    Xcur = ggml_silu(ctx0, Xcur);
+
+    return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs);
+}
+
+llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context_mamba(params), model(model) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "model.embed_tokens", -1);
+
+    // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
+    // So we don't need inp_pos
+
+    auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr;
+    auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr;
+    auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr();
+    auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr;
+    auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr;
+
+    // Output ids for selecting which tokens to output
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    ggml_tensor * chunked_causal_mask =
+        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
+                    GGML_TRI_TYPE_LOWER);
+
+    ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
+    ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity);
+
+    ggml_build_forward_expand(gf, chunked_causal_mask);
+    ggml_build_forward_expand(gf, chunked_identity);
+    ggml_build_forward_expand(gf, chunked_diag_mask);
+
+    // Kimi dimension constants
+    const int64_t n_head = hparams.n_head();
+    const int64_t head_dim = hparams.n_embd_head_kda;
+    const int64_t d_conv = hparams.ssm_d_conv;
+    const int64_t d_inner = n_head * head_dim;  // 32 * 128 = 4096
+    const int64_t n_seqs = ubatch.n_seqs;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    // Verify batch consistency for recurrent layers
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    // MLA params
+    const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
+    const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
+    const int64_t kv_lora_rank = hparams.n_lora_kv;
+    // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot
+    // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim]
+    const int64_t n_embd_head_qk_rope = hparams.n_rot;  // config.qk_rope_head_dim
+    const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;  // 192 - 64 = 128
+    // Attention scale for MLA
+    const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla);
+
+    for (int il = 0; il < n_layer; ++il) {
+        const auto & layer = model.layers[il];
+        ggml_tensor * inpSA = inpL;
+
+        // Attention Norm
+        cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        // Check layer type by checking which tensors exist
+        // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor
+        bool is_kda = (layer.ssm_a != nullptr);
+        bool is_mla = (layer.wkv_a_mqa != nullptr);
+
+        if (is_kda) {
+            // === KDA Layer (Kimi Delta Attention) with Recurrent State ===
+            // Reference: vLLM kda.py
+            const auto * mctx_cur = inp_rs->mctx;
+            const auto kv_head = mctx_cur->get_head();
+
+            // Get conv states from r_l tensor (Q, K, V each have separate state)
+            ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+            cb(conv_states_all, "conv_states_all", il);
+            ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs);
+            ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+            ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+            ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+
+            // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias)
+            ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur);
+            ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a);
+            cb(g1, "g1 f_b(f_a(cur))", il);
+            g1 = ggml_add(ctx0, g1, layer.ssm_dt_b);
+            g1 = ggml_softplus(ctx0, g1);
+            g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens);
+
+            // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens]. No need to -exp(a_log) because it was done in convert_hf_to_gguf.py
+            // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens]
+            ggml_tensor * A = ggml_reshape_3d(ctx0, layer.ssm_a, 1, n_head, 1);
+            g1 = ggml_mul(ctx0, g1, A);
+            cb(g1, "kda_g1", il);
+
+            // Compute beta (mixing coefficient)
+            ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur);
+            beta = ggml_reshape_4d(ctx0, beta, n_head, 1, n_seq_tokens, n_seqs);
+            cb(beta, "kda_beta", il);
+
+            // Reshape for KDA recurrence
+            // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs}
+            cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+            g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs);
+
+            // Get SSM state and compute KDA recurrence using ggml_kda_scan
+            ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
+            ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
+            state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
+            // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
+            std::pair attn_out = n_seq_tokens == 1 ?
+                build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
+                build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il);
+
+            ggml_tensor * output = attn_out.first;
+            ggml_tensor * new_state = attn_out.second;
+            cb(output, "attn_output", il);
+            cb(new_state, "new_state", il);
+
+            // Update the recurrent states
+            ggml_build_forward_expand(gf,
+                                     ggml_cpy(ctx0, new_state,
+                                              ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                                                           kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+            // Output gating g2 = g_b(g_a(x))
+            ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+            ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d);
+            ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a);
+            cb(g2, "g2 g_b(g_a(cur_2d))", il);
+            g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs);
+
+            // Apply o_norm with sigmoid gating
+            // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish)
+            // Formula: output = RMSNorm(x) * sigmoid(g)
+            ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head,  n_seq_tokens * n_seqs);
+            ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, nullptr, LLM_NORM_RMS, il);
+            cb(normed, "kda_normed", il);
+            ggml_tensor * gate = ggml_sigmoid(ctx0, g2);
+            ggml_tensor * gated = ggml_mul(ctx0, normed, gate);
+
+            // Output projection
+            gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens);
+            cur = ggml_mul_mat(ctx0, layer.wo, gated);
+            cb(cur, "kda_out", il);
+
+        } else if (is_mla) {
+            // === MLA Layer (Multi-head Latent Attention) without KV Cache ===
+            // Reference: vLLM mla.py
+            // Step 1: Q projection and reshape
+            // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim]
+            // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
+            ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur);
+
+            // Step 2: KV compression
+            // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens]
+            ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur);
+
+            // Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:]
+            ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens,
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
+            ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
+            // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM)
+            // k_pe is used directly without RoPE
+            // Normalize kv_c
+            kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
+
+            if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled
+                // extract q_nope
+                ggml_tensor * q_nope =
+                    ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
+                                 ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_embd_head_qk_rope, n_head, n_tokens}
+                ggml_tensor * q_pe = ggml_view_3d(
+                    ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
+                    ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd_head_qk_nope, n_tokens, n_head}
+                q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
+                cb(q_nope, "q_nope_perm", il);
+
+                // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
+                ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope);
+                cb(q_nope_absorbed, "q_nope_absorbed", il);
+
+                // {kv_lora_rank, n_head, n_tokens}
+                q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
+                cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
+
+                // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
+                // note: rope must go first for in-place context shifting in build_rope_shift()
+                Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
+                cb(Qcur, "Qcur", il);
+
+                kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
+                cb(kv_cmpr, "kv_cmpr_reshape", il);
+
+                // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
+                ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
+                cb(Kcur, "Kcur", il);
+
+                // {kv_lora_rank, 1, n_tokens}
+                ggml_tensor * Vcur = kv_cmpr;
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
+                cb(cur, "mla_out", il);
+            } else { // MLA KV cache disabled. Fall back to MHA KV cache.
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
+                cb(Qcur, "mla_Q", il);
+                // KV decompression: kv = kv_b_proj(kv_c_normed)
+                ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
+                const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
+
+                // Split kv into k_nope and v
+                ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                    ggml_row_size(kv->type, kv_per_head),
+                    ggml_row_size(kv->type, kv_per_head * n_head), 0);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
+                    ggml_row_size(kv->type, kv_per_head),
+                    ggml_row_size(kv->type, kv_per_head * n_head),
+                    ggml_row_size(kv->type, n_embd_head_qk_nope));
+                Vcur = ggml_cont(ctx0, Vcur);
+                cb(Vcur, "mla_V", il);
+
+                // Concatenate k_nope + k_pe (broadcast k_pe to all heads)
+                // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
+                // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
+                // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
+                ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
+                ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
+                ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0);
+                cb(Kcur, "mla_K", il);
+
+                // Direct softmax attention (with MHA KV cache)
+                // Use build_attn with inp_attn for proper mask handling
+                cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
+                cb(cur, "mla_out", il);
+            }
+        } else {
+            // Unknown layer type - this should not happen
+            GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors");
+        }
+
+        // On last layer, select only the output tokens
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur,   inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        // Residual
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // FFN Norm
+        cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        if ((uint32_t) il < hparams.n_layer_dense_lead) {
+            // Dense FFN layer
+            cur = build_ffn(cur,
+                layer.ffn_up, NULL, NULL,
+                layer.ffn_gate, NULL, NULL,
+                layer.ffn_down, NULL, NULL,
+                NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE layer
+            // Kimi uses moe_renormalize=True and routed_scaling_factor (stored as expert_weights_scale) = 2.446
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                layer.ffn_gate_inp,
+                layer.ffn_up_exps,
+                layer.ffn_gate_exps,
+                layer.ffn_down_exps,
+                layer.ffn_exp_probs_b,
+                hparams.n_expert,
+                hparams.n_expert_used,
+                LLM_FFN_SILU, true,
+                true, hparams.expert_weights_scale,
+                (llama_expert_gating_func_type) hparams.expert_gating_func,
+                il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // Shared expert
+            {
+                ggml_tensor * ffn_shexp = build_ffn(cur,
+                        layer.ffn_up_shexp, NULL, NULL,
+                        layer.ffn_gate_shexp, NULL, NULL,
+                        layer.ffn_down_shexp, NULL, NULL,
+                        NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(ffn_shexp, "ffn_shexp", il);
+
+                cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                cb(cur, "ffn_out", il);
+            }
+        }
+        // Residual
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // Final Norm
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // Output
+    cur = ggml_mul_mat(ctx0, model.output, cur);
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
+/*
+    This is a ggml implementation of the naive_chunk_kda function of
+    https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
+*/
+std::pair llm_build_kimi_linear::build_kda_chunking(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * gk,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        ggml_tensor * diag_mask,
+        int           il) {
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(gk->ne[0] == S_v && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+
+    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
+
+    // TODO: can this ever be false?
+    const bool use_qk_l2norm = true;
+
+    if (use_qk_l2norm) {
+        const float eps_norm = hparams.f_norm_rms_eps;
+
+        q = ggml_l2_norm(ctx0, q, eps_norm);
+        k = ggml_l2_norm(ctx0, k, eps_norm);
+    }
+
+    const float scale = 1.0f / sqrtf(S_v);
+
+    beta = ggml_sigmoid(ctx0, beta);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(gk, "gk_in", il);
+
+    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+    gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+
+    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
+    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
+
+    cb(q, "q_perm", il);
+    cb(k, "k_perm", il);
+    cb(v, "v_perm", il);
+    cb(beta, "beta_perm", il);
+    cb(gk, "gk_perm", il);
+    cb(state, "state_in", il);
+
+    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
+
+    // Do padding
+    const int64_t chunk_size = CHUNK_SIZE;
+
+    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
+    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
+
+    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
+    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
+    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
+    gk = ggml_pad(ctx0, gk, 0, pad, 0, 0);
+    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
+
+    cb(q, "q_pad", il);
+    cb(k, "k_pad", il);
+    cb(v, "v_pad", il);
+    cb(beta, "beta_pad", il);
+    cb(gk, "gk_pad", il);
+
+    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
+    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
+
+    cb(v_beta, "v_beta", il);
+    cb(k_beta, "k_beta", il);
+
+    const int64_t HB = H_k * n_seqs;
+
+    q      = ggml_cont_4d(ctx0, q,      S_k, chunk_size, n_chunks, HB);
+    k      = ggml_cont_4d(ctx0, k,      S_k, chunk_size, n_chunks, HB);
+    k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, HB);
+    v      = ggml_cont_4d(ctx0, v,      S_v, chunk_size, n_chunks, HB);
+    v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, HB);
+
+    gk    = ggml_cont_4d(ctx0, gk, S_k, chunk_size, n_chunks, HB);
+    beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, HB);
+
+    // switch for cumsum
+    gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 0, 2, 3), chunk_size, S_k, n_chunks, HB);
+    cb(gk, "gk", il);
+    ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk);
+    cb(gk_cumsum, "gk_cumsum", il);
+
+/*
+    Compute Akk and Aqk loop together
+    Akk loop:
+    for i in range(BT):
+        k_i = k[..., i, :] # k_i [B,H,NT,S]
+        g_i = g[..., i:i+1, :] # g_i [B,H,NT,1,S]
+        A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i)
+    Aqk loop:
+    for j in range(BT):
+        k_j = k[:, :, i, j]
+        g_j = g[:, :, i, j:j+1, :]
+        A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
+*/
+    const int64_t CHB = n_chunks * H_k * n_seqs;
+    ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB);  // [chunk_size, 1, S_k, CHB]
+    ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB]
+
+    ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
+    // decay_mask [chunk_size,chunk_size,S_k,CHB]
+    ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i);
+    cb(decay_mask, "decay_mask", il);
+
+    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+    cb(decay_mask, "decay_masked", il);
+    decay_mask = ggml_exp(ctx0, decay_mask);
+    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+
+    // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
+    decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
+
+    ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB);
+    ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
+    ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
+
+    ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
+    ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
+
+    // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
+    ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j);
+    ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j);
+    Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB)));
+    Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB)));
+    cb(Akk, "Akk", il);
+    cb(Aqk, "Aqk", il);
+
+    Akk = ggml_mul(ctx0, Akk, beta);
+    Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
+    cb(Akk, "attn_pre_solve", il);
+
+    Aqk = ggml_mul(ctx0, Aqk, diag_mask);
+    Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
+    cb(Aqk, "Aqk_masked", il);
+
+    // for i in range(1, chunk_size):
+    //          row = attn[..., i, :i].clone()
+    //          sub = attn[..., :i, :i].clone()
+    //          attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+    // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
+    //
+    // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
+    ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask);
+    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
+
+    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, Akk, true, true, false);
+    Akk                      = ggml_mul(ctx0, lin_solve, causal_mask);
+    Akk                      = ggml_add(ctx0, Akk, identity);
+
+    cb(Akk, "attn_solved", il);
+
+    // switch back for downstream
+    gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB);
+    ggml_tensor * gkexp      = ggml_exp(ctx0, gk_cumsum);
+    cb(gk_cumsum, "gk_cumsum", il);
+
+    // u = (A*beta[..., None, :]) @ v  aka U_[t]
+    ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk);
+
+    ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp);
+    cb(kbeta_gkexp, "kbeta_gkexp", il);
+
+    ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk);
+    cb(k_cumdecay, "k_cumdecay", il);
+
+    ggml_tensor * core_attn_out = nullptr;
+    ggml_tensor * new_state = ggml_dup(ctx0, state);
+
+    cb(new_state, "new_state", il);
+
+    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
+// extract one chunk worth of data
+        auto chunkify = [=](ggml_tensor * t) {
+                    return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
+                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
+        };
+        auto chunkify_A = [=](ggml_tensor * t) {
+                    return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3],
+                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
+        };
+
+
+// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B]
+        ggml_tensor * k_chunk = chunkify(k);
+        ggml_tensor * q_chunk = chunkify(q);
+        ggml_tensor * vb_chunk = chunkify(vb);
+
+// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B]
+        ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum);
+        ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
+        ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk);
+        ggml_tensor * Aqk_chunk = chunkify_A(Aqk);
+
+        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
+
+        // new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B]
+        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t]
+        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
+
+        // v_new = v_i - v_prime or U_[t] - W_[t]*S_[t]
+        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb_chunk, v_prime), v_prime);
+        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
+
+        // q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B]
+        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+        // or Gamma_[t]*Q_]t] @ S
+        ggml_tensor * q_gk_exp   = ggml_mul(ctx0, q_chunk, gkexp_chunk);
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp);
+        attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q
+
+        // v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B]
+        // core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t])
+        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk_chunk);
+
+        // o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i
+        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
+
+        core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
+
+        ggml_tensor * gk_cum_last =
+            ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cs_chunk, gk_cs_chunk->ne[0], 1, gk_cs_chunk->ne[2], gk_cs_chunk->ne[3],
+                                        gk_cs_chunk->nb[1], gk_cs_chunk->nb[2], gk_cs_chunk->nb[3],
+                                        gk_cs_chunk->nb[1] * (gk_cs_chunk->ne[1] - 1)));
+
+        ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last)));
+
+        ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cs_chunk, gk_cum_last));
+
+        ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff);
+
+        ggml_tensor * key_gkdiff = ggml_mul(ctx0, k_chunk, gk_diff_exp);
+
+        // rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S)
+        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff)));
+
+        new_state = ggml_add(ctx0,
+            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gkexp_last, gkexp_last->ne[0], gkexp_last->ne[1], H_v, n_seqs)),
+            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
+    }
+
+    core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
+
+    // truncate padded tokens
+    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
+            S_v, n_tokens, H_v, n_seqs,
+            ggml_row_size(core_attn_out->type, S_v),
+            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
+            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
+    output_tokens = ggml_cont(ctx0, output_tokens);
+    // permute back to (S_v, H_v, n_tokens, n_seqs)
+    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
+    output_tokens = ggml_cont(ctx0, output_tokens);
+
+    cb(new_state, "output_state", il);
+
+    return {output_tokens, new_state};
+}
+
+std::pair llm_build_kimi_linear::build_kda_autoregressive(
+    ggml_tensor * q,
+    ggml_tensor * k,
+    ggml_tensor * v,
+    ggml_tensor * gk,
+    ggml_tensor * beta,
+    ggml_tensor * state,
+    int il) {
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(gk));
+
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(n_tokens == 1);
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_k && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_k && state->ne[2] == H_v && state->ne[3] == n_seqs);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+
+    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
+
+    const float eps_norm = hparams.f_norm_rms_eps;
+
+    q = ggml_l2_norm(ctx0, q, eps_norm);
+    k = ggml_l2_norm(ctx0, k, eps_norm);
+
+    const float scale = 1.0f / sqrtf(S_v);
+
+    q    = ggml_scale(ctx0, q, scale);
+    beta = ggml_sigmoid(ctx0, beta);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(gk, "gk_in", il);
+
+// g [H,1,B,1] g_t [1,H,B,1] => [1,1,H,B]
+// gk [S,H,1,B] => [S,1,H,B] gk_t [1,S,H,B]
+// beta [H,1,1,B] beta_t [1,H,1,B] => [1,1,H,B]
+    gk = ggml_reshape_4d(ctx0, gk, S_k, 1, H_k, n_seqs);
+    ggml_tensor * gk_t = ggml_cont(ctx0, ggml_transpose(ctx0, gk));
+    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
+
+    // Apply exponential to gk_t
+    gk_t = ggml_exp(ctx0, gk_t);
+    // Apply the gated delta rule for the single timestep
+    // last_recurrent_state = last_recurrent_state * gk_t
+    // S = S * g_i[..., None].exp()
+    state = ggml_mul(ctx0, state, gk_t);
+
+    ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
+
+// state [S,S,H,B] k [S,1,H,B] k_state [S_v,1,H,B]
+    k = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
+    ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k);
+
+    // v_i - (k_i[..., None] * S).sum(-2)
+    v = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
+    ggml_tensor * v_diff = ggml_sub(ctx0, v, k_state);
+
+    // b_i[..., None] * k_i
+    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta_t);
+
+    // S = S + torch.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2))
+    // v_diff_t [1,S_v,H,B] k_beta_t [1,S_k,H,B] state [S_v,S_k,H,B]
+    state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta))));
+
+    q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs);
+    state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
+    ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q);
+    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
+    cb(core_attn_out, "output_tokens", il);
+    cb(state, "new_state", il);
+
+    return {core_attn_out, state};
+}
+
diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h
index 3a44f7f1..cfcbb9aa 100644
--- a/examples/talk-llama/models/models.h
+++ b/examples/talk-llama/models/models.h
@@ -288,6 +288,33 @@ struct llm_build_jamba : public llm_graph_context_mamba {
     llm_build_jamba(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_kimi_linear : public llm_graph_context_mamba {
+    llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
+
+    std::pair build_kda_autoregressive(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * gk,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                        int   il);
+
+    std::pair build_kda_chunking(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * gk,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                ggml_tensor * causal_mask,
+                ggml_tensor * identity,
+                ggml_tensor * diag_mask,
+                        int   il);
+
+    const llama_model & model;
+};
+
 struct llm_build_lfm2 : public llm_graph_context {
     const llama_model & model;
 
@@ -556,6 +583,10 @@ struct llm_build_starcoder : public llm_graph_context {
     llm_build_starcoder(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_step35_iswa : public llm_graph_context {
+    llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_t5_dec : public llm_graph_context {
     llm_build_t5_dec(const llama_model & model, const llm_graph_params & params);
 };
diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp
index ee46a337..fbf682ec 100644
--- a/examples/talk-llama/models/openelm.cpp
+++ b/examples/talk-llama/models/openelm.cpp
@@ -43,7 +43,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_
             ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head);
             cb(Kcur, "Kcur", il);
 
-            ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+            ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv));
             cb(Vcur, "Vcur", il);
 
             Qcur = build_norm(Qcur,
diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp
index 57b6659b..99b1a76a 100644
--- a/examples/talk-llama/models/qwen3next.cpp
+++ b/examples/talk-llama/models/qwen3next.cpp
@@ -265,9 +265,15 @@ std::pair llm_build_qwen3next::build_delta_net_chu
     cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
 
     ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
-    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
+    ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
+                                                 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
+
+    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
     cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
 
+    ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
+    cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
+
 
     // state to be updated per chunk
     ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
@@ -322,9 +328,9 @@ std::pair llm_build_qwen3next::build_delta_net_chu
             : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
 
         // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-        ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
+        ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
         //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
+        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
 
         // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
         ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
diff --git a/examples/talk-llama/models/step35-iswa.cpp b/examples/talk-llama/models/step35-iswa.cpp
new file mode 100644
index 00000000..f8737815
--- /dev/null
+++ b/examples/talk-llama/models/step35-iswa.cpp
@@ -0,0 +1,168 @@
+#include "models.h"
+
+llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    ggml_tensor * inp_pos     = build_inp_pos();
+    auto        * inp_attn    = build_attn_inp_kv_iswa();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        const uint32_t n_head_l    = hparams.n_head(il);
+        const uint32_t n_head_kv_l = hparams.n_head_kv(il);
+
+        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+        cur = inpL;
+
+        // dump pre-attn RMSNorm input to pinpoint layer boundary issues
+        cb(cur, "attn_norm_in", il);
+
+        // self-attention
+        {
+            cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
+
+            // Q/K per-head RMSNorm (Step35 q_norm / k_norm)
+            if (model.layers[il].attn_q_norm) {
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+            }
+            if (model.layers[il].attn_k_norm) {
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+            }
+
+            // RoPE (partial rotary factors per layer)
+            const bool is_swa = hparams.is_swa(il);
+            ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il);
+            const int64_t n_rot_l = is_swa ? hparams.n_rot : (hparams.n_rot / 2);
+            Qcur = ggml_rope_ext(
+                ctx0, Qcur, inp_pos, rope_factors,
+                n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            Kcur = ggml_rope_ext(
+                ctx0, Kcur, inp_pos, rope_factors,
+                n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            cb(Qcur, "Qcur_pos", il);
+            cb(Kcur, "Kcur_pos", il);
+
+            const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k));
+            ggml_tensor * attn_out = build_attn(inp_attn,
+                    nullptr, nullptr,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+            cb(attn_out, "attn_out", il);
+            // head-wise attention gate: sigmoid(g_proj(x)) in torch
+            if (model.layers[il].wqkv_gate) {
+                ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, cur); // [n_head_l, n_tokens]
+                cb(gate, "attn_gate", il);
+
+                gate = ggml_sigmoid(ctx0, gate);
+                cb(gate, "attn_gate_sigmoid", il);
+
+                // reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens]
+                ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens);
+                ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate,       1,          n_head_l, n_tokens);
+                cb(gate_3d, "attn_gate_3d", il);
+
+                attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
+                cb(attn_3d, "attn_gated_3d", il);
+
+                attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
+                cb(attn_out, "attn_gated", il);
+            }
+
+            // output projection
+            cur = build_lora_mm(model.layers[il].wo, attn_out);
+            cb(cur, "attn_proj", il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        cur = build_norm(ffn_inp, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        // feed-forward
+        if (model.layers[il].ffn_gate_inp == nullptr) {
+            // dense MLP
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   nullptr,
+                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr,
+                    nullptr,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE routed experts
+            const bool  norm_w  = hparams.expert_weights_norm;
+            const float w_scale = hparams.expert_weights_scale;
+            const bool  scale_w = w_scale != 0.0f;
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    model.layers[il].ffn_exp_probs_b,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU,
+                    norm_w, scale_w, w_scale,
+                    (llama_expert_gating_func_type) hparams.expert_gating_func,
+                    il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // shared expert MLP (always added on MoE layers in Step35)
+            ggml_tensor * sh_out = build_ffn(cur,
+                    model.layers[il].ffn_up_shexp,   nullptr, nullptr,
+                    model.layers[il].ffn_gate_shexp, nullptr, nullptr,
+                    model.layers[il].ffn_down_shexp, nullptr, nullptr,
+                    nullptr,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(sh_out, "ffn_shared_out", il);
+
+            cur = ggml_add(ctx0, moe_out, sh_out);
+            cb(cur, "ffn_out", il);
+        }
+        cur = ggml_add(ctx0, cur, ffn_inp);
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    cur = build_lora_mm(model.output, cur);
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp
index b47dcbe6..adfc489d 100644
--- a/examples/talk-llama/unicode.cpp
+++ b/examples/talk-llama/unicode.cpp
@@ -497,49 +497,26 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
     return bpe_offsets;
 }
 
-// use std::wregex to split the text
-static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) {
-    std::wregex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs);
+template 
+static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) {
+    using BidirIt = typename std::basic_string::const_iterator;
+#ifdef _MSC_VER
+    // Bypass bug in MSVC: https://github.com/ggml-org/llama.cpp/issues/17830
+    constexpr auto regex_flags = std::regex_constants::ECMAScript;
+#else
+    constexpr auto regex_flags = std::regex_constants::optimize | std::regex_constants::nosubs;
+#endif
+    std::basic_regex expr(regex, regex_flags);
     std::vector bpe_offsets; // store the offset of each word
     bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
     size_t start = 0;
     for (auto offset : offsets) {
-        std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
-        std::wcregex_iterator end;
+        std::regex_iterator it(text.begin() + start, text.begin() + start + offset, expr);
+        std::regex_iterator end;
 
         int64_t start_idx = 0;
         while (it != end) {
-            std::wcmatch match = *it;
-            if (match.position() > start_idx) {
-                bpe_offsets.emplace_back(match.position() - start_idx);
-            }
-            bpe_offsets.emplace_back(match.length());
-            start_idx = match.position() + match.length();
-            ++it;
-        }
-
-        if (start_idx < (int64_t) offset) {
-            bpe_offsets.emplace_back(offset - start_idx);
-        }
-        start += offset;
-    }
-
-    return bpe_offsets;
-}
-
-// use std::regex to split the text
-static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
-    std::regex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs);
-    std::vector bpe_offsets; // store the offset of each word
-    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
-    size_t start = 0;
-    for (auto offset : offsets) {
-        std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
-        std::cregex_iterator end;
-
-        int64_t start_idx = 0;
-        while (it != end) {
-            std::cmatch match = *it;
+            std::match_results match = *it;
             if (match.position() > start_idx) {
                 bpe_offsets.emplace_back(match.position() - start_idx);
             }