#include "llama-kv-cache-dsa.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-model.h" #include #include // // llama_kv_cache_dsa // llama_kv_cache_dsa::llama_kv_cache_dsa( const llama_model & model, ggml_type type_k, ggml_type type_v, bool v_trans, bool offload, bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, const layer_filter_cb & filter, const layer_reuse_cb & reuse) : hparams_lid(model.hparams), n_stream(unified ? 1 : n_seq_max) { LLAMA_LOG_INFO("%s: creating main KV cache, size = %u cells\n", __func__, kv_size); kv_mla = std::make_unique( model, model.hparams, type_k, type_v, v_trans, offload, unified, kv_size, n_seq_max, n_pad, n_swa, swa_type, nullptr, filter, reuse, nullptr); // we use llama_kv_cache for caching indexer keys // by hand-tweaking some hparams we fool it to create // indexer key cache tensors with correct dimensions // https://github.com/ggml-org/llama.cpp/pull/21149#discussion_r3015940823 // DSA lightning indexer uses MQA with single key head std::fill(hparams_lid.n_head_kv_arr.begin(), hparams_lid.n_head_kv_arr.end(), 1); hparams_lid.n_embd_head_k_full = model.hparams.indexer_head_size; hparams_lid.rope_type = LLAMA_ROPE_TYPE_NEOX; LLAMA_LOG_INFO("%s: creating indexer KV cache, size = %u cells\n", __func__, kv_size); kv_lid = std::make_unique( model, hparams_lid, type_k, type_v, v_trans, offload, unified, kv_size, n_seq_max, n_pad, n_swa, swa_type, nullptr, filter, reuse, nullptr); } void llama_kv_cache_dsa::clear(bool data) { kv_mla->clear(data); kv_lid->clear(data); } bool llama_kv_cache_dsa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { bool res = true; res = res & kv_mla->seq_rm(seq_id, p0, p1); res = res & kv_lid->seq_rm(seq_id, p0, p1); return res; } void llama_kv_cache_dsa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { kv_mla->seq_cp(seq_id_src, seq_id_dst, p0, p1); kv_lid->seq_cp(seq_id_src, seq_id_dst, p0, p1); } void llama_kv_cache_dsa::seq_keep(llama_seq_id seq_id) { kv_mla->seq_keep(seq_id); kv_lid->seq_keep(seq_id); } void llama_kv_cache_dsa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { kv_mla->seq_add(seq_id, p0, p1, shift); kv_lid->seq_add(seq_id, p0, p1, shift); } void llama_kv_cache_dsa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { kv_mla->seq_div(seq_id, p0, p1, d); kv_lid->seq_div(seq_id, p0, p1, d); } llama_pos llama_kv_cache_dsa::seq_pos_min(llama_seq_id seq_id) const { return kv_mla->seq_pos_min(seq_id); } llama_pos llama_kv_cache_dsa::seq_pos_max(llama_seq_id seq_id) const { return kv_mla->seq_pos_max(seq_id); } std::map llama_kv_cache_dsa::memory_breakdown() const { std::map mb = kv_mla->memory_breakdown(); for (const auto & buft_size : kv_lid->memory_breakdown()) { mb[buft_size.first] += buft_size.second; } return mb; } llama_memory_context_ptr llama_kv_cache_dsa::init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { GGML_UNUSED(embd_all); do { balloc.split_reset(); std::vector ubatches; while (true) { auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); if (ubatch.n_tokens == 0) { break; } ubatches.push_back(std::move(ubatch)); // NOLINT } if (balloc.get_n_used() < balloc.get_n_tokens()) { // failed to find a suitable split break; } auto sinfos_mla = kv_mla->prepare(ubatches); if (sinfos_mla.empty()) { break; } auto sinfos_lid = kv_lid->prepare(ubatches); if (sinfos_lid.empty()) { break; } assert(sinfos_mla.size() == sinfos_lid.size()); return std::make_unique( this, std::move(sinfos_mla), std::move(sinfos_lid), std::move(ubatches)); } while (false); return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } llama_memory_context_ptr llama_kv_cache_dsa::init_full() { return std::make_unique(this); } llama_memory_context_ptr llama_kv_cache_dsa::init_update(llama_context * lctx, bool optimize) { return std::make_unique(this, lctx, optimize); } bool llama_kv_cache_dsa::get_can_shift() const { return kv_mla->get_can_shift() && kv_lid->get_can_shift() && kv_mla->get_size() == kv_lid->get_size(); } void llama_kv_cache_dsa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { kv_mla->state_write(io, seq_id, flags); kv_lid->state_write(io, seq_id, flags); } void llama_kv_cache_dsa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { kv_mla->state_read(io, seq_id, flags); kv_lid->state_read(io, seq_id, flags); } llama_kv_cache * llama_kv_cache_dsa::get_mla() const { return kv_mla.get(); } llama_kv_cache * llama_kv_cache_dsa::get_lid() const { return kv_lid.get(); } // // llama_kv_cache_dsa_context // llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(llama_memory_status status) : status(status) {} llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( llama_kv_cache_dsa * kv) : ctx_mla(kv->get_mla()->init_full()), ctx_lid(kv->get_lid()->init_full()), status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { } llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( llama_kv_cache_dsa * kv, llama_context * lctx, bool optimize) : ctx_mla(kv->get_mla()->init_update(lctx, optimize)), ctx_lid(kv->get_lid()->init_update(lctx, optimize)), status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { } llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( llama_kv_cache_dsa * kv, slot_info_vec_t sinfos_mla, slot_info_vec_t sinfos_lid, std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal ctx_mla(new llama_kv_cache_context(kv->get_mla(), std::move(sinfos_mla), this->ubatches)), ctx_lid(new llama_kv_cache_context(kv->get_lid(), std::move(sinfos_lid), this->ubatches)), status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { } llama_kv_cache_dsa_context:: ~llama_kv_cache_dsa_context() = default; bool llama_kv_cache_dsa_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); ctx_mla->next(); ctx_lid->next(); if (++i_next >= ubatches.size()) { return false; } return true; } bool llama_kv_cache_dsa_context::apply() { assert(!llama_memory_status_is_fail(status)); bool res = true; res = res & ctx_mla->apply(); res = res & ctx_lid->apply(); return res; } llama_memory_status llama_kv_cache_dsa_context::get_status() const { return status; } const llama_ubatch & llama_kv_cache_dsa_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_next]; } const llama_kv_cache_context * llama_kv_cache_dsa_context::get_mla() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return static_cast(ctx_mla.get()); } const llama_kv_cache_context * llama_kv_cache_dsa_context::get_lid() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return static_cast(ctx_lid.get()); }