spec : parallel drafting support (#22838)

* spec : refactor

* spec : drop support for incompatible vocabs

* spec : update common_speculative_init()

* cont : pass seq_id

* cont : dedup ctx_seq_rm_type

* server : sketch the ctx_dft decode loop

* server : draft prompt cache and checkpoints

* server : improve ctx names

* server, spec : transition to unified spec context

* cont : sync main and drft contexts

* cont : async drft eval when possible

* cont : handle non-ckpt models

* cont : pass correct n_past for drafting

* cont : process images throught the draft context

* spec : handle draft running out of context

* server : fix mtmd draft processing

* server : fix URL for draft model

* server : add comment

* server : clean-up + dry

* speculative-simple : update

* spec : fix n_past type

* server : fix slot ctx_drft ptr

* tools : update readme

* naming : improve consistency

* spec : refactor for multi-sequence speculative context

* cont : prepare params

* cont : prepare params

* spec : support parallel drafts

* server : support parallel drafting

* llama : reuse device buffers when possible

* server, spec : clean-up

* cont : clean-up

* cont : minor

* spec : reset `drafting` flag at the end

* spec : introduce `common_speculative_process()`

* spec : allow for multiple spec types (chain of speculators)

* replace old type field of type common_speculative_type in the
  common_params_speculative struct with a vector to allow multiple
  types to be specified

* introduce common_get_enabled_speculative_impls(const std::vector<enum common_speculative_type>)
  to figure out which implementations the user has enabled

* introduce common_speculative_type_from_names(const std::vector<std::string> & names)
  to parse the already user provided spec types

* all speculators run sequentially, best one wins (we verify its drafted tokens)

* maximize expected accepted tokens for current round by calculating the
  product between the probability of accepting current token (n_acc_tokens / n_gen_drafts)
  and the draft's length

---------

Co-authored-by: Petros Sideris <petros.sideris@nokia.com>
This commit is contained in:
Georgi Gerganov
2026-05-11 19:09:43 +03:00
committed by GitHub
parent 928b486b0c
commit 68e7ea3eab
14 changed files with 1286 additions and 1071 deletions

View File

@@ -622,10 +622,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
string_process_escapes(seq_breaker);
}
for (auto & pair : params.speculative.draft.replacements) {
string_process_escapes(pair.first);
string_process_escapes(pair.second);
}
}
if (!params.kv_overrides.empty()) {
@@ -3518,13 +3514,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.draft.p_min = std::stof(value);
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN"));
add_opt(common_arg(
{"--spec-draft-ctx-size", "-cd", "--ctx-size-draft"}, "N",
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.draft.n_ctx),
[](common_params & params, int value) {
params.speculative.draft.n_ctx = value;
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_CTX_SIZE"));
add_opt(common_arg(
{"--spec-draft-device", "-devd", "--device-draft"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
@@ -3561,32 +3550,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
add_opt(common_arg(
{"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT",
"translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
[](common_params & params, const std::string & tgt, const std::string & dft) {
params.speculative.draft.replacements.push_back({ tgt, dft });
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
{"--spec-type"}, common_speculative_all_types_str(),
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
common_speculative_type_to_str(params.speculative.type).c_str()),
common_speculative_type_name_str(params.speculative.types).c_str()),
[](common_params & params, const std::string & value) {
if (value == "none") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
} else if (value == "ngram-cache") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
} else if (value == "ngram-simple") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
} else if (value == "ngram-map-k") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
} else if (value == "ngram-map-k4v") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
} else if (value == "ngram-mod") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
} else {
throw std::invalid_argument("unknown speculative decoding type without draft model");
}
const auto enabled_types = string_split<std::string>(value, ',');
params.speculative.types = common_speculative_types_from_names(enabled_types);
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE"));
add_opt(common_arg(
@@ -4075,7 +4044,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--spec-default"},
string_format("enable default speculative decoding config"),
[](common_params & params) {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD };
params.speculative.ngram_mod.n_match = 24;
params.speculative.ngram_mod.n_min = 48;
params.speculative.ngram_mod.n_max = 64;

View File

@@ -1422,7 +1422,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
LOG_WRN("%s: the context does not support partial sequence removal\n", __func__);
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
goto done;
}
@@ -1960,3 +1960,102 @@ bool common_prompt_batch_decode(
return true;
}
size_t common_prompt_checkpoint::size() const {
return data_tgt.size() + data_dft.size();
}
bool common_prompt_checkpoint::empty() const {
return data_tgt.empty();
}
void common_prompt_checkpoint::clear() {
n_tokens = 0;
pos_min = 0;
pos_max = 0;
data_tgt.clear();
data_dft.clear();
}
void common_prompt_checkpoint::update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max) {
this->n_tokens = n_tokens;
this->pos_min = pos_min;
this->pos_max = pos_max;
}
void common_prompt_checkpoint::update_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) {
if (ctx == nullptr) {
return;
}
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
data_tgt.resize(ckpt_size);
const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}
void common_prompt_checkpoint::update_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) {
if (ctx == nullptr) {
return;
}
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
data_dft.resize(ckpt_size);
const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}
void common_prompt_checkpoint::load_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const {
if (ctx == nullptr) {
return;
}
if (data_tgt.empty()) {
return;
}
const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags);
if (n != data_tgt.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n);
}
}
void common_prompt_checkpoint::load_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const {
if (ctx == nullptr) {
return;
}
if (data_dft.empty()) {
return;
}
const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags);
if (n != data_dft.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n);
}
}

