server, spec : transition to unified spec context

This commit is contained in:
Georgi Gerganov
2026-05-07 17:57:59 +03:00
parent d719d8aafc
commit 78faa2b79f
3 changed files with 37 additions and 217 deletions

View File

@@ -310,8 +310,6 @@ struct common_params_speculative_draft {
llama_context * ctx_tgt = nullptr;
llama_context * ctx_dft = nullptr;
bool use_ckpt = false;
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K

View File

@@ -156,43 +156,24 @@ struct common_speculative_state {
virtual int32_t n_min(const common_params_speculative & params) const = 0;
};
struct common_speculative_checkpoint {
llama_pos pos_min = 0;
llama_pos pos_max = 0;
int64_t n_tokens = 0;
std::vector<uint8_t> data;
size_t size() const {
return data.size();
}
};
struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;
bool use_ckpt = false;
common_speculative_checkpoint ckpt;
llama_seq_id seq_id;
common_sampler * smpl;
llama_batch batch;
llama_tokens prompt_dft;
llama_batch batch;
common_speculative_state_draft(
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_dft,
bool use_ckpt,
llama_seq_id seq_id)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
, use_ckpt(use_ckpt)
, seq_id(seq_id)
{
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
@@ -244,37 +225,6 @@ struct common_speculative_state_draft : public common_speculative_state {
void begin(const llama_tokens & /*prompt*/) override {
}
size_t create_checkpoint(int n_tokens_prompt) {
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), seq_id);
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
ckpt.n_tokens = n_tokens_prompt;
ckpt.data.resize(checkpoint_size);
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}
LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__,
ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024);
return n;
}
size_t restore_checkpoint() {
int seq_id = 0;
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());
}
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, ckpt.pos_max + 1, -1);
return n;
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
@@ -284,152 +234,20 @@ struct common_speculative_state_draft : public common_speculative_state {
auto * spec = this;
auto & batch = spec->batch;
auto & ctx_dft = spec->ctx_dft;
auto & smpl = spec->smpl;
auto & prompt_dft = spec->prompt_dft;
auto & batch = spec->batch;
auto & ctx_dft = spec->ctx_dft;
auto & smpl = spec->smpl;
auto * mem_dft = llama_get_memory(ctx_dft);
int reuse_i = 0; // index of part to be reused in prompt_dft
int reuse_n = 0; // length of part to be reused in prompt_dft
const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max;
const llama_tokens & prompt_cur = prompt_tgt;
const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
if (use_ckpt && i_start > 0) {
LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__);
return;
}
// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt_cur.size() &&
i + cur < (int) prompt_dft.size() &&
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
cur++;
}
if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) {
reuse_i = i;
reuse_n = cur;
}
if (use_ckpt) {
break;
}
}
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
if (use_ckpt && ckpt.n_tokens > reuse_n) {
LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens);
reuse_i = 0;
reuse_n = 0;
ckpt = {};
}
result.clear();
result.reserve(sparams.n_max);
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
llama_memory_clear(mem_dft, false);
prompt_dft.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
result.push_back(prompt_dft[i]);
if (sparams.n_max <= (int) result.size()) {
break;
}
}
return;
}
if (reuse_i > 0) {
GGML_ASSERT(!use_ckpt);
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
if (!is_removed) {
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
return;
}
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
}
if (reuse_n < (int) prompt_dft.size()) {
if (use_ckpt) {
if (ckpt.n_tokens > 0) {
LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
restore_checkpoint();
reuse_n = ckpt.n_tokens;
prompt_dft.resize(reuse_n);
}
} else {
const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1);
if (!is_removed) {
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
return;
}
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
}
}
}
// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);
for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]);
common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false);
prompt_dft.push_back(prompt_cur[i]);
}
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens);
int ret = llama_decode(ctx_dft, batch);
if (ret != 0 && ret != 1) {
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
__func__, ret, prompt_cur.size());
}
if (use_ckpt) {
create_checkpoint(prompt_dft.size());
}
}
const llama_pos n_past = prompt_dft.size();
const llama_pos n_past = prompt_tgt.size();
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);
prompt_dft.push_back(id_last);
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
common_batch_add (batch, id_last, n_past, { seq_id }, true);
int ret = llama_decode(ctx_dft, batch);
if (ret != 0 && ret != 1) {
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
__func__, ret, prompt_cur.size(), prompt_dft.size());
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
}
common_sampler_reset(smpl);
@@ -463,16 +281,13 @@ struct common_speculative_state_draft : public common_speculative_state {
break;
}
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
common_batch_add(batch, id, n_past + i + 1, { seq_id }, true);
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
__func__, i, ret, prompt_cur.size(), prompt_dft.size());
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
}
prompt_dft.push_back(id);
}
if (result.size() < (size_t) sparams.n_min) {
@@ -962,10 +777,9 @@ common_speculative * common_speculative_init(common_params_speculative & params,
break;
case COMMON_SPECULATIVE_TYPE_DRAFT: {
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ params.draft.ctx_tgt,
/* .ctx_dft = */ params.draft.ctx_dft,
/* .use_ckpt = */ params.draft.use_ckpt,
/* .seq_id = */ seq_id
/* .ctx_tgt = */ params.draft.ctx_tgt,
/* .ctx_dft = */ params.draft.ctx_dft,
/* .seq_id = */ seq_id
));
break;
}

