From ec8bc44854651ee05faf6b2b88a2df24c2886223 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 May 2026 15:28:29 +0300 Subject: [PATCH] cont : minor --- common/speculative.cpp | 2 +- common/speculative.h | 5 +- src/llama-context.cpp | 1 + tools/server/server-context.cpp | 88 +++++++++++++++++---------------- 4 files changed, 50 insertions(+), 46 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index e23c8467d1..8c94de212f 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -938,7 +938,7 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co } void common_speculative_draft(common_speculative * spec) { - if (!spec) { + if (spec == nullptr) { return; } diff --git a/common/speculative.h b/common/speculative.h index a46fe2aad4..a5337d4dc6 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -46,11 +46,10 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co // TODO: implement [TAG_COMMON_SPECULATIVE_PROCESS] //bool common_speculative_process(common_speculative * spec, const llama_batch & batch); -// generate drafts for the sequences specified in dparams -// requires that `dparams.size() == n_seq` using during common_speculative_init() +// generate drafts for the sequences specified with `common_speculative_get_draft_params` void common_speculative_draft(common_speculative * spec); -// informs the speculative decoder that n_accepted tokens were accepted by the target model +// informs the speculative context that n_accepted tokens were accepted by the target model void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); // print statistics about the speculative decoding diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a8e1e6ba66..3d9714ab16 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2484,6 +2484,7 @@ public: } else { //LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + // save the old buffer and allocate the new tensors in it auto buf = std::move(mbuf_cur.buf); mbuf_cur = std::move(mbuf); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ae1de2c04a..aefaa16ad7 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2204,45 +2204,45 @@ private: if (spec) { common_speculative_get_draft_params(spec.get(), slot.id).drafting = false; - } - const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; - const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; - const int n_draft_max = slot.get_n_draft_max(); + const int n_draft_max = slot.get_n_draft_max(); - if (n_draft_max > 0) { - GGML_ASSERT(slot.can_speculate()); + if (n_draft_max > 0) { + GGML_ASSERT(slot.can_speculate()); - if (!slot.spec_draft.empty()) { - // we have a previous (partial) draft to reuse - if (use_ckpt_tgt) { - GGML_ASSERT(!slot.spec_ckpt.empty()); + if (!slot.spec_draft.empty()) { + // we have a previous (partial) draft to reuse + if (use_ckpt_tgt) { + GGML_ASSERT(!slot.spec_ckpt.empty()); + } + } else { + GGML_ASSERT(slot.spec_i_batch.empty()); + + slot.spec_ckpt.update_pos( + slot.prompt.n_tokens(), + llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id), + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id)); + + if (use_ckpt_dft) { + slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } + + slot.spec_prompt = slot.prompt.tokens.get_text_tokens(); + + common_speculative_get_draft_params(spec.get(), slot.id) = { + /* .drafting = */ true, + /* .n_max = */ n_draft_max, + /* .n_past = */ slot.prompt.n_tokens(), + /* .id_last = */ slot.sampled, + /* .prompt = */ &slot.spec_prompt, + /* .result = */ &slot.spec_draft, + }; + + drafting.push_back(&slot); } - } else { - GGML_ASSERT(slot.spec_i_batch.empty()); - - slot.spec_ckpt.update_pos( - slot.prompt.n_tokens(), - llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id), - llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id)); - - if (use_ckpt_dft) { - slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - } - - slot.spec_prompt = slot.prompt.tokens.get_text_tokens(); - - common_speculative_get_draft_params(spec.get(), slot.id) = { - /* .drafting = */ true, - /* .n_max = */ n_draft_max, - /* .n_past = */ slot.prompt.n_tokens(), - /* .id_last = */ slot.sampled, - /* .prompt = */ &slot.spec_prompt, - /* .result = */ &slot.spec_draft, - }; - - drafting.push_back(&slot); } } } @@ -2256,29 +2256,33 @@ private: for (auto * slot_ptr : drafting) { auto & slot = *slot_ptr; - slot.n_draft_total += slot.spec_draft.size(); + auto & draft = slot.spec_draft; + auto & ckpt = slot.spec_ckpt; + + slot.n_draft_total += draft.size(); // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] if (ctx_dft) { - slot.spec_ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, slot.spec_ckpt.pos_max + 1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, ckpt.pos_max + 1, -1); } - const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + if (!draft.empty()) { + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; - if (!slot.spec_draft.empty()) { if (use_ckpt_tgt) { //const int64_t t_start = ggml_time_us(); - slot.spec_ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); //const int64_t t_total = ggml_time_us() - t_start; //printf("checkpoint total: %f ms\n", t_total / 1000.0); SLT_DBG(slot, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n", - slot.spec_ckpt.pos_min, slot.spec_ckpt.pos_max, slot.prompt.n_tokens(), - (float) slot.spec_ckpt.size() / 1024 / 1024, (float) slot.spec_ckpt.data_dft.size() / 1024 / 1024); + ckpt.pos_min, ckpt.pos_max, slot.prompt.n_tokens(), + (float) ckpt.size() / 1024 / 1024, + (float) ckpt.data_dft.size() / 1024 / 1024); } } }