From b3bd3bd4cccede6490e88aae6ef362b78c553140 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 May 2026 14:09:45 +0300 Subject: [PATCH] cont : clean-up --- common/speculative.cpp | 74 ++++++++++++------- common/speculative.h | 26 +++---- .../speculative-simple/speculative-simple.cpp | 5 +- tools/server/server-context.cpp | 56 ++++++-------- 4 files changed, 84 insertions(+), 77 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index da87b7aeac..e23c8467d1 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -116,11 +116,13 @@ static bool common_speculative_are_compatible( return true; } +using common_speculative_draft_params_vec = std::vector; + // state of an implementation of speculative decoding // // each implementation has a unique type and a state that is implementation-specific -// in a subclass of common_speculative_state -struct common_speculative_state { +// in a subclass of common_speculative_impl +struct common_speculative_impl { const common_speculative_type type; uint32_t n_seq; @@ -141,9 +143,9 @@ struct common_speculative_state { int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds. int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds. - common_speculative_state(common_speculative_type type, uint32_t n_seq) : type(type), n_seq(n_seq) {} + common_speculative_impl(common_speculative_type type, uint32_t n_seq) : type(type), n_seq(n_seq) {} - virtual ~common_speculative_state() = default; + virtual ~common_speculative_impl() = default; virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0; @@ -152,7 +154,7 @@ struct common_speculative_state { virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; }; -struct common_speculative_state_draft : public common_speculative_state { +struct common_speculative_state_draft : public common_speculative_impl { common_params_speculative_draft params; llama_batch batch; @@ -160,7 +162,7 @@ struct common_speculative_state_draft : public common_speculative_state { std::vector smpls; common_speculative_state_draft(const common_params_speculative & params, uint32_t n_seq) - : common_speculative_state(COMMON_SPECULATIVE_TYPE_DRAFT, n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT, n_seq) , params(params.draft) { auto * ctx_dft = this->params.ctx_dft; @@ -333,11 +335,11 @@ struct common_speculative_state_draft : public common_speculative_state { } }; -struct common_speculative_state_eagle3 : public common_speculative_state { +struct common_speculative_state_eagle3 : public common_speculative_impl { //common_params_speculative_eagle3 params; common_speculative_state_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq) - : common_speculative_state(COMMON_SPECULATIVE_TYPE_EAGLE3, n_seq) {} + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_EAGLE3, n_seq) {} void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop @@ -353,7 +355,7 @@ struct common_speculative_state_eagle3 : public common_speculative_state { }; // state of self-speculation (simple implementation, not ngram-map) -struct common_speculative_state_ngram_simple : public common_speculative_state { +struct common_speculative_state_ngram_simple : public common_speculative_impl { common_params_speculative_ngram_map params; // shared across all sequences @@ -362,7 +364,7 @@ struct common_speculative_state_ngram_simple : public common_speculative_state { common_speculative_state_ngram_simple( const common_params_speculative & params, uint32_t n_seq, common_ngram_simple_config config) - : common_speculative_state(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq) , params(params.ngram_simple) , config(config) {} @@ -388,7 +390,7 @@ struct common_speculative_state_ngram_simple : public common_speculative_state { } }; -struct common_speculative_state_ngram_map_k : public common_speculative_state { +struct common_speculative_state_ngram_map_k : public common_speculative_impl { common_params_speculative_ngram_map params; // n_seq configs @@ -398,7 +400,7 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state { const common_params_speculative & params, const common_ngram_map & config, uint32_t n_seq) - : common_speculative_state(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq) , params(params.ngram_map_k) { for (uint32_t i = 0; i < n_seq; i++) { this->config.push_back(config); @@ -431,7 +433,7 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state { } }; -struct common_speculative_state_ngram_mod : public common_speculative_state { +struct common_speculative_state_ngram_mod : public common_speculative_impl { common_params_speculative_ngram_mod params; // shared across all sequences @@ -456,7 +458,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { common_speculative_state_ngram_mod( const common_params_speculative & params, uint32_t n_seq) - : common_speculative_state(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq) , params(params.ngram_mod) , mod(params.ngram_mod.n_match, 4*1024*1024) , verbose(std::getenv("LLAMA_TRACE") != nullptr) { @@ -507,9 +509,11 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { auto & sinfo = sinfos[seq_id]; auto & result = *dparams.result; + const auto & prompt = *dparams.prompt; + sinfo.n_draft_last = 0; - const size_t cur_len = dparams.prompt->size(); + const size_t cur_len = prompt.size(); if (cur_len < mod.get_n()) { return; } @@ -519,7 +523,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { // add new ngrams in chunks if (sinfo.i_last + 32 < cur_len) { for (size_t i = sinfo.i_last; i < cur_len - n; ++i) { - mod.add(dparams.prompt->data() + i); + mod.add(prompt.data() + i); } sinfo.i_last = cur_len - n; @@ -527,7 +531,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { result.resize(n + params.n_max); for (size_t i = 0; i < n - 1; ++i) { - result[i] = dparams.prompt->at(cur_len - n + 1 + i); + result[i] = prompt.at(cur_len - n + 1 + i); } result[n - 1] = dparams.id_last; @@ -592,7 +596,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { } }; -struct common_speculative_state_ngram_cache : public common_speculative_state { +struct common_speculative_state_ngram_cache : public common_speculative_impl { common_params_speculative_ngram_cache params; uint16_t n_draft; @@ -618,7 +622,7 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { const std::string & path_dynamic, bool save_dynamic, bool save_static) - : common_speculative_state(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, n_seq) , params(params.ngram_cache) , n_draft(n_draft) , save_dynamic(save_dynamic) @@ -721,11 +725,13 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { }; struct common_speculative { + common_speculative_draft_params_vec dparams; + // list of implementations to use and their states - std::vector> impls; + std::vector> impls; // which implementaion was used for a given seq_id - std::vector impl_last; + std::vector impl_last; }; static common_ngram_map get_common_ngram_map( @@ -830,7 +836,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, } } - std::vector> impls = {}; + std::vector> impls = {}; for (const common_speculative_config & config : configs) { LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); @@ -894,8 +900,9 @@ common_speculative * common_speculative_init(common_params_speculative & params, } auto * result = new common_speculative { + /* .dparams = */ common_speculative_draft_params_vec(n_seq), /* .impls = */ std::move(impls), - /* .impl_last = */ std::vector(n_seq, nullptr) + /* .impl_last = */ std::vector(n_seq, nullptr) }; return result; @@ -909,6 +916,15 @@ void common_speculative_free(common_speculative * spec) { delete spec; } +common_speculative_draft_params & common_speculative_get_draft_params( + common_speculative * spec, + llama_seq_id seq_id) { + GGML_ASSERT(spec); + GGML_ASSERT(seq_id < (llama_seq_id) spec->dparams.size()); + + return spec->dparams[seq_id]; +} + void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt) { if (spec == nullptr) { return; @@ -921,9 +937,13 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co } } -void common_speculative_draft( - common_speculative * spec, - common_speculative_draft_params_vec & dparams) { +void common_speculative_draft(common_speculative * spec) { + if (!spec) { + return; + } + + auto & dparams = spec->dparams; + { int n_drafting = 0; @@ -994,7 +1014,7 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u return; } - common_speculative_state * impl = spec->impl_last[seq_id]; + common_speculative_impl * impl = spec->impl_last[seq_id]; GGML_ASSERT(impl); diff --git a/common/speculative.h b/common/speculative.h index 981c3e5697..a46fe2aad4 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -18,19 +18,11 @@ common_speculative * common_speculative_init(common_params_speculative & params, void common_speculative_free(common_speculative * spec); -// optionally call once at the beginning of a new generation -// TODO: when common_speculative_process() is implemented, we can remove this _begin() function and -// implement all the logic within common_speculative_process() -void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt); - -// TODO: implement [TAG_COMMON_SPECULATIVE_PROCESS] -//bool common_speculative_process(common_speculative * spec, const llama_batch & batch); - struct common_speculative_draft_params { - // this flag helps chain the drafts through all the implementations + // this flag is used to chain the drafts through all the available implementations // after the first successful draft from an implementation, we set it // to false to prevent further drafts for that sequence - bool drafting = true; + bool drafting = false; // overrides individual configurations (-1 disabled) // can be used to constraint the max draft based on the remaining context size @@ -39,18 +31,24 @@ struct common_speculative_draft_params { llama_pos n_past; llama_token id_last; + // TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls const llama_tokens * prompt; + // the generated draft from the last _draft() call llama_tokens * result; }; -using common_speculative_draft_params_vec = std::vector; +common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id); + +// optionally call once at the beginning of a new generation +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt); + +// TODO: implement [TAG_COMMON_SPECULATIVE_PROCESS] +//bool common_speculative_process(common_speculative * spec, const llama_batch & batch); // generate drafts for the sequences specified in dparams // requires that `dparams.size() == n_seq` using during common_speculative_init() -void common_speculative_draft( - common_speculative * spec, - common_speculative_draft_params_vec & dparams); +void common_speculative_draft(common_speculative * spec); // informs the speculative decoder that n_accepted tokens were accepted by the target model void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 4c81f3fd90..5325bcc9e3 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -179,8 +179,7 @@ int main(int argc, char ** argv) { } // generate a new draft - common_speculative_draft_params_vec dparams(1); - dparams[seq_id] = { + common_speculative_get_draft_params(spec, seq_id) = { /* .drafting = */ true, /* .n_max = */ -1, /* .n_past = */ n_past, @@ -188,7 +187,7 @@ int main(int argc, char ** argv) { /* .prompt = */ &prompt_tgt, /* .result = */ &draft, // output }; - common_speculative_draft(spec, dparams); + common_speculative_draft(spec); // save the original draft size n_draft = draft.size(); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index f2be6d451c..ae1de2c04a 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -616,7 +616,6 @@ private: common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; common_speculative_ptr spec; - common_speculative_draft_params_vec spec_dparams; bool add_bos_token = true; @@ -838,7 +837,6 @@ private: if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { try { spec.reset(common_speculative_init(params_base.speculative, params_base.n_parallel)); - spec_dparams.resize(params_base.n_parallel); } catch (const std::exception & e) { SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); } @@ -2186,10 +2184,11 @@ private: // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; - // first, process slots that are speculative decoding - for (auto & slot : slots) { - spec_dparams[slot.id].drafting = false; + std::vector generating; + std::vector drafting; + // determine which slots are generating and drafting + for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { continue; } @@ -2201,6 +2200,12 @@ private: continue; } + generating.push_back(&slot); + + if (spec) { + common_speculative_get_draft_params(spec.get(), slot.id).drafting = false; + } + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; @@ -2226,34 +2231,30 @@ private: slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); } - spec_dparams[slot.id] = common_speculative_draft_params { + slot.spec_prompt = slot.prompt.tokens.get_text_tokens(); + + common_speculative_get_draft_params(spec.get(), slot.id) = { /* .drafting = */ true, /* .n_max = */ n_draft_max, /* .n_past = */ slot.prompt.n_tokens(), /* .id_last = */ slot.sampled, - /* .prompt = */ nullptr, + /* .prompt = */ &slot.spec_prompt, /* .result = */ &slot.spec_draft, }; + + drafting.push_back(&slot); } } } // generate the actual drafts (if any) { - for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) spec_dparams.size(); seq_id++) { - auto & slot = slots[seq_id]; - auto & dp = spec_dparams[seq_id]; - - slot.spec_prompt = slot.prompt.tokens.get_text_tokens(); - - dp.prompt = &slot.spec_prompt; - } - - common_speculative_draft(spec.get(), spec_dparams); + common_speculative_draft(spec.get()); } - for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) spec_dparams.size(); seq_id++) { - auto & slot = slots[seq_id]; + // make checkpoints if needed + for (auto * slot_ptr : drafting) { + auto & slot = *slot_ptr; slot.n_draft_total += slot.spec_draft.size(); @@ -2282,20 +2283,9 @@ private: } } - slot_batched = nullptr; - - // add the speculative drafts to the batch, or simply add the sampled tokens - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - // check if we can batch this slot with the previous one - if (!slot_batched) { - slot_batched = &slot; - } else if (!slot_batched->can_batch_with(slot)) { - continue; - } + // update the batch with the sampled/drafted tokens + for (auto * slot_ptr : generating) { + auto & slot = *slot_ptr; slot.update_batch(batch); }