spec : update common_speculative_init()

[no ci]
This commit is contained in:
Georgi Gerganov
2026-05-07 10:06:32 +03:00
parent 2466149c25
commit 4eec5542ce
5 changed files with 12 additions and 17 deletions

View File

@@ -307,8 +307,8 @@ struct common_params_speculative_draft {
common_params_model mparams;
// the draft context
llama_context * ctx = nullptr;
llama_context * ctx_tgt = nullptr;
llama_context * ctx_dft = nullptr;
bool use_ckpt = false;

View File

@@ -282,7 +282,6 @@ struct common_speculative_state_draft : public common_speculative_state {
auto * spec = this;
auto & batch = spec->batch;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft = spec->ctx_dft;
auto & smpl = spec->smpl;
auto & prompt_dft = spec->prompt_dft;
@@ -897,11 +896,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt) {
llama_context * ctx_dft = params.draft.ctx;
common_speculative * common_speculative_init(common_params_speculative & params) {
// Compute the implementations to use based on the config and their order of preference
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
{
@@ -964,8 +959,8 @@ common_speculative * common_speculative_init(
break;
case COMMON_SPECULATIVE_TYPE_DRAFT: {
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .ctx_tgt = */ params.draft.ctx_tgt,
/* .ctx_dft = */ params.draft.ctx_dft,
/* .use_ckpt = */ params.draft.use_ckpt
));
break;

View File

@@ -14,9 +14,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);
common_speculative * common_speculative_init(common_params_speculative & params);
void common_speculative_free(common_speculative * spec);

View File

@@ -105,7 +105,8 @@ int main(int argc, char ** argv) {
auto cparams = common_context_params_to_llama(params_dft);
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
params.speculative.draft.ctx = ctx_dft.get();
params.speculative.draft.ctx_tgt = ctx_tgt;
params.speculative.draft.ctx_dft = ctx_dft.get();
}
// Tokenize the prompt
@@ -161,7 +162,7 @@ int main(int argc, char ** argv) {
// init the speculator
const auto & params_spec = params.speculative;
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
struct common_speculative * spec = common_speculative_init(params.speculative);
common_speculative_begin(spec, prompt_tgt);

View File

@@ -806,7 +806,8 @@ private:
auto cparams = common_context_params_to_llama(params_dft);
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
params_base.speculative.draft.ctx = ctx_dft.get();
params_base.speculative.draft.ctx_tgt = ctx;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
params_base.speculative.draft.use_ckpt = common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
}
@@ -911,7 +912,7 @@ private:
// try speculative decoding
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
try {
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
slot.spec.reset(common_speculative_init(params_base.speculative));
} catch (const std::exception & e) {
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
}