server : draft prompt cache and checkpoints

[no ci]
This commit is contained in:
Georgi Gerganov
2026-05-07 12:47:56 +03:00
parent e22a090f12
commit 4c957c4749
3 changed files with 150 additions and 57 deletions

View File

@@ -1981,7 +1981,7 @@ size_t server_prompt_cache::n_tokens() const {
return res;
}
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_dft) {
// first check if the current state is contained fully in the cache
for (auto it = states.begin(); it != states.end(); ++it) {
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
@@ -2005,11 +2005,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
}
}
std::vector<uint8_t> state_data;
std::vector<uint8_t> state_data_main;
std::vector<uint8_t> state_data_dft;
// check if we can allocate enough memory for the new state
try {
state_data.resize(state_size);
state_data_main.resize(state_size_main);
state_data_dft.resize(state_size_dft);
} catch (const std::bad_alloc & e) {
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
@@ -2022,17 +2024,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
return nullptr;
}
auto & cur = states.emplace_back();
cur = {
states.push_back({
/*.tokens =*/ prompt.tokens.clone(),
/*.data =*/ std::move(state_data),
/*.data =*/ {
/*.main =*/ std::move(state_data_main),
/*.dft =*/ std::move(state_data_dft),
},
/*.checkpoints =*/ prompt.checkpoints,
};
});
return &cur;
return &states.back();
}
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, llama_context * ctx_dft, int32_t id_slot) {
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
@@ -2065,16 +2069,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
if (it_best != states.end()) {
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
const size_t size = it_best->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
{
auto & data = it_best->data.main;
return false;
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
it_best->data.clear();
it_best->data.shrink_to_fit();
{
auto & data = it_best->data.dft;
if (!data.empty()) {
GGML_ASSERT(ctx_dft);
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
}
prompt = std::move(*it_best);