#include "speculative.h" #include "common.h" #include "ggml.h" #include "llama.h" #include "log.h" #include "ngram-cache.h" #include "ngram-map.h" #include "ngram-mod.h" #include "sampling.h" #include #include #include #include #include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 const std::vector common_speculative_types = { COMMON_SPECULATIVE_TYPE_NONE, COMMON_SPECULATIVE_TYPE_DRAFT, COMMON_SPECULATIVE_TYPE_EAGLE3, COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, COMMON_SPECULATIVE_TYPE_NGRAM_MOD, COMMON_SPECULATIVE_TYPE_NGRAM_CACHE }; const std::map common_speculative_type_from_name_map = { {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, {"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} }; struct common_speculative_config { common_speculative_type type; common_params_speculative params; common_speculative_config(common_speculative_type t, const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {} }; static bool common_speculative_are_compatible( const llama_model * model_tgt, const llama_model * model_dft) { const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); const bool vocab_type_dft = llama_vocab_type(vocab_dft); LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); if (vocab_type_tgt != vocab_type_dft) { LOG_WRN("%s: draft model vocab type must match target model to use speculation but " "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); return false; } if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || (llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) { LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n", __func__, llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft), llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft)); return false; } if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || (llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) { LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n", __func__, llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft), llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft)); return false; } { const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); const int vocab_diff = n_vocab_tgt > n_vocab_dft ? n_vocab_tgt - n_vocab_dft : n_vocab_dft - n_vocab_tgt; if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__); LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return false; } for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__); LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i, common_token_to_piece(vocab_tgt, i).c_str(), common_token_to_piece(vocab_dft, i).c_str()); return false; } } } return true; } // state of an implementation of speculative decoding // // each implementation has a unique type and a state that is implementation-specific // in a subclass of common_speculative_state struct common_speculative_state { const enum common_speculative_type type; size_t n_call_begin = 0; // number of times this implementation was called for refresh. size_t n_call_draft = 0; // number of times this implementation was called for generation. size_t n_call_accept = 0; // number of times this implementation was called for accumulation. size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation. size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model. size_t n_gen_tokens = 0; // number of tokens generated by this implementation. size_t n_acc_tokens = 0; // number of tokens accepted by the target model. // TODO: track performance of most recent calls const bool gen_perf = true; // whether to generate performance stats. int64_t t_begin_us = 0; // total time spent in refresh of this implementation in microseconds. int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds. int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds. common_speculative_state(enum common_speculative_type type) : type(type) {} virtual ~common_speculative_state() = default; virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0; virtual void draft( llama_seq_id seq_id, const common_params_speculative & params, const llama_tokens & prompt_tgt, llama_pos n_past, llama_token id_last, llama_tokens & result) = 0; virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; }; struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; uint32_t n_seq; // TODO: can become n_seq separate samplers common_sampler_ptr smpl; llama_batch batch; common_speculative_state_draft( enum common_speculative_type type, llama_context * ctx_tgt, llama_context * ctx_dft, uint32_t n_seq) : common_speculative_state(type) , ctx_tgt(ctx_tgt) , ctx_dft(ctx_dft) , n_seq(n_seq) { batch = llama_batch_init(llama_n_batch(ctx_dft), 0, n_seq); smpl = nullptr; // TODO: optimize or pass from outside? // { // common_params_sampling params; // params.no_perf = false; // // params.top_k = 40; // params.top_p = 0.9; // // params.samplers = { // COMMON_SAMPLER_TYPE_TOP_K, // COMMON_SAMPLER_TYPE_TOP_P, // COMMON_SAMPLER_TYPE_INFILL, // }; // // result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); // } { common_params_sampling params; params.no_perf = false; params.top_k = 10; params.samplers = { COMMON_SAMPLER_TYPE_TOP_K, }; smpl.reset(common_sampler_init(llama_get_model(ctx_dft), params)); } const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt); if (!vocab_cmpt) { LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__); throw std::runtime_error("draft model vocab type must match target model to use speculation"); } if (n_seq != llama_n_seq_max(ctx_dft)) { LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft)); throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq"); } } ~common_speculative_state_draft() override { llama_batch_free(batch); } void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop } void draft( llama_seq_id seq_id, const common_params_speculative & params, const llama_tokens & prompt_tgt, llama_pos n_past, llama_token id_last, llama_tokens & result) override { const auto & sparams = params.draft; auto * spec = this; auto & batch = spec->batch; auto & ctx_dft = spec->ctx_dft; auto & smpl = spec->smpl; GGML_ASSERT(n_past >= (llama_pos) prompt_tgt.size()); common_batch_clear(batch); common_batch_add (batch, id_last, n_past, { seq_id }, true); int ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); return; } common_sampler_reset(smpl.get()); // sample n_draft tokens from the draft model for (int i = 0; i < sparams.n_max; ++i) { common_batch_clear(batch); common_sampler_sample(smpl.get(), ctx_dft, 0, true); const auto * cur_p = common_sampler_get_candidates(smpl.get(), true); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); } // add drafted token for each sequence const llama_token id = cur_p->data[0].id; common_sampler_accept(smpl.get(), id, true); // only collect very high-confidence draft tokens if (cur_p->data[0].p < sparams.p_min) { break; } result.push_back(id); if (sparams.n_max <= (int) result.size()) { break; } common_batch_add(batch, id, n_past + i + 1, { seq_id }, true); // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); break; } } if (result.size() < (size_t) sparams.n_min) { result.clear(); } } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } }; struct common_speculative_state_eagle3 : public common_speculative_state { common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop } void draft( llama_seq_id /*seq_id*/, const common_params_speculative & /*params*/, const llama_tokens & /*prompt_tgt*/, llama_pos /*n_past*/, llama_token /*id_last*/, llama_tokens & /*draft_tokens*/) override { // TODO: implement } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } }; // state of self-speculation (simple implementation, not ngram-map) struct common_speculative_state_ngram_simple : public common_speculative_state { common_ngram_simple_config config; common_speculative_state_ngram_simple( enum common_speculative_type type, common_ngram_simple_config config) : common_speculative_state(type), config(config) {} void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop } void draft( llama_seq_id /*seq_id*/, const common_params_speculative & /*params*/, const llama_tokens & prompt_tgt, llama_pos /*n_past*/, llama_token id_last, llama_tokens & result) override { result = common_ngram_simple_draft(config, prompt_tgt, id_last); } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } }; struct common_speculative_state_ngram_map_k : public common_speculative_state { std::vector config; common_speculative_state_ngram_map_k( common_speculative_type type, const common_ngram_map & config, uint32_t n_seq) : common_speculative_state(type) { for (uint32_t i = 0; i < n_seq; i++) { this->config.push_back(config); } } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { GGML_ASSERT((seq_id < (llama_seq_id) config.size())); common_ngram_map_begin(config[seq_id], prompt); } void draft( llama_seq_id seq_id, const common_params_speculative & /*params*/, const llama_tokens & prompt_tgt, llama_pos /*n_past*/, llama_token id_last, llama_tokens & result) override { common_ngram_map_draft(config[seq_id], prompt_tgt, id_last, result); } void accept(llama_seq_id seq_id, uint16_t n_accepted) override { GGML_ASSERT((seq_id < (llama_seq_id) config.size())); common_ngram_map_accept(config[seq_id], n_accepted); } }; struct common_speculative_state_ngram_mod : public common_speculative_state { common_ngram_mod mod; // enable trace logging if LLAMA_TRACE is set const bool verbose; struct seq_info { // the last position in the prompt that was added to the ngram container size_t i_last = 0; // length of the last drafted n‑gram (number of tokens returned by draft) size_t n_draft_last = 0; // consecutive accept rounds with low acceptance fraction (< 0.5) int n_low = 0; }; std::vector sinfos; common_speculative_state_ngram_mod( common_speculative_type type, const common_params_speculative_ngram_mod & sparams, uint32_t n_seq) : common_speculative_state(type) , mod(sparams.n_match, 4*1024*1024) , verbose(std::getenv("LLAMA_TRACE") != nullptr) { static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, sparams.n_match, mod.size(), (float)(mod.size_bytes())/1024/1024); if (sparams.n_match < 16) { LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, sparams.n_match); } sinfos.resize(n_seq); } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { auto & sinfo = sinfos[seq_id]; sinfo.i_last = 0; sinfo.n_draft_last = 0; const size_t n = mod.get_n(); if (prompt.size() < n) { return; } for (size_t i = 0; i < prompt.size() - n; ++i) { mod.add(prompt.data() + i); } sinfo.i_last = prompt.size() - n; const double f = (double)mod.get_used() / (double)mod.size(); LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); constexpr double f_thold = 0.25; if (f > f_thold) { LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold); mod.reset(); } } void draft( llama_seq_id seq_id, const common_params_speculative & params, const llama_tokens & prompt_tgt, llama_pos /*n_past*/, llama_token id_last, llama_tokens & result) override { const auto & sparams = params.ngram_mod; auto & sinfo = sinfos[seq_id]; sinfo.n_draft_last = 0; const size_t cur_len = prompt_tgt.size(); if (cur_len < mod.get_n()) { return; } const size_t n = mod.get_n(); // add new ngrams in chunks if (sinfo.i_last + 32 < cur_len) { for (size_t i = sinfo.i_last; i < cur_len - n; ++i) { mod.add(prompt_tgt.data() + i); } sinfo.i_last = cur_len - n; } result.resize(n + sparams.n_max); for (size_t i = 0; i < n - 1; ++i) { result[i] = prompt_tgt[cur_len - n + 1 + i]; } result[n - 1] = id_last; for (int i = 0; i < sparams.n_max; ++i) { const llama_token token = mod.get(result.data() + i); if (token == common_ngram_mod::EMPTY) { if (i < sparams.n_min) { result.clear(); return; } result.resize(n + i); break; } result[n + i] = token; } // only return the m tokens that were drafted for (size_t i = 0; n + i < result.size(); ++i) { result[i] = result[n + i]; } result.resize(result.size() - n); // store length of drafted n‑gram for later acceptance analysis sinfo.n_draft_last = result.size(); } void accept(llama_seq_id seq_id, uint16_t n_accepted) override { auto & sinfo = sinfos[seq_id]; // compute acceptance fraction if we have a recorded draft length if (sinfo.n_draft_last > 0) { const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last; if (f_acc < 0.5) { sinfo.n_low++; if (sinfo.n_low >= 3) { if (verbose) { LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, sinfo.n_low); } mod.reset(); sinfo.n_low = 0; sinfo.i_last = 0; } } else { sinfo.n_low = 0; } } } }; struct common_speculative_state_ngram_cache : public common_speculative_state { uint16_t n_draft; bool save_dynamic; bool save_static; struct seq_info { size_t cache_size = 0; // number of tokens in n-gram cache common_ngram_cache ngram_cache_context; common_ngram_cache ngram_cache_dynamic; common_ngram_cache ngram_cache_static; }; std::vector sinfos; common_speculative_state_ngram_cache( const common_speculative_type type, uint32_t n_seq, uint16_t n_draft, const std::string & path_static, const std::string & path_dynamic, bool save_dynamic, bool save_static) : common_speculative_state(type) , n_draft(n_draft) , save_dynamic(save_dynamic) , save_static(save_static) { sinfos.resize(n_seq); if (!path_static.empty()) { try { auto ngram_cache_static = common_ngram_cache_load(path_static); for (auto & sinfo : sinfos) { sinfo.ngram_cache_static = ngram_cache_static; } } catch (...) { LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); GGML_ABORT("Couldn't read static lookup cache"); } } if (!path_dynamic.empty()) { try { auto ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); for (auto & sinfo : sinfos) { sinfo.ngram_cache_dynamic = ngram_cache_dynamic; } } catch (...) { LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); GGML_ABORT("Couldn't read dynamic lookup cache"); } } } void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop } void draft( llama_seq_id seq_id, const common_params_speculative & /*params*/, const llama_tokens & prompt_tgt, llama_pos /*n_past*/, llama_token id_last, llama_tokens & result) override { auto & sinfo = sinfos[seq_id]; if (sinfo.cache_size < prompt_tgt.size() + 1) { llama_tokens tokens_new; tokens_new.reserve(prompt_tgt.size() + 1 - sinfo.cache_size); for (size_t j = sinfo.cache_size; j < prompt_tgt.size(); ++j) { tokens_new.push_back(prompt_tgt[j]); } tokens_new.push_back(id_last); // add the last token // Update context ngram cache with new prompt_tgt: common_ngram_cache_update( sinfo.ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, tokens_new, tokens_new.size(), false); sinfo.cache_size = prompt_tgt.size() + 1; } llama_tokens inp; inp.reserve(prompt_tgt.size() + 1); for (size_t j = 0; j < prompt_tgt.size(); ++j) { inp.push_back(prompt_tgt[j]); } inp.push_back(id_last); result.push_back(id_last); common_ngram_cache_draft( inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, sinfo.ngram_cache_context, sinfo.ngram_cache_dynamic, sinfo.ngram_cache_static); if (result.size() > 0) { // delete first token in result (which is the id_last token) result.erase(result.begin()); } } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } }; struct common_speculative { std::vector> impls; // list of implementations to use and their states common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) }; static common_ngram_map get_common_ngram_map( common_speculative_type type, const common_params_speculative_ngram_map & config) { uint16_t size_key = config.size_n; uint16_t size_value = config.size_m; bool key_only = type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K; uint16_t min_hits = config.min_hits; return common_ngram_map(size_key, size_value, key_only, min_hits); } static common_speculative_state_ngram_cache create_state_ngram_cache( const common_speculative_config & config, uint32_t n_seq, const std::string & path_static, const std::string & path_dynamic) { uint16_t n_draft = 8; // TODO get from config? // TODO bool param in common/common.h to set save_static/save_dynamic? bool save_static = false; bool save_dynamic = false; common_speculative_state_ngram_cache state(config.type, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic); return state; } std::string common_speculative_type_name_str() { std::string result; for (size_t i = 0; i < common_speculative_types.size(); i++) { if (i > 0) { result += ", "; } result += common_speculative_type_to_str(common_speculative_types[i]); } return result; } std::string common_speculative_type_to_str(enum common_speculative_type type) { switch (type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; default: return "unknown"; } } enum common_speculative_type common_speculative_type_from_name(const std::string & name) { const auto it = common_speculative_type_from_name_map.find(name); if (it == common_speculative_type_from_name_map.end()) { return COMMON_SPECULATIVE_TYPE_COUNT; } return it->second; } // initialization of the speculative decoding system // common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq) { // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { bool has_draft = !params.draft.mparams.path.empty(); bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD); // In a more complex implementation we could use the same implementation but with different parameters. // This was initially used in PR-18471 but removed to simplify the code. if (has_ngram_simple) { // This implementation can guess a lot of tokens without any draft model. configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params)); } if (has_ngram_map_k) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, params)); } if (has_ngram_map_k4v) { // This implementation can guess tokens with high acceptance rate but is more expensive. configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); } if (has_ngram_mod) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params)); } if (has_ngram_cache) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } if (has_draft) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); } if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); } } std::vector> impls = {}; for (const common_speculative_config & config : configs) { LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); switch (config.type) { case COMMON_SPECULATIVE_TYPE_NONE: break; case COMMON_SPECULATIVE_TYPE_DRAFT: { impls.push_back(std::make_unique(config.type, /* .ctx_tgt = */ params.draft.ctx_tgt, /* .ctx_dft = */ params.draft.ctx_dft, /* .n_seq = */ n_seq )); break; } case COMMON_SPECULATIVE_TYPE_EAGLE3: { impls.push_back(std::make_unique(config.type)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple); uint16_t ngram_size_key = ngram_map.size_key; uint16_t mgram_size_value = ngram_map.size_value; auto config_simple = common_ngram_simple_config { /* .size_ngram = */ ngram_size_key, /* .size_mgram = */ mgram_size_value }; auto state = std::make_unique( /* .type = */ config.type, /* .state = */ config_simple ); impls.push_back(std::move(state)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { impls.push_back( std::make_unique( config.type, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { impls.push_back( std::make_unique( config.type, config.params.ngram_mod, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { auto state = create_state_ngram_cache( config, n_seq, params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic); impls.push_back(std::make_unique(state)); break; } default: break; } } if (impls.empty()) { LOG_WRN("%s", "no implementations specified for speculative decoding\n"); return nullptr; } auto * result = new common_speculative { /* .impls = */ std::move(impls), /* .curr_impl = */ nullptr, }; return result; } void common_speculative_free(common_speculative * spec) { if (spec == nullptr) { return; } delete spec; } void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt) { if (spec == nullptr) { return; } for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); impl->begin(seq_id, prompt); impl->n_call_begin++; } } llama_tokens common_speculative_draft( common_speculative * spec, llama_seq_id seq_id, const common_params_speculative & params, const llama_tokens & prompt_tgt, // specified in target model vocab llama_pos n_past, llama_token id_last) { llama_tokens result; spec->curr_impl = nullptr; // reset current implementation for (auto & impl : spec->impls) { { common_time_meas tm(impl->t_draft_us, !impl->gen_perf); impl->draft(seq_id, params, prompt_tgt, n_past, id_last, result); impl->n_call_draft++; } if (!result.empty()) { LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), impl.get()->n_call_draft, result.size()); spec->curr_impl = impl.get(); // set current implementation for stats impl->n_gen_drafts++; impl->n_gen_tokens += result.size(); break; // we have a draft, so break out of the loop and return it. } } return result; } void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) { if (n_accepted == 0) { return; } common_speculative_state * impl = spec->curr_impl; GGML_ASSERT(impl); { common_time_meas tm(impl->t_accept_us, !impl->gen_perf); if (n_accepted > 0) { impl->n_acc_drafts++; impl->n_acc_tokens += n_accepted; } impl->accept(seq_id, n_accepted); impl->n_call_accept++; } } void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; } for (const auto & impl : spec->impls) { std::string str_perf; if (impl->gen_perf) { std::ostringstream oss; oss << std::fixed << std::setprecision(3) << impl->t_begin_us / 1000.0 << ", "; oss << std::fixed << std::setprecision(3) << impl->t_draft_us / 1000.0 << ", "; oss << std::fixed << std::setprecision(3) << impl->t_accept_us / 1000.0; str_perf = ", dur(b,g,a) = " + oss.str() + " ms"; } else { str_perf = ""; } LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n", common_speculative_type_to_str(impl->type).c_str(), impl->n_call_begin, impl->n_call_draft, impl->n_call_accept, impl->n_gen_drafts, impl->n_acc_drafts, impl->n_gen_tokens, impl->n_acc_tokens, str_perf.c_str()); } }