mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-14 21:14:10 +00:00
Merge branch 'master' into pr/18039
This commit is contained in:
@@ -36,7 +36,7 @@ using json = nlohmann::ordered_json;
|
||||
|
||||
constexpr int HTTP_POLLING_SECONDS = 1;
|
||||
|
||||
static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
|
||||
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
|
||||
if (pos_min == -1) {
|
||||
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
|
||||
}
|
||||
@@ -46,19 +46,15 @@ static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int i
|
||||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto cur = server_prompt_checkpoint {
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.n_tokens = */ n_tokens,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
};
|
||||
ckpt.pos_min = pos_min;
|
||||
ckpt.pos_max = pos_max;
|
||||
ckpt.n_tokens = n_tokens;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
||||
@@ -364,7 +360,12 @@ struct server_slot {
|
||||
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
const auto n_tokens = prompt.tokens.size();
|
||||
|
||||
spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens);
|
||||
|
||||
//const int64_t t_total = ggml_time_us() - t_start;
|
||||
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
|
||||
|
||||
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
|
||||
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
|
||||
@@ -680,6 +681,7 @@ private:
|
||||
// slots / clients
|
||||
std::vector<server_slot> slots;
|
||||
|
||||
int trace = 0;
|
||||
int slots_debug = 0;
|
||||
int n_empty_consecutive = 0;
|
||||
|
||||
@@ -930,12 +932,21 @@ private:
|
||||
slot.reset();
|
||||
}
|
||||
|
||||
{
|
||||
const char * LLAMA_TRACE = getenv("LLAMA_TRACE");
|
||||
trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0;
|
||||
|
||||
if (trace) {
|
||||
SRV_WRN("LLAMA_TRACE = %d\n", trace);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG");
|
||||
slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0;
|
||||
|
||||
if (slots_debug) {
|
||||
SRV_WRN("slots debug = %d\n", slots_debug);
|
||||
SRV_WRN("LLAMA_SERVER_SLOTS_DEBUG = %d\n", slots_debug);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1838,7 +1849,8 @@ private:
|
||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
const auto & cur = slot.prompt.checkpoints.emplace_back(server_get_checkpoint(ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max));
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back();
|
||||
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
|
||||
|
||||
SLT_WRN(slot,
|
||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
@@ -2986,13 +2998,15 @@ private:
|
||||
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft);
|
||||
slot.spec_i_batch.clear();
|
||||
|
||||
SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size());
|
||||
|
||||
GGML_ASSERT(accepted.size() >= 1);
|
||||
|
||||
// check for partial draft acceptance
|
||||
if (accepted.size() < slot.spec_draft.size() + 1) {
|
||||
if (use_ckpt) {
|
||||
if (trace > 0) {
|
||||
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
|
||||
}
|
||||
|
||||
// partial acceptance is not supported by the context -> truncate the draft and restore the state
|
||||
slot.spec_draft = std::move(accepted);
|
||||
|
||||
@@ -3014,8 +3028,10 @@ private:
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DBG("%s: partial acceptance: %zu < %zu\n", __func__, accepted.size(), slot.spec_draft.size());
|
||||
if (trace > 0) {
|
||||
SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft);
|
||||
}
|
||||
|
||||
common_speculative_accept(slot.spec.get(), accepted.size() - 1);
|
||||
|
||||
Reference in New Issue
Block a user