server : draft prompt cache and checkpoints

[no ci]
This commit is contained in:
Georgi Gerganov
2026-05-07 12:47:56 +03:00
parent e22a090f12
commit 4c957c4749
3 changed files with 150 additions and 57 deletions

View File

@@ -36,7 +36,13 @@ using json = nlohmann::ordered_json;
constexpr int HTTP_POLLING_SECONDS = 1;
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, bool on_device, llama_pos pos_min = -1, llama_pos pos_max = -1) {
static void server_prompt_checkpoint_update(
server_prompt_checkpoint & ckpt,
llama_context * ctx,
llama_context * ctx_dft,
int id, int64_t n_tokens, bool on_device,
llama_pos pos_min = -1,
llama_pos pos_max = -1) {
if (pos_min == -1) {
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
}
@@ -49,16 +55,30 @@ static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, lla
flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE;
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, flags);
const size_t ckpt_size_main = llama_state_seq_get_size_ext(ctx, id, flags);
const size_t ckpt_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, flags) : 0;
ckpt.pos_min = pos_min;
ckpt.pos_max = pos_max;
ckpt.n_tokens = n_tokens;
ckpt.data.resize(checkpoint_size);
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, flags);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
ckpt.data_main.resize(ckpt_size_main);
ckpt.data_dft.resize (ckpt_size_dft);
{
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data_main.data(), ckpt_size_main, id, flags);
if (n != ckpt_size_main) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_main, n);
}
}
{
if (ctx_dft) {
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data_dft.data(), ckpt_size_dft, id, flags);
if (n != ckpt_size_dft) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_dft, n);
}
}
}
}
@@ -81,6 +101,7 @@ struct server_slot {
int id;
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
// multimodal
mtmd_context * mctx = nullptr;
@@ -133,21 +154,27 @@ struct server_slot {
void prompt_save(server_prompt_cache & prompt_cache) const {
GGML_ASSERT(prompt.data.size() == 0);
const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
const size_t cur_size_main = llama_state_seq_get_size_ext(ctx, id, 0);
const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, 0) : 0;
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
const size_t cur_size = cur_size_main + cur_size_dft;
auto * cur = prompt_cache.alloc(prompt, cur_size);
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n",
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0));
auto * cur = prompt_cache.alloc(prompt, cur_size_main, cur_size_dft);
if (cur == nullptr) {
return;
}
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
llama_state_seq_get_data_ext(ctx, cur->data.main.data(), cur_size_main, id, 0);
if (ctx_dft) {
llama_state_seq_get_data_ext(ctx_dft, cur->data.dft.data(), cur_size_dft, id, 0);
}
}
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
bool res = prompt_cache.load(prompt, tokens, ctx, id);
bool res = prompt_cache.load(prompt, tokens, ctx, ctx_dft, id);
if (!res) {
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
}
@@ -163,6 +190,10 @@ struct server_slot {
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
if (ctx_dft) {
llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1);
}
prompt.tokens.clear();
}
@@ -365,13 +396,13 @@ struct server_slot {
//const int64_t t_start = ggml_time_us();
server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens, true);
server_prompt_checkpoint_update(spec_ckpt, ctx, ctx_dft, this->id, n_tokens, true);
//const int64_t t_total = ggml_time_us() - t_start;
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024);
}
}
@@ -667,10 +698,12 @@ private:
// note: keep these alive - they determine the lifetime of the model, context, etc.
common_init_result_ptr llama_init;
// TODO: rename to ctx_main
llama_context * ctx = nullptr;
llama_batch batch {};
// TODO: rename to *_drft
llama_model_ptr model_dft;
llama_context_ptr ctx_dft;
@@ -1844,18 +1877,18 @@ private:
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
auto & cur = slot.prompt.checkpoints.emplace_back();
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max);
server_prompt_checkpoint_update(cur, ctx, ctx_dft.get(), slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max);
SLT_WRN(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
}
void process_single_task(server_task && task) {
@@ -2390,6 +2423,7 @@ private:
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (can_cache_reuse && n_cache_reuse > 0) {
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
GGML_ASSERT(ctx_dft == nullptr && "TODO: add support for draft context cache reuse");
size_t head_c = n_past; // cache
size_t head_p = n_past; // current prompt
@@ -2515,17 +2549,24 @@ private:
if (!do_reset) {
// restore the context checkpoint
const size_t checkpoint_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
const size_t ckpt_size_main = it->data_main.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data_main.data(), ckpt_size_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_main) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) ckpt_size_main / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) checkpoint_size / 1024 / 1024);
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) ckpt_size_main / 1024 / 1024);
if (ctx_dft) {
const size_t ckpt_size_dft = it->data_dft.size();
const size_t n = llama_state_seq_set_data_ext(ctx_dft.get(), it->data_dft.data(), ckpt_size_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_dft) {
GGML_ABORT("inconsistent draft state");
}
}
}
}
@@ -2543,7 +2584,7 @@ private:
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_max > pos_next) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
} else {
++it;
@@ -3045,13 +3086,25 @@ private:
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n",
ckpt.pos_min, ckpt.pos_max, ckpt.size());
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
{
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data_main.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1);
{
const size_t n = llama_state_seq_set_data_ext(slot.ctx_dft, ckpt.data_dft.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1);
}
slot.prompt.tokens.keep_first(ckpt.n_tokens);
slot.smpl = std::move(smpl_save);

