mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-10 19:14:07 +00:00
spec : update common_speculative_init()
[no ci]
This commit is contained in:
@@ -307,8 +307,8 @@ struct common_params_speculative_draft {
|
||||
|
||||
common_params_model mparams;
|
||||
|
||||
// the draft context
|
||||
llama_context * ctx = nullptr;
|
||||
llama_context * ctx_tgt = nullptr;
|
||||
llama_context * ctx_dft = nullptr;
|
||||
|
||||
bool use_ckpt = false;
|
||||
|
||||
|
||||
@@ -282,7 +282,6 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
auto * spec = this;
|
||||
|
||||
auto & batch = spec->batch;
|
||||
auto & ctx_tgt = spec->ctx_tgt;
|
||||
auto & ctx_dft = spec->ctx_dft;
|
||||
auto & smpl = spec->smpl;
|
||||
auto & prompt_dft = spec->prompt_dft;
|
||||
@@ -897,11 +896,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
||||
|
||||
// initialization of the speculative decoding system
|
||||
//
|
||||
common_speculative * common_speculative_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt) {
|
||||
llama_context * ctx_dft = params.draft.ctx;
|
||||
|
||||
common_speculative * common_speculative_init(common_params_speculative & params) {
|
||||
// Compute the implementations to use based on the config and their order of preference
|
||||
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
|
||||
{
|
||||
@@ -964,8 +959,8 @@ common_speculative * common_speculative_init(
|
||||
break;
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .ctx_dft = */ ctx_dft,
|
||||
/* .ctx_tgt = */ params.draft.ctx_tgt,
|
||||
/* .ctx_dft = */ params.draft.ctx_dft,
|
||||
/* .use_ckpt = */ params.draft.use_ckpt
|
||||
));
|
||||
break;
|
||||
|
||||
@@ -14,9 +14,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
||||
// convert type to string
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||
|
||||
common_speculative * common_speculative_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt);
|
||||
common_speculative * common_speculative_init(common_params_speculative & params);
|
||||
|
||||
void common_speculative_free(common_speculative * spec);
|
||||
|
||||
|
||||
@@ -105,7 +105,8 @@ int main(int argc, char ** argv) {
|
||||
auto cparams = common_context_params_to_llama(params_dft);
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
|
||||
params.speculative.draft.ctx = ctx_dft.get();
|
||||
params.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
}
|
||||
|
||||
// Tokenize the prompt
|
||||
@@ -161,7 +162,7 @@ int main(int argc, char ** argv) {
|
||||
// init the speculator
|
||||
const auto & params_spec = params.speculative;
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative);
|
||||
|
||||
common_speculative_begin(spec, prompt_tgt);
|
||||
|
||||
|
||||
@@ -806,7 +806,8 @@ private:
|
||||
auto cparams = common_context_params_to_llama(params_dft);
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
|
||||
params_base.speculative.draft.ctx = ctx_dft.get();
|
||||
params_base.speculative.draft.ctx_tgt = ctx;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
params_base.speculative.draft.use_ckpt = common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
}
|
||||
|
||||
@@ -911,7 +912,7 @@ private:
|
||||
// try speculative decoding
|
||||
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
try {
|
||||
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
|
||||
slot.spec.reset(common_speculative_init(params_base.speculative));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user