cont : clean-up

This commit is contained in:
Georgi Gerganov
2026-05-09 14:09:45 +03:00
parent ce0acf03ea
commit b3bd3bd4cc
4 changed files with 84 additions and 77 deletions

View File

@@ -116,11 +116,13 @@ static bool common_speculative_are_compatible(
return true;
}
using common_speculative_draft_params_vec = std::vector<common_speculative_draft_params>;
// state of an implementation of speculative decoding
//
// each implementation has a unique type and a state that is implementation-specific
// in a subclass of common_speculative_state
struct common_speculative_state {
// in a subclass of common_speculative_impl
struct common_speculative_impl {
const common_speculative_type type;
uint32_t n_seq;
@@ -141,9 +143,9 @@ struct common_speculative_state {
int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds.
int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds.
common_speculative_state(common_speculative_type type, uint32_t n_seq) : type(type), n_seq(n_seq) {}
common_speculative_impl(common_speculative_type type, uint32_t n_seq) : type(type), n_seq(n_seq) {}
virtual ~common_speculative_state() = default;
virtual ~common_speculative_impl() = default;
virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0;
@@ -152,7 +154,7 @@ struct common_speculative_state {
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
};
struct common_speculative_state_draft : public common_speculative_state {
struct common_speculative_state_draft : public common_speculative_impl {
common_params_speculative_draft params;
llama_batch batch;
@@ -160,7 +162,7 @@ struct common_speculative_state_draft : public common_speculative_state {
std::vector<common_sampler_ptr> smpls;
common_speculative_state_draft(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_state(COMMON_SPECULATIVE_TYPE_DRAFT, n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT, n_seq)
, params(params.draft)
{
auto * ctx_dft = this->params.ctx_dft;
@@ -333,11 +335,11 @@ struct common_speculative_state_draft : public common_speculative_state {
}
};
struct common_speculative_state_eagle3 : public common_speculative_state {
struct common_speculative_state_eagle3 : public common_speculative_impl {
//common_params_speculative_eagle3 params;
common_speculative_state_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq)
: common_speculative_state(COMMON_SPECULATIVE_TYPE_EAGLE3, n_seq) {}
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_EAGLE3, n_seq) {}
void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
// noop
@@ -353,7 +355,7 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
};
// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_state {
struct common_speculative_state_ngram_simple : public common_speculative_impl {
common_params_speculative_ngram_map params;
// shared across all sequences
@@ -362,7 +364,7 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
common_speculative_state_ngram_simple(
const common_params_speculative & params, uint32_t n_seq,
common_ngram_simple_config config)
: common_speculative_state(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq)
, params(params.ngram_simple)
, config(config) {}
@@ -388,7 +390,7 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
}
};
struct common_speculative_state_ngram_map_k : public common_speculative_state {
struct common_speculative_state_ngram_map_k : public common_speculative_impl {
common_params_speculative_ngram_map params;
// n_seq configs
@@ -398,7 +400,7 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state {
const common_params_speculative & params,
const common_ngram_map & config,
uint32_t n_seq)
: common_speculative_state(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq)
, params(params.ngram_map_k) {
for (uint32_t i = 0; i < n_seq; i++) {
this->config.push_back(config);
@@ -431,7 +433,7 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state {
}
};
struct common_speculative_state_ngram_mod : public common_speculative_state {
struct common_speculative_state_ngram_mod : public common_speculative_impl {
common_params_speculative_ngram_mod params;
// shared across all sequences
@@ -456,7 +458,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
common_speculative_state_ngram_mod(
const common_params_speculative & params,
uint32_t n_seq)
: common_speculative_state(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq)
, params(params.ngram_mod)
, mod(params.ngram_mod.n_match, 4*1024*1024)
, verbose(std::getenv("LLAMA_TRACE") != nullptr) {
@@ -507,9 +509,11 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
auto & sinfo = sinfos[seq_id];
auto & result = *dparams.result;
const auto & prompt = *dparams.prompt;
sinfo.n_draft_last = 0;
const size_t cur_len = dparams.prompt->size();
const size_t cur_len = prompt.size();
if (cur_len < mod.get_n()) {
return;
}
@@ -519,7 +523,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
// add new ngrams in chunks
if (sinfo.i_last + 32 < cur_len) {
for (size_t i = sinfo.i_last; i < cur_len - n; ++i) {
mod.add(dparams.prompt->data() + i);
mod.add(prompt.data() + i);
}
sinfo.i_last = cur_len - n;
@@ -527,7 +531,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
result.resize(n + params.n_max);
for (size_t i = 0; i < n - 1; ++i) {
result[i] = dparams.prompt->at(cur_len - n + 1 + i);
result[i] = prompt.at(cur_len - n + 1 + i);
}
result[n - 1] = dparams.id_last;
@@ -592,7 +596,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
}
};
struct common_speculative_state_ngram_cache : public common_speculative_state {
struct common_speculative_state_ngram_cache : public common_speculative_impl {
common_params_speculative_ngram_cache params;
uint16_t n_draft;
@@ -618,7 +622,7 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
const std::string & path_dynamic,
bool save_dynamic,
bool save_static)
: common_speculative_state(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, n_seq)
, params(params.ngram_cache)
, n_draft(n_draft)
, save_dynamic(save_dynamic)
@@ -721,11 +725,13 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
};
struct common_speculative {
common_speculative_draft_params_vec dparams;
// list of implementations to use and their states
std::vector<std::unique_ptr<common_speculative_state>> impls;
std::vector<std::unique_ptr<common_speculative_impl>> impls;
// which implementaion was used for a given seq_id
std::vector<common_speculative_state *> impl_last;
std::vector<common_speculative_impl *> impl_last;
};
static common_ngram_map get_common_ngram_map(
@@ -830,7 +836,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
}
}
std::vector<std::unique_ptr<common_speculative_state>> impls = {};
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
for (const common_speculative_config & config : configs) {
LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str());
@@ -894,8 +900,9 @@ common_speculative * common_speculative_init(common_params_speculative & params,
}
auto * result = new common_speculative {
/* .dparams = */ common_speculative_draft_params_vec(n_seq),
/* .impls = */ std::move(impls),
/* .impl_last = */ std::vector<common_speculative_state *>(n_seq, nullptr)
/* .impl_last = */ std::vector<common_speculative_impl *>(n_seq, nullptr)
};
return result;
@@ -909,6 +916,15 @@ void common_speculative_free(common_speculative * spec) {
delete spec;
}
common_speculative_draft_params & common_speculative_get_draft_params(
common_speculative * spec,
llama_seq_id seq_id) {
GGML_ASSERT(spec);
GGML_ASSERT(seq_id < (llama_seq_id) spec->dparams.size());
return spec->dparams[seq_id];
}
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt) {
if (spec == nullptr) {
return;
@@ -921,9 +937,13 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
}
}
void common_speculative_draft(
common_speculative * spec,
common_speculative_draft_params_vec & dparams) {
void common_speculative_draft(common_speculative * spec) {
if (!spec) {
return;
}
auto & dparams = spec->dparams;
{
int n_drafting = 0;
@@ -994,7 +1014,7 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
return;
}
common_speculative_state * impl = spec->impl_last[seq_id];
common_speculative_impl * impl = spec->impl_last[seq_id];
GGML_ASSERT(impl);

View File

@@ -18,19 +18,11 @@ common_speculative * common_speculative_init(common_params_speculative & params,
void common_speculative_free(common_speculative * spec);
// optionally call once at the beginning of a new generation
// TODO: when common_speculative_process() is implemented, we can remove this _begin() function and
// implement all the logic within common_speculative_process()
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt);
// TODO: implement [TAG_COMMON_SPECULATIVE_PROCESS]
//bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
struct common_speculative_draft_params {
// this flag helps chain the drafts through all the implementations
// this flag is used to chain the drafts through all the available implementations
// after the first successful draft from an implementation, we set it
// to false to prevent further drafts for that sequence
bool drafting = true;
bool drafting = false;
// overrides individual configurations (-1 disabled)
// can be used to constraint the max draft based on the remaining context size
@@ -39,18 +31,24 @@ struct common_speculative_draft_params {
llama_pos n_past;
llama_token id_last;
// TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls
const llama_tokens * prompt;
// the generated draft from the last _draft() call
llama_tokens * result;
};
using common_speculative_draft_params_vec = std::vector<common_speculative_draft_params>;
common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id);
// optionally call once at the beginning of a new generation
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt);
// TODO: implement [TAG_COMMON_SPECULATIVE_PROCESS]
//bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
// generate drafts for the sequences specified in dparams
// requires that `dparams.size() == n_seq` using during common_speculative_init()
void common_speculative_draft(
common_speculative * spec,
common_speculative_draft_params_vec & dparams);
void common_speculative_draft(common_speculative * spec);
// informs the speculative decoder that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);

View File

@@ -179,8 +179,7 @@ int main(int argc, char ** argv) {
}
// generate a new draft
common_speculative_draft_params_vec dparams(1);
dparams[seq_id] = {
common_speculative_get_draft_params(spec, seq_id) = {
/* .drafting = */ true,
/* .n_max = */ -1,
/* .n_past = */ n_past,
@@ -188,7 +187,7 @@ int main(int argc, char ** argv) {
/* .prompt = */ &prompt_tgt,
/* .result = */ &draft, // output
};
common_speculative_draft(spec, dparams);
common_speculative_draft(spec);
// save the original draft size
n_draft = draft.size();

View File

@@ -616,7 +616,6 @@ private:
common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
common_speculative_ptr spec;
common_speculative_draft_params_vec spec_dparams;
bool add_bos_token = true;
@@ -838,7 +837,6 @@ private:
if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
try {
spec.reset(common_speculative_init(params_base.speculative, params_base.n_parallel));
spec_dparams.resize(params_base.n_parallel);
} catch (const std::exception & e) {
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
}
@@ -2186,10 +2184,11 @@ private:
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
// first, process slots that are speculative decoding
for (auto & slot : slots) {
spec_dparams[slot.id].drafting = false;
std::vector<server_slot *> generating;
std::vector<server_slot *> drafting;
// determine which slots are generating and drafting
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) {
continue;
}
@@ -2201,6 +2200,12 @@ private:
continue;
}
generating.push_back(&slot);
if (spec) {
common_speculative_get_draft_params(spec.get(), slot.id).drafting = false;
}
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
@@ -2226,34 +2231,30 @@ private:
slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
spec_dparams[slot.id] = common_speculative_draft_params {
slot.spec_prompt = slot.prompt.tokens.get_text_tokens();
common_speculative_get_draft_params(spec.get(), slot.id) = {
/* .drafting = */ true,
/* .n_max = */ n_draft_max,
/* .n_past = */ slot.prompt.n_tokens(),
/* .id_last = */ slot.sampled,
/* .prompt = */ nullptr,
/* .prompt = */ &slot.spec_prompt,
/* .result = */ &slot.spec_draft,
};
drafting.push_back(&slot);
}
}
}
// generate the actual drafts (if any)
{
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) spec_dparams.size(); seq_id++) {
auto & slot = slots[seq_id];
auto & dp = spec_dparams[seq_id];
slot.spec_prompt = slot.prompt.tokens.get_text_tokens();
dp.prompt = &slot.spec_prompt;
}
common_speculative_draft(spec.get(), spec_dparams);
common_speculative_draft(spec.get());
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) spec_dparams.size(); seq_id++) {
auto & slot = slots[seq_id];
// make checkpoints if needed
for (auto * slot_ptr : drafting) {
auto & slot = *slot_ptr;
slot.n_draft_total += slot.spec_draft.size();
@@ -2282,20 +2283,9 @@ private:
}
}
slot_batched = nullptr;
// add the speculative drafts to the batch, or simply add the sampled tokens
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) {
continue;
}
// check if we can batch this slot with the previous one
if (!slot_batched) {
slot_batched = &slot;
} else if (!slot_batched->can_batch_with(slot)) {
continue;
}
// update the batch with the sampled/drafted tokens
for (auto * slot_ptr : generating) {
auto & slot = *slot_ptr;
slot.update_batch(batch);
}