mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-14 13:04:08 +00:00
Merge branch 'master' into pr/18039
This commit is contained in:
@@ -309,8 +309,10 @@ struct server_slot {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative);
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
int n_draft_max = task->params.speculative.n_max;
|
||||
int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative);
|
||||
|
||||
// note: slot.prompt is not yet expanded with the `id` token sampled above
|
||||
// also, need to leave space for 1 extra token to allow context shifts
|
||||
@@ -322,8 +324,8 @@ struct server_slot {
|
||||
|
||||
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
|
||||
|
||||
if (n_draft_max < task->params.speculative.n_min) {
|
||||
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
|
||||
if (n_draft_max < n_draft_min) {
|
||||
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min);
|
||||
n_draft_max = 0;
|
||||
}
|
||||
|
||||
@@ -352,17 +354,13 @@ struct server_slot {
|
||||
|
||||
// generate a new draft
|
||||
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
|
||||
n_draft_total += spec_draft.size();
|
||||
|
||||
if (spec_draft.size() > (size_t) n_draft_max) {
|
||||
SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max);
|
||||
spec_draft.resize(n_draft_max);
|
||||
}
|
||||
|
||||
if (spec_draft.size() < (size_t) params_spec.n_min) {
|
||||
SLT_DBG(*this, "ignoring small draft: %d < %d\n", (int) spec_draft.size(), params_spec.n_min);
|
||||
spec_draft.clear();
|
||||
}
|
||||
|
||||
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
const auto n_tokens = prompt.tokens.size();
|
||||
|
||||
@@ -770,9 +768,9 @@ private:
|
||||
|
||||
if (params_base.speculative.has_dft()) {
|
||||
// TODO speculative: move to common/speculative.cpp?
|
||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
|
||||
const auto & params_spec = params_base.speculative.draft;
|
||||
|
||||
const auto & params_spec = params_base.speculative;
|
||||
SRV_INF("loading draft model '%s'\n", params_spec.mparams.path.c_str());
|
||||
|
||||
auto params_dft = params_base;
|
||||
|
||||
@@ -780,7 +778,7 @@ private:
|
||||
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx);
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.mparams_dft;
|
||||
params_dft.model = params_spec.mparams;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
params_dft.cache_type_k = params_spec.cache_type_k;
|
||||
params_dft.cache_type_v = params_spec.cache_type_v;
|
||||
@@ -800,11 +798,12 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
params_base.speculative.model_dft = model_dft.get();
|
||||
params_base.speculative.model_tgt = model;
|
||||
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
|
||||
params_base.speculative.draft.model = model_dft.get();
|
||||
params_base.speculative.draft.model_tgt = model;
|
||||
|
||||
if (params_base.speculative.eagle3) {
|
||||
params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft);
|
||||
|
||||
if (params_base.speculative.draft.eagle3) {
|
||||
// EAGLE3 current limitation: extracted target features are per-context; multiple slots would overwrite each other
|
||||
if (params_base.n_parallel > 1) {
|
||||
SRV_ERR("%s", "EAGLE3 speculative decoding is not supported with n_parallel > 1\n");
|
||||
@@ -1321,7 +1320,7 @@ private:
|
||||
backend_sampling &= task.params.sampling.backend_sampling;
|
||||
|
||||
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
||||
backend_sampling &= !(slot.can_speculate() && task.params.speculative.n_max > 0);
|
||||
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
|
||||
|
||||
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
||||
backend_sampling &= !need_logits;
|
||||
@@ -3033,7 +3032,6 @@ private:
|
||||
|
||||
// update how many tokens out of those tested were accepted
|
||||
slot.n_draft_accepted += ids.size() - 1;
|
||||
slot.n_draft_total += n_draft;
|
||||
|
||||
// add accepted tokens to the prompt
|
||||
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
||||
@@ -3042,7 +3040,7 @@ private:
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.n_tokens(), -1);
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
completion_token_output result;
|
||||
|
||||
Reference in New Issue
Block a user