diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 59dde99e3..c9eead18a 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -757,14 +757,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the + // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so + // the model loader doesn't fault on the block index. + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -877,6 +878,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { } } +bool llm_arch_supports_rs_rollback(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + return true; + default: + return false; + } +} + bool llm_arch_supports_sm_tensor(const llm_arch & arch) { switch (arch) { case LLM_ARCH_GROK: diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index e37d548c9..89cf16cc3 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -637,3 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); bool llm_arch_is_diffusion (const llm_arch & arch); bool llm_arch_supports_sm_tensor(const llm_arch & arch); +bool llm_arch_supports_rs_rollback(const llm_arch & arch); diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 6554a89b2..f10397747 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -73,7 +73,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, - { "hunyuan-ocr", LLM_CHAT_TEMPLATE_HUNYUAN_OCR }, + { "hunyuan-vl", LLM_CHAT_TEMPLATE_HUNYUAN_VL }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, @@ -218,7 +218,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { return LLM_CHAT_TEMPLATE_OPENAI_MOE; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_begin▁of▁sentence|>")) { - return LLM_CHAT_TEMPLATE_HUNYUAN_OCR; + return LLM_CHAT_TEMPLATE_HUNYUAN_VL; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { @@ -825,8 +825,8 @@ int32_t llm_chat_apply_template( ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; } } - } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_OCR) { - // tencent/HunyuanOCR + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_VL) { + // tencent/HunyuanOCR & tencent/HunyuanVL ss << "<|hy_begin▁of▁sentence|>"; for (size_t i = 0; i < chat.size(); i++) { std::string role(chat[i]->role); diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 13f936a94..ea6540c0b 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -53,7 +53,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, - LLM_CHAT_TEMPLATE_HUNYUAN_OCR, + LLM_CHAT_TEMPLATE_HUNYUAN_VL, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 3d9714ab1..ad36c0666 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2,6 +2,7 @@ #include "ggml.h" #include "llama-arch.h" +#include "llama-graph.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" @@ -21,6 +22,14 @@ // llama_context // +static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) { + switch (ctx_type) { + case LLAMA_CONTEXT_TYPE_DEFAULT: return LLM_GRAPH_TYPE_DEFAULT; + case LLAMA_CONTEXT_TYPE_MTP : return LLM_GRAPH_TYPE_DECODER_MTP; + } + throw std::runtime_error("Unsupported ctx type"); +} + llama_context::llama_context( const llama_model & model, llama_context_params params) : @@ -42,13 +51,22 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } + cparams.n_rs_seq = params.n_rs_seq; + if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) { + LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n", + __func__, cparams.n_rs_seq); + cparams.n_rs_seq = 0; + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; - cparams.embeddings = params.embeddings; + cparams.embeddings = params.embeddings; + cparams.embeddings_pre_norm = false; + cparams.embeddings_pre_norm_masked = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -65,6 +83,8 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.ctx_type = params.ctx_type; + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -206,6 +226,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", @@ -278,6 +299,7 @@ llama_context::llama_context( /*.type_k =*/ params.type_k, /*.type_v =*/ params.type_v, /*.swa_full =*/ params.swa_full, + /*.ctx_type= */ cparams.ctx_type, }; memory.reset(model.create_memory(params_mem, cparams)); @@ -860,6 +882,42 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +float * llama_context::get_embeddings_pre_norm() { + output_reorder(); + + return embd_pre_norm.data; +} + +float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { + output_reorder(); + + try { + if (embd_pre_norm.data == nullptr) { + throw std::runtime_error("no pre-norm embeddings"); + } + + const uint32_t n_embd = model.hparams.n_embd; + + if (!cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); + } + return embd_pre_norm.data + (size_t) i * n_embd; + } + + const int64_t j = output_resolve_row(i); + return embd_pre_norm.data + j*n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1040,6 +1098,13 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } +void llama_context::set_embeddings_pre_norm(bool value, bool masked) { + LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); + + cparams.embeddings_pre_norm = value; + cparams.embeddings_pre_norm_masked = masked; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1072,6 +1137,19 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + if (sampler && model.split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + static bool warned = false; + if (!warned) { + LLAMA_LOG_WARN("%s: backend sampling not supported with SPLIT_MODE_TENSOR; using CPU\n", __func__); + warned = true; + } + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); + return false; + } + const bool can_offload = sampler && sampler->iface->backend_init && @@ -1241,7 +1319,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // so accept either present rather than requiring exactly one. + GGML_ASSERT(batch_inp.token || batch_inp.embd); if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -1312,8 +1392,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1379,6 +1460,16 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1531,7 +1622,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::mapget_ubatch(); @@ -1689,7 +1783,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + + const auto * res = process_ubatch(ubatch, ctx_type_to_graph_type(cparams.ctx_type), mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module @@ -1727,8 +1822,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1809,6 +1905,25 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. + { + const bool masked = cparams.embeddings_pre_norm_masked; + const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; + const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; + + if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; + + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float)); + } + } + // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -1823,6 +1938,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } n_outputs_prev += n_outputs; + n_tokens_prev += ubatch.n_tokens; } while (mctx->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith @@ -1893,10 +2009,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_pre_norm = cparams.embeddings_pre_norm; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -1908,8 +2026,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + + if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm row exists for every token in the batch, not just + // those flagged via batch.logits[i] -> size by token count instead. + embd_pre_norm.size = (size_t) n_embd * n_batch; + } // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); @@ -1925,8 +2050,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1942,6 +2067,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; + embd_pre_norm.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1970,6 +2096,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); + embd_pre_norm = has_embd_pre_norm ? buffer_view{(float *) (base + offset), embd_pre_norm.size} : buffer_view{nullptr, 0}; + offset += embd_pre_norm.size * sizeof(float); + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -2034,6 +2163,12 @@ void llama_context::output_reorder() { } } + if (embd_pre_norm.size > 0) { + for (uint64_t k = 0; k < n_embd; k++) { + std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -2121,7 +2256,7 @@ ggml_cgraph * llama_context::graph_reserve( auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3100,7 +3235,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3201,8 +3336,10 @@ llama_context_params llama_context_default_params() { /*.n_batch =*/ 2048, /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, + /*.n_rs_seq =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, @@ -3306,6 +3443,13 @@ llama_context * llama_init_from_model( model->hparams.pooling_type, params.pooling_type); } + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + model->hparams.nextn_predict_layers == 0) { + LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); + return nullptr; + } + + try { auto * ctx = new llama_context(*model, params); return ctx; @@ -3347,6 +3491,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) { return ctx->n_seq_max(); } +uint32_t llama_n_rs_seq(const llama_context * ctx) { + return ctx->get_cparams().n_rs_seq; +} + const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } @@ -3436,6 +3584,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_pre_norm(value, masked); +} + +float * llama_get_embeddings_pre_norm(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm(); +} + +float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm_ith(i); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 92d1b0cf9..d03f681d4 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -84,6 +84,9 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + float * get_embeddings_pre_norm(); + float * get_embeddings_pre_norm_ith(int32_t i); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -107,6 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); + void set_embeddings_pre_norm(bool value, bool masked); void set_causal_attn(bool value); void set_warmup(bool value); @@ -278,6 +282,11 @@ private: // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view embd = {nullptr, 0}; + // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_pre_norm is enabled and the model graph + // sets llm_graph_result::t_h_pre_norm + buffer_view embd_pre_norm = {nullptr, 0}; + struct sampling_info { // !samplers.empty() to check if any samplers are active std::map samplers; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 9d3594741..20ec59fe3 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -12,6 +12,7 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -27,6 +28,8 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; + bool embeddings_pre_norm; // also extract the hidden state before the final output norm + bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; @@ -40,6 +43,7 @@ struct llama_cparams { bool kv_unified; bool pipeline_parallel; + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index 8ce29d217..edfa71c20 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// +// pre-norm embeddings (hidden state before the final output norm) +// + +// Set whether the context outputs pre-norm embeddings or not +// If masked == true, output the embeddings only for the tokens with batch.logits != 0 +// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); + +// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index fe155c92d..fc027de8b 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -500,15 +500,21 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { } void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { - mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); - mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } - mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); - mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); + mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + } if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); @@ -534,14 +540,21 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { bool res = true; - res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + } - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + } return res; } @@ -848,6 +861,9 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + if (t_h_pre_norm != nullptr) { + ggml_set_output(t_h_pre_norm); + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); @@ -2528,7 +2544,8 @@ ggml_tensor * llm_graph_context::build_rs( int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); + GGML_UNUSED(rs_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 5cb1756c6..bf6778237 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -32,6 +32,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DECODER_MTP, }; enum llm_ffn_op_type { @@ -580,7 +581,8 @@ struct llm_graph_params { ubatch.n_seqs_unq == other.ubatch.n_seqs_unq && ( (!ubatch.token && !other.ubatch.token) || - (!ubatch.embd && !other.ubatch.embd) + (!ubatch.embd && !other.ubatch.embd) || + (ubatch.token && other.ubatch.token && ubatch.embd && other.ubatch.embd) ); // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same @@ -644,6 +646,7 @@ public: ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -672,6 +675,7 @@ public: ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm std::map t_sampled_logits; std::map t_candidates; diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 002d15d41..2239309c8 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const { } bool llama_hparams::has_kv(uint32_t il) const { + if (kv_only_nextn) { + // MTP head: only the trailing nextn_predict_layers blocks own a KV cache; + // the leading trunk blocks are not executed in this graph. + return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers); + } + if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { return true; diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 0160a89ca..e2d051edc 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -92,6 +92,8 @@ struct llama_hparams { uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; + bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) + float f_norm_eps; float f_norm_rms_eps; float f_norm_group_eps; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp index 10e6b4597..72f5c2fea 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.cpp +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr @@ -73,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_base()->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_base()->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.h b/examples/talk-llama/llama-memory-hybrid-iswa.h index 807c8aac9..c9d3f9f57 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.h +++ b/examples/talk-llama/llama-memory-hybrid-iswa.h @@ -34,6 +34,7 @@ public: uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index 4ce1af592..33b3b395e 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr @@ -73,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid.h b/examples/talk-llama/llama-memory-hybrid.h index 558cafdf9..484eafb74 100644 --- a/examples/talk-llama/llama-memory-hybrid.h +++ b/examples/talk-llama/llama-memory-hybrid.h @@ -34,6 +34,7 @@ public: uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index c07f1d969..ec5dc5835 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -24,6 +24,7 @@ llama_memory_recurrent::llama_memory_recurrent( bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; @@ -31,6 +32,9 @@ llama_memory_recurrent::llama_memory_recurrent( size = mem_size; used = 0; + this->n_rs_seq = n_rs_seq; + rs_idx.assign(n_seq_max, 0); + cells.clear(); cells.resize(mem_size); @@ -92,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size); - ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size); + const uint32_t n_rows = mem_size * (1 + n_rs_seq); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -115,8 +120,8 @@ llama_memory_recurrent::llama_memory_recurrent( const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, - (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs %2u rs_seq), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, n_rs_seq, ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); } @@ -138,10 +143,11 @@ void llama_memory_recurrent::clear(bool data) { ggml_backend_buffer_clear(buf.get(), 0); } } + + std::fill(rs_idx.begin(), rs_idx.end(), 0); } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -152,6 +158,15 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } + const bool rm_all = p0 == 0 && p1 == std::numeric_limits::max(); + if (rm_all) { + if (seq_id >= 0) { + set_rs_idx(seq_id, 0); + } else { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } + } + // models like Mamba or RWKV can't have a state partially erased at the end // of the sequence because their state isn't preserved for previous tokens if (seq_id >= (int64_t) size) { @@ -161,10 +176,16 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { - const auto & cell = cells[tail_id]; - // partial intersection is invalid if it includes the final pos + auto & cell = cells[tail_id]; + + // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); + const llama_pos rollback = cell.pos - (p0 - 1); + if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; + } return false; } // invalidate tails which will be cleared @@ -368,6 +389,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) { + if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) { + return; + } + rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx; +} + std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -388,9 +416,15 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + if (n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } } if (ubatch.n_tokens == 0) { @@ -703,6 +737,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq GGML_UNUSED(flags); std::vector> cell_ranges; // ranges, from inclusive, to exclusive + std::vector> cell_ranges_data; // logical source row ranges uint32_t cell_count = 0; // Count the number of cells with the specified seq_id @@ -712,6 +747,35 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq const auto & cell = cells[i]; if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { ++cell_count; + uint32_t rs_idx_cur = 0; + + if (n_rs_seq != 0) { + if (seq_id != -1) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < rs_idx.size()); + rs_idx_cur = rs_idx[seq_id]; + } else { + bool has_rs_idx = false; + for (const llama_seq_id cell_seq_id : cell.seq_id) { + GGML_ASSERT(cell_seq_id >= 0 && (size_t) cell_seq_id < rs_idx.size()); + + const uint32_t seq_rs_idx = rs_idx[cell_seq_id]; + if (!has_rs_idx) { + rs_idx_cur = seq_rs_idx; + has_rs_idx = true; + } else if (rs_idx_cur != seq_rs_idx) { + GGML_ABORT("cannot write shared recurrent state with different rollback indices"); + } + } + } + } + + const uint32_t cell_id = rs_idx_cur * size + (cell.src >= 0 ? cell.src : (int32_t) i); + if (cell_ranges_data.empty() || cell_ranges_data.back().second != cell_id) { + cell_ranges_data.emplace_back(cell_id, cell_id + 1); + } else { + cell_ranges_data.back().second++; + } + if (cell_range_begin == size) { cell_range_begin = i; } @@ -726,7 +790,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq cell_ranges.emplace_back(cell_range_begin, size); } - if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) && cell_ranges.size() > 1) { GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n"); } @@ -737,10 +801,16 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq } GGML_ASSERT(cell_count == cell_count_check); + cell_count_check = 0; + for (const auto & range : cell_ranges_data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + io.write(&cell_count, sizeof(cell_count)); state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); + state_write_data(io, cell_ranges_data); } void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { @@ -762,6 +832,14 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i } throw std::runtime_error("failed to restore kv cache"); } + + if (n_rs_seq != 0) { + if (seq_id == -1) { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } else { + set_rs_idx(seq_id, 0); + } + } } void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { @@ -804,7 +882,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Write each range of cells of r_size_row length + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -825,7 +904,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Write each range of S tensor rows + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -852,9 +932,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Write GQA embedding size io.write(&n_embd_s, sizeof(n_embd_s)); - // For each row, we get the element values of each cell + // For each row, we get the element values of each logical cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; @@ -1163,5 +1242,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + const uint32_t cell_idx = i + mem->head; + const int32_t src0 = mem->cells[cell_idx].src0; + + if (mem->n_rs_seq == 0) { + return src0; + } + + uint32_t idx = 0; + if (!mem->cells[cell_idx].seq_id.empty()) { + const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin(); + if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) { + idx = mem->rs_idx[seq]; + // reset rollback idx + mem->rs_idx[seq] = 0; + } + } + return (int32_t)(idx * mem->size) + src0; } diff --git a/examples/talk-llama/llama-memory-recurrent.h b/examples/talk-llama/llama-memory-recurrent.h index 47f01d739..b13b7b748 100644 --- a/examples/talk-llama/llama-memory-recurrent.h +++ b/examples/talk-llama/llama-memory-recurrent.h @@ -23,6 +23,7 @@ public: bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter); ~llama_memory_recurrent() = default; @@ -69,6 +70,14 @@ public: uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) + // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups + uint32_t n_rs_seq = 0; + + // per-seq rollback index + std::vector rs_idx; + + void set_rs_idx(llama_seq_id seq_id, uint32_t idx); + // computed before each graph build uint32_t n = 0; diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index 4a157b91f..4ad1612e4 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-graph.h" #include #include @@ -20,6 +21,8 @@ struct llama_memory_params { // use full-size SWA cache bool swa_full; + + llama_context_type ctx_type; }; enum llama_memory_status { diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 4e65a45a5..c645d0785 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte return tensor; } -void llama_model_loader::done_getting_tensors() const { - if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); +void llama_model_loader::done_getting_tensors(bool partial) const { + if (n_created > n_tensors) { + throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created)); + } + if (n_created < n_tensors) { + if (!partial) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n", + __func__, n_created, n_tensors); } if (n_tensors_moved > 0) { LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index 7b3d6703c..c476026d3 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -184,7 +184,7 @@ struct llama_model_loader { struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); - void done_getting_tensors() const; + void done_getting_tensors(bool partial = false) const; void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index e83056557..528e4c9c0 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -393,6 +393,8 @@ void llama_model_saver::add_tensors_from_model() { add_tensor(model->output); add_tensor(model->output_b); add_tensor(model->output_norm_enc); + add_tensor(model->output_s); + add_tensor(model->output_in_s); add_tensor(model->cls); add_tensor(model->cls_b); add_tensor(model->cls_out); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index ff30a2ae7..0d21b2a53 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1334,6 +1334,12 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_s && layer.ssm_beta) { layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.nextn.eh_proj_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } // input scales if (!layer.wq_in_s && layer.wq) { @@ -1393,11 +1399,30 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_in_s && layer.ssm_beta) { layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.nextn.eh_proj_in_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_in_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + } + // output scales + if (output && output->type == GGML_TYPE_NVFP4) { + // weight scale + if (!output_s) { + output_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "scale"), {1}, TENSOR_NOT_REQUIRED); + } + // input scale + if (!output_in_s) { + output_in_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "input_scale"), {1}, TENSOR_NOT_REQUIRED); + } } } - ml.done_getting_tensors(); + GGML_ASSERT(!(output && tok_embd && + strcmp(output->name, tok_embd->name) == 0 && + output->type == GGML_TYPE_NVFP4)); // populate tensors_by_name for (auto & [_, ctx_ptr] : ml.ctx_map) { for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { @@ -1934,6 +1959,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // checks default: { + // The MTP head is dense-attention only on hybrid Qwen3.5/3.6, so use a plain + // attention KV cache for the MTP context instead of the hybrid wrapper. + const bool mtp_on_hybrid_qwen35 = + params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE); + if (llm_arch_is_recurrent(arch)) { res = new llama_memory_recurrent( *this, @@ -1942,8 +1973,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max, + cparams.n_rs_seq, nullptr); - } else if (llm_arch_is_hybrid(arch)) { + } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -1958,6 +1990,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, filter_recr = [&](int32_t il) { return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; + } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter_attn = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && !hparams.is_recurrent(il); + }; + filter_recr = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && hparams.is_recurrent(il); + }; } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { @@ -1975,6 +2015,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_s */ GGML_TYPE_F32, /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -1993,6 +2034,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -2000,6 +2042,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_filter_cb filter = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { @@ -2011,6 +2054,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } + if (mtp_on_hybrid_qwen35) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; + } + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); @@ -2026,7 +2074,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, cparams.n_ubatch, 1, - nullptr, + filter, reuse); } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -2043,7 +2091,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, - nullptr, + filter, nullptr); } } @@ -2146,6 +2194,7 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index d63c68918..398a0aa72 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -202,12 +202,16 @@ struct llama_layer_shortconv { }; struct llama_layer_nextn { - struct ggml_tensor * eh_proj = nullptr; - struct ggml_tensor * embed_tokens = nullptr; - struct ggml_tensor * enorm = nullptr; - struct ggml_tensor * hnorm = nullptr; - struct ggml_tensor * shared_head_head = nullptr; - struct ggml_tensor * shared_head_norm = nullptr; + struct ggml_tensor * eh_proj = nullptr; + struct ggml_tensor * eh_proj_s = nullptr; + struct ggml_tensor * eh_proj_in_s = nullptr; + struct ggml_tensor * embed_tokens = nullptr; + struct ggml_tensor * enorm = nullptr; + struct ggml_tensor * hnorm = nullptr; + struct ggml_tensor * shared_head_head = nullptr; + struct ggml_tensor * shared_head_head_s = nullptr; + struct ggml_tensor * shared_head_head_in_s = nullptr; + struct ggml_tensor * shared_head_norm = nullptr; }; struct llama_layer { @@ -533,6 +537,11 @@ struct llama_model { struct ggml_tensor * output_b = nullptr; struct ggml_tensor * output_norm_enc = nullptr; + + // NVFP4 per-tensor scale2, input_scale for LM head + struct ggml_tensor * output_s = nullptr; + struct ggml_tensor * output_in_s = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index f43cf546c..a5cf148b2 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -530,6 +530,8 @@ struct llm_tokenizer_bpe : llm_tokenizer { struct llm_tokenizer_bpe_session { llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + virtual ~llm_tokenizer_bpe_session() = default; + static void append(const llama_token token_id, std::vector & output) { output.push_back(token_id); } @@ -567,7 +569,7 @@ struct llm_tokenizer_bpe_session { } } - void tokenize(const std::string & text, std::vector & output) { + virtual void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode); @@ -1579,6 +1581,88 @@ private: const llm_tokenizer_plamo2 & tokenizer; }; +// reserved suffix (U+E000) that keeps DNA k-mers distinct from identical +// base-vocab BPE tokens (e.g. CCCCCC) in token_to_id; erased from id_to_token +// text at load +static const std::string dna_kmer_marker = "\xee\x80\x80"; + +struct llm_tokenizer_hybriddna_session : llm_tokenizer_bpe_session { + llm_tokenizer_hybriddna_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} + + void tokenize(const std::string & text, std::vector & output) override { + static const std::string open_tag = ""; + static const std::string close_tag = ""; + + const auto dna_begin_id = vocab.text_to_token(open_tag); + const auto dna_end_id = vocab.text_to_token(close_tag); + const auto dna_oov_id = vocab.text_to_token(""); + + // Fall back to plain BPE if the DNA pieces aren't in the vocab. + if (dna_begin_id == LLAMA_TOKEN_NULL || dna_end_id == LLAMA_TOKEN_NULL || dna_oov_id == LLAMA_TOKEN_NULL) { + llm_tokenizer_bpe_session::tokenize(text, output); + return; + } + + const size_t k = 6; + size_t pos = 0; + + while (pos < text.size()) { + const size_t start = text.find(open_tag, pos); + if (start == std::string::npos) { + if (pos < text.size()) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos), output); + } + break; + } + if (start > pos) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos, start - pos), output); + } + output.push_back(dna_begin_id); + + const size_t content_start = start + open_tag.size(); + const size_t end = text.find(close_tag, content_start); + const size_t content_end = (end == std::string::npos) ? text.size() : end; + + emit_dna_kmers(text.substr(content_start, content_end - content_start), k, dna_oov_id, output); + + if (end == std::string::npos) { + break; + } + output.push_back(dna_end_id); + pos = end + close_tag.size(); + } + } + +private: + void emit_dna_kmers(const std::string & raw, size_t k, llama_token oov_id, std::vector & output) { + std::string seq = raw; + for (char & c : seq) { + if (c >= 'a' && c <= 'z') { + c = char(c - 32); + } + } + + // k-mers carry the reserved marker suffix; a non-ACGT k-mer simply + // isn't in the vocab and falls back to + auto kmer_token = [&](const std::string & kmer) { + const auto tok = vocab.text_to_token(kmer + dna_kmer_marker); + return tok != LLAMA_TOKEN_NULL ? tok : oov_id; + }; + + size_t i = 0; + for (; i + k <= seq.size(); i += k) { + output.push_back(kmer_token(seq.substr(i, k))); + } + if (i < seq.size()) { + std::string kmer = seq.substr(i); + kmer.append(k - kmer.size(), 'A'); + output.push_back(kmer_token(kmer)); + } + } + + const llama_vocab & vocab; +}; + // // impl // @@ -1808,7 +1892,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_mask_id = 103; add_sep = true; - } else if (tokenizer_model == "gpt2") { + } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna") { type = LLAMA_VOCAB_TYPE_BPE; // read bpe merges and populate bpe ranks @@ -2266,6 +2350,23 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } GGML_ASSERT(id_to_token.size() == token_to_id.size()); + // hybriddna: the marker suffix kept k-mer ids distinct in token_to_id; erase + // it from id_to_token so the k-mers detokenize to the bare DNA sequence. The + // k-mers are the block right after , so only scan from there. + if (tokenizer_model == "hybriddna") { + const auto idx = token_to_id.find(""); + if (idx != token_to_id.end()) { + auto it = id_to_token.begin() + idx->second + 1; + for (; it != id_to_token.end(); ++it) { + std::string & text = it->text; + if (text.size() > dna_kmer_marker.size() + && text.compare(text.size() - dna_kmer_marker.size(), dna_kmer_marker.size(), dna_kmer_marker) == 0) { + text.erase(text.size() - dna_kmer_marker.size()); + } + } + } + } + init_tokenizer(type); // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' @@ -3144,11 +3245,19 @@ std::vector llama_vocab::impl::tokenize( } break; case LLAMA_VOCAB_TYPE_BPE: { - llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get())); // it calls some other methods that are not exist in llm_tokenizer, // here just cast it to bpe tokenizer object + const llm_tokenizer_bpe * tok_bpe = static_cast(tokenizer.get()); + + std::unique_ptr session; + if (vocab.get_tokenizer_model() == "hybriddna") { + session = std::make_unique(vocab, *tok_bpe); + } else { + session = std::make_unique(vocab, *tok_bpe); + } + if (add_special) { - session.append_bos(output); + session->append_bos(output); } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -3161,15 +3270,15 @@ std::vector llama_vocab::impl::tokenize( #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif - session.tokenize(text, output); + session->tokenize(text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - session.append(fragment.token, output); + session->append(fragment.token, output); } } if (add_special) { - session.append_eos(output); - session.check_double_bos_eos(output); + session->append_eos(output); + session->check_double_bos_eos(output); } } break; case LLAMA_VOCAB_TYPE_WPM: diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 308e8ba9d..e8374c53b 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -198,6 +198,11 @@ extern "C" { LLAMA_SPLIT_MODE_TENSOR = 3, }; + enum llama_context_type { + LLAMA_CONTEXT_TYPE_DEFAULT = 0, + LLAMA_CONTEXT_TYPE_MTP = 1, + }; + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id @@ -333,9 +338,11 @@ extern "C" { uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + enum llama_context_type ctx_type; // set the context type (e.g. MTP) enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings @@ -530,6 +537,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); @@ -866,7 +874,8 @@ extern "C" { // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) #define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 -// keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load) +// Keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load). +// Getting the state for a seq_id with this flag invalidates all prior states gotten for that seq_id with this flag. #define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2 typedef uint32_t llama_state_seq_flags; diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 602e3176a..a7c77ee5d 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -277,7 +277,7 @@ llama_model_afmoe::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index 136ff7029..bec713652 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -160,7 +160,7 @@ llama_model_apertus::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index 70e86d411..d086c4717 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -148,7 +148,7 @@ llama_model_arcee::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index d8653a446..27deadffe 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -171,7 +171,7 @@ llama_model_arctic::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arwkv7.cpp b/examples/talk-llama/models/arwkv7.cpp index 79aa8c908..9bd04127b 100644 --- a/examples/talk-llama/models/arwkv7.cpp +++ b/examples/talk-llama/models/arwkv7.cpp @@ -193,7 +193,7 @@ llama_model_arwkv7::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index 4e55290e4..4d26081cd 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -146,7 +146,7 @@ llama_model_baichuan::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index 030dd4f42..fe1ae1086 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -171,7 +171,7 @@ llama_model_bailingmoe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index e7fe3d5b4..2f0d44a62 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -210,7 +210,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index b600fb0c9..30b0f3d07 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -142,7 +142,7 @@ llama_model_bloom::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 8510b9e29..4bceaefd6 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -181,7 +181,7 @@ llama_model_chameleon::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_with_img_logits", -1); // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index e898eff79..6766fa71c 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -151,7 +151,7 @@ llama_model_chatglm::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index e9e85d967..274dd3342 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -143,7 +143,7 @@ llama_model_codeshell::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 79236121b..2e231bb3f 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -150,7 +150,7 @@ llama_model_cogvlm::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); diff --git a/examples/talk-llama/models/cohere2.cpp b/examples/talk-llama/models/cohere2.cpp index 12edbae10..a514cf88f 100644 --- a/examples/talk-llama/models/cohere2.cpp +++ b/examples/talk-llama/models/cohere2.cpp @@ -146,7 +146,7 @@ llama_model_cohere2::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index decb89f54..adf7fcaa2 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -131,7 +131,7 @@ llama_model_command_r::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index bce6b04bc..af71c7753 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -145,7 +145,7 @@ llama_model_dbrx::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 9f1a959c3..567e35352 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -181,7 +181,7 @@ llama_model_deci::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index c79460596..f52ec9518 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -185,7 +185,7 @@ llama_model_deepseek::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp index 6bc989c95..4f4c7cac7 100644 --- a/examples/talk-llama/models/delta-net-base.cpp +++ b/examples/talk-llama/models/delta-net-base.cpp @@ -1,6 +1,7 @@ #include "models.h" #include "llama-impl.h" +#include "llama-memory-recurrent.h" // utility to get one slice from the third dimension // input dim: [x, y, c, b] @@ -397,7 +398,9 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + // K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net. + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs); + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d); if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -443,3 +446,162 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } + +ggml_tensor * llm_build_delta_net_base::build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il) { + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + const auto mem_size = mctx_cur->get_size(); + + const int64_t n_seqs = ubatch.n_seqs; + + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + const int64_t row_count = (conv_kernel_size - 1) * conv_channels; + + const size_t row_size = ggml_row_size(conv_states_all->type, row_count); + + if (cparams.n_rs_seq == 0) { + const int64_t s_idx = conv_input->ne[0] - conv_states->ne[0]; + const int64_t s_slot = 0; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + cb(conv_state_last, "conv_state_last", il); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, conv_states_all, + row_count, n_seqs, conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + cb(conv_state_update, "conv_state_update", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } else { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: this logic incorrectly assumes that the last (n_rs_seq + 1) tokens of a sequence in a batch are + // inside the same ubatch. currently with `split_equal()` this is not correct + + const int64_t K = (int64_t) cparams.n_rs_seq + 1; + + for (int64_t t = 1; t <= K; ++t) { + const int64_t s_idx = std::max(0, conv_input->ne[0] - conv_states->ne[0] - K + t); + const int64_t s_slot = K - t; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, + conv_states_all, row_count, n_seqs, + conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } + } + + return conv_input; +} + +ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + + const int64_t S_v = s->ne[0]; + const int64_t H_v = s->ne[2]; + const int64_t n_seqs = s->ne[3]; + const int64_t n_seq_tokens = q->ne[2]; + + const bool keep = cparams.n_rs_seq > 0; + + if (!keep) { + auto attn_out = build_delta_net(q, k, v, g, b, s, il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + return output; + } + + const int64_t D = S_v * S_v * H_v; + const int64_t K = cparams.n_rs_seq + 1; + + // TODO: remove pad + simplify + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); + ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0); + + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad); + if (n_seq_tokens > 1) { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); + } else { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_AR, il); + } + + const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs; + + ggml_tensor * output = ggml_view_4d(ctx0, gdn_out, + S_v, H_v, n_seq_tokens, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * H_v), + ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens), + 0); + cb(output, "attn_output", il); + + const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); + for (int64_t k_i = 0; k_i < K; ++k_i) { + const uint32_t cache_slot = (uint32_t) (K - 1 - k_i); + ggml_tensor * src = ggml_view_4d(ctx0, gdn_out, + S_v, S_v, H_v, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * S_v), + ggml_row_size(gdn_out->type, S_v * S_v * H_v), + ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); + + ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, + hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + ((size_t) cache_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + } + + return output; +} diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 93cbcf9d9..435d27281 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -183,7 +183,7 @@ llama_model_dots1::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 60a3f0ec2..12ac6f1ce 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -128,7 +128,7 @@ llama_model_dream::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index 2bd01a2c5..8d9ff1386 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -124,7 +124,7 @@ llama_model_ernie4_5_moe::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index fa989fe92..9b39c605e 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -155,7 +155,7 @@ llama_model_ernie4_5::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index 54bb3ca86..76d91982f 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -237,7 +237,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index 75d5f6063..c7e9960d7 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -127,7 +127,7 @@ llama_model_exaone::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 5506e7642..499e22dde 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -163,7 +163,7 @@ llama_model_exaone4::graph::graph(const llama_model & model, const llm_gra res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index d353befdb..94b65a3c7 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -200,7 +200,7 @@ llama_model_falcon_h1::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index 75f2cfef5..ad546ef2d 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -152,7 +152,7 @@ llama_model_falcon::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 067316700..1519682fd 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -130,7 +130,7 @@ llama_model_gemma::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma2.cpp b/examples/talk-llama/models/gemma2.cpp index 6255bf740..ae3f9ffb5 100644 --- a/examples/talk-llama/models/gemma2.cpp +++ b/examples/talk-llama/models/gemma2.cpp @@ -163,7 +163,7 @@ llama_model_gemma2::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // final logit soft-capping cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index ee510fe38..63a2b380e 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -207,7 +207,7 @@ llama_model_gemma3::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/gemma3n.cpp b/examples/talk-llama/models/gemma3n.cpp index 881499b0c..6ec3a0060 100644 --- a/examples/talk-llama/models/gemma3n.cpp +++ b/examples/talk-llama/models/gemma3n.cpp @@ -296,7 +296,7 @@ llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); { // final logit soft-capping diff --git a/examples/talk-llama/models/gemma4.cpp b/examples/talk-llama/models/gemma4.cpp index f45ae4cad..4f9d8b18b 100644 --- a/examples/talk-llama/models/gemma4.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -380,7 +380,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 45886b51a..27654b8cb 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -275,7 +275,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index d6ef76e26..7c242fed2 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -185,7 +185,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // Output projection - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index ba49c31b5..e2dcc8b15 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -138,7 +138,7 @@ llama_model_gpt2::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 33ebe2d88..443e35add 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -209,7 +209,7 @@ llama_model_gptneox::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 12e4790ae..27f6706ea 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -186,7 +186,7 @@ llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits if (hparams.f_logit_scale) { diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 5e7c7b681..cda4aa231 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -145,7 +145,7 @@ llama_model_granite::graph::graph( res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 0bc49d002..7c46ec1c0 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -206,7 +206,7 @@ llama_model_grok::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index feef81516..1cab75adc 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -184,7 +184,7 @@ llama_model_grovemoe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index 44af42412..deb3c9671 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -179,7 +179,7 @@ llama_model_hunyuan_moe::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-vl.cpp b/examples/talk-llama/models/hunyuan-vl.cpp index 5fb9154be..da9bb74de 100644 --- a/examples/talk-llama/models/hunyuan-vl.cpp +++ b/examples/talk-llama/models/hunyuan-vl.cpp @@ -181,7 +181,7 @@ llama_model_hunyuan_vl::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index f0c5580a6..f9ee37a24 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -129,7 +129,7 @@ llama_model_internlm2::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index a6451dca0..2ba162605 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -123,7 +123,7 @@ llama_model_jais::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index ad59b953e..896613144 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -152,7 +152,7 @@ llama_model_jais2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // Output projection - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index e1b8d137e..84ea63c31 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -189,7 +189,7 @@ llama_model_jamba::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index df6a80287..29081344b 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -262,7 +262,7 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index b60f67f6c..9722dde9f 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -153,7 +153,7 @@ llama_model_llada_moe::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index fa21c5fe3..58b2c466e 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -147,7 +147,7 @@ llama_model_llada::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index 8ddb59368..cef66d054 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -235,7 +235,7 @@ llama_model_llama::graph::graph(const llama_model & model, const llm_grap if constexpr (!embed) { // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llama4.cpp b/examples/talk-llama/models/llama4.cpp index 899611d53..0ff5376d5 100644 --- a/examples/talk-llama/models/llama4.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -260,7 +260,7 @@ llama_model_llama4::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index 3dbd82fd3..84cfe3990 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -141,7 +141,7 @@ llama_model_maincoder::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index b7708d7fd..887a1fa50 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -128,7 +128,7 @@ llama_model_mamba::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mimo2.cpp b/examples/talk-llama/models/mimo2.cpp index 719966166..d0295ec11 100644 --- a/examples/talk-llama/models/mimo2.cpp +++ b/examples/talk-llama/models/mimo2.cpp @@ -231,7 +231,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index ff5eb6ffa..1ffc54fa7 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -251,7 +251,7 @@ llama_model_minicpm3::graph::graph(const llama_model & model, const llm_graph_pa cb(cur, "lmhead_scaling", -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index 0dee89346..22e291d73 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -158,7 +158,7 @@ llama_model_minimax_m2::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 708da49af..4e6ebef82 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -222,7 +222,7 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 6d5f18a8e..7e551eb96 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // use the ggml_gated_delta_net fused operator + // use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs)) std::pair build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, @@ -65,6 +65,29 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) + // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs) + ggml_tensor * build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il); + + // run delta-net attention and write the new recurrent state(s) back to ssm_states_all + // s: (head_v_dim, head_v_dim, num_v_heads, n_seqs); returns output: (head_v_dim, num_v_heads, n_seq_tokens, n_seqs) + ggml_tensor * build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { @@ -1739,6 +1762,10 @@ struct llama_model_qwen35 : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; @@ -1781,6 +1808,10 @@ struct llama_model_qwen35moe : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index cfc60e8de..0229d20ed 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -161,7 +161,7 @@ llama_model_mpt::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index 865461f61..a82f9c170 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -174,7 +174,7 @@ llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 0c72ed297..5d4a3b5c6 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -140,7 +140,7 @@ llama_model_nemotron::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index 161035e72..cfcf17bcb 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -133,7 +133,7 @@ llama_model_olmo::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 9633f2699..7cc262f55 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -198,7 +198,7 @@ llama_model_olmo2::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index 4bb901305..7976ae44a 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -164,7 +164,7 @@ llama_model_olmoe::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/openai-moe.cpp b/examples/talk-llama/models/openai-moe.cpp index 13a590ce6..15b6c8c12 100644 --- a/examples/talk-llama/models/openai-moe.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -160,7 +160,7 @@ llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index b4128e116..9f76350fd 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -162,7 +162,7 @@ llama_model_openelm::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index 7ace0a513..bcb4bbba4 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -132,7 +132,7 @@ llama_model_orion::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp index 1c0eadefa..d39220bd7 100644 --- a/examples/talk-llama/models/paddleocr.cpp +++ b/examples/talk-llama/models/paddleocr.cpp @@ -98,7 +98,7 @@ llama_model_paddleocr::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/pangu-embed.cpp b/examples/talk-llama/models/pangu-embed.cpp index 41b7e2ac2..7593f879b 100644 --- a/examples/talk-llama/models/pangu-embed.cpp +++ b/examples/talk-llama/models/pangu-embed.cpp @@ -148,7 +148,7 @@ llama_model_pangu_embed::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index a333602c7..8f3ed5f7b 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -130,7 +130,7 @@ llama_model_phi2::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_no_bias", -1); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index 0a65e91fe..f8a4a4d5a 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -179,7 +179,7 @@ llama_model_phi3::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cb(cur, "result_output_no_bias", -1); diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index 4c16c20a0..c7ed1211c 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -127,7 +127,7 @@ llama_model_plamo::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index 29c870260..b713889fe 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -185,7 +185,7 @@ llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); // Explicitly mark as output tensor to ensure proper backend assignment diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 849f1579e..29f3e803d 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -186,7 +186,7 @@ llama_model_plamo3::graph::graph(const llama_model & model, const llm_grap cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); res->t_logits = cur; ggml_build_forward_expand(gf, cur); diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index 57f599510..ce050919e 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -204,7 +204,7 @@ llama_model_plm::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index cdc076cdf..00467dbad 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -131,7 +131,7 @@ llama_model_qwen::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 6320458a1..a5147460b 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -141,7 +141,7 @@ llama_model_qwen2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 7587c802c..7cb03859d 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -184,7 +184,7 @@ llama_model_qwen2moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index 1a40fa89b..d79db682c 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -134,7 +134,7 @@ llama_model_qwen2vl::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index fa656c84e..41b97fed9 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -147,7 +147,7 @@ llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index f276be61b..04ecc18fc 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -12,16 +12,22 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -29,9 +35,14 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -43,50 +54,85 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -111,7 +157,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -128,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -160,6 +208,13 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -167,7 +222,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -297,8 +352,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -328,41 +381,14 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -413,7 +439,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -424,18 +450,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -471,3 +486,151 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // hparams.n_layer includes both main model layers and MTP layers. The MTP + // layer is stored immediately after the main layers in model.layers[]. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index cf05dc9d6..dc24f6ed5 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -15,16 +15,22 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -32,9 +38,14 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -46,60 +57,105 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, flags); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, flags); // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, flags); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -124,7 +180,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -141,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -173,6 +231,13 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -180,7 +245,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -310,8 +375,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -341,41 +404,14 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -426,7 +462,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -437,18 +473,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -525,3 +550,183 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto * inp_attn = build_attn_inp_kv(); + + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35MOE MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index 4440b83aa..a4f8e1379 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -168,7 +168,7 @@ llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index cb1b4814c..1d873427d 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -176,7 +176,7 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -378,8 +378,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -429,41 +427,14 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -540,18 +511,7 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index 7871f8f79..5defd8939 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -163,7 +163,7 @@ llama_model_qwen3vl::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3vlmoe.cpp b/examples/talk-llama/models/qwen3vlmoe.cpp index b99143c89..5b77df571 100644 --- a/examples/talk-llama/models/qwen3vlmoe.cpp +++ b/examples/talk-llama/models/qwen3vlmoe.cpp @@ -180,7 +180,7 @@ llama_model_qwen3vlmoe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index f14f10917..bf3949a90 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -150,7 +150,7 @@ llama_model_refact::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index 325ee73ba..ca8e00961 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -167,7 +167,7 @@ llama_model_rnd1::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index 2944711ac..ba2a9dfa0 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -176,7 +176,7 @@ llama_model_rwkv6::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6qwen2.cpp b/examples/talk-llama/models/rwkv6qwen2.cpp index 6f7d1f572..566b8cdcb 100644 --- a/examples/talk-llama/models/rwkv6qwen2.cpp +++ b/examples/talk-llama/models/rwkv6qwen2.cpp @@ -158,7 +158,7 @@ llama_model_rwkv6qwen2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index b205e3935..7574b2526 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -202,7 +202,7 @@ llama_model_rwkv7::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 83e114740..806cba574 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -141,7 +141,7 @@ llama_model_seed_oss::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 3214e7cba..4231cccc6 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -178,7 +178,7 @@ llama_model_smallthinker::graph::graph(const llama_model & model, const ll res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 7adaf34c5..90e7d473e 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -143,7 +143,7 @@ llama_model_smollm3::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index 8f613e559..4da7f7aef 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -163,7 +163,7 @@ llama_model_stablelm::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index 58cf0ac0e..e131af058 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -135,7 +135,7 @@ llama_model_starcoder::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index 45dae0602..9c207c028 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -148,7 +148,7 @@ llama_model_starcoder2::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/step35.cpp b/examples/talk-llama/models/step35.cpp index c4789752d..3b68e6870 100644 --- a/examples/talk-llama/models/step35.cpp +++ b/examples/talk-llama/models/step35.cpp @@ -261,7 +261,7 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/t5.cpp b/examples/talk-llama/models/t5.cpp index 27a0711ba..73e327414 100644 --- a/examples/talk-llama/models/t5.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -265,7 +265,7 @@ llama_model_t5::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/wavtokenizer-dec.cpp b/examples/talk-llama/models/wavtokenizer-dec.cpp index a873e5d2e..214fed99b 100644 --- a/examples/talk-llama/models/wavtokenizer-dec.cpp +++ b/examples/talk-llama/models/wavtokenizer-dec.cpp @@ -253,7 +253,7 @@ llama_model_wavtokenizer_dec::graph::graph(const llama_model & model, const llm_ LLM_NORM, -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index e4d111e62..d6d1c7a2e 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -126,7 +126,7 @@ llama_model_xverse::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index dc13e53f0..b02ecdc93 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -605,6 +605,136 @@ static std::vector unicode_regex_split_custom_qwen2(const std::string & return bpe_offsets; } +// Qwen3.5 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +// Compared to Qwen2, letter-runs also consume Unicode combining marks (\p{M}): [\p{L}\p{M}]+ instead of \p{L}+ +static std::vector unicode_regex_split_custom_qwen35(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos+1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?[\p{L}\p{M}]+ + if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { + if (flags.is_letter || flags.is_accent_mark || _get_flags(pos + 1).is_accent_mark || _get_flags(pos+1).is_letter) { + pos++; + while (_get_flags(pos).is_letter || _get_flags(pos).is_accent_mark) { + pos++; + } + _add_token(pos); + continue; + } + } + + // regex: \p{N} + if (flags.is_number) { + pos++; + _add_token(pos); + continue; + } + + // regex: ?[^\s\p{L}\p{M}\p{N}]+[\r\n]* + auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos+num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos+num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + template static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) { using BidirIt = typename std::basic_string::const_iterator; @@ -929,6 +1059,9 @@ static std::vector unicode_regex_split_custom(const std::string & text, } else if ( regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { bpe_offsets = unicode_regex_split_custom_qwen2(text, offsets); + } else if ( + regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + bpe_offsets = unicode_regex_split_custom_qwen35(text, offsets); } else if (regex_expr == "\\p{Han}+") { // K2's first pattern - handle all K2 patterns together bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);