Compare commits

...

24 Commits

Author SHA1 Message Date
Georgi Gerganov
efa2f8e5a7 naming : improve consistency 2026-05-08 12:24:57 +03:00
Georgi Gerganov
778f9e247e tools : update readme 2026-05-08 11:55:16 +03:00
Georgi Gerganov
1dbc054da5 server : fix slot ctx_drft ptr 2026-05-08 11:55:05 +03:00
Georgi Gerganov
161eae0adf spec : fix n_past type 2026-05-08 11:54:32 +03:00
Georgi Gerganov
e5b1401318 speculative-simple : update 2026-05-08 11:09:34 +03:00
Georgi Gerganov
3b1a8df8fd server : clean-up + dry 2026-05-08 10:20:01 +03:00
Georgi Gerganov
233d1aee69 server : add comment
[no ci]
2026-05-08 08:50:23 +03:00
Georgi Gerganov
12c7cfbe83 server : fix URL for draft model 2026-05-08 08:03:49 +03:00
Georgi Gerganov
6a4b05a030 server : fix mtmd draft processing 2026-05-08 08:02:11 +03:00
Georgi Gerganov
8be14e40de spec : handle draft running out of context 2026-05-08 07:11:51 +03:00
Georgi Gerganov
7e118cdce0 cont : process images throught the draft context 2026-05-07 21:44:09 +03:00
Georgi Gerganov
ae6703fa89 cont : pass correct n_past for drafting 2026-05-07 21:44:08 +03:00
Georgi Gerganov
0239f4c611 cont : handle non-ckpt models 2026-05-07 21:44:08 +03:00
Georgi Gerganov
c7facb0fe1 cont : async drft eval when possible 2026-05-07 21:44:08 +03:00
Georgi Gerganov
08c8012bde cont : sync main and drft contexts 2026-05-07 21:44:08 +03:00
Georgi Gerganov
de35b1255c server, spec : transition to unified spec context 2026-05-07 21:44:08 +03:00
Georgi Gerganov
1afee5b262 server : improve ctx names
[no ci]
2026-05-07 21:44:08 +03:00
Georgi Gerganov
11fd5e7272 server : draft prompt cache and checkpoints
[no ci]
2026-05-07 21:44:08 +03:00
Georgi Gerganov
c97dc3605e server : sketch the ctx_dft decode loop
[no ci]
2026-05-07 21:44:08 +03:00
Georgi Gerganov
8a50f6f0b9 cont : dedup ctx_seq_rm_type
[no ci]
2026-05-07 21:44:07 +03:00
Georgi Gerganov
77269ad8a7 cont : pass seq_id
[no ci]
2026-05-07 21:44:07 +03:00
Georgi Gerganov
4550f0f08b spec : update common_speculative_init()
[no ci]
2026-05-07 21:44:07 +03:00
Georgi Gerganov
befc7ef635 spec : drop support for incompatible vocabs
[no ci]
2026-05-07 21:44:07 +03:00
Georgi Gerganov
2c9a40849f spec : refactor
[no ci]
2026-05-07 21:44:07 +03:00
13 changed files with 584 additions and 597 deletions

View File

@@ -622,10 +622,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
string_process_escapes(seq_breaker);
}
for (auto & pair : params.speculative.draft.replacements) {
string_process_escapes(pair.first);
string_process_escapes(pair.second);
}
}
if (!params.kv_overrides.empty()) {
@@ -3518,13 +3514,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.draft.p_min = std::stof(value);
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN"));
add_opt(common_arg(
{"--spec-draft-ctx-size", "-cd", "--ctx-size-draft"}, "N",
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.draft.n_ctx),
[](common_params & params, int value) {
params.speculative.draft.n_ctx = value;
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_CTX_SIZE"));
add_opt(common_arg(
{"--spec-draft-device", "-devd", "--device-draft"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
@@ -3560,13 +3549,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.draft.mparams.path = value;
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
add_opt(common_arg(
{"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT",
"translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
[](common_params & params, const std::string & tgt, const std::string & dft) {
params.speculative.draft.replacements.push_back({ tgt, dft });
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",

View File

@@ -1422,7 +1422,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
LOG_WRN("%s: the context does not support partial sequence removal\n", __func__);
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
goto done;
}
@@ -1960,3 +1960,102 @@ bool common_prompt_batch_decode(
return true;
}
size_t common_prompt_checkpoint::size() const {
return data_tgt.size() + data_dft.size();
}
bool common_prompt_checkpoint::empty() const {
return data_tgt.empty();
}
void common_prompt_checkpoint::clear() {
n_tokens = 0;
pos_min = 0;
pos_max = 0;
data_tgt.clear();
data_dft.clear();
}
void common_prompt_checkpoint::update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max) {
this->n_tokens = n_tokens;
this->pos_min = pos_min;
this->pos_max = pos_max;
}
void common_prompt_checkpoint::update_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) {
if (ctx == nullptr) {
return;
}
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
data_tgt.resize(ckpt_size);
const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}
void common_prompt_checkpoint::update_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) {
if (ctx == nullptr) {
return;
}
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
data_dft.resize(ckpt_size);
const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}
void common_prompt_checkpoint::load_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const {
if (ctx == nullptr) {
return;
}
if (data_tgt.empty()) {
return;
}
const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags);
if (n != data_tgt.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n);
}
}
void common_prompt_checkpoint::load_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const {
if (ctx == nullptr) {
return;
}
if (data_dft.empty()) {
return;
}
const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags);
if (n != data_dft.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n);
}
}

View File

@@ -307,11 +307,9 @@ struct common_params_speculative_draft {
common_params_model mparams;
llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts
llama_context * ctx_tgt = nullptr;
llama_context * ctx_dft = nullptr;
llama_context_params cparams; // these are the parameters for the draft llama_context
int32_t n_ctx = 0; // draft context size
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
@@ -322,7 +320,6 @@ struct common_params_speculative_draft {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
};
@@ -1026,3 +1023,47 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
// "adamw" or "sgd" (case insensitive)
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
//
// prompt utils
//
struct common_prompt_checkpoint {
int64_t n_tokens;
llama_pos pos_min;
llama_pos pos_max;
std::vector<uint8_t> data_tgt;
std::vector<uint8_t> data_dft;
size_t size() const;
bool empty() const;
void clear();
void update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max);
void update_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);
void update_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);
void load_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;
void load_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;
};