View File

@@ -295,8 +295,6 @@ struct common_params_model {
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
};
struct common_ngram_mod;
// draft-model-based speculative decoding parameters
struct common_params_speculative_draft {
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
@@ -307,11 +305,9 @@ struct common_params_speculative_draft {
common_params_model mparams;
llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts
llama_context * ctx_tgt = nullptr;
llama_context * ctx_dft = nullptr;
llama_context_params cparams; // these are the parameters for the draft llama_context
int32_t n_ctx = 0; // draft context size
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
@@ -322,7 +318,6 @@ struct common_params_speculative_draft {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
};
@@ -331,9 +326,6 @@ struct common_params_speculative_ngram_mod {
int32_t n_max = 64;
int32_t n_min = 48;
// shared instance of the ngram container for all speculative decoding contexts
std::shared_ptr<common_ngram_mod> obj;
};
struct common_params_speculative_ngram_map {
@@ -348,8 +340,7 @@ struct common_params_speculative_ngram_cache {
};
struct common_params_speculative {
// TODO: become a vector in order to support "chains of speculators"
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
std::vector<enum common_speculative_type> types = { COMMON_SPECULATIVE_TYPE_NONE };
common_params_speculative_draft draft;
@@ -1026,3 +1017,47 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
// "adamw" or "sgd" (case insensitive)
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
//
// prompt utils
//
struct common_prompt_checkpoint {
int64_t n_tokens;
llama_pos pos_min;
llama_pos pos_max;
std::vector<uint8_t> data_tgt;
std::vector<uint8_t> data_dft;
size_t size() const;
bool empty() const;
void clear();
void update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max);
void update_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);
void update_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);
void load_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;
void load_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;
};

File diff suppressed because it is too large Load Diff

View File

@@ -5,8 +5,14 @@
struct common_speculative;
// comma separated list the provided types
std::string common_speculative_type_name_str(const std::vector<enum common_speculative_type> & types);
// comma separated list of all types
std::string common_speculative_type_name_str();
const char * common_speculative_all_types_str();
// parse user provided types
std::vector<enum common_speculative_type> common_speculative_types_from_names(const std::vector<std::string> & names);
// convert string to type
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
@@ -14,27 +20,44 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);
common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq);
void common_speculative_free(common_speculative * spec);
struct common_speculative_draft_params {
// this flag is used to chain the drafts through all the available implementations
// after the first successful draft from an implementation, we set it
// to false to prevent further drafts for that sequence
// at the end of the draft() call, all drafting flags will be reset to false
bool drafting = false;
// overrides individual configurations (-1 disabled)
// can be used to constraint the max draft based on the remaining context size
int32_t n_max = -1;
llama_pos n_past;
llama_token id_last;
// TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls
const llama_tokens * prompt;
// the generated draft from the last _draft() call
llama_tokens * result;
};
common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id);
// optionally call once at the beginning of a new generation
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_draft(
common_speculative * spec,
const common_params_speculative & params,
const llama_tokens & prompt,
llama_token id_last);
// process the batch and update the internal state of the speculative context
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
// informs the speculative decoder that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
void common_speculative_draft(common_speculative * spec);
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params);
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params);
// informs the speculative context that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);

