From 4eec5542cefdb949dbe3cefe926a45f54b56f25a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 May 2026 10:06:32 +0300 Subject: [PATCH] spec : update common_speculative_init() [no ci] --- common/common.h | 4 ++-- common/speculative.cpp | 11 +++-------- common/speculative.h | 4 +--- examples/speculative-simple/speculative-simple.cpp | 5 +++-- tools/server/server-context.cpp | 5 +++-- 5 files changed, 12 insertions(+), 17 deletions(-) diff --git a/common/common.h b/common/common.h index a8db51cd1a..587f00d785 100644 --- a/common/common.h +++ b/common/common.h @@ -307,8 +307,8 @@ struct common_params_speculative_draft { common_params_model mparams; - // the draft context - llama_context * ctx = nullptr; + llama_context * ctx_tgt = nullptr; + llama_context * ctx_dft = nullptr; bool use_ckpt = false; diff --git a/common/speculative.cpp b/common/speculative.cpp index 36519f0ce9..a56f15d03e 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -282,7 +282,6 @@ struct common_speculative_state_draft : public common_speculative_state { auto * spec = this; auto & batch = spec->batch; - auto & ctx_tgt = spec->ctx_tgt; auto & ctx_dft = spec->ctx_dft; auto & smpl = spec->smpl; auto & prompt_dft = spec->prompt_dft; @@ -897,11 +896,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // initialization of the speculative decoding system // -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt) { - llama_context * ctx_dft = params.draft.ctx; - +common_speculative * common_speculative_init(common_params_speculative & params) { // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { @@ -964,8 +959,8 @@ common_speculative * common_speculative_init( break; case COMMON_SPECULATIVE_TYPE_DRAFT: { impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, + /* .ctx_tgt = */ params.draft.ctx_tgt, + /* .ctx_dft = */ params.draft.ctx_dft, /* .use_ckpt = */ params.draft.use_ckpt )); break; diff --git a/common/speculative.h b/common/speculative.h index 1474476317..c900ef6e12 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -14,9 +14,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt); +common_speculative * common_speculative_init(common_params_speculative & params); void common_speculative_free(common_speculative * spec); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index d3a95f2311..110dd4774f 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -105,7 +105,8 @@ int main(int argc, char ** argv) { auto cparams = common_context_params_to_llama(params_dft); ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); - params.speculative.draft.ctx = ctx_dft.get(); + params.speculative.draft.ctx_tgt = ctx_tgt; + params.speculative.draft.ctx_dft = ctx_dft.get(); } // Tokenize the prompt @@ -161,7 +162,7 @@ int main(int argc, char ** argv) { // init the speculator const auto & params_spec = params.speculative; - struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt); + struct common_speculative * spec = common_speculative_init(params.speculative); common_speculative_begin(spec, prompt_tgt); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index e922d6667a..486073ee92 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -806,7 +806,8 @@ private: auto cparams = common_context_params_to_llama(params_dft); ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); - params_base.speculative.draft.ctx = ctx_dft.get(); + params_base.speculative.draft.ctx_tgt = ctx; + params_base.speculative.draft.ctx_dft = ctx_dft.get(); params_base.speculative.draft.use_ckpt = common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; } @@ -911,7 +912,7 @@ private: // try speculative decoding if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { try { - slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); + slot.spec.reset(common_speculative_init(params_base.speculative)); } catch (const std::exception & e) { SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); }