diff --git a/common/common.h b/common/common.h index 5b266f44f9..6d13ac40e1 100644 --- a/common/common.h +++ b/common/common.h @@ -295,8 +295,6 @@ struct common_params_model { std::string name = ""; // in format /[:] (tag is optional) // NOLINT }; -struct common_ngram_mod; - // draft-model-based speculative decoding parameters struct common_params_speculative_draft { int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding @@ -328,9 +326,6 @@ struct common_params_speculative_ngram_mod { int32_t n_max = 64; int32_t n_min = 48; - - // shared instance of the ngram container for all speculative decoding contexts - std::shared_ptr obj; }; struct common_params_speculative_ngram_map { diff --git a/common/speculative.cpp b/common/speculative.cpp index f72c51f09e..13f87ae428 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -142,28 +142,27 @@ struct common_speculative_state { virtual ~common_speculative_state() = default; - virtual void begin(const llama_tokens & prompt) = 0; + virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0; virtual void draft( + llama_seq_id seq_id, const common_params_speculative & params, const llama_tokens & prompt_tgt, - llama_pos n_past, + llama_pos n_past, llama_token id_last, llama_tokens & result) = 0; - virtual void accept(uint16_t n_accepted) = 0; - - virtual int32_t n_max(const common_params_speculative & params) const = 0; - virtual int32_t n_min(const common_params_speculative & params) const = 0; + virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; }; struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; - llama_seq_id seq_id; + uint32_t n_seq; - common_sampler * smpl; + // TODO: can become n_seq separate samplers + common_sampler_ptr smpl; llama_batch batch; @@ -171,13 +170,13 @@ struct common_speculative_state_draft : public common_speculative_state { enum common_speculative_type type, llama_context * ctx_tgt, llama_context * ctx_dft, - llama_seq_id seq_id) + uint32_t n_seq) : common_speculative_state(type) , ctx_tgt(ctx_tgt) , ctx_dft(ctx_dft) - , seq_id(seq_id) + , n_seq(n_seq) { - batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, n_seq); smpl = nullptr; // TODO: optimize or pass from outside? @@ -204,7 +203,7 @@ struct common_speculative_state_draft : public common_speculative_state { COMMON_SAMPLER_TYPE_TOP_K, }; - smpl = common_sampler_init(llama_get_model(ctx_dft), params); + smpl.reset(common_sampler_init(llama_get_model(ctx_dft), params)); } const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); @@ -215,18 +214,24 @@ struct common_speculative_state_draft : public common_speculative_state { throw std::runtime_error("draft model vocab type must match target model to use speculation"); } + + if (n_seq != llama_n_seq_max(ctx_dft)) { + LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft)); + + throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq"); + } } ~common_speculative_state_draft() override { - common_sampler_free(smpl); - llama_batch_free(batch); } - void begin(const llama_tokens & /*prompt*/) override { + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } void draft( + llama_seq_id seq_id, const common_params_speculative & params, const llama_tokens & prompt_tgt, llama_pos n_past, @@ -251,15 +256,15 @@ struct common_speculative_state_draft : public common_speculative_state { return; } - common_sampler_reset(smpl); + common_sampler_reset(smpl.get()); // sample n_draft tokens from the draft model for (int i = 0; i < sparams.n_max; ++i) { common_batch_clear(batch); - common_sampler_sample(smpl, ctx_dft, 0, true); + common_sampler_sample(smpl.get(), ctx_dft, 0, true); - const auto * cur_p = common_sampler_get_candidates(smpl, true); + const auto * cur_p = common_sampler_get_candidates(smpl.get(), true); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", @@ -269,7 +274,7 @@ struct common_speculative_state_draft : public common_speculative_state { // add drafted token for each sequence const llama_token id = cur_p->data[0].id; - common_sampler_accept(smpl, id, true); + common_sampler_accept(smpl.get(), id, true); // only collect very high-confidence draft tokens if (cur_p->data[0].p < sparams.p_min) { @@ -297,52 +302,30 @@ struct common_speculative_state_draft : public common_speculative_state { } } - void accept(uint16_t n_accepted) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop - GGML_UNUSED(n_accepted); - } - - int32_t n_max(const common_params_speculative & params) const override { - return params.draft.n_max; - } - - int32_t n_min(const common_params_speculative & params) const override { - return params.draft.n_min; } }; struct common_speculative_state_eagle3 : public common_speculative_state { common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_pos n_past, - llama_token id_last, - llama_tokens & draft_tokens) override { + llama_seq_id /*seq_id*/, + const common_params_speculative & /*params*/, + const llama_tokens & /*prompt_tgt*/, + llama_pos /*n_past*/, + llama_token /*id_last*/, + llama_tokens & /*draft_tokens*/) override { // TODO: implement - GGML_UNUSED(params); - GGML_UNUSED(prompt_tgt); - GGML_UNUSED(n_past); - GGML_UNUSED(id_last); - GGML_UNUSED(draft_tokens); } - void accept(uint16_t n_accepted) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop - GGML_UNUSED(n_accepted); - } - - int32_t n_max(const common_params_speculative & params) const override { - return params.draft.n_max; - } - - int32_t n_min(const common_params_speculative & params) const override { - return params.draft.n_min; } }; @@ -355,98 +338,106 @@ struct common_speculative_state_ngram_simple : public common_speculative_state { common_ngram_simple_config config) : common_speculative_state(type), config(config) {} - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } void draft( - const common_params_speculative & params, + llama_seq_id /*seq_id*/, + const common_params_speculative & /*params*/, const llama_tokens & prompt_tgt, - llama_pos n_past, + llama_pos /*n_past*/, llama_token id_last, llama_tokens & result) override { - GGML_UNUSED(params); - GGML_UNUSED(n_past); result = common_ngram_simple_draft(config, prompt_tgt, id_last); } - void accept(uint16_t n_accepted) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop - GGML_UNUSED(n_accepted); - } - - int32_t n_max(const common_params_speculative & /*params*/) const override { - return config.size_mgram; - } - - int32_t n_min(const common_params_speculative & /*params*/) const override { - return config.size_mgram; } }; struct common_speculative_state_ngram_map_k : public common_speculative_state { - // draft ngram map for speculative decoding without draft model - common_ngram_map config; + std::vector config; common_speculative_state_ngram_map_k( - enum common_speculative_type type, - common_ngram_map config) - : common_speculative_state(type), config(std::move(config)) {} + common_speculative_type type, + const common_ngram_map & config, + uint32_t n_seq) : common_speculative_state(type) { + for (uint32_t i = 0; i < n_seq; i++) { + this->config.push_back(config); + } + } - void begin(const llama_tokens & prompt) override { - common_ngram_map_begin(config, prompt); + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + GGML_ASSERT((seq_id < (llama_seq_id) config.size())); + + common_ngram_map_begin(config[seq_id], prompt); } void draft( - const common_params_speculative & params, + llama_seq_id seq_id, + const common_params_speculative & /*params*/, const llama_tokens & prompt_tgt, - llama_pos n_past, + llama_pos /*n_past*/, llama_token id_last, llama_tokens & result) override { - GGML_UNUSED(params); - GGML_UNUSED(n_past); - - common_ngram_map_draft(config, prompt_tgt, id_last, result); + common_ngram_map_draft(config[seq_id], prompt_tgt, id_last, result); } - void accept(uint16_t n_accepted) override { - common_ngram_map_accept(config, n_accepted); - } + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + GGML_ASSERT((seq_id < (llama_seq_id) config.size())); - int32_t n_max(const common_params_speculative & /*params*/) const override { - return config.size_value; - } - - int32_t n_min(const common_params_speculative & /*params*/) const override { - return config.size_value; + common_ngram_map_accept(config[seq_id], n_accepted); } }; struct common_speculative_state_ngram_mod : public common_speculative_state { - common_ngram_mod & mod; - - // the last position in the prompt that was added to the ngram container - size_t i_last = 0; - - // length of the last drafted n‑gram (number of tokens returned by draft) - size_t n_draft_last = 0; - - // consecutive accept rounds with low acceptance fraction (< 0.5) - int n_low = 0; + common_ngram_mod mod; // enable trace logging if LLAMA_TRACE is set const bool verbose; - common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) - : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { + struct seq_info { + // the last position in the prompt that was added to the ngram container + size_t i_last = 0; + + // length of the last drafted n‑gram (number of tokens returned by draft) + size_t n_draft_last = 0; + + // consecutive accept rounds with low acceptance fraction (< 0.5) + int n_low = 0; + }; + + std::vector sinfos; + + common_speculative_state_ngram_mod( + common_speculative_type type, + const common_params_speculative_ngram_mod & sparams, + uint32_t n_seq) + : common_speculative_state(type) + , mod(sparams.n_match, 4*1024*1024) + , verbose(std::getenv("LLAMA_TRACE") != nullptr) { static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); + + LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, + sparams.n_match, mod.size(), (float)(mod.size_bytes())/1024/1024); + + if (sparams.n_match < 16) { + LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " + "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, sparams.n_match); + } + + sinfos.resize(n_seq); } - void begin(const llama_tokens & prompt) override { - i_last = 0; + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + auto & sinfo = sinfos[seq_id]; - n_draft_last = 0; + sinfo.i_last = 0; + + sinfo.n_draft_last = 0; const size_t n = mod.get_n(); @@ -458,7 +449,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { mod.add(prompt.data() + i); } - i_last = prompt.size() - n; + sinfo.i_last = prompt.size() - n; const double f = (double)mod.get_used() / (double)mod.size(); LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); @@ -472,16 +463,17 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { } void draft( + llama_seq_id seq_id, const common_params_speculative & params, const llama_tokens & prompt_tgt, - llama_pos n_past, + llama_pos /*n_past*/, llama_token id_last, llama_tokens & result) override { - GGML_UNUSED(n_past); - const auto & sparams = params.ngram_mod; - n_draft_last = 0; + auto & sinfo = sinfos[seq_id]; + + sinfo.n_draft_last = 0; const size_t cur_len = prompt_tgt.size(); if (cur_len < mod.get_n()) { @@ -491,12 +483,12 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { const size_t n = mod.get_n(); // add new ngrams in chunks - if (i_last + 32 < cur_len) { - for (size_t i = i_last; i < cur_len - n; ++i) { + if (sinfo.i_last + 32 < cur_len) { + for (size_t i = sinfo.i_last; i < cur_len - n; ++i) { mod.add(prompt_tgt.data() + i); } - i_last = cur_len - n; + sinfo.i_last = cur_len - n; } result.resize(n + sparams.n_max); @@ -526,65 +518,71 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { result.resize(result.size() - n); // store length of drafted n‑gram for later acceptance analysis - n_draft_last = result.size(); + sinfo.n_draft_last = result.size(); } - void accept(uint16_t n_accepted) override { + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + auto & sinfo = sinfos[seq_id]; + // compute acceptance fraction if we have a recorded draft length - if (n_draft_last > 0) { - const double f_acc = (double)n_accepted / (double)n_draft_last; + if (sinfo.n_draft_last > 0) { + const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last; if (f_acc < 0.5) { - n_low++; - if (n_low >= 3) { + sinfo.n_low++; + if (sinfo.n_low >= 3) { if (verbose) { - LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); + LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, sinfo.n_low); } mod.reset(); - n_low = 0; - i_last = 0; + sinfo.n_low = 0; + sinfo.i_last = 0; } } else { - n_low = 0; + sinfo.n_low = 0; } } } - - int32_t n_max(const common_params_speculative & params) const override { - return params.ngram_mod.n_max; - } - - int32_t n_min(const common_params_speculative & params) const override { - return params.ngram_mod.n_min; - } }; struct common_speculative_state_ngram_cache : public common_speculative_state { uint16_t n_draft; + bool save_dynamic; bool save_static; - common_ngram_cache ngram_cache_context; - common_ngram_cache ngram_cache_dynamic; - common_ngram_cache ngram_cache_static; + struct seq_info { + size_t cache_size = 0; // number of tokens in n-gram cache - size_t cache_size = 0; // number of tokens in n-gram cache + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + }; + + std::vector sinfos; common_speculative_state_ngram_cache( - const enum common_speculative_type type, + const common_speculative_type type, + uint32_t n_seq, + uint16_t n_draft, const std::string & path_static, const std::string & path_dynamic, - uint16_t n_draft, - bool save_dynamic, - bool save_static) + bool save_dynamic, + bool save_static) : common_speculative_state(type) , n_draft(n_draft) , save_dynamic(save_dynamic) , save_static(save_static) { + sinfos.resize(n_seq); + if (!path_static.empty()) { try { - ngram_cache_static = common_ngram_cache_load(path_static); + auto ngram_cache_static = common_ngram_cache_load(path_static); + + for (auto & sinfo : sinfos) { + sinfo.ngram_cache_static = ngram_cache_static; + } } catch (...) { LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); GGML_ABORT("Couldn't read static lookup cache"); @@ -593,7 +591,11 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { if (!path_dynamic.empty()) { try { - ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + auto ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + + for (auto & sinfo : sinfos) { + sinfo.ngram_cache_dynamic = ngram_cache_dynamic; + } } catch (...) { LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); GGML_ABORT("Couldn't read dynamic lookup cache"); @@ -601,31 +603,33 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } void draft( - const common_params_speculative & params, + llama_seq_id seq_id, + const common_params_speculative & /*params*/, const llama_tokens & prompt_tgt, - llama_pos n_past, + llama_pos /*n_past*/, llama_token id_last, llama_tokens & result) override { - GGML_UNUSED(params); - GGML_UNUSED(n_past); + auto & sinfo = sinfos[seq_id]; - if (cache_size < prompt_tgt.size() + 1) { + if (sinfo.cache_size < prompt_tgt.size() + 1) { llama_tokens tokens_new; - tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); - for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { + tokens_new.reserve(prompt_tgt.size() + 1 - sinfo.cache_size); + for (size_t j = sinfo.cache_size; j < prompt_tgt.size(); ++j) { tokens_new.push_back(prompt_tgt[j]); } tokens_new.push_back(id_last); // add the last token // Update context ngram cache with new prompt_tgt: - common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + common_ngram_cache_update( + sinfo.ngram_cache_context, + LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, tokens_new, tokens_new.size(), false); - cache_size = prompt_tgt.size() + 1; + sinfo.cache_size = prompt_tgt.size() + 1; } llama_tokens inp; @@ -637,10 +641,11 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { result.push_back(id_last); - common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, - ngram_cache_context, - ngram_cache_dynamic, - ngram_cache_static); + common_ngram_cache_draft( + inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + sinfo.ngram_cache_context, + sinfo.ngram_cache_dynamic, + sinfo.ngram_cache_static); if (result.size() > 0) { // delete first token in result (which is the id_last token) @@ -648,17 +653,8 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } - void accept(uint16_t n_accepted) override { - // TODO: noop - GGML_UNUSED(n_accepted); - } - - int32_t n_max(const common_params_speculative & /*params*/) const override { - return n_draft; - } - - int32_t n_min(const common_params_speculative & /*params*/) const override { - return 0; + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; @@ -680,15 +676,17 @@ static common_ngram_map get_common_ngram_map( } static common_speculative_state_ngram_cache create_state_ngram_cache( - const std::string & path_static, const std::string & path_dynamic, - const common_speculative_config & config) { + const common_speculative_config & config, + uint32_t n_seq, + const std::string & path_static, + const std::string & path_dynamic) { uint16_t n_draft = 8; // TODO get from config? // TODO bool param in common/common.h to set save_static/save_dynamic? bool save_static = false; bool save_dynamic = false; - common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); + common_speculative_state_ngram_cache state(config.type, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic); return state; } @@ -728,7 +726,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // initialization of the speculative decoding system // -common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id) { +common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq) { // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { @@ -755,20 +753,6 @@ common_speculative * common_speculative_init(common_params_speculative & params, configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); } if (has_ngram_mod) { - auto & sparams = params.ngram_mod; - - if (!sparams.obj) { - sparams.obj = std::make_shared(sparams.n_match, 4*1024*1024); - - LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, - sparams.n_match, sparams.obj->size(), (float)(sparams.obj->size_bytes())/1024/1024); - - if (sparams.n_match < 16) { - LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " - "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, sparams.n_match); - } - } - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params)); } if (has_ngram_cache) { @@ -793,7 +777,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, impls.push_back(std::make_unique(config.type, /* .ctx_tgt = */ params.draft.ctx_tgt, /* .ctx_dft = */ params.draft.ctx_dft, - /* .seq_id = */ seq_id + /* .n_seq = */ n_seq )); break; } @@ -820,19 +804,22 @@ common_speculative * common_speculative_init(common_params_speculative & params, } case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { - impls.push_back(std::make_unique( - (config.type), - get_common_ngram_map(config.type, config.params.ngram_map_k) - )); + impls.push_back( + std::make_unique( + config.type, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { - GGML_ASSERT(config.params.ngram_mod.obj); - impls.push_back(std::make_unique(config.type, *config.params.ngram_mod.obj)); + impls.push_back( + std::make_unique( + config.type, config.params.ngram_mod, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { - auto state = create_state_ngram_cache(params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic, config); + auto state = create_state_ngram_cache( + config, n_seq, + params.ngram_cache.lookup_cache_static, + params.ngram_cache.lookup_cache_dynamic); impls.push_back(std::make_unique(state)); break; } @@ -862,23 +849,24 @@ void common_speculative_free(common_speculative * spec) { delete spec; } -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt) { if (spec == nullptr) { return; } for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); - impl->begin(prompt); + impl->begin(seq_id, prompt); impl->n_call_begin++; } } llama_tokens common_speculative_draft( common_speculative * spec, + llama_seq_id seq_id, const common_params_speculative & params, const llama_tokens & prompt_tgt, // specified in target model vocab - llama_pos n_past, + llama_pos n_past, llama_token id_last) { llama_tokens result; @@ -887,19 +875,10 @@ llama_tokens common_speculative_draft( for (auto & impl : spec->impls) { { common_time_meas tm(impl->t_draft_us, !impl->gen_perf); - impl->draft(params, prompt_tgt, n_past, id_last, result); + impl->draft(seq_id, params, prompt_tgt, n_past, id_last, result); impl->n_call_draft++; } - { - const int n_min = impl->n_min(params); - - if (!result.empty() && (int) result.size() < n_min) { - LOG_DBG("%s: ignoring small draft: %d < %d\n", __func__, (int) result.size(), n_min); - result.clear(); - } - } - if (!result.empty()) { LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), @@ -916,7 +895,7 @@ llama_tokens common_speculative_draft( return result; } -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { +void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) { if (n_accepted == 0) { return; } @@ -932,37 +911,11 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { impl->n_acc_tokens += n_accepted; } - impl->accept(n_accepted); + impl->accept(seq_id, n_accepted); impl->n_call_accept++; } } -int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params) { - if (spec == nullptr) { - return 0; - } - - int32_t n_max = 0; - for (const auto & impl : spec->impls) { - n_max = std::max(n_max, impl->n_max(params)); - } - - return n_max; -} - -int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params) { - if (spec == nullptr) { - return 0; - } - - int32_t n_min = 0; - for (const auto & impl : spec->impls) { - n_min = std::max(n_min, impl->n_min(params)); - } - - return n_min; -} - void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index 2f7e3144bb..f9d3cf8a55 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -14,26 +14,24 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); -common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id); +common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq); void common_speculative_free(common_speculative * spec); // optionally call once at the beginning of a new generation -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt); +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt); // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_draft( common_speculative * spec, + llama_seq_id seq_id, const common_params_speculative & params, const llama_tokens & prompt, llama_pos n_past, llama_token id_last); // informs the speculative decoder that n_accepted tokens were accepted by the target model -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); - -int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params); -int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params); +void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 6585da7382..08be5679e2 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -147,7 +147,7 @@ int main(int argc, char ** argv) { struct common_speculative * spec = common_speculative_init(params.speculative, seq_id); - common_speculative_begin(spec, prompt_tgt); + common_speculative_begin(spec, seq_id, prompt_tgt); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); @@ -179,7 +179,7 @@ int main(int argc, char ** argv) { } // generate a new draft - draft = common_speculative_draft(spec, params_spec, prompt_tgt, prompt_tgt.size(), id_last); + draft = common_speculative_draft(spec, seq_id, params_spec, prompt_tgt, prompt_tgt.size(), id_last); // save the original draft size n_draft = draft.size(); @@ -272,7 +272,7 @@ int main(int argc, char ** argv) { continue; } - common_speculative_accept(spec, ids.size() - 1); + common_speculative_accept(spec, seq_id, ids.size() - 1); // full acceptance: consume the draft and commit accepted tokens n_past += ids.size() - 1; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9b2fab832a..bac2338d43 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -61,10 +61,11 @@ struct server_slot { mtmd_context * mctx = nullptr; // speculative decoding + common_speculative * spec; + llama_tokens spec_draft; std::vector spec_i_batch; common_prompt_checkpoint spec_ckpt; - common_speculative_ptr spec; // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 @@ -293,14 +294,10 @@ struct server_slot { return 0; } - const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative); - // determine the max draft that fits the current slot state - int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative); - // note: slot.prompt is not yet expanded with the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2); + int n_draft_max = n_ctx - prompt.n_tokens() - 2; if (n_remaining > 0) { n_draft_max = std::min(n_draft_max, n_remaining - 1); @@ -308,11 +305,6 @@ struct server_slot { SLT_DBG(*this, "max possible draft: %d\n", n_draft_max); - if (n_draft_max < n_draft_min) { - SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min); - n_draft_max = 0; - } - return n_draft_max; } @@ -346,7 +338,7 @@ struct server_slot { } // generate a new draft - spec_draft = common_speculative_draft(spec.get(), params_spec, tokens_text, prompt.n_tokens(), sampled); + spec_draft = common_speculative_draft(spec, this->id, params_spec, tokens_text, prompt.n_tokens(), sampled); n_draft_total += spec_draft.size(); if (spec_draft.size() > (size_t) n_draft_max) { @@ -510,7 +502,7 @@ struct server_slot { ); } - common_speculative_print_stats(spec.get()); + common_speculative_print_stats(spec); } json to_json(bool only_metrics = false) const { @@ -684,6 +676,8 @@ private: common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + common_speculative_ptr spec; + bool add_bos_token = true; int32_t n_ctx; // total context for all clients / slots @@ -723,12 +717,6 @@ private: mtmd_free(mctx); mctx = nullptr; - for (server_slot & slot : slots) { - if (slot.can_speculate()) { - slot.spec.reset(); - } - } - llama_batch_free(batch); } @@ -906,34 +894,33 @@ private: slots.emplace_back(); } - bool no_dft = false; + // try speculative decoding + if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + try { + spec.reset(common_speculative_init(params_base.speculative, params_base.n_parallel)); + } catch (const std::exception & e) { + SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); + } + + if (spec) { + SRV_INF("%s", "speculative decoding context initialized\n"); + } else { + ctx_dft.reset(); + } + } for (int i = 0; i < params_base.n_parallel; i++) { server_slot & slot = slots[i]; - slot.id = i; + slot.id = i; slot.ctx_tgt = ctx_tgt; slot.ctx_dft = ctx_dft.get(); - slot.n_ctx = n_ctx_slot; + slot.spec = spec.get(); + slot.n_ctx = n_ctx_slot; slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; - // try speculative decoding - if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { - try { - slot.spec.reset(common_speculative_init(params_base.speculative, slot.id)); - } catch (const std::exception & e) { - SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); - - no_dft = true; - } - - if (slot.spec) { - SLT_INF(slot, "%s", "speculative decoding context initialized\n"); - } - } - SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); slot.callback_on_release = [this](int id_slot) { @@ -943,16 +930,6 @@ private: slot.reset(); } - if (no_dft && ctx_dft) { - SRV_WRN("%s", "destroying the draft model as it is not going to be used\n"); - - ctx_dft.reset(); - - for (auto & slot : slots) { - slot.ctx_dft = nullptr; - } - } - { const char * LLAMA_TRACE = getenv("LLAMA_TRACE"); trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0; @@ -1353,7 +1330,7 @@ private: backend_sampling &= task.params.sampling.backend_sampling; // TODO: speculative decoding requires multiple samples per batch - not supported yet - backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0); + backend_sampling &= !(slot.can_speculate()); // TODO: getting post/pre sampling logits is not yet supported with backend sampling backend_sampling &= !need_logits; @@ -3011,7 +2988,7 @@ private: slot.state = SLOT_STATE_GENERATING; if (slot.can_speculate()) { - common_speculative_begin(slot.spec.get(), slot.prompt.tokens.get_text_tokens()); + common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots @@ -3126,7 +3103,7 @@ private: SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); } - common_speculative_accept(slot.spec.get(), accepted.size() - 1); + common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); slot.spec_draft = std::move(accepted); }