View File

@@ -380,16 +380,7 @@ struct server_slot {
} else {
GGML_ASSERT(spec_i_batch.empty());
// generate a new draft
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
n_draft_total += spec_draft.size();
if (spec_draft.size() > (size_t) n_draft_max) {
SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max);
spec_draft.resize(n_draft_max);
}
if (!spec_draft.empty() && use_ckpt) {
if (use_ckpt) {
const auto n_tokens = prompt.tokens.size();
//const int64_t t_start = ggml_time_us();
@@ -402,6 +393,25 @@ struct server_slot {
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024);
}
// generate a new draft
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
n_draft_total += spec_draft.size();
if (spec_draft.size() > (size_t) n_draft_max) {
SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max);
spec_draft.resize(n_draft_max);
}
if (ctx_drft) {
const size_t n = llama_state_seq_set_data_ext(ctx_drft, spec_ckpt.data_drft.data(), spec_ckpt.data_drft.size(), this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != spec_ckpt.data_drft.size()) {
GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, spec_ckpt.pos_min, spec_ckpt.pos_max, spec_ckpt.data_drft.size(), spec_ckpt.data_drft.size(), n);
}
llama_memory_seq_rm(llama_get_memory(ctx_drft), this->id, spec_ckpt.pos_max + 1, -1);
}
}
GGML_ASSERT(spec_draft.size() <= (size_t) n_draft_max);
@@ -841,7 +851,6 @@ private:
params_base.speculative.draft.ctx_tgt = ctx_main;
params_base.speculative.draft.ctx_dft = ctx_drft.get();
params_base.speculative.draft.use_ckpt = ctx_drft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
}
std::string & mmproj_path = params_base.mmproj.path;
@@ -935,6 +944,7 @@ private:
slot.id = i;
slot.ctx_main = ctx_main;
slot.ctx_drft = ctx_drft.get();
slot.n_ctx = n_ctx_slot;
slot.mctx = mctx;
@@ -2899,8 +2909,6 @@ private:
}
if (ctx_drft) {
SRV_WRN("%s", "processing the batch using the draft context\n");
// note: for now, to keep things simple, synchronize the target context
// TODO: revisit later on
llama_synchronize(ctx_main);
@@ -3086,9 +3094,9 @@ private:
{
const size_t n = llama_state_seq_set_data_ext(slot.ctx_main, ckpt.data_main.data(), ckpt.data_main.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
if (n != ckpt.data_main.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.data_main.size(), ckpt.data_main.size(), n);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, ckpt.pos_max + 1, -1);
@@ -3096,9 +3104,9 @@ private:
if (slot.ctx_drft) {
const size_t n = llama_state_seq_set_data_ext(slot.ctx_drft, ckpt.data_drft.data(), ckpt.data_drft.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
if (n != ckpt.data_drft.size()) {
GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.data_drft.size(), ckpt.data_drft.size(), n);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, ckpt.pos_max + 1, -1);