From 0239f4c611c5eaf00c1d4defe0f463e61343ecdb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 May 2026 21:10:03 +0300 Subject: [PATCH] cont : handle non-ckpt models --- tools/server/server-context.cpp | 67 ++++++++++++++++++++++----------- tools/server/server-task.cpp | 2 +- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 1d32559768..5d528f576c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -364,7 +364,7 @@ struct server_slot { return n_draft_max; } - void update_batch(llama_batch & batch, bool use_ckpt) { + void update_batch(llama_batch & batch, bool use_ckpt_main, bool use_ckpt_drft) { const int n_draft_max = get_n_draft_max(); if (n_draft_max > 0) { GGML_ASSERT(can_speculate()); @@ -378,13 +378,15 @@ struct server_slot { if (!spec_draft.empty()) { // we have a previous (partial) draft to reuse - if (use_ckpt) { + if (use_ckpt_main) { GGML_ASSERT(!spec_ckpt.empty()); } } else { GGML_ASSERT(spec_i_batch.empty()); - if (use_ckpt) { + spec_ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_main), id); + + if (use_ckpt_drft) { const auto n_tokens = prompt.tokens.size(); server_prompt_checkpoint_update(spec_ckpt, nullptr, ctx_drft, this->id, n_tokens, true); @@ -402,26 +404,30 @@ struct server_slot { spec_draft.resize(n_draft_max); } - if (!spec_draft.empty() && use_ckpt) { - const auto n_tokens = prompt.tokens.size(); + if (!spec_draft.empty()) { + if (use_ckpt_main) { + const auto n_tokens = prompt.tokens.size(); - //const int64_t t_start = ggml_time_us(); + //const int64_t t_start = ggml_time_us(); - server_prompt_checkpoint_update(spec_ckpt, ctx_main, nullptr, this->id, n_tokens, true); + server_prompt_checkpoint_update(spec_ckpt, ctx_main, nullptr, this->id, n_tokens, true); - //const int64_t t_total = ggml_time_us() - t_start; - //printf("checkpoint total: %f ms\n", t_total / 1000.0); + //const int64_t t_total = ggml_time_us() - t_start; + //printf("checkpoint total: %f ms\n", t_total / 1000.0); - SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", - spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024); + SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", + spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024); + } } // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] if (ctx_drft) { - const size_t n = llama_state_seq_set_data_ext(ctx_drft, spec_ckpt.data_drft.data(), spec_ckpt.data_drft.size(), this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != spec_ckpt.data_drft.size()) { - GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", - __func__, spec_ckpt.pos_min, spec_ckpt.pos_max, spec_ckpt.data_drft.size(), spec_ckpt.data_drft.size(), n); + if (use_ckpt_drft) { + const size_t n = llama_state_seq_set_data_ext(ctx_drft, spec_ckpt.data_drft.data(), spec_ckpt.data_drft.size(), this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + if (n != spec_ckpt.data_drft.size()) { + GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, spec_ckpt.pos_min, spec_ckpt.pos_max, spec_ckpt.data_drft.size(), spec_ckpt.data_drft.size(), n); + } } llama_memory_seq_rm(llama_get_memory(ctx_drft), this->id, spec_ckpt.pos_max + 1, -1); @@ -958,6 +964,8 @@ private: slots.emplace_back(); } + bool no_drft = false; + for (int i = 0; i < params_base.n_parallel; i++) { server_slot & slot = slots[i]; @@ -975,6 +983,8 @@ private: slot.spec.reset(common_speculative_init(params_base.speculative, slot.id)); } catch (const std::exception & e) { SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); + + no_drft = true; } if (slot.spec) { @@ -991,6 +1001,12 @@ private: slot.reset(); } + if (no_drft && ctx_drft) { + SRV_WRN("%s", "destroying the draft model as it is not going to be used\n"); + + ctx_drft.reset(); + } + { const char * LLAMA_TRACE = getenv("LLAMA_TRACE"); trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0; @@ -2320,7 +2336,9 @@ private: continue; } - slot.update_batch(batch, ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + slot.update_batch(batch, + ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL, + ctx_drft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); } // process in chunks of params.n_batch @@ -3100,11 +3118,12 @@ private: // verify and try to accept the draft { - const bool use_ckpt = ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_main = ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_drft = ctx_drft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; // only save the sampler sampler state if we use checkpoints common_sampler_ptr smpl_save; - if (use_ckpt) { + if (use_ckpt_main) { smpl_save.reset(common_sampler_clone(slot.smpl.get())); } @@ -3116,7 +3135,7 @@ private: // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (use_ckpt) { + if (use_ckpt_main) { if (trace > 0) { SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); } @@ -3140,10 +3159,12 @@ private: } if (slot.ctx_drft) { - const size_t n = llama_state_seq_set_data_ext(slot.ctx_drft, ckpt.data_drft.data(), ckpt.data_drft.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.data_drft.size()) { - GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.data_drft.size(), ckpt.data_drft.size(), n); + if (use_ckpt_drft) { + const size_t n = llama_state_seq_set_data_ext(slot.ctx_drft, ckpt.data_drft.data(), ckpt.data_drft.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + if (n != ckpt.data_drft.size()) { + GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.data_drft.size(), ckpt.data_drft.size(), n); + } } llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, ckpt.pos_max + 1, -1); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index b002eab98f..809ae345e9 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -2101,7 +2101,7 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok data.clear(); data.shrink_to_fit(); } - } + } prompt = std::move(*it_best);