mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 12:34:05 +00:00
spec : refactor
This commit is contained in:
@@ -50,7 +50,6 @@ struct server_slot {
|
||||
|
||||
// TODO: change to unique_ptrs for consistency:
|
||||
llama_context * ctx = nullptr;
|
||||
llama_context * ctx_dft = nullptr;
|
||||
|
||||
// multimodal
|
||||
mtmd_context * mctx = nullptr;
|
||||
@@ -256,9 +255,8 @@ struct server_slot {
|
||||
return state != SLOT_STATE_IDLE;
|
||||
}
|
||||
|
||||
// Checks if a draft model is active or self-speculation using context-tokens
|
||||
bool can_speculate() const {
|
||||
return task->params.speculative.configs.size() > 0;
|
||||
return !!spec;
|
||||
}
|
||||
|
||||
void add_token(const completion_token_output & token) {
|
||||
@@ -553,18 +551,13 @@ private:
|
||||
|
||||
// note: keep these alive - they determine the lifetime of the model, context, etc.
|
||||
common_init_result_ptr llama_init;
|
||||
common_init_result_ptr llama_init_dft;
|
||||
|
||||
llama_context * ctx = nullptr;
|
||||
|
||||
bool vocab_dft_compatible = true;
|
||||
|
||||
llama_model * model_dft = nullptr;
|
||||
|
||||
llama_context_params cparams_dft;
|
||||
|
||||
llama_batch batch {};
|
||||
|
||||
llama_model_ptr model_dft;
|
||||
|
||||
bool add_bos_token = true;
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
@@ -597,9 +590,6 @@ private:
|
||||
|
||||
// Clear any sampling context
|
||||
for (server_slot & slot : slots) {
|
||||
llama_free(slot.ctx_dft);
|
||||
slot.ctx_dft = nullptr;
|
||||
|
||||
common_speculative_free(slot.spec);
|
||||
slot.spec = nullptr;
|
||||
}
|
||||
@@ -646,44 +636,26 @@ private:
|
||||
|
||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||
|
||||
if (params_base.has_speculative()) {
|
||||
if (params_base.speculative.has_dft()) {
|
||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
|
||||
|
||||
const auto & params_spec = params_base.speculative;
|
||||
|
||||
auto params_dft = params_base;
|
||||
|
||||
params_dft.devices = params_base.speculative.devices;
|
||||
params_dft.model = params_base.speculative.model;
|
||||
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
|
||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
||||
params_dft.cache_type_v = params_base.speculative.cache_type_v;
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.model;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
|
||||
params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads;
|
||||
params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads;
|
||||
params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
|
||||
params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
|
||||
|
||||
llama_init_dft = common_init_from_params(params_dft);
|
||||
|
||||
model_dft = llama_init_dft->model();
|
||||
auto mparams_dft = common_model_params_to_llama(params_dft);
|
||||
|
||||
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
|
||||
if (model_dft == nullptr) {
|
||||
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
|
||||
SRV_ERR("failed to load draft model, '%s'\n", params_spec.model.path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
|
||||
if (!vocab_dft_compatible) {
|
||||
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
|
||||
}
|
||||
|
||||
const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
|
||||
|
||||
cparams_dft = common_context_params_to_llama(params_dft);
|
||||
cparams_dft.n_batch = n_ctx_dft;
|
||||
|
||||
// the context is not needed - we will create one for each slot
|
||||
llama_init_dft->free_context();
|
||||
}
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
@@ -693,6 +665,7 @@ private:
|
||||
}
|
||||
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
@@ -700,6 +673,7 @@ private:
|
||||
mparams.warmup = params_base.warmup;
|
||||
mparams.image_min_tokens = params_base.image_min_tokens;
|
||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||
|
||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
||||
if (mctx == nullptr) {
|
||||
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
||||
@@ -716,11 +690,6 @@ private:
|
||||
params_base.n_cache_reuse = 0;
|
||||
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
||||
}
|
||||
|
||||
if (params_base.has_speculative()) {
|
||||
SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!llama_memory_can_shift(llama_get_memory(ctx))) {
|
||||
@@ -755,30 +724,40 @@ private:
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot slot;
|
||||
|
||||
slot.id = i;
|
||||
slot.ctx = ctx;
|
||||
slot.id = i;
|
||||
slot.ctx = ctx;
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
slot.mctx = mctx;
|
||||
|
||||
slot.mctx = mctx;
|
||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||
|
||||
if (model_dft) {
|
||||
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
|
||||
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
||||
if (slot.ctx_dft == nullptr) {
|
||||
SRV_ERR("%s", "failed to create draft context\n");
|
||||
return false;
|
||||
}
|
||||
// try speculative decoding
|
||||
{
|
||||
const auto & params_spec = params_base.speculative;
|
||||
|
||||
slot.spec = common_speculative_init(params_base.speculative, slot.ctx, slot.ctx_dft);
|
||||
if (slot.spec == nullptr) {
|
||||
SRV_ERR("%s", "failed to create speculator\n");
|
||||
return false;
|
||||
auto params_dft = params_base;
|
||||
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx);
|
||||
params_dft.cache_type_k = params_spec.cache_type_k;
|
||||
params_dft.cache_type_v = params_spec.cache_type_v;
|
||||
|
||||
params_dft.cpuparams.n_threads = params_spec.cpuparams.n_threads;
|
||||
params_dft.cpuparams_batch.n_threads = params_spec.cpuparams_batch.n_threads;
|
||||
|
||||
auto cparams_dft = common_context_params_to_llama(params_dft);
|
||||
|
||||
slot.spec = common_speculative_init(params_base.speculative, slot.ctx, cparams_dft, model_dft.get());
|
||||
if (slot.spec) {
|
||||
if (mctx) {
|
||||
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
||||
return false;
|
||||
}
|
||||
SRV_WRN("%s", "speculative decoding context initialized\n");
|
||||
} else {
|
||||
SRV_WRN("%s", "speculative decoding context not initialized\n");
|
||||
}
|
||||
for (auto & pair : params_base.speculative.replacements) {
|
||||
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
|
||||
}
|
||||
} else if (params_base.speculative.configs.size() > 0) {
|
||||
slot.spec = common_speculative_init(params_base.speculative, nullptr, nullptr);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
||||
@@ -1057,7 +1036,7 @@ private:
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) {
|
||||
std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) const {
|
||||
std::vector<common_adapter_lora_info> output = params_base.lora_adapters; // copy
|
||||
for (size_t i = 0; i < output.size(); ++i) {
|
||||
auto it = config.find(i);
|
||||
@@ -1160,7 +1139,7 @@ private:
|
||||
backend_sampling &= task.params.sampling.backend_sampling;
|
||||
|
||||
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
||||
backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0);
|
||||
backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0);
|
||||
|
||||
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
||||
backend_sampling &= !need_logits;
|
||||
@@ -2058,7 +2037,6 @@ private:
|
||||
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = slot.ctx_dft ? (llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max) : 0;
|
||||
params_spec.p_min = slot.task->params.speculative.p_min;
|
||||
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
||||
|
||||
Reference in New Issue
Block a user