Merge branch 'master' into pr/18039

This commit is contained in:
Georgi Gerganov
2026-04-30 10:08:10 +03:00
281 changed files with 23333 additions and 12703 deletions

View File

@@ -309,8 +309,10 @@ struct server_slot {
return 0;
}
const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative);
// determine the max draft that fits the current slot state
int n_draft_max = task->params.speculative.n_max;
int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative);
// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
@@ -322,8 +324,8 @@ struct server_slot {
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
if (n_draft_max < task->params.speculative.n_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
if (n_draft_max < n_draft_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min);
n_draft_max = 0;
}
@@ -352,17 +354,13 @@ struct server_slot {
// generate a new draft
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
n_draft_total += spec_draft.size();
if (spec_draft.size() > (size_t) n_draft_max) {
SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max);
spec_draft.resize(n_draft_max);
}
if (spec_draft.size() < (size_t) params_spec.n_min) {
SLT_DBG(*this, "ignoring small draft: %d < %d\n", (int) spec_draft.size(), params_spec.n_min);
spec_draft.clear();
}
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
const auto n_tokens = prompt.tokens.size();
@@ -770,9 +768,9 @@ private:
if (params_base.speculative.has_dft()) {
// TODO speculative: move to common/speculative.cpp?
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
const auto & params_spec = params_base.speculative.draft;
const auto & params_spec = params_base.speculative;
SRV_INF("loading draft model '%s'\n", params_spec.mparams.path.c_str());
auto params_dft = params_base;
@@ -780,7 +778,7 @@ private:
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams_dft;
params_dft.model = params_spec.mparams;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
params_dft.cache_type_k = params_spec.cache_type_k;
params_dft.cache_type_v = params_spec.cache_type_v;
@@ -800,11 +798,12 @@ private:
return false;
}
params_base.speculative.model_dft = model_dft.get();
params_base.speculative.model_tgt = model;
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
params_base.speculative.draft.model = model_dft.get();
params_base.speculative.draft.model_tgt = model;
if (params_base.speculative.eagle3) {
params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft);
if (params_base.speculative.draft.eagle3) {
// EAGLE3 current limitation: extracted target features are per-context; multiple slots would overwrite each other
if (params_base.n_parallel > 1) {
SRV_ERR("%s", "EAGLE3 speculative decoding is not supported with n_parallel > 1\n");
@@ -1321,7 +1320,7 @@ private:
backend_sampling &= task.params.sampling.backend_sampling;
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.can_speculate() && task.params.speculative.n_max > 0);
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
@@ -3033,7 +3032,6 @@ private:
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
slot.n_draft_total += n_draft;
// add accepted tokens to the prompt
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
@@ -3042,7 +3040,7 @@ private:
slot.sampled = ids.back(); // last accepted token
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.n_tokens(), -1);
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.tokens.pos_next(), -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;