server, spec : transition to unified spec context

This commit is contained in:
Georgi Gerganov
2026-05-07 17:57:59 +03:00
parent d719d8aafc
commit 78faa2b79f
3 changed files with 37 additions and 217 deletions

View File

@@ -380,16 +380,7 @@ struct server_slot {
} else {
GGML_ASSERT(spec_i_batch.empty());
// generate a new draft
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
n_draft_total += spec_draft.size();
if (spec_draft.size() > (size_t) n_draft_max) {
SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max);
spec_draft.resize(n_draft_max);
}
if (!spec_draft.empty() && use_ckpt) {
if (use_ckpt) {
const auto n_tokens = prompt.tokens.size();
//const int64_t t_start = ggml_time_us();
@@ -402,6 +393,25 @@ struct server_slot {
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024);
}
// generate a new draft
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
n_draft_total += spec_draft.size();
if (spec_draft.size() > (size_t) n_draft_max) {
SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max);
spec_draft.resize(n_draft_max);
}
if (ctx_drft) {
const size_t n = llama_state_seq_set_data_ext(ctx_drft, spec_ckpt.data_drft.data(), spec_ckpt.data_drft.size(), this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != spec_ckpt.data_drft.size()) {
GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, spec_ckpt.pos_min, spec_ckpt.pos_max, spec_ckpt.data_drft.size(), spec_ckpt.data_drft.size(), n);
}
llama_memory_seq_rm(llama_get_memory(ctx_drft), this->id, spec_ckpt.pos_max + 1, -1);
}
}
GGML_ASSERT(spec_draft.size() <= (size_t) n_draft_max);
@@ -841,7 +851,6 @@ private:
params_base.speculative.draft.ctx_tgt = ctx_main;
params_base.speculative.draft.ctx_dft = ctx_drft.get();
params_base.speculative.draft.use_ckpt = ctx_drft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
}
std::string & mmproj_path = params_base.mmproj.path;
@@ -935,6 +944,7 @@ private:
slot.id = i;
slot.ctx_main = ctx_main;
slot.ctx_drft = ctx_drft.get();
slot.n_ctx = n_ctx_slot;
slot.mctx = mctx;
@@ -2899,8 +2909,6 @@ private:
}
if (ctx_drft) {
SRV_WRN("%s", "processing the batch using the draft context\n");
// note: for now, to keep things simple, synchronize the target context
// TODO: revisit later on
llama_synchronize(ctx_main);
@@ -3086,9 +3094,9 @@ private:
{
const size_t n = llama_state_seq_set_data_ext(slot.ctx_main, ckpt.data_main.data(), ckpt.data_main.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
if (n != ckpt.data_main.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.data_main.size(), ckpt.data_main.size(), n);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, ckpt.pos_max + 1, -1);
@@ -3096,9 +3104,9 @@ private:
if (slot.ctx_drft) {
const size_t n = llama_state_seq_set_data_ext(slot.ctx_drft, ckpt.data_drft.data(), ckpt.data_drft.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
if (n != ckpt.data_drft.size()) {
GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.data_drft.size(), ckpt.data_drft.size(), n);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, ckpt.pos_max + 1, -1);