diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 17e39cb5bf..70df76df0d 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -43,11 +43,13 @@ static void server_prompt_checkpoint_update( int id, int64_t n_tokens, bool on_device, llama_pos pos_min = -1, llama_pos pos_max = -1) { + auto * ctx = ctx_main ? ctx_main : ctx_drft; + if (pos_min == -1) { - pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_main), id); + pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id); } if (pos_max == -1) { - pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_main), id); + pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id); } auto flags = LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY; @@ -55,27 +57,29 @@ static void server_prompt_checkpoint_update( flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE; } - const size_t ckpt_size_main = llama_state_seq_get_size_ext(ctx_main, id, flags); - const size_t ckpt_size_drft = ctx_drft ? llama_state_seq_get_size_ext(ctx_drft, id, flags) : 0; - ckpt.pos_min = pos_min; ckpt.pos_max = pos_max; ckpt.n_tokens = n_tokens; - ckpt.data_main.resize(ckpt_size_main); - ckpt.data_drft.resize(ckpt_size_drft); + if (ctx_main) { + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_main, id, flags); - { - const size_t n = llama_state_seq_get_data_ext(ctx_main, ckpt.data_main.data(), ckpt_size_main, id, flags); - if (n != ckpt_size_main) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_main, n); + ckpt.data_main.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx_main, ckpt.data_main.data(), ckpt_size, id, flags); + if (n != ckpt_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); } } if (ctx_drft) { - const size_t n = llama_state_seq_get_data_ext(ctx_drft, ckpt.data_drft.data(), ckpt_size_drft, id, flags); - if (n != ckpt_size_drft) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_drft, n); + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_drft, id, flags); + + ckpt.data_drft.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx_drft, ckpt.data_drft.data(), ckpt_size, id, flags); + if (n != ckpt_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); } } } @@ -383,12 +387,7 @@ struct server_slot { if (use_ckpt) { const auto n_tokens = prompt.tokens.size(); - //const int64_t t_start = ggml_time_us(); - - server_prompt_checkpoint_update(spec_ckpt, ctx_main, ctx_drft, this->id, n_tokens, true); - - //const int64_t t_total = ggml_time_us() - t_start; - //printf("checkpoint total: %f ms\n", t_total / 1000.0); + server_prompt_checkpoint_update(spec_ckpt, nullptr, ctx_drft, this->id, n_tokens, true); SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024); @@ -403,6 +402,21 @@ struct server_slot { spec_draft.resize(n_draft_max); } + if (!spec_draft.empty() && use_ckpt) { + const auto n_tokens = prompt.tokens.size(); + + //const int64_t t_start = ggml_time_us(); + + server_prompt_checkpoint_update(spec_ckpt, ctx_main, nullptr, this->id, n_tokens, true); + + //const int64_t t_total = ggml_time_us() - t_start; + //printf("checkpoint total: %f ms\n", t_total / 1000.0); + + SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", + spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024); + } + + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] if (ctx_drft) { const size_t n = llama_state_seq_set_data_ext(ctx_drft, spec_ckpt.data_drft.data(), spec_ckpt.data_drft.size(), this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); if (n != spec_ckpt.data_drft.size()) { @@ -590,6 +604,11 @@ struct server_slot { llama_memory_seq_rm(llama_get_memory(ctx_main), other.id, -1, -1); llama_memory_seq_cp(llama_get_memory(ctx_main), id, other.id, -1, -1); + if (ctx_drft) { + llama_memory_seq_rm(llama_get_memory(ctx_drft), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx_drft), id, other.id, -1, -1); + } + other.n_decoded = n_decoded; other.n_remaining = n_remaining; other.i_batch = i_batch; @@ -2252,6 +2271,11 @@ private: llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, n_keep , n_keep + n_discard); llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + if (ctx_drft) { + llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard); + } + // add generated tokens to cache // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 { @@ -2431,7 +2455,6 @@ private: // reuse chunks from the cached prompt by shifting their KV cache in the new position if (can_cache_reuse && n_cache_reuse > 0) { GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - GGML_ASSERT(ctx_drft == nullptr && "TODO: add support for draft context cache reuse"); size_t head_c = n_past; // cache size_t head_p = n_past; // current prompt @@ -2464,6 +2487,11 @@ private: llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, head_p, head_c); llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, head_c, head_c + n_match, kv_shift); + if (ctx_drft) { + llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, head_c, head_c + n_match, kv_shift); + } + for (size_t i = 0; i < n_match; i++) { slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); n_past++; @@ -2569,7 +2597,7 @@ private: SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) ckpt_size_main / 1024 / 1024); if (ctx_drft) { - const size_t ckpt_size_drft = it->data_drft.size(); + const size_t ckpt_size_drft = it->data_drft.size(); const size_t n = llama_state_seq_set_data_ext(ctx_drft.get(), it->data_drft.data(), ckpt_size_drft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); if (n != ckpt_size_drft) { GGML_ABORT("inconsistent draft state"); @@ -2639,7 +2667,11 @@ private: // there is no common part left slot.n_prompt_tokens_cache = 0; - } + } else { + if (ctx_drft && !llama_memory_seq_rm(llama_get_memory(ctx_drft.get()), slot.id, p0, -1)) { + GGML_ABORT("failed to truncate draft context\n"); + } + } // If using an alora, there may be uncached tokens that come // before the invocation sequence. When this happens, the @@ -2908,6 +2940,8 @@ private: continue; // continue loop of n_batch } + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + // for methods that require the target embeddings, I think we have to re-evaluate the draft tokens? if (ctx_drft) { // note: for now, to keep things simple, synchronize the target context // TODO: revisit later on @@ -3146,6 +3180,9 @@ private: SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, slot.prompt.tokens.pos_next(), -1); + if (slot.ctx_drft) { + llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, slot.prompt.tokens.pos_next(), -1); + } for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result;