spec : refactor for multi-sequence speculative context

This commit is contained in:
Georgi Gerganov
2026-05-08 15:43:36 +03:00
parent efa2f8e5a7
commit 6582523eaa
5 changed files with 220 additions and 297 deletions

View File

@@ -61,10 +61,11 @@ struct server_slot {
mtmd_context * mctx = nullptr;
// speculative decoding
common_speculative * spec;
llama_tokens spec_draft;
std::vector<int32_t> spec_i_batch;
common_prompt_checkpoint spec_ckpt;
common_speculative_ptr spec;
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
@@ -293,14 +294,10 @@ struct server_slot {
return 0;
}
const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative);
// determine the max draft that fits the current slot state
int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative);
// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2);
int n_draft_max = n_ctx - prompt.n_tokens() - 2;
if (n_remaining > 0) {
n_draft_max = std::min(n_draft_max, n_remaining - 1);
@@ -308,11 +305,6 @@ struct server_slot {
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
if (n_draft_max < n_draft_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min);
n_draft_max = 0;
}
return n_draft_max;
}
@@ -346,7 +338,7 @@ struct server_slot {
}
// generate a new draft
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens_text, prompt.n_tokens(), sampled);
spec_draft = common_speculative_draft(spec, this->id, params_spec, tokens_text, prompt.n_tokens(), sampled);
n_draft_total += spec_draft.size();
if (spec_draft.size() > (size_t) n_draft_max) {
@@ -510,7 +502,7 @@ struct server_slot {
);
}
common_speculative_print_stats(spec.get());
common_speculative_print_stats(spec);
}
json to_json(bool only_metrics = false) const {
@@ -684,6 +676,8 @@ private:
common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
common_speculative_ptr spec;
bool add_bos_token = true;
int32_t n_ctx; // total context for all clients / slots
@@ -723,12 +717,6 @@ private:
mtmd_free(mctx);
mctx = nullptr;
for (server_slot & slot : slots) {
if (slot.can_speculate()) {
slot.spec.reset();
}
}
llama_batch_free(batch);
}
@@ -906,34 +894,33 @@ private:
slots.emplace_back();
}
bool no_dft = false;
// try speculative decoding
if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
try {
spec.reset(common_speculative_init(params_base.speculative, params_base.n_parallel));
} catch (const std::exception & e) {
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
}
if (spec) {
SRV_INF("%s", "speculative decoding context initialized\n");
} else {
ctx_dft.reset();
}
}
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot & slot = slots[i];
slot.id = i;
slot.id = i;
slot.ctx_tgt = ctx_tgt;
slot.ctx_dft = ctx_dft.get();
slot.n_ctx = n_ctx_slot;
slot.spec = spec.get();
slot.n_ctx = n_ctx_slot;
slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr;
// try speculative decoding
if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
try {
slot.spec.reset(common_speculative_init(params_base.speculative, slot.id));
} catch (const std::exception & e) {
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
no_dft = true;
}
if (slot.spec) {
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
}
}
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
slot.callback_on_release = [this](int id_slot) {
@@ -943,16 +930,6 @@ private:
slot.reset();
}
if (no_dft && ctx_dft) {
SRV_WRN("%s", "destroying the draft model as it is not going to be used\n");
ctx_dft.reset();
for (auto & slot : slots) {
slot.ctx_dft = nullptr;
}
}
{
const char * LLAMA_TRACE = getenv("LLAMA_TRACE");
trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0;
@@ -1353,7 +1330,7 @@ private:
backend_sampling &= task.params.sampling.backend_sampling;
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
backend_sampling &= !(slot.can_speculate());
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
@@ -3011,7 +2988,7 @@ private:
slot.state = SLOT_STATE_GENERATING;
if (slot.can_speculate()) {
common_speculative_begin(slot.spec.get(), slot.prompt.tokens.get_text_tokens());
common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens());
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
@@ -3126,7 +3103,7 @@ private:
SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft);
}
common_speculative_accept(slot.spec.get(), accepted.size() - 1);
common_speculative_accept(spec.get(), slot.id, accepted.size() - 1);
slot.spec_draft = std::move(accepted);
}