mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 20:44:09 +00:00
spec : refactor for multi-sequence speculative context
This commit is contained in:
@@ -61,10 +61,11 @@ struct server_slot {
|
||||
mtmd_context * mctx = nullptr;
|
||||
|
||||
// speculative decoding
|
||||
common_speculative * spec;
|
||||
|
||||
llama_tokens spec_draft;
|
||||
std::vector<int32_t> spec_i_batch;
|
||||
common_prompt_checkpoint spec_ckpt;
|
||||
common_speculative_ptr spec;
|
||||
|
||||
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
||||
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
||||
@@ -293,14 +294,10 @@ struct server_slot {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative);
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative);
|
||||
|
||||
// note: slot.prompt is not yet expanded with the `id` token sampled above
|
||||
// also, need to leave space for 1 extra token to allow context shifts
|
||||
n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2);
|
||||
int n_draft_max = n_ctx - prompt.n_tokens() - 2;
|
||||
|
||||
if (n_remaining > 0) {
|
||||
n_draft_max = std::min(n_draft_max, n_remaining - 1);
|
||||
@@ -308,11 +305,6 @@ struct server_slot {
|
||||
|
||||
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
|
||||
|
||||
if (n_draft_max < n_draft_min) {
|
||||
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min);
|
||||
n_draft_max = 0;
|
||||
}
|
||||
|
||||
return n_draft_max;
|
||||
}
|
||||
|
||||
@@ -346,7 +338,7 @@ struct server_slot {
|
||||
}
|
||||
|
||||
// generate a new draft
|
||||
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens_text, prompt.n_tokens(), sampled);
|
||||
spec_draft = common_speculative_draft(spec, this->id, params_spec, tokens_text, prompt.n_tokens(), sampled);
|
||||
n_draft_total += spec_draft.size();
|
||||
|
||||
if (spec_draft.size() > (size_t) n_draft_max) {
|
||||
@@ -510,7 +502,7 @@ struct server_slot {
|
||||
);
|
||||
}
|
||||
|
||||
common_speculative_print_stats(spec.get());
|
||||
common_speculative_print_stats(spec);
|
||||
}
|
||||
|
||||
json to_json(bool only_metrics = false) const {
|
||||
@@ -684,6 +676,8 @@ private:
|
||||
common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
|
||||
common_speculative_ptr spec;
|
||||
|
||||
bool add_bos_token = true;
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
@@ -723,12 +717,6 @@ private:
|
||||
mtmd_free(mctx);
|
||||
mctx = nullptr;
|
||||
|
||||
for (server_slot & slot : slots) {
|
||||
if (slot.can_speculate()) {
|
||||
slot.spec.reset();
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
@@ -906,34 +894,33 @@ private:
|
||||
slots.emplace_back();
|
||||
}
|
||||
|
||||
bool no_dft = false;
|
||||
// try speculative decoding
|
||||
if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
try {
|
||||
spec.reset(common_speculative_init(params_base.speculative, params_base.n_parallel));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
|
||||
}
|
||||
|
||||
if (spec) {
|
||||
SRV_INF("%s", "speculative decoding context initialized\n");
|
||||
} else {
|
||||
ctx_dft.reset();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot & slot = slots[i];
|
||||
|
||||
slot.id = i;
|
||||
slot.id = i;
|
||||
slot.ctx_tgt = ctx_tgt;
|
||||
slot.ctx_dft = ctx_dft.get();
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
slot.spec = spec.get();
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
|
||||
slot.mctx = mctx;
|
||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||
|
||||
// try speculative decoding
|
||||
if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
try {
|
||||
slot.spec.reset(common_speculative_init(params_base.speculative, slot.id));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
|
||||
|
||||
no_dft = true;
|
||||
}
|
||||
|
||||
if (slot.spec) {
|
||||
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
||||
}
|
||||
}
|
||||
|
||||
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
||||
|
||||
slot.callback_on_release = [this](int id_slot) {
|
||||
@@ -943,16 +930,6 @@ private:
|
||||
slot.reset();
|
||||
}
|
||||
|
||||
if (no_dft && ctx_dft) {
|
||||
SRV_WRN("%s", "destroying the draft model as it is not going to be used\n");
|
||||
|
||||
ctx_dft.reset();
|
||||
|
||||
for (auto & slot : slots) {
|
||||
slot.ctx_dft = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const char * LLAMA_TRACE = getenv("LLAMA_TRACE");
|
||||
trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0;
|
||||
@@ -1353,7 +1330,7 @@ private:
|
||||
backend_sampling &= task.params.sampling.backend_sampling;
|
||||
|
||||
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
||||
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
|
||||
backend_sampling &= !(slot.can_speculate());
|
||||
|
||||
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
||||
backend_sampling &= !need_logits;
|
||||
@@ -3011,7 +2988,7 @@ private:
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
|
||||
if (slot.can_speculate()) {
|
||||
common_speculative_begin(slot.spec.get(), slot.prompt.tokens.get_text_tokens());
|
||||
common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens());
|
||||
}
|
||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue; // continue loop of slots
|
||||
@@ -3126,7 +3103,7 @@ private:
|
||||
SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft);
|
||||
}
|
||||
|
||||
common_speculative_accept(slot.spec.get(), accepted.size() - 1);
|
||||
common_speculative_accept(spec.get(), slot.id, accepted.size() - 1);
|
||||
|
||||
slot.spec_draft = std::move(accepted);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user