From 78faa2b79f930fcbd9c75d86dfadd5aab9b25ac4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 May 2026 17:57:59 +0300 Subject: [PATCH] server, spec : transition to unified spec context --- common/common.h | 2 - common/speculative.cpp | 210 ++------------------------------ tools/server/server-context.cpp | 42 ++++--- 3 files changed, 37 insertions(+), 217 deletions(-) diff --git a/common/common.h b/common/common.h index 587f00d785..5a41ff9f03 100644 --- a/common/common.h +++ b/common/common.h @@ -310,8 +310,6 @@ struct common_params_speculative_draft { llama_context * ctx_tgt = nullptr; llama_context * ctx_dft = nullptr; - bool use_ckpt = false; - int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K diff --git a/common/speculative.cpp b/common/speculative.cpp index 657008ac58..4209ba8c32 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -156,43 +156,24 @@ struct common_speculative_state { virtual int32_t n_min(const common_params_speculative & params) const = 0; }; -struct common_speculative_checkpoint { - llama_pos pos_min = 0; - llama_pos pos_max = 0; - - int64_t n_tokens = 0; - - std::vector data; - - size_t size() const { - return data.size(); - } -}; - struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; - bool use_ckpt = false; - common_speculative_checkpoint ckpt; - llama_seq_id seq_id; common_sampler * smpl; - llama_batch batch; - llama_tokens prompt_dft; + llama_batch batch; common_speculative_state_draft( enum common_speculative_type type, llama_context * ctx_tgt, llama_context * ctx_dft, - bool use_ckpt, llama_seq_id seq_id) : common_speculative_state(type) , ctx_tgt(ctx_tgt) , ctx_dft(ctx_dft) - , use_ckpt(use_ckpt) , seq_id(seq_id) { batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); @@ -244,37 +225,6 @@ struct common_speculative_state_draft : public common_speculative_state { void begin(const llama_tokens & /*prompt*/) override { } - size_t create_checkpoint(int n_tokens_prompt) { - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - - ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), seq_id); - ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - ckpt.n_tokens = n_tokens_prompt; - ckpt.data.resize(checkpoint_size); - - const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != checkpoint_size) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); - } - - LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__, - ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024); - return n; - } - - size_t restore_checkpoint() { - int seq_id = 0; - LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); - const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size()); - } - llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, ckpt.pos_max + 1, -1); - - return n; - } - void draft( const common_params_speculative & params, const llama_tokens & prompt_tgt, @@ -284,152 +234,20 @@ struct common_speculative_state_draft : public common_speculative_state { auto * spec = this; - auto & batch = spec->batch; - auto & ctx_dft = spec->ctx_dft; - auto & smpl = spec->smpl; - auto & prompt_dft = spec->prompt_dft; + auto & batch = spec->batch; + auto & ctx_dft = spec->ctx_dft; + auto & smpl = spec->smpl; - auto * mem_dft = llama_get_memory(ctx_dft); - - int reuse_i = 0; // index of part to be reused in prompt_dft - int reuse_n = 0; // length of part to be reused in prompt_dft - - const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max; - - const llama_tokens & prompt_cur = prompt_tgt; - - const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); - - if (use_ckpt && i_start > 0) { - LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__); - return; - } - - // reuse as much as possible from the old draft context - // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt - for (int i = 0; i < (int) prompt_dft.size(); ++i) { - int cur = 0; - while (i_start + cur < (int) prompt_cur.size() && - i + cur < (int) prompt_dft.size() && - prompt_cur[i_start + cur] == prompt_dft[i + cur]) { - cur++; - } - - if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { - reuse_i = i; - reuse_n = cur; - } - - if (use_ckpt) { - break; - } - } - - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", - __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); - if (use_ckpt && ckpt.n_tokens > reuse_n) { - LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens); - - reuse_i = 0; - reuse_n = 0; - - ckpt = {}; - } - - result.clear(); - result.reserve(sparams.n_max); - - if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) { - llama_memory_clear(mem_dft, false); - prompt_dft.clear(); - } else { - // this happens when a previous draft has been discarded (for example, due to being too small), but the - // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { - for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { - result.push_back(prompt_dft[i]); - - if (sparams.n_max <= (int) result.size()) { - break; - } - } - - return; - } - - if (reuse_i > 0) { - GGML_ASSERT(!use_ckpt); - - bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); - if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i); - return; - } - llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); - - prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); - } - - if (reuse_n < (int) prompt_dft.size()) { - if (use_ckpt) { - if (ckpt.n_tokens > 0) { - LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); - restore_checkpoint(); - reuse_n = ckpt.n_tokens; - prompt_dft.resize(reuse_n); - } - } else { - const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1); - if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); - return; - } - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); - } - } - } - - // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); - - for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); - common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); - - prompt_dft.push_back(prompt_cur[i]); - } - - // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { - //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens); - - int ret = llama_decode(ctx_dft, batch); - if (ret != 0 && ret != 1) { - LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n", - __func__, ret, prompt_cur.size()); - } - - if (use_ckpt) { - create_checkpoint(prompt_dft.size()); - } - } - - const llama_pos n_past = prompt_dft.size(); + const llama_pos n_past = prompt_tgt.size(); LOG_DBG("%s: n_past = %d\n", __func__, n_past); common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); - - prompt_dft.push_back(id_last); - - //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); + common_batch_add (batch, id_last, n_past, { seq_id }, true); int ret = llama_decode(ctx_dft, batch); if (ret != 0 && ret != 1) { - LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", - __func__, ret, prompt_cur.size(), prompt_dft.size()); + LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); } common_sampler_reset(smpl); @@ -463,16 +281,13 @@ struct common_speculative_state_draft : public common_speculative_state { break; } - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + common_batch_add(batch, id, n_past + i + 1, { seq_id }, true); // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); if (ret != 0) { - LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", - __func__, i, ret, prompt_cur.size(), prompt_dft.size()); + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); } - - prompt_dft.push_back(id); } if (result.size() < (size_t) sparams.n_min) { @@ -962,10 +777,9 @@ common_speculative * common_speculative_init(common_params_speculative & params, break; case COMMON_SPECULATIVE_TYPE_DRAFT: { impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ params.draft.ctx_tgt, - /* .ctx_dft = */ params.draft.ctx_dft, - /* .use_ckpt = */ params.draft.use_ckpt, - /* .seq_id = */ seq_id + /* .ctx_tgt = */ params.draft.ctx_tgt, + /* .ctx_dft = */ params.draft.ctx_dft, + /* .seq_id = */ seq_id )); break; } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 48f020898e..17e39cb5bf 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -380,16 +380,7 @@ struct server_slot { } else { GGML_ASSERT(spec_i_batch.empty()); - // generate a new draft - spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled); - n_draft_total += spec_draft.size(); - - if (spec_draft.size() > (size_t) n_draft_max) { - SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max); - spec_draft.resize(n_draft_max); - } - - if (!spec_draft.empty() && use_ckpt) { + if (use_ckpt) { const auto n_tokens = prompt.tokens.size(); //const int64_t t_start = ggml_time_us(); @@ -402,6 +393,25 @@ struct server_slot { SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024); } + + // generate a new draft + spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled); + n_draft_total += spec_draft.size(); + + if (spec_draft.size() > (size_t) n_draft_max) { + SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max); + spec_draft.resize(n_draft_max); + } + + if (ctx_drft) { + const size_t n = llama_state_seq_set_data_ext(ctx_drft, spec_ckpt.data_drft.data(), spec_ckpt.data_drft.size(), this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + if (n != spec_ckpt.data_drft.size()) { + GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, spec_ckpt.pos_min, spec_ckpt.pos_max, spec_ckpt.data_drft.size(), spec_ckpt.data_drft.size(), n); + } + + llama_memory_seq_rm(llama_get_memory(ctx_drft), this->id, spec_ckpt.pos_max + 1, -1); + } } GGML_ASSERT(spec_draft.size() <= (size_t) n_draft_max); @@ -841,7 +851,6 @@ private: params_base.speculative.draft.ctx_tgt = ctx_main; params_base.speculative.draft.ctx_dft = ctx_drft.get(); - params_base.speculative.draft.use_ckpt = ctx_drft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; } std::string & mmproj_path = params_base.mmproj.path; @@ -935,6 +944,7 @@ private: slot.id = i; slot.ctx_main = ctx_main; + slot.ctx_drft = ctx_drft.get(); slot.n_ctx = n_ctx_slot; slot.mctx = mctx; @@ -2899,8 +2909,6 @@ private: } if (ctx_drft) { - SRV_WRN("%s", "processing the batch using the draft context\n"); - // note: for now, to keep things simple, synchronize the target context // TODO: revisit later on llama_synchronize(ctx_main); @@ -3086,9 +3094,9 @@ private: { const size_t n = llama_state_seq_set_data_ext(slot.ctx_main, ckpt.data_main.data(), ckpt.data_main.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { + if (n != ckpt.data_main.size()) { GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n); + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.data_main.size(), ckpt.data_main.size(), n); } llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, ckpt.pos_max + 1, -1); @@ -3096,9 +3104,9 @@ private: if (slot.ctx_drft) { const size_t n = llama_state_seq_set_data_ext(slot.ctx_drft, ckpt.data_drft.data(), ckpt.data_drft.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { + if (n != ckpt.data_drft.size()) { GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n); + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.data_drft.size(), ckpt.data_drft.size(), n); } llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, ckpt.pos_max + 1, -1);