From ab8875567ce8bfa11ebccfc6567ad60cc07efc9c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 May 2026 10:14:18 +0300 Subject: [PATCH] cont : pass seq_id [no ci] --- common/speculative.cpp | 26 +++++++++++-------- common/speculative.h | 2 +- .../speculative-simple/speculative-simple.cpp | 2 +- tools/server/server-context.cpp | 2 +- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index a56f15d03e..657008ac58 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -176,6 +176,8 @@ struct common_speculative_state_draft : public common_speculative_state { bool use_ckpt = false; common_speculative_checkpoint ckpt; + llama_seq_id seq_id; + common_sampler * smpl; llama_batch batch; @@ -185,11 +187,13 @@ struct common_speculative_state_draft : public common_speculative_state { enum common_speculative_type type, llama_context * ctx_tgt, llama_context * ctx_dft, - bool use_ckpt) + bool use_ckpt, + llama_seq_id seq_id) : common_speculative_state(type) , ctx_tgt(ctx_tgt) , ctx_dft(ctx_dft) , use_ckpt(use_ckpt) + , seq_id(seq_id) { batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); smpl = nullptr; @@ -241,15 +245,14 @@ struct common_speculative_state_draft : public common_speculative_state { } size_t create_checkpoint(int n_tokens_prompt) { - int slot_id = 0; - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id); - ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id); + ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), seq_id); + ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); ckpt.n_tokens = n_tokens_prompt; ckpt.data.resize(checkpoint_size); - const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); if (n != checkpoint_size) { GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); } @@ -260,14 +263,14 @@ struct common_speculative_state_draft : public common_speculative_state { } size_t restore_checkpoint() { - int slot_id = 0; + int seq_id = 0; LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); - const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); if (n != ckpt.size()) { GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu", __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size()); } - llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, ckpt.pos_max + 1, -1); return n; } @@ -896,7 +899,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // initialization of the speculative decoding system // -common_speculative * common_speculative_init(common_params_speculative & params) { +common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id) { // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { @@ -961,7 +964,8 @@ common_speculative * common_speculative_init(common_params_speculative & params) impls.push_back(std::make_unique(config.type, /* .ctx_tgt = */ params.draft.ctx_tgt, /* .ctx_dft = */ params.draft.ctx_dft, - /* .use_ckpt = */ params.draft.use_ckpt + /* .use_ckpt = */ params.draft.use_ckpt, + /* .seq_id = */ seq_id )); break; } diff --git a/common/speculative.h b/common/speculative.h index c900ef6e12..3b6211a223 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -14,7 +14,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); -common_speculative * common_speculative_init(common_params_speculative & params); +common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id); void common_speculative_free(common_speculative * spec); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 110dd4774f..da1fb56d19 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -162,7 +162,7 @@ int main(int argc, char ** argv) { // init the speculator const auto & params_spec = params.speculative; - struct common_speculative * spec = common_speculative_init(params.speculative); + struct common_speculative * spec = common_speculative_init(params.speculative, 0); common_speculative_begin(spec, prompt_tgt); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 486073ee92..79f0e1df75 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -912,7 +912,7 @@ private: // try speculative decoding if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { try { - slot.spec.reset(common_speculative_init(params_base.speculative)); + slot.spec.reset(common_speculative_init(params_base.speculative, slot.id)); } catch (const std::exception & e) { SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); }