mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-14 04:54:06 +00:00
server : draft prompt cache and checkpoints
[no ci]
This commit is contained in:
@@ -36,7 +36,13 @@ using json = nlohmann::ordered_json;
|
||||
|
||||
constexpr int HTTP_POLLING_SECONDS = 1;
|
||||
|
||||
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, bool on_device, llama_pos pos_min = -1, llama_pos pos_max = -1) {
|
||||
static void server_prompt_checkpoint_update(
|
||||
server_prompt_checkpoint & ckpt,
|
||||
llama_context * ctx,
|
||||
llama_context * ctx_dft,
|
||||
int id, int64_t n_tokens, bool on_device,
|
||||
llama_pos pos_min = -1,
|
||||
llama_pos pos_max = -1) {
|
||||
if (pos_min == -1) {
|
||||
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
|
||||
}
|
||||
@@ -49,16 +55,30 @@ static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, lla
|
||||
flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE;
|
||||
}
|
||||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, flags);
|
||||
const size_t ckpt_size_main = llama_state_seq_get_size_ext(ctx, id, flags);
|
||||
const size_t ckpt_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, flags) : 0;
|
||||
|
||||
ckpt.pos_min = pos_min;
|
||||
ckpt.pos_max = pos_max;
|
||||
ckpt.n_tokens = n_tokens;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, flags);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
ckpt.data_main.resize(ckpt_size_main);
|
||||
ckpt.data_dft.resize (ckpt_size_dft);
|
||||
|
||||
{
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data_main.data(), ckpt_size_main, id, flags);
|
||||
if (n != ckpt_size_main) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_main, n);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (ctx_dft) {
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data_dft.data(), ckpt_size_dft, id, flags);
|
||||
if (n != ckpt_size_dft) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size_dft, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,6 +101,7 @@ struct server_slot {
|
||||
int id;
|
||||
|
||||
llama_context * ctx = nullptr;
|
||||
llama_context * ctx_dft = nullptr;
|
||||
|
||||
// multimodal
|
||||
mtmd_context * mctx = nullptr;
|
||||
@@ -133,21 +154,27 @@ struct server_slot {
|
||||
void prompt_save(server_prompt_cache & prompt_cache) const {
|
||||
GGML_ASSERT(prompt.data.size() == 0);
|
||||
|
||||
const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
|
||||
const size_t cur_size_main = llama_state_seq_get_size_ext(ctx, id, 0);
|
||||
const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, 0) : 0;
|
||||
|
||||
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
|
||||
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
|
||||
const size_t cur_size = cur_size_main + cur_size_dft;
|
||||
|
||||
auto * cur = prompt_cache.alloc(prompt, cur_size);
|
||||
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n",
|
||||
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0));
|
||||
|
||||
auto * cur = prompt_cache.alloc(prompt, cur_size_main, cur_size_dft);
|
||||
if (cur == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
|
||||
llama_state_seq_get_data_ext(ctx, cur->data.main.data(), cur_size_main, id, 0);
|
||||
if (ctx_dft) {
|
||||
llama_state_seq_get_data_ext(ctx_dft, cur->data.dft.data(), cur_size_dft, id, 0);
|
||||
}
|
||||
}
|
||||
|
||||
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
||||
bool res = prompt_cache.load(prompt, tokens, ctx, id);
|
||||
bool res = prompt_cache.load(prompt, tokens, ctx, ctx_dft, id);
|
||||
if (!res) {
|
||||
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
|
||||
}
|
||||
@@ -163,6 +190,10 @@ struct server_slot {
|
||||
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
|
||||
if (ctx_dft) {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1);
|
||||
}
|
||||
|
||||
prompt.tokens.clear();
|
||||
}
|
||||
|
||||
@@ -365,13 +396,13 @@ struct server_slot {
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens, true);
|
||||
server_prompt_checkpoint_update(spec_ckpt, ctx, ctx_dft, this->id, n_tokens, true);
|
||||
|
||||
//const int64_t t_total = ggml_time_us() - t_start;
|
||||
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
|
||||
|
||||
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
|
||||
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
|
||||
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.size() / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -667,10 +698,12 @@ private:
|
||||
// note: keep these alive - they determine the lifetime of the model, context, etc.
|
||||
common_init_result_ptr llama_init;
|
||||
|
||||
// TODO: rename to ctx_main
|
||||
llama_context * ctx = nullptr;
|
||||
|
||||
llama_batch batch {};
|
||||
|
||||
// TODO: rename to *_drft
|
||||
llama_model_ptr model_dft;
|
||||
llama_context_ptr ctx_dft;
|
||||
|
||||
@@ -1844,18 +1877,18 @@ private:
|
||||
const auto & cur = slot.prompt.checkpoints.front();
|
||||
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
|
||||
|
||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back();
|
||||
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max);
|
||||
server_prompt_checkpoint_update(cur, ctx, ctx_dft.get(), slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max);
|
||||
|
||||
SLT_WRN(slot,
|
||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
|
||||
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
void process_single_task(server_task && task) {
|
||||
@@ -2390,6 +2423,7 @@ private:
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (can_cache_reuse && n_cache_reuse > 0) {
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
GGML_ASSERT(ctx_dft == nullptr && "TODO: add support for draft context cache reuse");
|
||||
|
||||
size_t head_c = n_past; // cache
|
||||
size_t head_p = n_past; // current prompt
|
||||
@@ -2515,17 +2549,24 @@ private:
|
||||
|
||||
if (!do_reset) {
|
||||
// restore the context checkpoint
|
||||
const size_t checkpoint_size = it->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
if (n != checkpoint_size) {
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
|
||||
const size_t ckpt_size_main = it->data_main.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it->data_main.data(), ckpt_size_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != ckpt_size_main) {
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) ckpt_size_main / 1024 / 1024);
|
||||
do_reset = true;
|
||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||
} else {
|
||||
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
|
||||
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) checkpoint_size / 1024 / 1024);
|
||||
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) ckpt_size_main / 1024 / 1024);
|
||||
|
||||
if (ctx_dft) {
|
||||
const size_t ckpt_size_dft = it->data_dft.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft.get(), it->data_dft.data(), ckpt_size_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != ckpt_size_dft) {
|
||||
GGML_ABORT("inconsistent draft state");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2543,7 +2584,7 @@ private:
|
||||
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
|
||||
const auto & cur = *it;
|
||||
if (cur.pos_max > pos_next) {
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.data.size() / 1024 / 1024);
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024);
|
||||
it = slot.prompt.checkpoints.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
@@ -3045,13 +3086,25 @@ private:
|
||||
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n",
|
||||
ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
|
||||
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != ckpt.size()) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
|
||||
{
|
||||
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data_main.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != ckpt.size()) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
|
||||
}
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1);
|
||||
{
|
||||
const size_t n = llama_state_seq_set_data_ext(slot.ctx_dft, ckpt.data_dft.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != ckpt.size()) {
|
||||
GGML_ABORT("%s: failed to restore draft checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
|
||||
}
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
slot.prompt.tokens.keep_first(ckpt.n_tokens);
|
||||
slot.smpl = std::move(smpl_save);
|
||||
|
||||
@@ -1981,7 +1981,7 @@ size_t server_prompt_cache::n_tokens() const {
|
||||
return res;
|
||||
}
|
||||
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_dft) {
|
||||
// first check if the current state is contained fully in the cache
|
||||
for (auto it = states.begin(); it != states.end(); ++it) {
|
||||
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
|
||||
@@ -2005,11 +2005,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> state_data;
|
||||
std::vector<uint8_t> state_data_main;
|
||||
std::vector<uint8_t> state_data_dft;
|
||||
|
||||
// check if we can allocate enough memory for the new state
|
||||
try {
|
||||
state_data.resize(state_size);
|
||||
state_data_main.resize(state_size_main);
|
||||
state_data_dft.resize(state_size_dft);
|
||||
} catch (const std::bad_alloc & e) {
|
||||
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
|
||||
|
||||
@@ -2022,17 +2024,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto & cur = states.emplace_back();
|
||||
cur = {
|
||||
states.push_back({
|
||||
/*.tokens =*/ prompt.tokens.clone(),
|
||||
/*.data =*/ std::move(state_data),
|
||||
/*.data =*/ {
|
||||
/*.main =*/ std::move(state_data_main),
|
||||
/*.dft =*/ std::move(state_data_dft),
|
||||
},
|
||||
/*.checkpoints =*/ prompt.checkpoints,
|
||||
};
|
||||
});
|
||||
|
||||
return &cur;
|
||||
return &states.back();
|
||||
}
|
||||
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, llama_context * ctx_dft, int32_t id_slot) {
|
||||
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
|
||||
|
||||
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
|
||||
@@ -2065,16 +2069,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
|
||||
if (it_best != states.end()) {
|
||||
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
|
||||
const size_t size = it_best->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
{
|
||||
auto & data = it_best->data.main;
|
||||
|
||||
return false;
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
|
||||
it_best->data.clear();
|
||||
it_best->data.shrink_to_fit();
|
||||
{
|
||||
auto & data = it_best->data.dft;
|
||||
|
||||
if (!data.empty()) {
|
||||
GGML_ASSERT(ctx_dft);
|
||||
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
prompt = std::move(*it_best);
|
||||
|
||||
|
||||
@@ -571,36 +571,49 @@ struct server_prompt_checkpoint {
|
||||
|
||||
int64_t n_tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
std::vector<uint8_t> data_main;
|
||||
std::vector<uint8_t> data_dft;
|
||||
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
return data_main.size() + data_dft.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return data.empty();
|
||||
return data_main.empty();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
pos_min = 0;
|
||||
pos_max = 0;
|
||||
n_tokens = 0;
|
||||
data.clear();
|
||||
data_main.clear();
|
||||
data_dft.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt_data {
|
||||
std::vector<uint8_t> main;
|
||||
std::vector<uint8_t> dft;
|
||||
|
||||
size_t size() const {
|
||||
return main.size() + dft.size();
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt {
|
||||
server_tokens tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
server_prompt_data data;
|
||||
|
||||
std::list<server_prompt_checkpoint> checkpoints;
|
||||
|
||||
size_t size() const {
|
||||
size_t res = data.size();
|
||||
size_t res = 0;
|
||||
|
||||
for (const auto & checkpoint : checkpoints) {
|
||||
res += checkpoint.size();
|
||||
res += data.size();
|
||||
|
||||
for (const auto & ckpt : checkpoints) {
|
||||
res += ckpt.size();
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -614,7 +627,7 @@ struct server_prompt {
|
||||
return server_prompt {
|
||||
tokens.clone(),
|
||||
data,
|
||||
checkpoints
|
||||
checkpoints,
|
||||
};
|
||||
}
|
||||
};
|
||||
@@ -637,9 +650,9 @@ struct server_prompt_cache {
|
||||
|
||||
size_t n_tokens() const;
|
||||
|
||||
server_prompt * alloc(const server_prompt & prompt, size_t state_size);
|
||||
server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_dft);
|
||||
|
||||
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
|
||||
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, llama_context * ctx_dft, int32_t id_slot);
|
||||
|
||||
void update();
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user