View File

@@ -13,20 +13,6 @@
#include <vector>
#include <utility>
struct spec_checkpoint {
int64_t n_tokens = 0;
std::vector<uint8_t> data;
size_t size() const {
return data.size();
}
bool empty() const {
return data.empty();
}
};
int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
@@ -43,11 +29,6 @@ int main(int argc, char ** argv) {
return 1;
}
if (params.speculative.draft.mparams.path.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
// init llama.cpp
llama_backend_init();
llama_numa_init(params.numa);
@@ -62,18 +43,11 @@ int main(int argc, char ** argv) {
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();
// check if the context supports partial sequence removal
const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt);
const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
if (use_ckpt) {
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
}
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
// load the draft model
llama_model_ptr model_dft;
llama_context_ptr ctx_dft;
// TODO: simplify this logic
{
@@ -81,9 +55,6 @@ int main(int argc, char ** argv) {
auto params_dft = params;
params_dft.n_parallel = 1;
params_dft.n_ctx = params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
@@ -103,8 +74,19 @@ int main(int argc, char ** argv) {
return 1;
}
params.speculative.draft.model = model_dft.get();
params.speculative.draft.cparams = common_context_params_to_llama(params_dft);
auto cparams = common_context_params_to_llama(params_dft);
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
params.speculative.draft.ctx_tgt = ctx_tgt;
params.speculative.draft.ctx_dft = ctx_dft.get();
}
// check if the context supports partial sequence removal
const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
if (use_ckpt_tgt) {
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
}
// Tokenize the prompt
@@ -136,6 +118,8 @@ int main(int argc, char ** argv) {
// used to determine end of generation
bool has_eos = false;
llama_seq_id seq_id = 0;
// ================================================
// everything until here is standard initialization
// the relevant stuff for speculative decoding starts here
@@ -146,7 +130,8 @@ int main(int argc, char ** argv) {
common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling));
// eval the prompt
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
llama_decode(ctx_dft.get(), llama_batch_get_one(inp.data(), inp.size() - 1));
// note: keep the last token separate!
llama_token id_last = inp.back();
@@ -160,16 +145,16 @@ int main(int argc, char ** argv) {
// init the speculator
const auto & params_spec = params.speculative;
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
struct common_speculative * spec = common_speculative_init(params.speculative, 1);
common_speculative_begin(spec, prompt_tgt);
common_speculative_begin(spec, seq_id, prompt_tgt);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
size_t n_draft = 0;
llama_tokens draft;
spec_checkpoint spec_ckpt;
common_prompt_checkpoint ckpt;
const auto t_enc_end = ggml_time_us();
@@ -184,40 +169,57 @@ int main(int argc, char ** argv) {
// from a cache or lookup tables.
//
if (draft.empty()) {
ckpt.update_pos(
prompt_tgt.size(),
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), seq_id),
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id));
if (use_ckpt_dft) {
ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
// generate a new draft
draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
common_speculative_get_draft_params(spec, seq_id) = {
/* .drafting = */ true,
/* .n_max = */ -1,
/* .n_past = */ n_past,
/* .id_last = */ id_last,
/* .prompt = */ &prompt_tgt,
/* .result = */ &draft, // output
};
common_speculative_draft(spec);
// save the original draft size
n_draft = draft.size();
// save a checkpoint of the target context before evaluating the draft
// this allows us to restore the state if partial draft acceptance occurs
if (!draft.empty() && use_ckpt) {
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
spec_ckpt.data.resize(ckpt_size);
if (!draft.empty()) {
if (use_ckpt_tgt) {
ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
}
const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
GGML_ASSERT(n == ckpt_size);
{
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
spec_ckpt.n_tokens = (int64_t) prompt_tgt.size();
LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n",
spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
}
} else {
// we have a previous (partial) draft to reuse from checkpoint restoration
if (use_ckpt) {
GGML_ASSERT(!spec_ckpt.empty());
if (use_ckpt_tgt) {
GGML_ASSERT(!ckpt.empty());
}
}
// always have a token to evaluate from before - id_last
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
common_batch_add (batch_tgt, id_last, n_past++, { seq_id }, true);
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
common_batch_add(batch_tgt, draft[i], n_past + i, { seq_id }, true);
}
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
@@ -225,9 +227,15 @@ int main(int argc, char ** argv) {
llama_decode(ctx_tgt, batch_tgt);
}
// evaluate the same batch with the draft model
{
// TODO: extend to support MTP, Eagle, etc. See server code for reference
llama_decode(ctx_dft.get(), batch_tgt);
}
// only save the sampler sampler state if we use checkpoints
common_sampler_ptr smpl_save;
if (use_ckpt) {
if (use_ckpt_tgt) {
smpl_save.reset(common_sampler_clone(smpl.get()));
}
@@ -247,17 +255,24 @@ int main(int argc, char ** argv) {
// check for partial draft acceptance:
// if the context doesn't support partial sequence removal, restore the checkpoint
// and make the accepted tokens the new partial draft for the next iteration
if (use_ckpt && ids.size() - 1 < draft.size()) {
if (use_ckpt_tgt && ids.size() - 1 < draft.size()) {
LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size());
draft = std::move(ids);
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
GGML_ASSERT(n == spec_ckpt.size());
{
ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, ckpt.pos_max + 1, -1);
}
prompt_tgt.resize(spec_ckpt.n_tokens);
{
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
}
prompt_tgt.resize(ckpt.n_tokens);
smpl = std::move(smpl_save);
n_past = (int) prompt_tgt.size();
@@ -265,7 +280,7 @@ int main(int argc, char ** argv) {
continue;
}
common_speculative_accept(spec, ids.size() - 1);
common_speculative_accept(spec, seq_id, ids.size() - 1);
// full acceptance: consume the draft and commit accepted tokens
n_past += ids.size() - 1;
@@ -305,7 +320,8 @@ int main(int argc, char ** argv) {
{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, n_past, -1);
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, n_past, -1);
}
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {

View File

@@ -858,6 +858,8 @@ extern "C" {
size_t n_token_capacity,
size_t * n_token_count_out);
#define LLAMA_STATE_SEQ_FLAGS_NONE 0
// for backwards-compat
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1

View File

@@ -2475,11 +2475,29 @@ public:
}
if (need_alloc) {
mbuf_cur = std::move(mbuf);
if (!mbuf_cur.buf || mbuf_cur.total_size != mbuf.total_size) {
mbuf_cur = std::move(mbuf);
mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft));
mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft));
LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
} else {
//LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
// save the old buffer and allocate the new tensors in it
auto buf = std::move(mbuf_cur.buf);
mbuf_cur = std::move(mbuf);
ggml_tallocr talloc = ggml_tallocr_new(buf.get());
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
ggml_backend_view_init(mbuf_cur.org[i]);
ggml_tallocr_alloc(&talloc, mbuf_cur.cpy[i]);
}
mbuf_cur.buf = std::move(buf);
}
}
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
@@ -2559,8 +2577,7 @@ public:
mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset));
auto & view = mbuf.org.back();
view->buffer = rinfo.tensor->buffer;
ggml_backend_view_init(mbuf.org.back());
}
for (auto & [buft, mbuf] : mbufs_new) {

View File

@@ -195,11 +195,9 @@
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |

View File

@@ -244,11 +244,9 @@ For the full list of features, please refer to [server's changelog](https://gith
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |

File diff suppressed because it is too large Load Diff

View File

@@ -76,7 +76,7 @@ json task_params::to_json(bool only_metrics) const {
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"generation_prompt", chat_parser_params.generation_prompt},
{"samplers", samplers},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.types", common_speculative_type_name_str(speculative.types)},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@@ -133,7 +133,7 @@ json task_params::to_json(bool only_metrics) const {
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"generation_prompt", chat_parser_params.generation_prompt},
{"samplers", samplers},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.types", common_speculative_type_name_str(speculative.types)},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@@ -296,6 +296,8 @@ task_params server_task::params_from_json_cmpl(
params.speculative = defaults.speculative;
// TODO: to keep things simple, we disable speculative parameter adjustments for now
#if 0
// TODO: for now, be able to adjust only the draft-model based speculative parameters
params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min);
params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max);
@@ -305,7 +307,6 @@ task_params server_task::params_from_json_cmpl(
params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0);
params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0);
#if 0
// for debugging and research purposes
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
@@ -1981,7 +1982,7 @@ size_t server_prompt_cache::n_tokens() const {
return res;
}
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) {
// first check if the current state is contained fully in the cache
for (auto it = states.begin(); it != states.end(); ++it) {
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
@@ -2005,11 +2006,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
}
}
std::vector<uint8_t> state_data;
std::vector<uint8_t> state_data_tgt;
std::vector<uint8_t> state_data_dft;
// check if we can allocate enough memory for the new state
try {
state_data.resize(state_size);
state_data_tgt.resize(state_size_tgt);
state_data_dft.resize(state_size_dft);
} catch (const std::bad_alloc & e) {
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
@@ -2022,17 +2025,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
return nullptr;
}
auto & cur = states.emplace_back();
cur = {
states.push_back({
/*.tokens =*/ prompt.tokens.clone(),
/*.data =*/ std::move(state_data),
/*.data =*/ {
/*.main =*/ std::move(state_data_tgt),
/*.drft =*/ std::move(state_data_dft),
},
/*.checkpoints =*/ prompt.checkpoints,
};
});
return &cur;
return &states.back();
}
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) {
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
@@ -2065,16 +2070,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
if (it_best != states.end()) {
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
const size_t size = it_best->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
{
auto & data = it_best->data.main;
return false;
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
it_best->data.clear();
it_best->data.shrink_to_fit();
{
auto & data = it_best->data.drft;
if (!data.empty()) {
GGML_ASSERT(ctx_dft);
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
}
prompt = std::move(*it_best);

View File

@@ -565,42 +565,29 @@ struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override;
};
struct server_prompt_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
int64_t n_tokens;
std::vector<uint8_t> data;
struct server_prompt_data {
std::vector<uint8_t> main;
std::vector<uint8_t> drft;
size_t size() const {
return data.size();
}
bool empty() const {
return data.empty();
}
void clear() {
pos_min = 0;
pos_max = 0;
n_tokens = 0;
data.clear();
return main.size() + drft.size();
}
};
struct server_prompt {
server_tokens tokens;
std::vector<uint8_t> data;
server_prompt_data data;
std::list<server_prompt_checkpoint> checkpoints;
std::list<common_prompt_checkpoint> checkpoints;
size_t size() const {
size_t res = data.size();
size_t res = 0;
for (const auto & checkpoint : checkpoints) {
res += checkpoint.size();
res += data.size();
for (const auto & ckpt : checkpoints) {
res += ckpt.size();
}
return res;
@@ -614,7 +601,7 @@ struct server_prompt {
return server_prompt {
tokens.clone(),
data,
checkpoints
checkpoints,
};
}
};
@@ -637,9 +624,9 @@ struct server_prompt_cache {
size_t n_tokens() const;
server_prompt * alloc(const server_prompt & prompt, size_t state_size);
server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot);
void update();
};

View File

@@ -5,7 +5,7 @@ from utils import *
server = ServerPreset.stories15m_moe()
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf"
def create_server():
global server