View File

@@ -1981,7 +1981,7 @@ size_t server_prompt_cache::n_tokens() const {
return res;
}
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_dft) {
// first check if the current state is contained fully in the cache
for (auto it = states.begin(); it != states.end(); ++it) {
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
@@ -2005,11 +2005,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
}
}
std::vector<uint8_t> state_data;
std::vector<uint8_t> state_data_main;
std::vector<uint8_t> state_data_dft;
// check if we can allocate enough memory for the new state
try {
state_data.resize(state_size);
state_data_main.resize(state_size_main);
state_data_dft.resize(state_size_dft);
} catch (const std::bad_alloc & e) {
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
@@ -2022,17 +2024,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
return nullptr;
}
auto & cur = states.emplace_back();
cur = {
states.push_back({
/*.tokens =*/ prompt.tokens.clone(),
/*.data =*/ std::move(state_data),
/*.data =*/ {
/*.main =*/ std::move(state_data_main),
/*.dft =*/ std::move(state_data_dft),
},
/*.checkpoints =*/ prompt.checkpoints,
};
});
return &cur;
return &states.back();
}
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, llama_context * ctx_dft, int32_t id_slot) {
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
@@ -2065,16 +2069,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
if (it_best != states.end()) {
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
const size_t size = it_best->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
{
auto & data = it_best->data.main;
return false;
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
it_best->data.clear();
it_best->data.shrink_to_fit();
{
auto & data = it_best->data.dft;
if (!data.empty()) {
GGML_ASSERT(ctx_dft);
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
}
prompt = std::move(*it_best);

View File

@@ -571,36 +571,49 @@ struct server_prompt_checkpoint {
int64_t n_tokens;
std::vector<uint8_t> data;
std::vector<uint8_t> data_main;
std::vector<uint8_t> data_dft;
size_t size() const {
return data.size();
return data_main.size() + data_dft.size();
}
bool empty() const {
return data.empty();
return data_main.empty();
}
void clear() {
pos_min = 0;
pos_max = 0;
n_tokens = 0;
data.clear();
data_main.clear();
data_dft.clear();
}
};
struct server_prompt_data {
std::vector<uint8_t> main;
std::vector<uint8_t> dft;
size_t size() const {
return main.size() + dft.size();
}
};
struct server_prompt {
server_tokens tokens;
std::vector<uint8_t> data;
server_prompt_data data;
std::list<server_prompt_checkpoint> checkpoints;
size_t size() const {
size_t res = data.size();
size_t res = 0;
for (const auto & checkpoint : checkpoints) {
res += checkpoint.size();
res += data.size();
for (const auto & ckpt : checkpoints) {
res += ckpt.size();
}
return res;
@@ -614,7 +627,7 @@ struct server_prompt {
return server_prompt {
tokens.clone(),
data,
checkpoints
checkpoints,
};
}
};
@@ -637,9 +650,9 @@ struct server_prompt_cache {
size_t n_tokens() const;
server_prompt * alloc(const server_prompt & prompt, size_t state_size);
server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_dft);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, llama_context * ctx_dft, int32_t id_slot);
void update();
};