#include "speculative.h" #include "common.h" #include "ggml.h" #include "llama.h" #include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP) #include "log.h" #include "ngram-cache.h" #include "ngram-map.h" #include "ngram-mod.h" #include "sampling.h" #include #include #include #include #include #include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 const std::vector common_speculative_types = { COMMON_SPECULATIVE_TYPE_NONE, COMMON_SPECULATIVE_TYPE_DRAFT, COMMON_SPECULATIVE_TYPE_EAGLE3, COMMON_SPECULATIVE_TYPE_MTP, COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, COMMON_SPECULATIVE_TYPE_NGRAM_MOD, COMMON_SPECULATIVE_TYPE_NGRAM_CACHE }; const std::map common_speculative_type_from_name_map = { {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, {"mtp", COMMON_SPECULATIVE_TYPE_MTP}, {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, {"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} }; struct common_speculative_config { common_speculative_type type; common_params_speculative params; common_speculative_config(common_speculative_type t, const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {} }; static bool common_speculative_are_compatible( const llama_model * model_tgt, const llama_model * model_dft) { const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); const bool vocab_type_dft = llama_vocab_type(vocab_dft); LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); if (vocab_type_tgt != vocab_type_dft) { LOG_WRN("%s: draft model vocab type must match target model to use speculation but " "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); return false; } if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || (llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) { LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n", __func__, llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft), llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft)); return false; } if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || (llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) { LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n", __func__, llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft), llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft)); return false; } { const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); const int vocab_diff = n_vocab_tgt > n_vocab_dft ? n_vocab_tgt - n_vocab_dft : n_vocab_dft - n_vocab_tgt; if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__); LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return false; } for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__); LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i, common_token_to_piece(vocab_tgt, i).c_str(), common_token_to_piece(vocab_dft, i).c_str()); return false; } } } return true; } using common_speculative_draft_params_vec = std::vector; // state of an implementation of speculative decoding // // each implementation has a unique type and a state that is implementation-specific // in a subclass of common_speculative_impl struct common_speculative_impl { const common_speculative_type type; uint32_t n_seq; size_t n_call_begin = 0; // number of times this implementation was called for refresh. size_t n_call_draft = 0; // number of times this implementation was called for generation. size_t n_call_accept = 0; // number of times this implementation was called for accumulation. size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation. size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model. size_t n_gen_tokens = 0; // number of tokens generated by this implementation. size_t n_acc_tokens = 0; // number of tokens accepted by the target model. // TODO: track performance of most recent calls const bool gen_perf = true; // whether to generate performance stats. int64_t t_begin_us = 0; // total time spent in refresh of this implementation in microseconds. int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds. int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds. common_speculative_impl(common_speculative_type type, uint32_t n_seq) : type(type), n_seq(n_seq) {} virtual ~common_speculative_impl() = default; virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0; virtual bool process(const llama_batch & batch) = 0; virtual void draft(common_speculative_draft_params_vec & dparams) = 0; virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; }; struct common_speculative_state_draft : public common_speculative_impl { common_params_speculative_draft params; llama_batch batch; std::vector smpls; common_speculative_state_draft(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT, n_seq) , params(params.draft) { auto * ctx_dft = this->params.ctx_dft; auto * ctx_tgt = this->params.ctx_tgt; batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); // TODO: optimize or pass from outside? // { // common_params_sampling params; // params.no_perf = false; // // params.top_k = 40; // params.top_p = 0.9; // // params.samplers = { // COMMON_SAMPLER_TYPE_TOP_K, // COMMON_SAMPLER_TYPE_TOP_P, // COMMON_SAMPLER_TYPE_INFILL, // }; // // result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); // } smpls.resize(n_seq); for (auto & smpl : smpls) { common_params_sampling params; params.no_perf = false; params.top_k = 10; params.samplers = { COMMON_SAMPLER_TYPE_TOP_K, }; smpl.reset(common_sampler_init(llama_get_model(ctx_dft), params)); } const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt); if (!vocab_cmpt) { LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__); throw std::runtime_error("draft model vocab type must match target model to use speculation"); } if (n_seq != llama_n_seq_max(ctx_dft)) { LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft)); throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq"); } } ~common_speculative_state_draft() override { llama_batch_free(batch); } void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop } bool process(const llama_batch & batch) override { auto * ctx_dft = params.ctx_dft; const int ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret); return false; } return true; } void draft(common_speculative_draft_params_vec & dparams) override { auto & ctx_dft = params.ctx_dft; common_batch_clear(batch); // keep track of which sequences are still drafting int n_drafting = 0; std::vector drafting(n_seq); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } n_drafting++; drafting[seq_id] = true; common_sampler_reset(smpls[seq_id].get()); common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); } int ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); return; } int i = 0; while (n_drafting > 0) { int i_batch = 0; common_batch_clear(batch); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { if (!drafting[seq_id]) { continue; } auto * smpl = smpls[seq_id].get(); common_sampler_sample(smpl, ctx_dft, i_batch, true); ++i_batch; const auto * cur_p = common_sampler_get_candidates(smpl, true); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); } // add drafted token for each sequence const llama_token id = cur_p->data[0].id; // only collect very high-confidence draft tokens if (cur_p->data[0].p < params.p_min) { drafting[seq_id] = false; n_drafting--; continue; } common_sampler_accept(smpl, id, true); auto & dp = dparams.at(seq_id); auto & result = *dp.result; result.push_back(id); if ((params.n_max <= (int) result.size()) || (dp.n_max > 0 && dp.n_max <= (int) result.size())) { drafting[seq_id] = false; n_drafting--; continue; } common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); } if (batch.n_tokens == 0) { break; } // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); break; } ++i; } for (auto & dp : dparams) { if (!dp.drafting) { continue; } if (dp.result->size() < (size_t) params.n_min) { dp.result->clear(); } } } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } }; struct common_speculative_state_eagle3 : public common_speculative_impl { //common_params_speculative_eagle3 params; common_speculative_state_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_EAGLE3, n_seq) {} void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop } bool process(const llama_batch & /*batch*/) override { // TODO: implement return true; } void draft(common_speculative_draft_params_vec & /*dparams*/) override { // TODO: implement } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } }; struct common_speculative_state_mtp : public common_speculative_impl { common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft) llama_batch batch; std::vector smpls; int32_t n_embd = 0; // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. // The last h-row of one process() call needs the first token of the NEXT // call to pair with, so it's stashed here until that next call fires. std::vector> pending_h; // [n_seq][n_embd] std::vector pending_pos; // [n_seq] std::vector last_n_drafted; std::vector last_n_accepted; // Number of trunk output rows produced by the most recent process() call. // Used by draft() for the first AR step (when last_n_accepted is -1) to // pick the last prefill row out of ctx_tgt's pre-norm buffer. std::vector last_trunk_n_outputs; common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq) , params(params.draft) { GGML_ASSERT(n_seq == 1 && "MTP currently supports only single-sequence speculation"); auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); const int32_t n_b = (int32_t) llama_n_batch(ctx_dft); batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1); // llama_batch_init allocates only one of token/embd; MTP needs both. // TODO: fix, how to call without malloc batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b); smpls.resize(n_seq); for (auto & s : smpls) { common_params_sampling sparams; sparams.no_perf = false; sparams.top_k = 1; sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } llama_set_embeddings_pre_norm(ctx_tgt, true); llama_set_embeddings_pre_norm(ctx_dft, true); pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); pending_pos.assign(n_seq, -1); last_n_drafted.assign(n_seq, 0); last_n_accepted.assign(n_seq, -1); last_trunk_n_outputs.assign(n_seq, 0); } ~common_speculative_state_mtp() override { if (batch.token != nullptr) { free(batch.token); batch.token = nullptr; } llama_batch_free(batch); } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < pending_pos.size()); last_n_accepted[seq_id] = -1; last_n_drafted [seq_id] = 0; pending_pos [seq_id] = -1; const int32_t N = (int32_t) prompt.size(); if (N <= 0) { return; } auto * ctx_dft = this->params.ctx_dft; const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); if (pos_max < N - 1) { LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — " "process() hook may not have run on every prefill ubatch " "(need_embd / logits=1 on every prompt position?). " "Drafts may degrade.\n", __func__, (int) pos_max, N - 1); } } bool process(const llama_batch & batch_in) override { if (batch_in.n_tokens <= 0) { return true; } // Single-seq for now (asserted in ctor). Future: bucket by seq_id. const llama_seq_id seq_id = 0; // TODO: how to make it work with vision tokens? if (batch_in.token == nullptr || batch_in.embd != nullptr) { pending_pos[seq_id] = -1; return true; } auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; const int32_t n_rows = batch_in.n_tokens; const llama_pos pos_start = batch_in.pos[0]; const llama_pos pos_max_dft = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); if (pos_start <= pos_max_dft) { return true; } // Stale pending: discard if the new batch doesn't start one past it. const bool pending_continues = pending_pos[seq_id] >= 0 && pending_pos[seq_id] + 1 == pos_start; if (pending_pos[seq_id] >= 0 && !pending_continues) { pending_pos[seq_id] = -1; } // Build a paired hook batch: // row 0 = (pending_h, batch_in.token[0]) at pos_start if pending_continues // rows 1..n_rows-1 = (h_k from this batch, batch_in.token[k+1]) at pos[k+1] // The last h-row (h_{n_rows-1}) is stashed as the new pending and is *not* // decoded this call — it waits for the next batch's first token to pair. const size_t row_bytes = (size_t) n_embd * sizeof(float); common_batch_clear(batch); int out_idx = 0; auto add_pair = [&](const float * h_row, llama_token tok, llama_pos pos) { std::memcpy(batch.embd + (size_t) out_idx * n_embd, h_row, row_bytes); batch.token [out_idx] = tok; batch.pos [out_idx] = pos; batch.n_seq_id[out_idx] = 1; batch.seq_id [out_idx][0] = seq_id; batch.logits [out_idx] = 0; ++out_idx; }; if (pending_continues) { add_pair(pending_h[seq_id].data(), batch_in.token[0], pos_start); } // TODO: is there is a fast way to build this batch? for (int k = 0; k + 1 < n_rows; ++k) { if (batch_in.logits[k] == 0) { LOG_WRN("%s: batch_in.logits[%d] == 0 (need_embd / logits=1 missing on prefill); stopping hook at this row\n", __func__, k); break; } const float * h_k = llama_get_embeddings_pre_norm_ith(ctx_tgt, k); if (h_k == nullptr) { LOG_WRN("%s: ctx_tgt has no pre-norm row at i=%d; stopping hook\n", __func__, k); break; } add_pair(h_k, batch_in.token[k + 1], batch_in.pos[k + 1]); } if (out_idx > 0) { batch.n_tokens = out_idx; const int32_t rc = llama_decode(ctx_dft, batch); if (rc != 0) { LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d, n=%d)\n", __func__, (int) rc, (int) pos_start, out_idx); return false; } } // last_n_accepted < 0) can find the last pre-norm row of this batch. // We assume every batch position has logits=1 (server sets need_embd // for MTP slots) → n_outputs == n_tokens. last_trunk_n_outputs[seq_id] = n_rows; // Stash the last h-row (h_{n_rows-1}) as the new pending for the next // process() call's first token to pair with. if (batch_in.logits[n_rows - 1] != 0) { const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, n_rows - 1); if (h_last != nullptr) { std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); pending_pos[seq_id] = batch_in.pos[n_rows - 1]; } else { pending_pos[seq_id] = -1; } } else { // No trunk output at the tail — can't carry over. pending_pos[seq_id] = -1; } return true; } void draft(common_speculative_draft_params_vec & dparams) override { // Single-seq for now (asserted in ctor). Future: iterate over dparams. const llama_seq_id seq_id = 0; if ((size_t) seq_id >= dparams.size()) { return; } auto & dp = dparams[seq_id]; if (!dp.drafting) { return; } auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; auto * smpl = smpls[seq_id].get(); GGML_ASSERT(dp.result != nullptr); auto & draft_tokens = *dp.result; draft_tokens.clear(); if (last_n_drafted[seq_id] > 0) { const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1; if (n_to_drop > 0) { const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); if (pos_max >= 0) { const llama_pos drop_from = pos_max - n_to_drop + 1; llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); } } last_n_drafted[seq_id] = 0; last_n_accepted[seq_id] = 0; } // Effective draft length: min(global cap, per-sequence override). int32_t n_max = std::max(1, params.n_max); if (dp.n_max > 0) { n_max = std::min(n_max, dp.n_max); } const size_t row_bytes = (size_t) n_embd * sizeof(float); common_sampler_reset(smpl); llama_token cond_tok = dp.id_last; llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id) + 1; for (int32_t k = 0; k < n_max; ++k) { const float * h_row = nullptr; if (k == 0) { // Condition on the trunk's pre-norm row. int32_t row_idx; if (last_n_accepted[seq_id] < 0) { // First draft after begin(): use the last prefill row. row_idx = last_trunk_n_outputs[seq_id] - 1; } else { // After accept(n_accepted): row of the next conditioning // position in the trunk's verify batch. row_idx = last_n_accepted[seq_id]; } if (row_idx < 0) { LOG_WRN("%s: no trunk pre-norm row available (row_idx=%d); stopping chain\n", __func__, row_idx); break; } h_row = llama_get_embeddings_pre_norm_ith(ctx_tgt, row_idx); } else { // AR step: condition on the MTP head's own pre-norm row from // the just-completed single-token decode. n_outputs=1 there, // so the row is at batch position 0. h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, 0); } if (h_row == nullptr) { LOG_WRN("%s: missing pre-norm row at k=%d; stopping chain\n", __func__, k); break; } // 1-token batch carrying both (token, h_pre_norm). common_batch_clear(batch); std::memcpy(batch.embd, h_row, row_bytes); batch.token [0] = cond_tok; batch.pos [0] = pos; batch.n_seq_id[0] = 1; batch.seq_id [0][0] = seq_id; batch.logits [0] = 1; // need logits for sampling batch.n_tokens = 1; const int32_t rc = llama_decode(ctx_dft, batch); if (rc != 0) { LOG_WRN("%s: llama_decode(ctx_dft) failed rc=%d at k=%d; stopping chain\n", __func__, rc, k); break; } const llama_token best = common_sampler_sample(smpl, ctx_dft, 0); common_sampler_accept(smpl, best, /*is_generated=*/ false); draft_tokens.push_back(best); cond_tok = best; ++pos; } last_n_drafted[seq_id] = (uint16_t) draft_tokens.size(); } void accept(llama_seq_id seq_id, uint16_t n_accepted) override { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size()); auto * ctx_dft = this->params.ctx_dft; const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); const int32_t n_drafted_last = (int32_t) last_n_drafted[seq_id]; const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1); if (pos_max < 0) { last_n_accepted[seq_id] = (int32_t) n_accepted; return; } if (n_to_drop > 0) { const llama_pos drop_from = pos_max - n_to_drop + 1; llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); } last_n_drafted [seq_id] = 0; last_n_accepted[seq_id] = (int32_t) n_accepted; } }; // state of self-speculation (simple implementation, not ngram-map) struct common_speculative_state_ngram_simple : public common_speculative_impl { common_params_speculative_ngram_map params; // shared across all sequences common_ngram_simple_config config; common_speculative_state_ngram_simple( const common_params_speculative & params, uint32_t n_seq, common_ngram_simple_config config) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq) , params(params.ngram_simple) , config(config) {} void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop } bool process(const llama_batch & /*batch*/) override { // TODO: implement return true; } void draft(common_speculative_draft_params_vec & dparams) override { assert(dparams.size() == n_seq); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } *dp.result = common_ngram_simple_draft(config, *dp.prompt, dp.id_last); } } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } }; struct common_speculative_state_ngram_map_k : public common_speculative_impl { common_params_speculative_ngram_map params; // n_seq configs std::vector config; common_speculative_state_ngram_map_k( const common_params_speculative & params, const common_ngram_map & config, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq) , params(params.ngram_map_k) { for (uint32_t i = 0; i < n_seq; i++) { this->config.push_back(config); } } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { GGML_ASSERT(seq_id < (llama_seq_id) n_seq); common_ngram_map_begin(config[seq_id], prompt); } bool process(const llama_batch & /*batch*/) override { // TODO: implement return true; } void draft(common_speculative_draft_params_vec & dparams) override { assert(dparams.size() == n_seq); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } common_ngram_map_draft(config[seq_id], *dp.prompt, dp.id_last, *dp.result); } } void accept(llama_seq_id seq_id, uint16_t n_accepted) override { GGML_ASSERT((seq_id < (llama_seq_id) config.size())); common_ngram_map_accept(config[seq_id], n_accepted); } }; struct common_speculative_state_ngram_mod : public common_speculative_impl { common_params_speculative_ngram_mod params; // shared across all sequences common_ngram_mod mod; // enable trace logging if LLAMA_TRACE is set const bool verbose; struct seq_info { // the last position in the prompt that was added to the ngram container size_t i_last = 0; // length of the last drafted n‑gram (number of tokens returned by draft) size_t n_draft_last = 0; // consecutive accept rounds with low acceptance fraction (< 0.5) int n_low = 0; }; std::vector sinfos; common_speculative_state_ngram_mod( const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq) , params(params.ngram_mod) , mod(params.ngram_mod.n_match, 4*1024*1024) , verbose(std::getenv("LLAMA_TRACE") != nullptr) { static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, this->params.n_match, mod.size(), (float)(mod.size_bytes())/1024/1024); if (this->params.n_match < 16) { LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match); } sinfos.resize(n_seq); } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { auto & sinfo = sinfos[seq_id]; sinfo.i_last = 0; sinfo.n_draft_last = 0; const size_t n = mod.get_n(); if (prompt.size() < n) { return; } for (size_t i = 0; i < prompt.size() - n; ++i) { mod.add(prompt.data() + i); } sinfo.i_last = prompt.size() - n; const double f = (double)mod.get_used() / (double)mod.size(); LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); constexpr double f_thold = 0.25; if (f > f_thold) { LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold); mod.reset(); } } void draft_one( llama_seq_id seq_id, common_speculative_draft_params & dparams) { auto & sinfo = sinfos[seq_id]; auto & result = *dparams.result; const auto & prompt = *dparams.prompt; sinfo.n_draft_last = 0; const size_t cur_len = prompt.size(); if (cur_len < mod.get_n()) { return; } const size_t n = mod.get_n(); // add new ngrams in chunks if (sinfo.i_last + 32 < cur_len) { for (size_t i = sinfo.i_last; i < cur_len - n; ++i) { mod.add(prompt.data() + i); } sinfo.i_last = cur_len - n; } result.resize(n + params.n_max); for (size_t i = 0; i < n - 1; ++i) { result[i] = prompt.at(cur_len - n + 1 + i); } result[n - 1] = dparams.id_last; for (int i = 0; i < params.n_max; ++i) { const llama_token token = mod.get(result.data() + i); if (token == common_ngram_mod::EMPTY) { if (i < params.n_min) { result.clear(); return; } result.resize(n + i); break; } result[n + i] = token; } // only return the m tokens that were drafted for (size_t i = 0; n + i < result.size(); ++i) { result[i] = result[n + i]; } result.resize(result.size() - n); // store length of drafted n‑gram for later acceptance analysis sinfo.n_draft_last = result.size(); } bool process(const llama_batch & /*batch*/) override { // TODO: implement return true; } void draft(common_speculative_draft_params_vec & dparams) override { assert(dparams.size() == n_seq); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } draft_one(seq_id, dp); } } void accept(llama_seq_id seq_id, uint16_t n_accepted) override { auto & sinfo = sinfos[seq_id]; // compute acceptance fraction if we have a recorded draft length if (sinfo.n_draft_last > 0) { const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last; if (f_acc < 0.5) { sinfo.n_low++; if (sinfo.n_low >= 3) { if (verbose) { LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, sinfo.n_low); } mod.reset(); sinfo.n_low = 0; sinfo.i_last = 0; } } else { sinfo.n_low = 0; } } } }; struct common_speculative_state_ngram_cache : public common_speculative_impl { common_params_speculative_ngram_cache params; uint16_t n_draft; bool save_dynamic; bool save_static; struct seq_info { size_t cache_size = 0; // number of tokens in n-gram cache common_ngram_cache ngram_cache_context; common_ngram_cache ngram_cache_dynamic; common_ngram_cache ngram_cache_static; }; std::vector sinfos; common_speculative_state_ngram_cache( const common_params_speculative & params, uint32_t n_seq, uint16_t n_draft, const std::string & path_static, const std::string & path_dynamic, bool save_dynamic, bool save_static) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, n_seq) , params(params.ngram_cache) , n_draft(n_draft) , save_dynamic(save_dynamic) , save_static(save_static) { sinfos.resize(n_seq); if (!path_static.empty()) { try { auto ngram_cache_static = common_ngram_cache_load(path_static); for (auto & sinfo : sinfos) { sinfo.ngram_cache_static = ngram_cache_static; } } catch (...) { LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); GGML_ABORT("Couldn't read static lookup cache"); } } if (!path_dynamic.empty()) { try { auto ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); for (auto & sinfo : sinfos) { sinfo.ngram_cache_dynamic = ngram_cache_dynamic; } } catch (...) { LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); GGML_ABORT("Couldn't read dynamic lookup cache"); } } } void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop } void draft_one( llama_seq_id seq_id, common_speculative_draft_params & dparams) { auto & sinfo = sinfos[seq_id]; auto & result = *dparams.result; const auto & prompt = *dparams.prompt; if (sinfo.cache_size < prompt.size() + 1) { llama_tokens tokens_new; tokens_new.reserve(prompt.size() + 1 - sinfo.cache_size); for (size_t j = sinfo.cache_size; j < prompt.size(); ++j) { tokens_new.push_back(prompt[j]); } tokens_new.push_back(dparams.id_last); // add the last token // Update context ngram cache with new dparams.prompt: common_ngram_cache_update( sinfo.ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, tokens_new, tokens_new.size(), false); sinfo.cache_size = prompt.size() + 1; } llama_tokens inp; inp.reserve(prompt.size() + 1); for (size_t j = 0; j < prompt.size(); ++j) { inp.push_back(prompt[j]); } inp.push_back(dparams.id_last); result.push_back(dparams.id_last); common_ngram_cache_draft( inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, sinfo.ngram_cache_context, sinfo.ngram_cache_dynamic, sinfo.ngram_cache_static); if (result.size() > 0) { // delete first token in result (which is the id_last token) result.erase(result.begin()); } } bool process(const llama_batch & /*batch*/) override { // TODO: implement return true; } void draft(common_speculative_draft_params_vec & dparams) override { assert(dparams.size() == n_seq); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } draft_one(seq_id, dp); } } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } }; struct common_speculative { common_speculative_draft_params_vec dparams; // list of implementations to use and their states std::vector> impls; // which implementaion was used for a given seq_id std::vector impl_last; }; static common_ngram_map get_common_ngram_map( common_speculative_type type, const common_params_speculative_ngram_map & config) { uint16_t size_key = config.size_n; uint16_t size_value = config.size_m; bool key_only = type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K; uint16_t min_hits = config.min_hits; return common_ngram_map(size_key, size_value, key_only, min_hits); } static common_speculative_state_ngram_cache create_state_ngram_cache( const common_speculative_config & config, uint32_t n_seq, const std::string & path_static, const std::string & path_dynamic) { uint16_t n_draft = 8; // TODO get from config? // TODO bool param in common/common.h to set save_static/save_dynamic? bool save_static = false; bool save_dynamic = false; common_speculative_state_ngram_cache state(config.params, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic); return state; } std::string common_speculative_type_name_str() { std::string result; for (size_t i = 0; i < common_speculative_types.size(); i++) { if (i > 0) { result += ", "; } result += common_speculative_type_to_str(common_speculative_types[i]); } return result; } std::string common_speculative_type_to_str(enum common_speculative_type type) { switch (type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; default: return "unknown"; } } enum common_speculative_type common_speculative_type_from_name(const std::string & name) { const auto it = common_speculative_type_from_name_map.find(name); if (it == common_speculative_type_from_name_map.end()) { return COMMON_SPECULATIVE_TYPE_COUNT; } return it->second; } // initialization of the speculative decoding system // common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq) { // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { bool has_draft = !params.draft.mparams.path.empty(); bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP) && params.draft.ctx_dft != nullptr; bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD); // In a more complex implementation we could use the same implementation but with different parameters. // This was initially used in PR-18471 but removed to simplify the code. if (has_ngram_simple) { // This implementation can guess a lot of tokens without any draft model. configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params)); } if (has_ngram_map_k) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, params)); } if (has_ngram_map_k4v) { // This implementation can guess tokens with high acceptance rate but is more expensive. configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); } if (has_ngram_mod) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params)); } if (has_ngram_cache) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } if (has_draft) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); } if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); } if (has_mtp) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params)); } } std::vector> impls = {}; for (const common_speculative_config & config : configs) { LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); switch (config.type) { case COMMON_SPECULATIVE_TYPE_NONE: break; case COMMON_SPECULATIVE_TYPE_DRAFT: { impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_EAGLE3: { impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_MTP: { impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple); uint16_t ngram_size_key = ngram_map.size_key; uint16_t mgram_size_value = ngram_map.size_value; auto config_simple = common_ngram_simple_config { /* .size_ngram = */ ngram_size_key, /* .size_mgram = */ mgram_size_value }; auto state = std::make_unique( /* .params = */ config.params, /* .n_seq = */ n_seq, /* .state = */ config_simple ); impls.push_back(std::move(state)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { impls.push_back( std::make_unique( config.params, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { impls.push_back( std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { auto state = create_state_ngram_cache( config, n_seq, params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic); impls.push_back(std::make_unique(state)); break; } default: break; } } if (impls.empty()) { LOG_WRN("%s", "no implementations specified for speculative decoding\n"); return nullptr; } auto * result = new common_speculative { /* .dparams = */ common_speculative_draft_params_vec(n_seq), /* .impls = */ std::move(impls), /* .impl_last = */ std::vector(n_seq, nullptr) }; return result; } void common_speculative_free(common_speculative * spec) { if (spec == nullptr) { return; } delete spec; } common_speculative_draft_params & common_speculative_get_draft_params( common_speculative * spec, llama_seq_id seq_id) { GGML_ASSERT(spec); GGML_ASSERT(seq_id < (llama_seq_id) spec->dparams.size()); return spec->dparams[seq_id]; } void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt) { if (spec == nullptr) { return; } for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); impl->begin(seq_id, prompt); impl->n_call_begin++; } } bool common_speculative_process(common_speculative * spec, const llama_batch & batch) { bool result = true; if (spec == nullptr) { return result; } for (auto & impl : spec->impls) { result = result && impl->process(batch); } return result; } void common_speculative_draft(common_speculative * spec) { if (spec == nullptr) { return; } auto & dparams = spec->dparams; { int n_drafting = 0; for (auto & dp : dparams) { GGML_ASSERT(!dp.drafting || dp.result->empty()); if (dp.drafting) { n_drafting++; } } if (n_drafting == 0) { return; } } for (auto & impl : spec->impls) { { common_time_meas tm(impl->t_draft_us, !impl->gen_perf); impl->draft(dparams); impl->n_call_draft++; } int n_drafting = 0; for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { auto & dp = dparams[seq_id]; auto & result = *dp.result; // a new draft has been sampled if (dp.drafting && !result.empty()) { dp.drafting = false; if (dp.n_max > 0) { if (!result.empty() && (int) result.size() > dp.n_max) { LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max); result.resize(dp.n_max); } } if (!result.empty()) { LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(), impl.get()->n_call_draft, result.size()); // remember which implementation was used spec->impl_last[seq_id] = impl.get(); impl->n_gen_drafts++; impl->n_gen_tokens += result.size(); } } if (dp.drafting) { n_drafting++; } } if (n_drafting == 0) { break; } } // these sequences failed to generate a draft for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { auto & dp = dparams[seq_id]; if (dp.drafting) { dp.drafting = false; } } } void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) { if (n_accepted == 0) { return; } common_speculative_impl * impl = spec->impl_last[seq_id]; GGML_ASSERT(impl); { common_time_meas tm(impl->t_accept_us, !impl->gen_perf); if (n_accepted > 0) { impl->n_acc_drafts++; impl->n_acc_tokens += n_accepted; } impl->accept(seq_id, n_accepted); impl->n_call_accept++; } } void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; } for (const auto & impl : spec->impls) { std::string str_perf; if (impl->gen_perf) { std::ostringstream oss; oss << std::fixed << std::setprecision(3) << impl->t_begin_us / 1000.0 << ", "; oss << std::fixed << std::setprecision(3) << impl->t_draft_us / 1000.0 << ", "; oss << std::fixed << std::setprecision(3) << impl->t_accept_us / 1000.0; str_perf = ", dur(b,g,a) = " + oss.str() + " ms"; } else { str_perf = ""; } LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n", common_speculative_type_to_str(impl->type).c_str(), impl->n_call_begin, impl->n_call_draft, impl->n_call_accept, impl->n_gen_drafts, impl->n_acc_drafts, impl->n_gen_tokens, impl->n_acc_tokens, str_perf.c_str()); } }