server : draft prompt cache and checkpoints

[no ci]
This commit is contained in:
Georgi Gerganov
2026-05-07 12:47:56 +03:00
parent e22a090f12
commit 4c957c4749
3 changed files with 150 additions and 57 deletions

View File

@@ -36,7 +36,13 @@ using json = nlohmann::ordered_json;
constexpr int HTTP_POLLING_SECONDS = 1;
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, bool on_device, llama_pos pos_min = -1, llama_pos pos_max = -1) {
static void server_prompt_checkpoint_update(
server_prompt_checkpoint & ckpt,
llama_context * ctx,
llama_context * ctx_dft,
int id, int64_t n_tokens, bool on_device,
llama_pos pos_min = -1,
llama_pos pos_max = -1) {
if (pos_min == -1) {
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
}
@@ -49,16 +55,30 @@ static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, lla
flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE;
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, flags);
const size_t ckpt_size_main = llama_state_seq_get_size_ext(ctx, id, flags);
const size_t ckpt_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, flags) : 0;
ckpt.pos_min = pos_min;
ckpt.pos_max = pos_max;
ckpt.n_tokens = n_tokens;
ckpt.data.resize(checkpoint_size);
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, flags);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
ckpt.data_main.resize(ckpt_size_main);
ckpt.data_dft.resize (ckpt_size_dft);
{
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data_main.data(), ckpt_size_main, id, flags);
if (n != ckpt_size_main) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_main, n);
}
}
{
if (ctx_dft) {
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data_dft.data(), ckpt_size_dft, id, flags);
if (n != ckpt_size_dft) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_dft, n);
}
}
}
}
@@ -81,6 +101,7 @@ struct server_slot {
int id;
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
// multimodal
mtmd_context * mctx = nullptr;
@@ -133,21 +154,27 @@ struct server_slot {
void prompt_save(server_prompt_cache & prompt_cache) const {
GGML_ASSERT(prompt.data.size() == 0);
const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
const size_t cur_size_main = llama_state_seq_get_size_ext(ctx, id, 0);
const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, 0) : 0;
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
const size_t cur_size = cur_size_main + cur_size_dft;
auto * cur = prompt_cache.alloc(prompt, cur_size);
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n",
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0));
auto * cur = prompt_cache.alloc(prompt, cur_size_main, cur_size_dft);
if (cur == nullptr) {
return;
}
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
llama_state_seq_get_data_ext(ctx, cur->data.main.data(), cur_size_main, id, 0);
if (ctx_dft) {
llama_state_seq_get_data_ext(ctx_dft, cur->data.dft.data(), cur_size_dft, id, 0);
}
}
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
bool res = prompt_cache.load(prompt, tokens, ctx, id);
bool res = prompt_cache.load(prompt, tokens, ctx, ctx_dft, id);
if (!res) {
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
}
@@ -163,6 +190,10 @@ struct server_slot {
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
if (ctx_dft) {
llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1);
}
prompt.tokens.clear();
}
@@ -365,13 +396,13 @@ struct server_slot {
//const int64_t t_start = ggml_time_us();
server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens, true);
server_prompt_checkpoint_update(spec_ckpt, ctx, ctx_dft, this->id, n_tokens, true);
//const int64_t t_total = ggml_time_us() - t_start;
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024);
}
}
@@ -667,10 +698,12 @@ private:
// note: keep these alive - they determine the lifetime of the model, context, etc.
common_init_result_ptr llama_init;
// TODO: rename to ctx_main
llama_context * ctx = nullptr;
llama_batch batch {};
// TODO: rename to *_drft
llama_model_ptr model_dft;
llama_context_ptr ctx_dft;
@@ -1844,18 +1877,18 @@ private:
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
auto & cur = slot.prompt.checkpoints.emplace_back();
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max);
server_prompt_checkpoint_update(cur, ctx, ctx_dft.get(), slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max);
SLT_WRN(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
}
void process_single_task(server_task && task) {
@@ -2390,6 +2423,7 @@ private:
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (can_cache_reuse && n_cache_reuse > 0) {
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
GGML_ASSERT(ctx_dft == nullptr && "TODO: add support for draft context cache reuse");
size_t head_c = n_past; // cache
size_t head_p = n_past; // current prompt
@@ -2515,17 +2549,24 @@ private:
if (!do_reset) {
// restore the context checkpoint
const size_t checkpoint_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
const size_t ckpt_size_main = it->data_main.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data_main.data(), ckpt_size_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_main) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) ckpt_size_main / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) checkpoint_size / 1024 / 1024);
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) ckpt_size_main / 1024 / 1024);
if (ctx_dft) {
const size_t ckpt_size_dft = it->data_dft.size();
const size_t n = llama_state_seq_set_data_ext(ctx_dft.get(), it->data_dft.data(), ckpt_size_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_dft) {
GGML_ABORT("inconsistent draft state");
}
}
}
}
@@ -2543,7 +2584,7 @@ private:
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_max > pos_next) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
} else {
++it;
@@ -3045,13 +3086,25 @@ private:
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n",
ckpt.pos_min, ckpt.pos_max, ckpt.size());
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
{
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data_main.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1);
{
const size_t n = llama_state_seq_set_data_ext(slot.ctx_dft, ckpt.data_dft.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1);
}
slot.prompt.tokens.keep_first(ckpt.n_tokens);
slot.smpl = std::move(smpl_save);