mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-11 03:24:21 +00:00
server : draft prompt cache and checkpoints
[no ci]
This commit is contained in:
@@ -1981,7 +1981,7 @@ size_t server_prompt_cache::n_tokens() const {
|
||||
return res;
|
||||
}
|
||||
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_dft) {
|
||||
// first check if the current state is contained fully in the cache
|
||||
for (auto it = states.begin(); it != states.end(); ++it) {
|
||||
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
|
||||
@@ -2005,11 +2005,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> state_data;
|
||||
std::vector<uint8_t> state_data_main;
|
||||
std::vector<uint8_t> state_data_dft;
|
||||
|
||||
// check if we can allocate enough memory for the new state
|
||||
try {
|
||||
state_data.resize(state_size);
|
||||
state_data_main.resize(state_size_main);
|
||||
state_data_dft.resize(state_size_dft);
|
||||
} catch (const std::bad_alloc & e) {
|
||||
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
|
||||
|
||||
@@ -2022,17 +2024,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto & cur = states.emplace_back();
|
||||
cur = {
|
||||
states.push_back({
|
||||
/*.tokens =*/ prompt.tokens.clone(),
|
||||
/*.data =*/ std::move(state_data),
|
||||
/*.data =*/ {
|
||||
/*.main =*/ std::move(state_data_main),
|
||||
/*.dft =*/ std::move(state_data_dft),
|
||||
},
|
||||
/*.checkpoints =*/ prompt.checkpoints,
|
||||
};
|
||||
});
|
||||
|
||||
return &cur;
|
||||
return &states.back();
|
||||
}
|
||||
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, llama_context * ctx_dft, int32_t id_slot) {
|
||||
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
|
||||
|
||||
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
|
||||
@@ -2065,16 +2069,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
|
||||
if (it_best != states.end()) {
|
||||
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
|
||||
const size_t size = it_best->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
{
|
||||
auto & data = it_best->data.main;
|
||||
|
||||
return false;
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
|
||||
it_best->data.clear();
|
||||
it_best->data.shrink_to_fit();
|
||||
{
|
||||
auto & data = it_best->data.dft;
|
||||
|
||||
if (!data.empty()) {
|
||||
GGML_ASSERT(ctx_dft);
|
||||
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
prompt = std::move(*it_best);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user