diff --git a/common/speculative.cpp b/common/speculative.cpp index f7450e46c4..b4a3ec7fee 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -144,9 +144,7 @@ struct common_speculative_state { virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0; - virtual void draft( - llama_seq_id seq_id, - common_speculative_draft_params & dparams) = 0; + virtual void draft(common_speculative_draft_params_map & dparams) = 0; virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; }; @@ -169,7 +167,7 @@ struct common_speculative_state_draft : public common_speculative_state { auto * ctx_dft = this->params.ctx_dft; auto * ctx_tgt = this->params.ctx_tgt; - batch = llama_batch_init(llama_n_batch(ctx_dft), 0, n_seq); + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); smpl = nullptr; // TODO: optimize or pass from outside? @@ -223,20 +221,30 @@ struct common_speculative_state_draft : public common_speculative_state { // noop } - void draft( - llama_seq_id seq_id, - common_speculative_draft_params & dparams) override { + void draft(common_speculative_draft_params_map & dparams) override { auto * spec = this; auto & batch = spec->batch; auto & ctx_dft = spec->params.ctx_dft; auto & smpl = spec->smpl; - // sanity check - GGML_ASSERT(dparams.n_past >= (llama_pos) dparams.prompt.size()); - common_batch_clear(batch); - common_batch_add (batch, dparams.id_last, dparams.n_past, { seq_id }, true); + + int n_drafting = 0; + std::map drafting; + + for (auto & dp : dparams) { + if (!dp.second.drafting) { + continue; + } + + llama_seq_id seq_id = dp.first; + + n_drafting++; + drafting[seq_id] = true; + + common_batch_add(batch, dp.second.id_last, dp.second.n_past, { seq_id }, true); + } int ret = llama_decode(ctx_dft, batch); if (ret != 0) { @@ -246,49 +254,76 @@ struct common_speculative_state_draft : public common_speculative_state { common_sampler_reset(smpl.get()); - const auto n_max = dparams.n_max > 0 ? std::min(dparams.n_max, spec->params.n_max) : spec->params.n_max; + int i = 0; + + while (n_drafting > 0) { + int i_batch = 0; - // sample n_draft tokens from the draft model - for (int i = 0; i < n_max; ++i) { common_batch_clear(batch); - common_sampler_sample(smpl.get(), ctx_dft, 0, true); + for (auto [seq_id, _] : drafting) { + if (!drafting[seq_id]) { + continue; + } - const auto * cur_p = common_sampler_get_candidates(smpl.get(), true); + common_sampler_sample(smpl.get(), ctx_dft, i_batch, true); - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + const auto * cur_p = common_sampler_get_candidates(smpl.get(), true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl.get(), id, true); + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < spec->params.p_min) { + drafting[seq_id] = false; + n_drafting--; + + continue; + } + + auto & dp = dparams.at(seq_id); + + dp.result->push_back(id); + + if ((spec->params.n_max <= (int) dp.result->size()) || + (dp.n_max > 0 && dp.n_max <= (int) dp.result->size())) { + drafting[seq_id] = false; + n_drafting--; + continue; + } + + common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); } - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; - - common_sampler_accept(smpl.get(), id, true); - - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < spec->params.p_min) { + if (batch.n_tokens == 0) { break; } - dparams.result.push_back(id); - - if (spec->params.n_max <= (int) dparams.result.size()) { - break; - } - - common_batch_add(batch, id, dparams.n_past + i + 1, { seq_id }, true); - // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); break; } + + ++i; } - if (dparams.result.size() < (size_t) spec->params.n_min) { - dparams.result.clear(); + for (auto & dp : dparams) { + if (!dp.second.drafting) { + continue; + } + + if (dp.second.result->size() < (size_t) spec->params.n_min) { + dp.second.result->clear(); + } } } @@ -307,9 +342,7 @@ struct common_speculative_state_eagle3 : public common_speculative_state { // noop } - void draft( - llama_seq_id /*seq_id*/, - common_speculative_draft_params & /*dparams*/) override { + void draft(common_speculative_draft_params_map & /*dparams*/) override { // TODO: implement } @@ -335,10 +368,14 @@ struct common_speculative_state_ngram_simple : public common_speculative_state { // noop } - void draft( - llama_seq_id /*seq_id*/, - common_speculative_draft_params & dparams) override { - dparams.result = common_ngram_simple_draft(config, dparams.prompt, dparams.id_last); + void draft(common_speculative_draft_params_map & dparams) override { + for (auto & dp : dparams) { + if (!dp.second.drafting) { + continue; + } + + *dp.second.result = common_ngram_simple_draft(config, *dp.second.prompt, dp.second.id_last); + } } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { @@ -368,10 +405,14 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state { common_ngram_map_begin(config[seq_id], prompt); } - void draft( - llama_seq_id seq_id, - common_speculative_draft_params & dparams) override { - common_ngram_map_draft(config[seq_id], dparams.prompt, dparams.id_last, dparams.result); + void draft(common_speculative_draft_params_map & dparams) override { + for (auto & dp : dparams) { + if (!dp.second.drafting) { + continue; + } + + common_ngram_map_draft(config[dp.first], *dp.second.prompt, dp.second.id_last, *dp.second.result); + } } void accept(llama_seq_id seq_id, uint16_t n_accepted) override { @@ -452,15 +493,15 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { } } - void draft( + void draft_one( llama_seq_id seq_id, - common_speculative_draft_params & dparams) override { + common_speculative_draft_params & dparams) { auto & sinfo = sinfos[seq_id]; - auto & result = dparams.result; + auto & result = *dparams.result; sinfo.n_draft_last = 0; - const size_t cur_len = dparams.prompt.size(); + const size_t cur_len = dparams.prompt->size(); if (cur_len < mod.get_n()) { return; } @@ -470,7 +511,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { // add new ngrams in chunks if (sinfo.i_last + 32 < cur_len) { for (size_t i = sinfo.i_last; i < cur_len - n; ++i) { - mod.add(dparams.prompt.data() + i); + mod.add(dparams.prompt->data() + i); } sinfo.i_last = cur_len - n; @@ -478,7 +519,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { result.resize(n + params.n_max); for (size_t i = 0; i < n - 1; ++i) { - result[i] = dparams.prompt[cur_len - n + 1 + i]; + result[i] = dparams.prompt->at(cur_len - n + 1 + i); } result[n - 1] = dparams.id_last; @@ -506,6 +547,16 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { sinfo.n_draft_last = result.size(); } + void draft(common_speculative_draft_params_map & dparams) override { + for (auto & dp : dparams) { + if (!dp.second.drafting) { + continue; + } + + draft_one(dp.first, dp.second); + } + } + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { auto & sinfo = sinfos[seq_id]; @@ -595,17 +646,19 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { // noop } - void draft( + void draft_one( llama_seq_id seq_id, - common_speculative_draft_params & dparams) override { + common_speculative_draft_params & dparams) { auto & sinfo = sinfos[seq_id]; - auto & result = dparams.result; + auto & result = *dparams.result; - if (sinfo.cache_size < dparams.prompt.size() + 1) { + const auto & prompt = *dparams.prompt; + + if (sinfo.cache_size < prompt.size() + 1) { llama_tokens tokens_new; - tokens_new.reserve(dparams.prompt.size() + 1 - sinfo.cache_size); - for (size_t j = sinfo.cache_size; j < dparams.prompt.size(); ++j) { - tokens_new.push_back(dparams.prompt[j]); + tokens_new.reserve(prompt.size() + 1 - sinfo.cache_size); + for (size_t j = sinfo.cache_size; j < prompt.size(); ++j) { + tokens_new.push_back(prompt[j]); } tokens_new.push_back(dparams.id_last); // add the last token @@ -614,13 +667,13 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { sinfo.ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, tokens_new, tokens_new.size(), false); - sinfo.cache_size = dparams.prompt.size() + 1; + sinfo.cache_size = prompt.size() + 1; } llama_tokens inp; - inp.reserve(dparams.prompt.size() + 1); - for (size_t j = 0; j < dparams.prompt.size(); ++j) { - inp.push_back(dparams.prompt[j]); + inp.reserve(prompt.size() + 1); + for (size_t j = 0; j < prompt.size(); ++j) { + inp.push_back(prompt[j]); } inp.push_back(dparams.id_last); @@ -638,15 +691,27 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } + void draft(common_speculative_draft_params_map & dparams) override { + for (auto & dp : dparams) { + if (!dp.second.drafting) { + continue; + } + + draft_one(dp.first, dp.second); + } + } + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop } }; struct common_speculative { - std::vector> impls; // list of implementations to use and their states + // list of implementations to use and their states + std::vector> impls; - common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) + // which implementaion was used for a given seq_id + std::vector impl_for_seq_id; }; static common_ngram_map get_common_ngram_map( @@ -814,8 +879,8 @@ common_speculative * common_speculative_init(common_params_speculative & params, } auto * result = new common_speculative { - /* .impls = */ std::move(impls), - /* .curr_impl = */ nullptr, + /* .impls = */ std::move(impls), + /* .impl_for_seq_id = */ std::vector(n_seq, nullptr) }; return result; @@ -843,36 +908,55 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co void common_speculative_draft( common_speculative * spec, - llama_seq_id seq_id, - common_speculative_draft_params & dparams) { - spec->curr_impl = nullptr; // reset current implementation + common_speculative_draft_params_map & dparams) { + for (auto & dp : dparams) { + GGML_ASSERT(dp.second.drafting); + GGML_ASSERT(dp.second.result->empty()); + } for (auto & impl : spec->impls) { { common_time_meas tm(impl->t_draft_us, !impl->gen_perf); - impl->draft(seq_id, dparams); + impl->draft(dparams); impl->n_call_draft++; } - auto & result = dparams.result; + int n_drafting = 0; + for (auto & dp : dparams) { + auto & result = *dp.second.result; - if (dparams.n_max > 0) { - if (!result.empty() && (int) result.size() > dparams.n_max) { - LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dparams.n_max); - result.resize(dparams.n_max); + if (!result.empty() && dp.second.drafting) { + dp.second.drafting = false; + + if (dp.second.n_max > 0) { + if (!result.empty() && (int) result.size() > dp.second.n_max) { + LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.second.n_max); + result.resize(dp.second.n_max); + } + } + + if (!result.empty()) { + LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, + common_speculative_type_to_str(impl.get()->type).c_str(), dp.second.prompt->size(), + impl.get()->n_call_draft, result.size()); + + // remember which implementation was used + spec->impl_for_seq_id[dp.first] = impl.get(); + + impl->n_gen_drafts++; + impl->n_gen_tokens += result.size(); + + break; // we have a draft, so break out of the loop and return it. + } + } + + if (dp.second.drafting) { + n_drafting++; } } - if (!result.empty()) { - LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, - common_speculative_type_to_str(impl.get()->type).c_str(), dparams.prompt.size(), - impl.get()->n_call_draft, result.size()); - - spec->curr_impl = impl.get(); // set current implementation for stats - impl->n_gen_drafts++; - impl->n_gen_tokens += result.size(); - - break; // we have a draft, so break out of the loop and return it. + if (n_drafting == 0) { + break; } } } @@ -882,7 +966,7 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u return; } - common_speculative_state * impl = spec->curr_impl; + common_speculative_state * impl = spec->impl_for_seq_id[seq_id]; GGML_ASSERT(impl); @@ -896,6 +980,8 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u impl->accept(seq_id, n_accepted); impl->n_call_accept++; } + + spec->impl_for_seq_id[seq_id] = nullptr; } void common_speculative_print_stats(const common_speculative * spec) { diff --git a/common/speculative.h b/common/speculative.h index a8d75d30c0..ad2a553cc2 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -22,20 +22,23 @@ void common_speculative_free(common_speculative * spec); void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt); struct common_speculative_draft_params { + bool drafting = true; + int32_t n_max = -1; // overrides individual configurations llama_pos n_past; llama_token id_last; - const llama_tokens & prompt; + const llama_tokens * prompt; - llama_tokens & result; + llama_tokens * result; }; +using common_speculative_draft_params_map = std::map; + void common_speculative_draft( common_speculative * spec, - llama_seq_id seq_id, - common_speculative_draft_params & dparams); + common_speculative_draft_params_map & dparams); // informs the speculative decoder that n_accepted tokens were accepted by the target model void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 6cd531bc8a..5116ca73b6 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -179,14 +179,16 @@ int main(int argc, char ** argv) { } // generate a new draft - auto dparams = common_speculative_draft_params { + common_speculative_draft_params_map dparams; + dparams[seq_id] = { + /* .drafting = */ true, /* .n_max = */ -1, /* .n_past = */ n_past, /* .id_last = */ id_last, - /* .prompt = */ prompt_tgt, - /* .result = */ draft, // output + /* .prompt = */ &prompt_tgt, + /* .result = */ &draft, // output }; - common_speculative_draft(spec, seq_id, dparams); + common_speculative_draft(spec, dparams); // save the original draft size n_draft = draft.size(); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9d4b939d9d..be8bb6fbab 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -336,14 +336,16 @@ struct server_slot { } // generate a new draft - auto dparams = common_speculative_draft_params { - /* .n_max = */ n_draft_max, - /* .n_past = */ prompt.n_tokens(), - /* .id_last = */ sampled, - /* .prompt = */ tokens_text, - /* .result = */ spec_draft, + common_speculative_draft_params_map dparams; + dparams[this->id] = common_speculative_draft_params { + /* .drafting = */ true, + /* .n_max = */ n_draft_max, + /* .n_past = */ prompt.n_tokens(), + /* .id_last = */ sampled, + /* .prompt = */ &tokens_text, + /* .result = */ &spec_draft, }; - common_speculative_draft(spec, this->id, dparams); + common_speculative_draft(spec, dparams); n_draft_total += spec_draft.size(); if (spec_draft.size() > (size_t) n_draft_max) {