View File

@@ -147,6 +147,7 @@ struct common_speculative_state {
virtual void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_pos n_past,
llama_token id_last,
llama_tokens & result) = 0;
@@ -156,44 +157,25 @@ struct common_speculative_state {
virtual int32_t n_min(const common_params_speculative & params) const = 0;
};
struct common_speculative_checkpoint {
llama_pos pos_min = 0;
llama_pos pos_max = 0;
int64_t n_tokens = 0;
std::vector<uint8_t> data;
size_t size() const {
return data.size();
}
};
struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;
bool use_ckpt = false;
common_speculative_checkpoint ckpt;
llama_seq_id seq_id;
common_sampler * smpl;
llama_batch batch;
llama_tokens prompt_dft;
bool vocab_cmpt = true; // whether retokenization is needed
std::unordered_map<std::string, std::string> vocab_map;
llama_batch batch;
common_speculative_state_draft(
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_dft,
const std::vector<std::pair<std::string, std::string>> & replacements,
bool use_ckpt)
llama_seq_id seq_id)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
, use_ckpt(use_ckpt)
, seq_id(seq_id)
{
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
smpl = nullptr;
@@ -225,23 +207,17 @@ struct common_speculative_state_draft : public common_speculative_state {
smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt);
const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt);
if (!vocab_cmpt) {
LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n");
LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__);
for (const auto & pair : replacements) {
vocab_map[pair.first] = pair.second;
}
throw std::runtime_error("draft model vocab type must match target model to use speculation");
}
}
~common_speculative_state_draft() override {
llama_perf_context_print(ctx_dft);
llama_free(ctx_dft);
common_sampler_free(smpl);
llama_batch_free(batch);
@@ -250,220 +226,29 @@ struct common_speculative_state_draft : public common_speculative_state {
void begin(const llama_tokens & /*prompt*/) override {
}
size_t create_checkpoint(int n_tokens_prompt) {
int slot_id = 0;
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
ckpt.n_tokens = n_tokens_prompt;
ckpt.data.resize(checkpoint_size);
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}
LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__,
ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024);
return n;
}
size_t restore_checkpoint() {
int slot_id = 0;
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());
}
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);
return n;
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_pos n_past,
llama_token id_last,
llama_tokens & result) override {
const auto & sparams = params.draft;
auto * spec = this;
auto & batch = spec->batch;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft = spec->ctx_dft;
auto & smpl = spec->smpl;
auto & prompt_dft = spec->prompt_dft;
auto & batch = spec->batch;
auto & ctx_dft = spec->ctx_dft;
auto & smpl = spec->smpl;
auto * mem_dft = llama_get_memory(ctx_dft);
int reuse_i = 0; // index of part to be reused in prompt_dft
int reuse_n = 0; // length of part to be reused in prompt_dft
const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max;
llama_tokens prompt_cnv;
if (!spec->vocab_cmpt) {
std::string text;
text = common_detokenize(ctx_tgt, prompt_tgt, true);
text = replace_to_dft(text);
LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
prompt_cnv = common_tokenize(ctx_dft, text, false, true);
// convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
const auto * model_tgt = llama_get_model(ctx_tgt);
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
text.resize(-n_chars);
llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
text = replace_to_dft(text);
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
id_last = common_tokenize(ctx_dft, text, false, true)[0];
}
const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv;
const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
if (use_ckpt && i_start > 0) {
LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__);
return;
}
// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt_cur.size() &&
i + cur < (int) prompt_dft.size() &&
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
cur++;
}
if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) {
reuse_i = i;
reuse_n = cur;
}
if (use_ckpt) {
break;
}
}
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
if (use_ckpt && ckpt.n_tokens > reuse_n) {
LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens);
reuse_i = 0;
reuse_n = 0;
ckpt = {};
}
result.clear();
result.reserve(sparams.n_max);
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
llama_memory_clear(mem_dft, false);
prompt_dft.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
result.push_back(prompt_dft[i]);
if (sparams.n_max <= (int) result.size()) {
break;
}
}
return;
}
if (reuse_i > 0) {
GGML_ASSERT(!use_ckpt);
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
if (!is_removed) {
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
return;
}
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
}
if (reuse_n < (int) prompt_dft.size()) {
if (use_ckpt) {
if (ckpt.n_tokens > 0) {
LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
restore_checkpoint();
reuse_n = ckpt.n_tokens;
prompt_dft.resize(reuse_n);
}
} else {
const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1);
if (!is_removed) {
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
return;
}
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
}
}
}
// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);
for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]);
common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false);
prompt_dft.push_back(prompt_cur[i]);
}
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens);
int ret = llama_decode(ctx_dft, batch);
if (ret != 0 && ret != 1) {
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
__func__, ret, prompt_cur.size());
}
if (use_ckpt) {
create_checkpoint(prompt_dft.size());
}
}
const llama_pos n_past = prompt_dft.size();
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
GGML_ASSERT(n_past >= (llama_pos) prompt_tgt.size());
common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);
prompt_dft.push_back(id_last);
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
common_batch_add (batch, id_last, n_past, { seq_id }, true);
int ret = llama_decode(ctx_dft, batch);
if (ret != 0 && ret != 1) {
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
__func__, ret, prompt_cur.size(), prompt_dft.size());
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
}
common_sampler_reset(smpl);
@@ -497,25 +282,13 @@ struct common_speculative_state_draft : public common_speculative_state {
break;
}
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
common_batch_add(batch, id, n_past + i + 1, { seq_id }, true);
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
__func__, i, ret, prompt_cur.size(), prompt_dft.size());
}
prompt_dft.push_back(id);
}
if (!spec->vocab_cmpt) {
std::string detokenized = common_detokenize(ctx_dft, result, true);
detokenized = replace_to_tgt(detokenized);
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
result = common_tokenize(ctx_tgt, detokenized, false, true);
if (result.size() > (size_t) sparams.n_max) {
result.resize(sparams.n_max);
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
}
@@ -536,34 +309,6 @@ struct common_speculative_state_draft : public common_speculative_state {
int32_t n_min(const common_params_speculative & params) const override {
return params.draft.n_min;
}
std::string replace_to_dft(const std::string & input) const {
std::string result = input;
for (const auto & pair : this->vocab_map) {
size_t pos = result.find(pair.first);
while (pos != std::string::npos) {
result.replace(pos, pair.first.length(), pair.second);
pos = result.find(pair.first, pos + pair.second.length());
}
}
return result;
}
std::string replace_to_tgt(const std::string & input) const {
std::string result = input;
for (const auto & pair : this->vocab_map) {
size_t pos = result.find(pair.second);
while (pos != std::string::npos) {
result.replace(pos, pair.second.length(), pair.first);
pos = result.find(pair.second, pos + pair.first.length());
}
}
return result;
}
};
struct common_speculative_state_eagle3 : public common_speculative_state {
@@ -576,11 +321,13 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_pos n_past,
llama_token id_last,
llama_tokens & draft_tokens) override {
// TODO: implement
GGML_UNUSED(params);
GGML_UNUSED(prompt_tgt);
GGML_UNUSED(n_past);
GGML_UNUSED(id_last);
GGML_UNUSED(draft_tokens);
}
@@ -615,11 +362,13 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_pos n_past,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(params);
GGML_UNUSED(n_past);
result = common_ngram_simple_draft(config, prompt_tgt, id_last);
GGML_UNUSED(params);
}
void accept(uint16_t n_accepted) override {
@@ -652,10 +401,13 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state {
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_pos n_past,
llama_token id_last,
llama_tokens & result) override {
common_ngram_map_draft(config, prompt_tgt, id_last, result);
GGML_UNUSED(params);
GGML_UNUSED(n_past);
common_ngram_map_draft(config, prompt_tgt, id_last, result);
}
void accept(uint16_t n_accepted) override {
@@ -722,8 +474,11 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_pos n_past,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(n_past);
const auto & sparams = params.ngram_mod;
n_draft_last = 0;
@@ -853,9 +608,11 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_pos n_past,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(params);
GGML_UNUSED(n_past);
if (cache_size < prompt_tgt.size() + 1) {
llama_tokens tokens_new;
@@ -971,18 +728,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt) {
llama_context * ctx_dft = nullptr;
if (params.draft.model) {
ctx_dft = llama_init_from_model(params.draft.model, params.draft.cparams);
if (ctx_dft == nullptr) {
LOG_ERR("%s", "failed to create draft context\n");
return nullptr;
}
}
common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id) {
// Compute the implementations to use based on the config and their order of preference
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
{
@@ -1044,13 +790,10 @@ common_speculative * common_speculative_init(
case COMMON_SPECULATIVE_TYPE_NONE:
break;
case COMMON_SPECULATIVE_TYPE_DRAFT: {
const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.draft.replacements,
/* .use_ckpt = */ use_ckpt
/* .ctx_tgt = */ params.draft.ctx_tgt,
/* .ctx_dft = */ params.draft.ctx_dft,
/* .seq_id = */ seq_id
));
break;
}
@@ -1135,6 +878,7 @@ llama_tokens common_speculative_draft(
common_speculative * spec,
const common_params_speculative & params,
const llama_tokens & prompt_tgt, // specified in target model vocab
llama_pos n_past,
llama_token id_last) {
llama_tokens result;
@@ -1143,7 +887,7 @@ llama_tokens common_speculative_draft(
for (auto & impl : spec->impls) {
{
common_time_meas tm(impl->t_draft_us, !impl->gen_perf);
impl->draft(params, prompt_tgt, id_last, result);
impl->draft(params, prompt_tgt, n_past, id_last, result);
impl->n_call_draft++;
}

View File

@@ -14,9 +14,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);
common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id);
void common_speculative_free(common_speculative * spec);
@@ -28,6 +26,7 @@ llama_tokens common_speculative_draft(
common_speculative * spec,
const common_params_speculative & params,
const llama_tokens & prompt,
llama_pos n_past,
llama_token id_last);
// informs the speculative decoder that n_accepted tokens were accepted by the target model

View File

@@ -13,20 +13,6 @@
#include <vector>
#include <utility>
struct spec_checkpoint {
int64_t n_tokens = 0;
std::vector<uint8_t> data;
size_t size() const {
return data.size();
}
bool empty() const {
return data.empty();
}
};
int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
@@ -43,11 +29,6 @@ int main(int argc, char ** argv) {
return 1;
}
if (params.speculative.draft.mparams.path.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
// init llama.cpp
llama_backend_init();
llama_numa_init(params.numa);
@@ -62,18 +43,11 @@ int main(int argc, char ** argv) {
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();
// check if the context supports partial sequence removal
const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt);
const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
if (use_ckpt) {
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
}
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
// load the draft model
llama_model_ptr model_dft;
llama_context_ptr ctx_dft;
// TODO: simplify this logic
{
@@ -81,9 +55,6 @@ int main(int argc, char ** argv) {
auto params_dft = params;
params_dft.n_parallel = 1;
params_dft.n_ctx = params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
@@ -103,8 +74,19 @@ int main(int argc, char ** argv) {
return 1;
}
params.speculative.draft.model = model_dft.get();
params.speculative.draft.cparams = common_context_params_to_llama(params_dft);
auto cparams = common_context_params_to_llama(params_dft);
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
params.speculative.draft.ctx_tgt = ctx_tgt;
params.speculative.draft.ctx_dft = ctx_dft.get();
}
// check if the context supports partial sequence removal
const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
if (use_ckpt_tgt) {
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
}
// Tokenize the prompt
@@ -136,6 +118,8 @@ int main(int argc, char ** argv) {
// used to determine end of generation
bool has_eos = false;
llama_seq_id seq_id = 0;
// ================================================
// everything until here is standard initialization
// the relevant stuff for speculative decoding starts here
@@ -146,7 +130,8 @@ int main(int argc, char ** argv) {
common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling));
// eval the prompt
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
llama_decode(ctx_dft.get(), llama_batch_get_one(inp.data(), inp.size() - 1));
// note: keep the last token separate!
llama_token id_last = inp.back();
@@ -160,7 +145,7 @@ int main(int argc, char ** argv) {
// init the speculator
const auto & params_spec = params.speculative;
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
struct common_speculative * spec = common_speculative_init(params.speculative, seq_id);
common_speculative_begin(spec, prompt_tgt);
@@ -169,7 +154,7 @@ int main(int argc, char ** argv) {
size_t n_draft = 0;
llama_tokens draft;
spec_checkpoint spec_ckpt;
common_prompt_checkpoint ckpt;
const auto t_enc_end = ggml_time_us();
@@ -184,40 +169,49 @@ int main(int argc, char ** argv) {
// from a cache or lookup tables.
//
if (draft.empty()) {
ckpt.update_pos(
prompt_tgt.size(),
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), seq_id),
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id));
if (use_ckpt_dft) {
ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
// generate a new draft
draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
draft = common_speculative_draft(spec, params_spec, prompt_tgt, prompt_tgt.size(), id_last);
// save the original draft size
n_draft = draft.size();
// save a checkpoint of the target context before evaluating the draft
// this allows us to restore the state if partial draft acceptance occurs
if (!draft.empty() && use_ckpt) {
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
spec_ckpt.data.resize(ckpt_size);
if (!draft.empty()) {
if (use_ckpt_tgt) {
ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
}
const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
GGML_ASSERT(n == ckpt_size);
{
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
spec_ckpt.n_tokens = (int64_t) prompt_tgt.size();
LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n",
spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
}
} else {
// we have a previous (partial) draft to reuse from checkpoint restoration
if (use_ckpt) {
GGML_ASSERT(!spec_ckpt.empty());
if (use_ckpt_tgt) {
GGML_ASSERT(!ckpt.empty());
}
}
// always have a token to evaluate from before - id_last
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
common_batch_add (batch_tgt, id_last, n_past++, { seq_id }, true);
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
common_batch_add(batch_tgt, draft[i], n_past + i, { seq_id }, true);
}
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
@@ -225,9 +219,15 @@ int main(int argc, char ** argv) {
llama_decode(ctx_tgt, batch_tgt);
}
// evaluate the same batch with the draft model
{
// TODO: extend to support MTP, Eagle, etc. See server code for reference
llama_decode(ctx_dft.get(), batch_tgt);
}
// only save the sampler sampler state if we use checkpoints
common_sampler_ptr smpl_save;
if (use_ckpt) {
if (use_ckpt_tgt) {
smpl_save.reset(common_sampler_clone(smpl.get()));
}
@@ -247,17 +247,24 @@ int main(int argc, char ** argv) {
// check for partial draft acceptance:
// if the context doesn't support partial sequence removal, restore the checkpoint
// and make the accepted tokens the new partial draft for the next iteration
if (use_ckpt && ids.size() - 1 < draft.size()) {
if (use_ckpt_tgt && ids.size() - 1 < draft.size()) {
LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size());
draft = std::move(ids);
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
GGML_ASSERT(n == spec_ckpt.size());
{
ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, ckpt.pos_max + 1, -1);
}
prompt_tgt.resize(spec_ckpt.n_tokens);
{
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
}
prompt_tgt.resize(ckpt.n_tokens);
smpl = std::move(smpl_save);
n_past = (int) prompt_tgt.size();
@@ -305,7 +312,8 @@ int main(int argc, char ** argv) {
{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, n_past, -1);
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, n_past, -1);
}
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {

View File

@@ -858,6 +858,8 @@ extern "C" {
size_t n_token_capacity,
size_t * n_token_count_out);
#define LLAMA_STATE_SEQ_FLAGS_NONE 0
// for backwards-compat
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1

View File

@@ -195,11 +195,9 @@
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |

View File

@@ -244,11 +244,9 @@ For the full list of features, please refer to [server's changelog](https://gith
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |

View File

@@ -36,32 +36,6 @@ using json = nlohmann::ordered_json;
constexpr int HTTP_POLLING_SECONDS = 1;
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, bool on_device, llama_pos pos_min = -1, llama_pos pos_max = -1) {
if (pos_min == -1) {
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
}
if (pos_max == -1) {
pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id);
}
auto flags = LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY;
if (on_device) {
flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE;
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, flags);
ckpt.pos_min = pos_min;
ckpt.pos_max = pos_max;
ckpt.n_tokens = n_tokens;
ckpt.data.resize(checkpoint_size);
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, flags);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}
}
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
enum slot_state {
SLOT_STATE_IDLE,
@@ -80,9 +54,8 @@ enum server_state {
struct server_slot {
int id;
llama_context * ctx = nullptr;
common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
llama_context * ctx_tgt = nullptr;
llama_context * ctx_dft = nullptr;
// multimodal
mtmd_context * mctx = nullptr;
@@ -90,7 +63,7 @@ struct server_slot {
// speculative decoding
llama_tokens spec_draft;
std::vector<int32_t> spec_i_batch;
server_prompt_checkpoint spec_ckpt;
common_prompt_checkpoint spec_ckpt;
common_speculative_ptr spec;
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
@@ -135,21 +108,27 @@ struct server_slot {
void prompt_save(server_prompt_cache & prompt_cache) const {
GGML_ASSERT(prompt.data.size() == 0);
const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
const size_t cur_size_tgt = llama_state_seq_get_size_ext(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE);
const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0;
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
const size_t cur_size = cur_size_tgt + cur_size_dft;
auto * cur = prompt_cache.alloc(prompt, cur_size);
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n",
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0));
auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft);
if (cur == nullptr) {
return;
}
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
llama_state_seq_get_data_ext(ctx_tgt, cur->data.main.data(), cur_size_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE);
if (ctx_dft) {
llama_state_seq_get_data_ext(ctx_dft, cur->data.drft.data(), cur_size_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE);
}
}
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
bool res = prompt_cache.load(prompt, tokens, ctx, id);
bool res = prompt_cache.load(prompt, tokens, ctx_tgt, ctx_dft, id);
if (!res) {
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
}
@@ -164,7 +143,11 @@ struct server_slot {
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1);
if (ctx_dft) {
llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1);
}
prompt.tokens.clear();
}
@@ -222,7 +205,7 @@ struct server_slot {
task_prev = std::move(task);
task.reset();
llama_set_sampler(ctx, id, nullptr);
llama_set_sampler(ctx_tgt, id, nullptr);
// clear alora start
alora_invocation_start = -1;
@@ -259,7 +242,7 @@ struct server_slot {
return
!task->need_embd() ||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
(llama_get_memory(ctx_tgt) && llama_pooling_type(ctx_tgt) == LLAMA_POOLING_TYPE_LAST);
}
bool can_batch_with(server_slot & other_slot) const {
@@ -333,7 +316,7 @@ struct server_slot {
return n_draft_max;
}
void update_batch(llama_batch & batch) {
void update_batch(llama_batch & batch, bool use_ckpt_tgt, bool use_ckpt_dft) {
const int n_draft_max = get_n_draft_max();
if (n_draft_max > 0) {
GGML_ASSERT(can_speculate());
@@ -341,20 +324,29 @@ struct server_slot {
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
const llama_tokens & tokens = prompt.tokens.get_text_tokens();
const llama_tokens & tokens_text = prompt.tokens.get_text_tokens();
const auto & params_spec = task->params.speculative;
if (!spec_draft.empty()) {
// we have a previous (partial) draft to reuse
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
if (use_ckpt_tgt) {
GGML_ASSERT(!spec_ckpt.empty());
}
} else {
GGML_ASSERT(spec_i_batch.empty());
spec_ckpt.update_pos(
prompt.n_tokens(),
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), id),
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), id));
if (use_ckpt_dft) {
spec_ckpt.update_dft(ctx_dft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
// generate a new draft
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens_text, prompt.n_tokens(), sampled);
n_draft_total += spec_draft.size();
if (spec_draft.size() > (size_t) n_draft_max) {
@@ -362,18 +354,25 @@ struct server_slot {
spec_draft.resize(n_draft_max);
}
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
const auto n_tokens = prompt.tokens.size();
if (!spec_draft.empty()) {
if (use_ckpt_tgt) {
//const int64_t t_start = ggml_time_us();
//const int64_t t_start = ggml_time_us();
spec_ckpt.update_tgt(ctx_tgt, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens, true);
//const int64_t t_total = ggml_time_us() - t_start;
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
//const int64_t t_total = ggml_time_us() - t_start;
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n",
spec_ckpt.pos_min, spec_ckpt.pos_max, prompt.n_tokens(), (float) spec_ckpt.size() / 1024 / 1024, (float) spec_ckpt.data_dft.size() / 1024 / 1024);
}
}
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
if (ctx_dft) {
spec_ckpt.load_dft(ctx_dft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(ctx_dft), this->id, spec_ckpt.pos_max + 1, -1);
}
}
@@ -539,7 +538,7 @@ struct server_slot {
};
if (!only_metrics) {
res["prompt"] = ptask->tokens.detokenize(ctx, true);
res["prompt"] = ptask->tokens.detokenize(ctx_tgt, true);
res["generated"] = generated_text.empty() ? debug_generated_text : generated_text;
}
}
@@ -550,8 +549,13 @@ struct server_slot {
void copy_state_to(server_slot & other) const {
GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);
llama_memory_seq_rm(llama_get_memory(ctx), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1);
if (ctx_dft) {
llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1);
}
other.n_decoded = n_decoded;
other.n_remaining = n_remaining;
@@ -642,7 +646,8 @@ public:
// only use these pointers outside of this class:
// - when not in sleeping state
// - and, with thread-safe APIs (e.g., tokenizer calls)
llama_model * model = nullptr;
llama_model * model_tgt = nullptr;
mtmd_context * mctx = nullptr;
const llama_vocab * vocab = nullptr;
@@ -669,11 +674,15 @@ private:
// note: keep these alive - they determine the lifetime of the model, context, etc.
common_init_result_ptr llama_init;
llama_context * ctx = nullptr;
llama_context * ctx_tgt = nullptr;
llama_batch batch {};
llama_model_ptr model_dft;
llama_context_ptr ctx_dft;
common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
bool add_bos_token = true;
@@ -708,8 +717,8 @@ private:
void destroy() {
llama_init.reset();
ctx = nullptr;
model = nullptr;
ctx_tgt = nullptr;
model_tgt = nullptr;
mtmd_free(mctx);
mctx = nullptr;
@@ -759,17 +768,17 @@ private:
llama_init = common_init_from_params(params_base);
model = llama_init->model();
ctx = llama_init->context();
model_tgt = llama_init->model();
ctx_tgt = llama_init->context();
if (model == nullptr) {
if (model_tgt == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
return false;
}
vocab = llama_model_get_vocab(model);
vocab = llama_model_get_vocab(model_tgt);
n_ctx = llama_n_ctx(ctx);
n_ctx = llama_n_ctx(ctx_tgt);
add_bos_token = llama_vocab_get_add_bos(vocab);
@@ -781,9 +790,6 @@ private:
auto params_dft = params_base;
params_dft.n_parallel = 1;
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
@@ -805,8 +811,13 @@ private:
return false;
}
params_base.speculative.draft.model = model_dft.get();
params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft);
auto cparams = common_context_params_to_llama(params_dft);
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
}
std::string & mmproj_path = params_base.mmproj.path;
@@ -826,7 +837,7 @@ private:
mparams.image_max_tokens = params_base.image_max_tokens;
mparams.media_marker = get_media_marker();
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams);
if (mctx == nullptr) {
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
return false;
@@ -844,7 +855,7 @@ private:
}
}
if (!llama_memory_can_shift(llama_get_memory(ctx))) {
if (!llama_memory_can_shift(llama_get_memory(ctx_tgt))) {
if (params_base.ctx_shift) {
params_base.ctx_shift = false;
SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled");
@@ -856,14 +867,14 @@ private:
}
}
if (llama_model_n_swa(model) == 0) {
if (llama_model_n_swa(model_tgt) == 0) {
if (params_base.swa_full) {
params_base.swa_full = false;
SRV_WRN("%s\n", "swa_full is not supported by this model, it will be disabled");
}
}
n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model);
n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_tgt);
// Necessary similarity of prompt for slot selection
slot_prompt_similarity = params_base.slot_prompt_similarity;
@@ -871,9 +882,9 @@ private:
// setup slots
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx_train = llama_model_n_ctx_train(model_tgt);
int n_ctx_slot = llama_n_ctx_seq(ctx);
int n_ctx_slot = llama_n_ctx_seq(ctx_tgt);
if (n_ctx_slot > n_ctx_train) {
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
n_ctx_slot = n_ctx_train;
@@ -881,12 +892,12 @@ private:
slots.clear();
const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx);
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
ctx_tgt_seq_rm_type = common_context_can_seq_rm(ctx_tgt);
if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
}
@@ -895,21 +906,28 @@ private:
slots.emplace_back();
}
bool no_dft = false;
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot & slot = slots[i];
slot.id = i;
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.ctx_seq_rm_type = ctx_seq_rm_type;
slot.id = i;
slot.ctx_tgt = ctx_tgt;
slot.ctx_dft = ctx_dft.get();
slot.n_ctx = n_ctx_slot;
slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr;
// try speculative decoding
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
try {
slot.spec.reset(common_speculative_init(params_base.speculative, slot.id));
} catch (const std::exception & e) {
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
no_dft = true;
}
if (slot.spec) {
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
@@ -925,6 +943,16 @@ private:
slot.reset();
}
if (no_dft && ctx_dft) {
SRV_WRN("%s", "destroying the draft model as it is not going to be used\n");
ctx_dft.reset();
for (auto & slot : slots) {
slot.ctx_dft = nullptr;
}
}
{
const char * LLAMA_TRACE = getenv("LLAMA_TRACE");
trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0;
@@ -946,7 +974,7 @@ private:
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
{
const int32_t n_batch = llama_n_batch(ctx);
const int32_t n_batch = llama_n_batch(ctx_tgt);
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
@@ -990,8 +1018,9 @@ private:
// unlike load_model(), this is only called once during initialization
bool init() {
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(model != nullptr);
GGML_ASSERT(ctx_tgt != nullptr);
GGML_ASSERT(model_tgt != nullptr);
GGML_ASSERT(!sleeping);
// wiring up server queues
@@ -1037,7 +1066,7 @@ private:
common_chat_templates_ptr chat_templates;
try {
chat_templates = common_chat_templates_init(model, params_base.chat_template);
chat_templates = common_chat_templates_init(model_tgt, params_base.chat_template);
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
@@ -1300,7 +1329,7 @@ private:
}
}
if (!task.tokens.validate(ctx)) {
if (!task.tokens.validate(ctx_tgt)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
}
@@ -1310,7 +1339,7 @@ private:
// initialize samplers
if (task.need_sampling()) {
try {
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
slot.smpl.reset(common_sampler_init(model_tgt, task.params.sampling));
} catch (std::exception & e) {
std::string err_msg = std::string("Failed to initialize samplers: ") + e.what();
send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST);
@@ -1331,9 +1360,9 @@ private:
// TODO: tmp until backend sampling is fully implemented
if (backend_sampling) {
llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get()));
llama_set_sampler(ctx_tgt, slot.id, common_sampler_get(slot.smpl.get()));
} else {
llama_set_sampler(ctx, slot.id, nullptr);
llama_set_sampler(ctx_tgt, slot.id, nullptr);
}
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
@@ -1506,13 +1535,13 @@ private:
for (size_t i = 0; i < n_probs; i++) {
result.probs.push_back({
cur_p->data[i].id,
common_token_to_piece(ctx, cur_p->data[i].id, special),
common_token_to_piece(ctx_tgt, cur_p->data[i].id, special),
cur_p->data[i].p
});
}
} else {
// TODO: optimize this with min-p optimization
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
std::vector<llama_token_data> cur = get_token_probabilities(ctx_tgt, idx);
const size_t max_probs = cur.size();
const size_t n_probs = std::min(max_probs, n_probs_request);
@@ -1530,7 +1559,7 @@ private:
for (size_t i = 0; i < n_probs; i++) {
result.probs.push_back({
cur[i].id,
common_token_to_piece(ctx, cur[i].id, special),
common_token_to_piece(ctx_tgt, cur[i].id, special),
cur[i].p
});
}
@@ -1633,7 +1662,7 @@ private:
res->tokens = std::move(slot.generated_tokens);
}
res->timings = slot.get_timings();
res->prompt = slot.task->tokens.detokenize(ctx, true);
res->prompt = slot.task->tokens.detokenize(ctx_tgt, true);
res->response_fields = std::move(slot.task->params.response_fields);
res->truncated = slot.truncated;
@@ -1656,7 +1685,7 @@ private:
// populate res.probs_output
if (slot.task->params.sampling.n_probs > 0) {
if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) {
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
const llama_tokens stop_word_toks = common_tokenize(ctx_tgt, slot.stopping_word, false);
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
res->probs_output = std::vector<completion_token_output>(
@@ -1681,7 +1710,7 @@ private:
res->n_tokens = slot.task->n_tokens();
res->res_type = slot.task->params.res_type;
const int n_embd_out = llama_model_n_embd_out(model);
const int n_embd_out = llama_model_n_embd_out(model_tgt);
std::vector<float> embd_res(n_embd_out, 0.0f);
@@ -1691,10 +1720,10 @@ private:
}
const float * embd = nullptr;
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
embd = llama_get_embeddings_ith(ctx, i);
if (llama_pooling_type(slot.ctx_tgt) == LLAMA_POOLING_TYPE_NONE) {
embd = llama_get_embeddings_ith(slot.ctx_tgt, i);
} else {
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
embd = llama_get_embeddings_seq(slot.ctx_tgt, batch.seq_id[i][0]);
}
if (embd == nullptr) {
@@ -1705,7 +1734,7 @@ private:
}
// normalize only when there is pooling
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
if (llama_pooling_type(slot.ctx_tgt) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize);
res->embedding.push_back(embd_res);
break;
@@ -1730,9 +1759,9 @@ private:
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
const float * embd = llama_get_embeddings_seq(ctx_tgt, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
embd = llama_get_embeddings_ith(ctx_tgt, i);
}
if (embd == NULL) {
@@ -1837,18 +1866,22 @@ private:
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
auto & cur = slot.prompt.checkpoints.emplace_back();
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max);
cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
}
void process_single_task(server_task && task) {
@@ -2003,7 +2036,7 @@ private:
std::string filepath = task.slot_action.filepath;
const llama_tokens & tokens = slot->prompt.tokens.get_tokens();
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
const size_t nwrite = llama_state_seq_save_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), token_count);
const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0;
@@ -2042,7 +2075,7 @@ private:
llama_tokens tokens;
tokens.resize(slot->n_ctx);
size_t token_count = 0;
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
size_t nread = llama_state_seq_load_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
if (nread == 0) {
slot->prompt.tokens.clear(); // KV may already been invalidated?
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
@@ -2201,8 +2234,13 @@ private:
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
if (ctx_dft) {
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
}
// add generated tokens to cache
// ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
@@ -2248,12 +2286,14 @@ private:
continue;
}
slot.update_batch(batch);
slot.update_batch(batch,
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL,
ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
}
// process in chunks of params.n_batch
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
int32_t n_batch = llama_n_batch(ctx_tgt);
int32_t n_ubatch = llama_n_ubatch(ctx_tgt);
float alora_scale = -1.0f;
size_t alora_disabled_id = 0;
@@ -2297,12 +2337,12 @@ private:
/*if (1) {
// first 16 tokens (avoid flooding logs)
for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str());
}
} else {
// all
for (int i = 0; i < (int) input_tokens.size(); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str());
}
}*/
@@ -2321,7 +2361,7 @@ private:
}
// TODO: support memory-less logits computation
if (slot.task->need_logits() && !llama_get_memory(ctx)) {
if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) {
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
slot.release();
continue;
@@ -2373,7 +2413,7 @@ private:
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
const bool can_cache_reuse =
llama_memory_can_shift(llama_get_memory(ctx)) &&
llama_memory_can_shift(llama_get_memory(ctx_tgt)) &&
!slot.prompt.tokens.has_mtmd;
if (!can_cache_reuse && n_cache_reuse > 0) {
@@ -2407,13 +2447,18 @@ private:
if (n_match >= (size_t) n_cache_reuse) {
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
//for (size_t i = head_p; i < head_p + n_match; i++) {
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_tgt, prompt_tokens[i]).c_str());
//}
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift);
if (ctx_dft) {
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift);
}
for (size_t i = 0; i < n_match; i++) {
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
@@ -2440,7 +2485,7 @@ private:
const auto pos_min_thold = std::max(0, pos_next - n_swa);
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
@@ -2469,14 +2514,14 @@ private:
{
const auto token = slot.prompt.tokens[i];
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]";
ss0 << piece;
st0 << std::setw(8) << token;
}
{
const auto token = slot.task->tokens[i];
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]";
ss1 << piece;
st1 << std::setw(8) << token;
}
@@ -2508,18 +2553,13 @@ private:
if (!do_reset) {
// restore the context checkpoint
const size_t checkpoint_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) checkpoint_size / 1024 / 1024);
}
it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024);
}
if (do_reset) {
@@ -2536,7 +2576,7 @@ private:
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_max > pos_next) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
} else {
++it;
@@ -2576,14 +2616,18 @@ private:
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
slot.prompt_clear(true);
// there is no common part left
slot.n_prompt_tokens_cache = 0;
}
} else {
if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) {
GGML_ABORT("failed to truncate draft context\n");
}
}
// If using an alora, there may be uncached tokens that come
// before the invocation sequence. When this happens, the
@@ -2609,7 +2653,7 @@ private:
// - the model does not support partial sequence removal
// - the model uses SWA (and we are not using `swa_full`)
do_checkpoint = do_checkpoint && (
(slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
(n_swa > 0));
bool has_mtmd = false;
@@ -2618,7 +2662,7 @@ private:
while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image
size_t n_tokens_out = 0;
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
if (res != 0) {
SLT_ERR(slot, "failed to process image, res = %d\n", res);
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
@@ -2626,6 +2670,15 @@ private:
continue;
}
if (ctx_dft) {
// TODO: in the future, figure out how to infuse target embeddings to the images
// for now, we skip this for simplicity
res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
if (res != 0) {
GGML_ABORT("failed to process multi-modal data on draft context\n");
}
}
slot.n_prompt_tokens_processed += n_tokens_out;
// add the image chunk to cache
@@ -2727,8 +2780,8 @@ private:
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
}
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64);
@@ -2761,7 +2814,7 @@ private:
if (slot_batched) {
// apply lora, only need to do it once per batch
common_set_adapter_lora(ctx, slot_batched->lora);
common_set_adapter_lora(ctx_tgt, slot_batched->lora);
// if the lora is temporarily disabled for an alora, re-enable it
// for next time
@@ -2770,7 +2823,7 @@ private:
slot_batched->lora[alora_disabled_id].scale = alora_scale;
}
llama_set_embeddings(ctx, slot_batched->task->need_embd());
llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd());
}
if (batch.n_tokens == 0) {
@@ -2799,7 +2852,7 @@ private:
batch.logits + i,
};
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode(ctx_tgt, batch_view);
metrics.on_decoded(slots);
@@ -2852,11 +2905,54 @@ private:
continue; // continue loop of n_batch
}
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
// for now, always re-evaluate for simplicity
// ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384
//
// | spec type | need re-eval |
// | --- | --- |
// | draft model | no | because the draft model does not use embeddings from the target
// | MTP (std) | yes |
// | MTP Gemma4 | no | because the KV cache is shared
// | Eagle3 | yes |
// | DFlash | yes? |
//
if (ctx_dft) {
// TODO: update as needed for MTP, Eagle3, etc.
const bool need_tgt_embd = false;
if (need_tgt_embd) {
llama_synchronize(ctx_tgt);
}
// the logic here varies depending on the speculative decoding method
// - some draft contexts require embeddings from the target context, others don't
// - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings
// TODO: extract this in a function ?
{
// TODO: hook the embeddings from the last target batch here
if (llama_model_has_encoder(model_dft.get())) {
//llama_encode(ctx_dft, ...);
GGML_ABORT("not implemented yet\n");
}
const int ret = llama_decode(ctx_dft.get(), batch_view);
if (ret != 0) {
SRV_ERR("failed to decode draft batch, ret = %d\n", ret);
// TODO: handle error
break;
}
}
}
// move the head of the batch forward with the number of tokens we just processed
i_next = i + n_tokens;
// on successful decode, restore the original batch size
n_batch = llama_n_batch(ctx);
n_batch = llama_n_batch(ctx_tgt);
// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
for (auto & slot : slots) {
@@ -2927,7 +3023,7 @@ private:
const int tok_idx = slot.i_batch - i;
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx);
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
slot.i_batch = -1;
@@ -2948,7 +3044,7 @@ private:
completion_token_output result;
result.tok = id;
result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok));
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
if (slot.task->params.sampling.n_probs > 0) {
@@ -2979,23 +3075,23 @@ private:
// verify and try to accept the draft
{
const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
// only save the sampler sampler state if we use checkpoints
common_sampler_ptr smpl_save;
if (use_ckpt) {
if (use_ckpt_tgt) {
smpl_save.reset(common_sampler_clone(slot.smpl.get()));
}
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft);
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft);
slot.spec_i_batch.clear();
GGML_ASSERT(accepted.size() >= 1);
// check for partial draft acceptance
if (accepted.size() < slot.spec_draft.size() + 1) {
if (use_ckpt) {
if (use_ckpt_tgt) {
if (trace > 0) {
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
}
@@ -3005,16 +3101,19 @@ private:
const auto & ckpt = slot.spec_ckpt;
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n",
ckpt.pos_min, ckpt.pos_max, ckpt.size());
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
{
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1);
if (slot.ctx_dft) {
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1);
}
slot.prompt.tokens.keep_first(ckpt.n_tokens);
slot.smpl = std::move(smpl_save);
@@ -3049,13 +3148,16 @@ private:
slot.sampled = ids.back(); // last accepted token
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.tokens.pos_next(), -1);
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1);
if (slot.ctx_dft) {
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1);
}
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok));
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // set later
// TODO: set result.probs
@@ -3107,7 +3209,7 @@ void server_context::terminate() {
}
llama_context * server_context::get_llama_context() const {
return impl->ctx;
return impl->ctx_tgt;
}
server_response_reader server_context::get_response_reader() {
@@ -3117,8 +3219,8 @@ server_response_reader server_context::get_response_reader() {
server_context_meta server_context::get_meta() const {
auto bos_id = llama_vocab_bos(impl->vocab);
auto eos_id = llama_vocab_eos(impl->vocab);
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : "";
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, eos_id, true) : "";
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : "";
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : "";
return server_context_meta {
/* build_info */ std::string(llama_build_info()),
@@ -3131,7 +3233,7 @@ server_context_meta server_context::get_meta() const {
/* has_inp_audio */ impl->chat_params.allow_audio,
/* json_webui_settings */ impl->json_webui_settings,
/* slot_n_ctx */ impl->get_slot_n_ctx(),
/* pooling_type */ llama_pooling_type(impl->ctx),
/* pooling_type */ llama_pooling_type(impl->ctx_tgt),
/* chat_params */ impl->chat_params,
/* chat_template_caps */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()),
@@ -3149,10 +3251,10 @@ server_context_meta server_context::get_meta() const {
/* model_vocab_type */ llama_vocab_type(impl->vocab),
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model),
/* model_n_embd_inp */ llama_model_n_embd(impl->model),
/* model_n_params */ llama_model_n_params(impl->model),
/* model_size */ llama_model_size(impl->model),
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_tgt),
/* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt),
/* model_n_params */ llama_model_n_params(impl->model_tgt),
/* model_size */ llama_model_size(impl->model_tgt),
};
}
@@ -4054,7 +4156,7 @@ void server_routes::init_routes() {
std::vector<server_task> tasks;
tasks.reserve(documents.size());
for (size_t i = 0; i < documents.size(); i++) {
auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
auto tmp = format_prompt_rerank(ctx_server.model_tgt, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = rd.get_new_id();
task.tokens = std::move(tmp);

View File

@@ -1981,7 +1981,7 @@ size_t server_prompt_cache::n_tokens() const {
return res;
}
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) {
// first check if the current state is contained fully in the cache
for (auto it = states.begin(); it != states.end(); ++it) {
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
@@ -2005,11 +2005,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
}
}
std::vector<uint8_t> state_data;
std::vector<uint8_t> state_data_tgt;
std::vector<uint8_t> state_data_dft;
// check if we can allocate enough memory for the new state
try {
state_data.resize(state_size);
state_data_tgt.resize(state_size_tgt);
state_data_dft.resize(state_size_dft);
} catch (const std::bad_alloc & e) {
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
@@ -2022,17 +2024,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
return nullptr;
}
auto & cur = states.emplace_back();
cur = {
states.push_back({
/*.tokens =*/ prompt.tokens.clone(),
/*.data =*/ std::move(state_data),
/*.data =*/ {
/*.main =*/ std::move(state_data_tgt),
/*.drft =*/ std::move(state_data_dft),
},
/*.checkpoints =*/ prompt.checkpoints,
};
});
return &cur;
return &states.back();
}
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) {
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
@@ -2065,16 +2069,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
if (it_best != states.end()) {
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
const size_t size = it_best->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
{
auto & data = it_best->data.main;
return false;
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
it_best->data.clear();
it_best->data.shrink_to_fit();
{
auto & data = it_best->data.drft;
if (!data.empty()) {
GGML_ASSERT(ctx_dft);
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
}
prompt = std::move(*it_best);

View File

@@ -565,42 +565,29 @@ struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override;
};
struct server_prompt_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
int64_t n_tokens;
std::vector<uint8_t> data;
struct server_prompt_data {
std::vector<uint8_t> main;
std::vector<uint8_t> drft;
size_t size() const {
return data.size();
}
bool empty() const {
return data.empty();
}
void clear() {
pos_min = 0;
pos_max = 0;
n_tokens = 0;
data.clear();
return main.size() + drft.size();
}
};
struct server_prompt {
server_tokens tokens;
std::vector<uint8_t> data;
server_prompt_data data;
std::list<server_prompt_checkpoint> checkpoints;
std::list<common_prompt_checkpoint> checkpoints;
size_t size() const {
size_t res = data.size();
size_t res = 0;
for (const auto & checkpoint : checkpoints) {
res += checkpoint.size();
res += data.size();
for (const auto & ckpt : checkpoints) {
res += ckpt.size();
}
return res;
@@ -614,7 +601,7 @@ struct server_prompt {
return server_prompt {
tokens.clone(),
data,
checkpoints
checkpoints,
};
}
};
@@ -637,9 +624,9 @@ struct server_prompt_cache {
size_t n_tokens() const;
server_prompt * alloc(const server_prompt & prompt, size_t state_size);
server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot);
void update();
};

View File

@@ -5,7 +5,7 @@ from utils import *
server = ServerPreset.stories15m_moe()
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf"
def create_server():
global server