diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index e3c2749809..e0036312bf 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -36,7 +36,13 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; -static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, bool on_device, llama_pos pos_min = -1, llama_pos pos_max = -1) { +static void server_prompt_checkpoint_update( + server_prompt_checkpoint & ckpt, + llama_context * ctx, + llama_context * ctx_dft, + int id, int64_t n_tokens, bool on_device, + llama_pos pos_min = -1, + llama_pos pos_max = -1) { if (pos_min == -1) { pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id); } @@ -49,16 +55,30 @@ static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, lla flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE; } - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, flags); + const size_t ckpt_size_main = llama_state_seq_get_size_ext(ctx, id, flags); + const size_t ckpt_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, flags) : 0; ckpt.pos_min = pos_min; ckpt.pos_max = pos_max; ckpt.n_tokens = n_tokens; - ckpt.data.resize(checkpoint_size); - const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, flags); - if (n != checkpoint_size) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + ckpt.data_main.resize(ckpt_size_main); + ckpt.data_dft.resize (ckpt_size_dft); + + { + const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data_main.data(), ckpt_size_main, id, flags); + if (n != ckpt_size_main) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_main, n); + } + } + + { + if (ctx_dft) { + const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data_dft.data(), ckpt_size_dft, id, flags); + if (n != ckpt_size_dft) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_dft, n); + } + } } } @@ -81,6 +101,7 @@ struct server_slot { int id; llama_context * ctx = nullptr; + llama_context * ctx_dft = nullptr; // multimodal mtmd_context * mctx = nullptr; @@ -133,21 +154,27 @@ struct server_slot { void prompt_save(server_prompt_cache & prompt_cache) const { GGML_ASSERT(prompt.data.size() == 0); - const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); + const size_t cur_size_main = llama_state_seq_get_size_ext(ctx, id, 0); + const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, 0) : 0; - SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", - (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + const size_t cur_size = cur_size_main + cur_size_dft; - auto * cur = prompt_cache.alloc(prompt, cur_size); + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size_main, cur_size_dft); if (cur == nullptr) { return; } - llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + llama_state_seq_get_data_ext(ctx, cur->data.main.data(), cur_size_main, id, 0); + if (ctx_dft) { + llama_state_seq_get_data_ext(ctx_dft, cur->data.dft.data(), cur_size_dft, id, 0); + } } bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - bool res = prompt_cache.load(prompt, tokens, ctx, id); + bool res = prompt_cache.load(prompt, tokens, ctx, ctx_dft, id); if (!res) { SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); } @@ -163,6 +190,10 @@ struct server_slot { SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size()); llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1); + if (ctx_dft) { + llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1); + } + prompt.tokens.clear(); } @@ -365,13 +396,13 @@ struct server_slot { //const int64_t t_start = ggml_time_us(); - server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens, true); + server_prompt_checkpoint_update(spec_ckpt, ctx, ctx_dft, this->id, n_tokens, true); //const int64_t t_total = ggml_time_us() - t_start; //printf("checkpoint total: %f ms\n", t_total / 1000.0); SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", - spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); + spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024); } } @@ -667,10 +698,12 @@ private: // note: keep these alive - they determine the lifetime of the model, context, etc. common_init_result_ptr llama_init; + // TODO: rename to ctx_main llama_context * ctx = nullptr; llama_batch batch {}; + // TODO: rename to *_drft llama_model_ptr model_dft; llama_context_ptr ctx_dft; @@ -1844,18 +1877,18 @@ private: const auto & cur = slot.prompt.checkpoints.front(); SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } auto & cur = slot.prompt.checkpoints.emplace_back(); - server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max); + server_prompt_checkpoint_update(cur, ctx, ctx_dft.get(), slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max); SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, - cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); } void process_single_task(server_task && task) { @@ -2390,6 +2423,7 @@ private: // reuse chunks from the cached prompt by shifting their KV cache in the new position if (can_cache_reuse && n_cache_reuse > 0) { GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + GGML_ASSERT(ctx_dft == nullptr && "TODO: add support for draft context cache reuse"); size_t head_c = n_past; // cache size_t head_p = n_past; // current prompt @@ -2515,17 +2549,24 @@ private: if (!do_reset) { // restore the context checkpoint - const size_t checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024); + const size_t ckpt_size_main = it->data_main.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it->data_main.data(), ckpt_size_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != ckpt_size_main) { + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) ckpt_size_main / 1024 / 1024); do_reset = true; //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); - n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) checkpoint_size / 1024 / 1024); + n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) ckpt_size_main / 1024 / 1024); + + if (ctx_dft) { + const size_t ckpt_size_dft = it->data_dft.size(); + const size_t n = llama_state_seq_set_data_ext(ctx_dft.get(), it->data_dft.data(), ckpt_size_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != ckpt_size_dft) { + GGML_ABORT("inconsistent draft state"); + } + } } } @@ -2543,7 +2584,7 @@ private: for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; if (cur.pos_max > pos_next) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024); it = slot.prompt.checkpoints.erase(it); } else { ++it; @@ -3045,13 +3086,25 @@ private: SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); - const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n); + { + const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data_main.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + if (n != ckpt.size()) { + GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n); + } + + llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1); } - llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1); + { + const size_t n = llama_state_seq_set_data_ext(slot.ctx_dft, ckpt.data_dft.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + if (n != ckpt.size()) { + GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n); + } + + llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1); + } slot.prompt.tokens.keep_first(ckpt.n_tokens); slot.smpl = std::move(smpl_save); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 45e5168fab..7d40b48c64 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -1981,7 +1981,7 @@ size_t server_prompt_cache::n_tokens() const { return res; } -server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) { +server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_dft) { // first check if the current state is contained fully in the cache for (auto it = states.begin(); it != states.end(); ++it) { const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); @@ -2005,11 +2005,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t } } - std::vector state_data; + std::vector state_data_main; + std::vector state_data_dft; // check if we can allocate enough memory for the new state try { - state_data.resize(state_size); + state_data_main.resize(state_size_main); + state_data_dft.resize(state_size_dft); } catch (const std::bad_alloc & e) { SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); @@ -2022,17 +2024,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t return nullptr; } - auto & cur = states.emplace_back(); - cur = { + states.push_back({ /*.tokens =*/ prompt.tokens.clone(), - /*.data =*/ std::move(state_data), + /*.data =*/ { + /*.main =*/ std::move(state_data_main), + /*.dft =*/ std::move(state_data_dft), + }, /*.checkpoints =*/ prompt.checkpoints, - }; + }); - return &cur; + return &states.back(); } -bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { +bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, llama_context * ctx_dft, int32_t id_slot) { const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins @@ -2065,16 +2069,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok if (it_best != states.end()) { SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); - if (n != size) { - SRV_WRN("failed to restore state with size %zu\n", size); + { + auto & data = it_best->data.main; - return false; + const size_t size = data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + data.clear(); + data.shrink_to_fit(); } - it_best->data.clear(); - it_best->data.shrink_to_fit(); + { + auto & data = it_best->data.dft; + + if (!data.empty()) { + GGML_ASSERT(ctx_dft); + + const size_t size = data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + data.clear(); + data.shrink_to_fit(); + } + } prompt = std::move(*it_best); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 289e1fb8d2..8618fb3887 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -571,36 +571,49 @@ struct server_prompt_checkpoint { int64_t n_tokens; - std::vector data; + std::vector data_main; + std::vector data_dft; size_t size() const { - return data.size(); + return data_main.size() + data_dft.size(); } bool empty() const { - return data.empty(); + return data_main.empty(); } void clear() { pos_min = 0; pos_max = 0; n_tokens = 0; - data.clear(); + data_main.clear(); + data_dft.clear(); + } +}; + +struct server_prompt_data { + std::vector main; + std::vector dft; + + size_t size() const { + return main.size() + dft.size(); } }; struct server_prompt { server_tokens tokens; - std::vector data; + server_prompt_data data; std::list checkpoints; size_t size() const { - size_t res = data.size(); + size_t res = 0; - for (const auto & checkpoint : checkpoints) { - res += checkpoint.size(); + res += data.size(); + + for (const auto & ckpt : checkpoints) { + res += ckpt.size(); } return res; @@ -614,7 +627,7 @@ struct server_prompt { return server_prompt { tokens.clone(), data, - checkpoints + checkpoints, }; } }; @@ -637,9 +650,9 @@ struct server_prompt_cache { size_t n_tokens() const; - server_prompt * alloc(const server_prompt & prompt, size_t state_size); + server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_dft); - bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot); + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, llama_context * ctx_dft, int32_t id_slot); void update(); };