mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-12 12:04:08 +00:00
cont : sync main and drft contexts
This commit is contained in:
@@ -43,11 +43,13 @@ static void server_prompt_checkpoint_update(
|
||||
int id, int64_t n_tokens, bool on_device,
|
||||
llama_pos pos_min = -1,
|
||||
llama_pos pos_max = -1) {
|
||||
auto * ctx = ctx_main ? ctx_main : ctx_drft;
|
||||
|
||||
if (pos_min == -1) {
|
||||
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_main), id);
|
||||
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
|
||||
}
|
||||
if (pos_max == -1) {
|
||||
pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_main), id);
|
||||
pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id);
|
||||
}
|
||||
|
||||
auto flags = LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY;
|
||||
@@ -55,27 +57,29 @@ static void server_prompt_checkpoint_update(
|
||||
flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE;
|
||||
}
|
||||
|
||||
const size_t ckpt_size_main = llama_state_seq_get_size_ext(ctx_main, id, flags);
|
||||
const size_t ckpt_size_drft = ctx_drft ? llama_state_seq_get_size_ext(ctx_drft, id, flags) : 0;
|
||||
|
||||
ckpt.pos_min = pos_min;
|
||||
ckpt.pos_max = pos_max;
|
||||
ckpt.n_tokens = n_tokens;
|
||||
|
||||
ckpt.data_main.resize(ckpt_size_main);
|
||||
ckpt.data_drft.resize(ckpt_size_drft);
|
||||
if (ctx_main) {
|
||||
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_main, id, flags);
|
||||
|
||||
{
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_main, ckpt.data_main.data(), ckpt_size_main, id, flags);
|
||||
if (n != ckpt_size_main) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_main, n);
|
||||
ckpt.data_main.resize(ckpt_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_main, ckpt.data_main.data(), ckpt_size, id, flags);
|
||||
if (n != ckpt_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx_drft) {
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_drft, ckpt.data_drft.data(), ckpt_size_drft, id, flags);
|
||||
if (n != ckpt_size_drft) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_drft, n);
|
||||
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_drft, id, flags);
|
||||
|
||||
ckpt.data_drft.resize(ckpt_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_drft, ckpt.data_drft.data(), ckpt_size, id, flags);
|
||||
if (n != ckpt_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -383,12 +387,7 @@ struct server_slot {
|
||||
if (use_ckpt) {
|
||||
const auto n_tokens = prompt.tokens.size();
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
server_prompt_checkpoint_update(spec_ckpt, ctx_main, ctx_drft, this->id, n_tokens, true);
|
||||
|
||||
//const int64_t t_total = ggml_time_us() - t_start;
|
||||
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
|
||||
server_prompt_checkpoint_update(spec_ckpt, nullptr, ctx_drft, this->id, n_tokens, true);
|
||||
|
||||
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
|
||||
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024);
|
||||
@@ -403,6 +402,21 @@ struct server_slot {
|
||||
spec_draft.resize(n_draft_max);
|
||||
}
|
||||
|
||||
if (!spec_draft.empty() && use_ckpt) {
|
||||
const auto n_tokens = prompt.tokens.size();
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
server_prompt_checkpoint_update(spec_ckpt, ctx_main, nullptr, this->id, n_tokens, true);
|
||||
|
||||
//const int64_t t_total = ggml_time_us() - t_start;
|
||||
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
|
||||
|
||||
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
|
||||
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
|
||||
if (ctx_drft) {
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_drft, spec_ckpt.data_drft.data(), spec_ckpt.data_drft.size(), this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != spec_ckpt.data_drft.size()) {
|
||||
@@ -590,6 +604,11 @@ struct server_slot {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_main), other.id, -1, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx_main), id, other.id, -1, -1);
|
||||
|
||||
if (ctx_drft) {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_drft), other.id, -1, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx_drft), id, other.id, -1, -1);
|
||||
}
|
||||
|
||||
other.n_decoded = n_decoded;
|
||||
other.n_remaining = n_remaining;
|
||||
other.i_batch = i_batch;
|
||||
@@ -2252,6 +2271,11 @@ private:
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, n_keep , n_keep + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
|
||||
|
||||
if (ctx_drft) {
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, n_keep , n_keep + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
|
||||
}
|
||||
|
||||
// add generated tokens to cache
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
|
||||
{
|
||||
@@ -2431,7 +2455,6 @@ private:
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (can_cache_reuse && n_cache_reuse > 0) {
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
GGML_ASSERT(ctx_drft == nullptr && "TODO: add support for draft context cache reuse");
|
||||
|
||||
size_t head_c = n_past; // cache
|
||||
size_t head_p = n_past; // current prompt
|
||||
@@ -2464,6 +2487,11 @@ private:
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, head_p, head_c);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
|
||||
if (ctx_drft) {
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, head_p, head_c);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
|
||||
n_past++;
|
||||
@@ -2569,7 +2597,7 @@ private:
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) ckpt_size_main / 1024 / 1024);
|
||||
|
||||
if (ctx_drft) {
|
||||
const size_t ckpt_size_drft = it->data_drft.size();
|
||||
const size_t ckpt_size_drft = it->data_drft.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_drft.get(), it->data_drft.data(), ckpt_size_drft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != ckpt_size_drft) {
|
||||
GGML_ABORT("inconsistent draft state");
|
||||
@@ -2639,7 +2667,11 @@ private:
|
||||
|
||||
// there is no common part left
|
||||
slot.n_prompt_tokens_cache = 0;
|
||||
}
|
||||
} else {
|
||||
if (ctx_drft && !llama_memory_seq_rm(llama_get_memory(ctx_drft.get()), slot.id, p0, -1)) {
|
||||
GGML_ABORT("failed to truncate draft context\n");
|
||||
}
|
||||
}
|
||||
|
||||
// If using an alora, there may be uncached tokens that come
|
||||
// before the invocation sequence. When this happens, the
|
||||
@@ -2908,6 +2940,8 @@ private:
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
|
||||
// for methods that require the target embeddings, I think we have to re-evaluate the draft tokens?
|
||||
if (ctx_drft) {
|
||||
// note: for now, to keep things simple, synchronize the target context
|
||||
// TODO: revisit later on
|
||||
@@ -3146,6 +3180,9 @@ private:
|
||||
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
if (slot.ctx_drft) {
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
completion_token_output result;
|
||||
|
||||
Reference in New Issue
Block a user