cont : sync main and drft contexts

This commit is contained in:
Georgi Gerganov
2026-05-07 18:47:34 +03:00
parent 78faa2b79f
commit 30af6bbbdc

View File

@@ -43,11 +43,13 @@ static void server_prompt_checkpoint_update(
int id, int64_t n_tokens, bool on_device,
llama_pos pos_min = -1,
llama_pos pos_max = -1) {
auto * ctx = ctx_main ? ctx_main : ctx_drft;
if (pos_min == -1) {
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_main), id);
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
}
if (pos_max == -1) {
pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_main), id);
pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id);
}
auto flags = LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY;
@@ -55,27 +57,29 @@ static void server_prompt_checkpoint_update(
flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE;
}
const size_t ckpt_size_main = llama_state_seq_get_size_ext(ctx_main, id, flags);
const size_t ckpt_size_drft = ctx_drft ? llama_state_seq_get_size_ext(ctx_drft, id, flags) : 0;
ckpt.pos_min = pos_min;
ckpt.pos_max = pos_max;
ckpt.n_tokens = n_tokens;
ckpt.data_main.resize(ckpt_size_main);
ckpt.data_drft.resize(ckpt_size_drft);
if (ctx_main) {
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_main, id, flags);
{
const size_t n = llama_state_seq_get_data_ext(ctx_main, ckpt.data_main.data(), ckpt_size_main, id, flags);
if (n != ckpt_size_main) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_main, n);
ckpt.data_main.resize(ckpt_size);
const size_t n = llama_state_seq_get_data_ext(ctx_main, ckpt.data_main.data(), ckpt_size, id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}
if (ctx_drft) {
const size_t n = llama_state_seq_get_data_ext(ctx_drft, ckpt.data_drft.data(), ckpt_size_drft, id, flags);
if (n != ckpt_size_drft) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_drft, n);
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_drft, id, flags);
ckpt.data_drft.resize(ckpt_size);
const size_t n = llama_state_seq_get_data_ext(ctx_drft, ckpt.data_drft.data(), ckpt_size, id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}
}
@@ -383,12 +387,7 @@ struct server_slot {
if (use_ckpt) {
const auto n_tokens = prompt.tokens.size();
//const int64_t t_start = ggml_time_us();
server_prompt_checkpoint_update(spec_ckpt, ctx_main, ctx_drft, this->id, n_tokens, true);
//const int64_t t_total = ggml_time_us() - t_start;
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
server_prompt_checkpoint_update(spec_ckpt, nullptr, ctx_drft, this->id, n_tokens, true);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024);
@@ -403,6 +402,21 @@ struct server_slot {
spec_draft.resize(n_draft_max);
}
if (!spec_draft.empty() && use_ckpt) {
const auto n_tokens = prompt.tokens.size();
//const int64_t t_start = ggml_time_us();
server_prompt_checkpoint_update(spec_ckpt, ctx_main, nullptr, this->id, n_tokens, true);
//const int64_t t_total = ggml_time_us() - t_start;
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024);
}
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
if (ctx_drft) {
const size_t n = llama_state_seq_set_data_ext(ctx_drft, spec_ckpt.data_drft.data(), spec_ckpt.data_drft.size(), this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != spec_ckpt.data_drft.size()) {
@@ -590,6 +604,11 @@ struct server_slot {
llama_memory_seq_rm(llama_get_memory(ctx_main), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx_main), id, other.id, -1, -1);
if (ctx_drft) {
llama_memory_seq_rm(llama_get_memory(ctx_drft), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx_drft), id, other.id, -1, -1);
}
other.n_decoded = n_decoded;
other.n_remaining = n_remaining;
other.i_batch = i_batch;
@@ -2252,6 +2271,11 @@ private:
llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
if (ctx_drft) {
llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
}
// add generated tokens to cache
// ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
{
@@ -2431,7 +2455,6 @@ private:
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (can_cache_reuse && n_cache_reuse > 0) {
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
GGML_ASSERT(ctx_drft == nullptr && "TODO: add support for draft context cache reuse");
size_t head_c = n_past; // cache
size_t head_p = n_past; // current prompt
@@ -2464,6 +2487,11 @@ private:
llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, head_c, head_c + n_match, kv_shift);
if (ctx_drft) {
llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, head_c, head_c + n_match, kv_shift);
}
for (size_t i = 0; i < n_match; i++) {
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
n_past++;
@@ -2569,7 +2597,7 @@ private:
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) ckpt_size_main / 1024 / 1024);
if (ctx_drft) {
const size_t ckpt_size_drft = it->data_drft.size();
const size_t ckpt_size_drft = it->data_drft.size();
const size_t n = llama_state_seq_set_data_ext(ctx_drft.get(), it->data_drft.data(), ckpt_size_drft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_drft) {
GGML_ABORT("inconsistent draft state");
@@ -2639,7 +2667,11 @@ private:
// there is no common part left
slot.n_prompt_tokens_cache = 0;
}
} else {
if (ctx_drft && !llama_memory_seq_rm(llama_get_memory(ctx_drft.get()), slot.id, p0, -1)) {
GGML_ABORT("failed to truncate draft context\n");
}
}
// If using an alora, there may be uncached tokens that come
// before the invocation sequence. When this happens, the
@@ -2908,6 +2940,8 @@ private:
continue; // continue loop of n_batch
}
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
// for methods that require the target embeddings, I think we have to re-evaluate the draft tokens?
if (ctx_drft) {
// note: for now, to keep things simple, synchronize the target context
// TODO: revisit later on
@@ -3146,6 +3180,9 @@ private:
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, slot.prompt.tokens.pos_next(), -1);
if (slot.ctx_drft) {
llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, slot.prompt.tokens.pos_next(), -1);
}
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;