diff --git a/common/common.cpp b/common/common.cpp index 374d95d4ee..352af0b178 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1962,11 +1962,11 @@ bool common_prompt_batch_decode( } size_t common_prompt_checkpoint::size() const { - return data_main.size() + data_drft.size(); + return data_tgt.size() + data_dft.size(); } bool common_prompt_checkpoint::empty() const { - return data_main.empty(); + return data_tgt.empty(); } void common_prompt_checkpoint::clear() { @@ -1975,8 +1975,8 @@ void common_prompt_checkpoint::clear() { pos_min = 0; pos_max = 0; - data_main.clear(); - data_drft.clear(); + data_tgt.clear(); + data_dft.clear(); } void common_prompt_checkpoint::update_pos( @@ -1988,7 +1988,7 @@ void common_prompt_checkpoint::update_pos( this->pos_max = pos_max; } -void common_prompt_checkpoint::update_main( +void common_prompt_checkpoint::update_tgt( llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { @@ -1998,15 +1998,15 @@ void common_prompt_checkpoint::update_main( const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags); - data_main.resize(ckpt_size); + data_tgt.resize(ckpt_size); - const size_t n = llama_state_seq_get_data_ext(ctx, data_main.data(), ckpt_size, seq_id, flags); + const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags); if (n != ckpt_size) { GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); } } -void common_prompt_checkpoint::update_drft( +void common_prompt_checkpoint::update_dft( llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { @@ -2016,15 +2016,15 @@ void common_prompt_checkpoint::update_drft( const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags); - data_drft.resize(ckpt_size); + data_dft.resize(ckpt_size); - const size_t n = llama_state_seq_get_data_ext(ctx, data_drft.data(), ckpt_size, seq_id, flags); + const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags); if (n != ckpt_size) { GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); } } -void common_prompt_checkpoint::load_main( +void common_prompt_checkpoint::load_tgt( llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) const { @@ -2032,17 +2032,17 @@ void common_prompt_checkpoint::load_main( return; } - if (data_main.empty()) { + if (data_tgt.empty()) { return; } - const size_t n = llama_state_seq_set_data_ext(ctx, data_main.data(), data_main.size(), seq_id, flags); - if (n != data_main.size()) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_main.size(), n); + const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags); + if (n != data_tgt.size()) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n); } } -void common_prompt_checkpoint::load_drft( +void common_prompt_checkpoint::load_dft( llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) const { @@ -2050,12 +2050,12 @@ void common_prompt_checkpoint::load_drft( return; } - if (data_drft.empty()) { + if (data_dft.empty()) { return; } - const size_t n = llama_state_seq_set_data_ext(ctx, data_drft.data(), data_drft.size(), seq_id, flags); - if (n != data_drft.size()) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_drft.size(), n); + const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags); + if (n != data_dft.size()) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n); } } diff --git a/common/common.h b/common/common.h index 83bb5a37ce..5b266f44f9 100644 --- a/common/common.h +++ b/common/common.h @@ -1034,8 +1034,8 @@ struct common_prompt_checkpoint { llama_pos pos_min; llama_pos pos_max; - std::vector data_main; - std::vector data_drft; + std::vector data_tgt; + std::vector data_dft; size_t size() const; @@ -1047,22 +1047,22 @@ struct common_prompt_checkpoint { llama_pos pos_min, llama_pos pos_max); - void update_main( + void update_tgt( llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags); - void update_drft( + void update_dft( llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags); - void load_main( + void load_tgt( llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) const; - void load_drft( + void load_dft( llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) const; diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 9355ea4e34..6585da7382 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -175,7 +175,7 @@ int main(int argc, char ** argv) { llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id)); if (use_ckpt_dft) { - ckpt.update_drft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); } // generate a new draft @@ -188,12 +188,12 @@ int main(int argc, char ** argv) { // this allows us to restore the state if partial draft acceptance occurs if (!draft.empty()) { if (use_ckpt_tgt) { - ckpt.update_main(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); } } { - ckpt.load_drft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1); } @@ -253,13 +253,13 @@ int main(int argc, char ** argv) { draft = std::move(ids); { - ckpt.load_main(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, ckpt.pos_max + 1, -1); } { - ckpt.load_drft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1); } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 1889477a49..9b2fab832a 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -54,8 +54,8 @@ enum server_state { struct server_slot { int id; - llama_context * ctx_main = nullptr; - llama_context * ctx_drft = nullptr; + llama_context * ctx_tgt = nullptr; + llama_context * ctx_dft = nullptr; // multimodal mtmd_context * mctx = nullptr; @@ -108,27 +108,27 @@ struct server_slot { void prompt_save(server_prompt_cache & prompt_cache) const { GGML_ASSERT(prompt.data.size() == 0); - const size_t cur_size_main = llama_state_seq_get_size_ext(ctx_main, id, LLAMA_STATE_SEQ_FLAGS_NONE); - const size_t cur_size_drft = ctx_drft ? llama_state_seq_get_size_ext(ctx_drft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0; + const size_t cur_size_tgt = llama_state_seq_get_size_ext(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE); + const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0; - const size_t cur_size = cur_size_main + cur_size_drft; + const size_t cur_size = cur_size_tgt + cur_size_dft; SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n", - (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_drft / (1024.0 * 1024.0)); + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0)); - auto * cur = prompt_cache.alloc(prompt, cur_size_main, cur_size_drft); + auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft); if (cur == nullptr) { return; } - llama_state_seq_get_data_ext(ctx_main, cur->data.main.data(), cur_size_main, id, LLAMA_STATE_SEQ_FLAGS_NONE); - if (ctx_drft) { - llama_state_seq_get_data_ext(ctx_drft, cur->data.drft.data(), cur_size_drft, id, LLAMA_STATE_SEQ_FLAGS_NONE); + llama_state_seq_get_data_ext(ctx_tgt, cur->data.main.data(), cur_size_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE); + if (ctx_dft) { + llama_state_seq_get_data_ext(ctx_dft, cur->data.drft.data(), cur_size_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE); } } bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - bool res = prompt_cache.load(prompt, tokens, ctx_main, ctx_drft, id); + bool res = prompt_cache.load(prompt, tokens, ctx_tgt, ctx_dft, id); if (!res) { SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); } @@ -143,9 +143,9 @@ struct server_slot { SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size()); - llama_memory_seq_rm(llama_get_memory(ctx_main), id, -1, -1); - if (ctx_drft) { - llama_memory_seq_rm(llama_get_memory(ctx_drft), id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1); + if (ctx_dft) { + llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1); } prompt.tokens.clear(); @@ -205,7 +205,7 @@ struct server_slot { task_prev = std::move(task); task.reset(); - llama_set_sampler(ctx_main, id, nullptr); + llama_set_sampler(ctx_tgt, id, nullptr); // clear alora start alora_invocation_start = -1; @@ -242,7 +242,7 @@ struct server_slot { return !task->need_embd() || - (llama_get_memory(ctx_main) && llama_pooling_type(ctx_main) == LLAMA_POOLING_TYPE_LAST); + (llama_get_memory(ctx_tgt) && llama_pooling_type(ctx_tgt) == LLAMA_POOLING_TYPE_LAST); } bool can_batch_with(server_slot & other_slot) const { @@ -316,7 +316,7 @@ struct server_slot { return n_draft_max; } - void update_batch(llama_batch & batch, bool use_ckpt_main, bool use_ckpt_drft) { + void update_batch(llama_batch & batch, bool use_ckpt_tgt, bool use_ckpt_dft) { const int n_draft_max = get_n_draft_max(); if (n_draft_max > 0) { GGML_ASSERT(can_speculate()); @@ -330,7 +330,7 @@ struct server_slot { if (!spec_draft.empty()) { // we have a previous (partial) draft to reuse - if (use_ckpt_main) { + if (use_ckpt_tgt) { GGML_ASSERT(!spec_ckpt.empty()); } } else { @@ -338,11 +338,11 @@ struct server_slot { spec_ckpt.update_pos( prompt.n_tokens(), - llama_memory_seq_pos_min(llama_get_memory(ctx_main), id), - llama_memory_seq_pos_max(llama_get_memory(ctx_main), id)); + llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), id), + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), id)); - if (use_ckpt_drft) { - spec_ckpt.update_drft(ctx_drft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + if (use_ckpt_dft) { + spec_ckpt.update_dft(ctx_dft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); } // generate a new draft @@ -355,24 +355,24 @@ struct server_slot { } if (!spec_draft.empty()) { - if (use_ckpt_main) { + if (use_ckpt_tgt) { //const int64_t t_start = ggml_time_us(); - spec_ckpt.update_main(ctx_main, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + spec_ckpt.update_tgt(ctx_tgt, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); //const int64_t t_total = ggml_time_us() - t_start; //printf("checkpoint total: %f ms\n", t_total / 1000.0); SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n", - spec_ckpt.pos_min, spec_ckpt.pos_max, prompt.n_tokens(), (float) spec_ckpt.size() / 1024 / 1024, (float) spec_ckpt.data_drft.size() / 1024 / 1024); + spec_ckpt.pos_min, spec_ckpt.pos_max, prompt.n_tokens(), (float) spec_ckpt.size() / 1024 / 1024, (float) spec_ckpt.data_dft.size() / 1024 / 1024); } } // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] - if (ctx_drft) { - spec_ckpt.load_drft(ctx_drft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + if (ctx_dft) { + spec_ckpt.load_dft(ctx_dft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - llama_memory_seq_rm(llama_get_memory(ctx_drft), this->id, spec_ckpt.pos_max + 1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_dft), this->id, spec_ckpt.pos_max + 1, -1); } } @@ -538,7 +538,7 @@ struct server_slot { }; if (!only_metrics) { - res["prompt"] = ptask->tokens.detokenize(ctx_main, true); + res["prompt"] = ptask->tokens.detokenize(ctx_tgt, true); res["generated"] = generated_text.empty() ? debug_generated_text : generated_text; } } @@ -549,12 +549,12 @@ struct server_slot { void copy_state_to(server_slot & other) const { GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT); - llama_memory_seq_rm(llama_get_memory(ctx_main), other.id, -1, -1); - llama_memory_seq_cp(llama_get_memory(ctx_main), id, other.id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1); - if (ctx_drft) { - llama_memory_seq_rm(llama_get_memory(ctx_drft), other.id, -1, -1); - llama_memory_seq_cp(llama_get_memory(ctx_drft), id, other.id, -1, -1); + if (ctx_dft) { + llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1); } other.n_decoded = n_decoded; @@ -646,7 +646,7 @@ public: // only use these pointers outside of this class: // - when not in sleeping state // - and, with thread-safe APIs (e.g., tokenizer calls) - llama_model * model_main = nullptr; + llama_model * model_tgt = nullptr; mtmd_context * mctx = nullptr; const llama_vocab * vocab = nullptr; @@ -674,15 +674,15 @@ private: // note: keep these alive - they determine the lifetime of the model, context, etc. common_init_result_ptr llama_init; - llama_context * ctx_main = nullptr; + llama_context * ctx_tgt = nullptr; llama_batch batch {}; - llama_model_ptr model_drft; - llama_context_ptr ctx_drft; + llama_model_ptr model_dft; + llama_context_ptr ctx_dft; - common_context_seq_rm_type ctx_main_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; - common_context_seq_rm_type ctx_drft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; bool add_bos_token = true; @@ -717,8 +717,8 @@ private: void destroy() { llama_init.reset(); - ctx_main = nullptr; - model_main = nullptr; + ctx_tgt = nullptr; + model_tgt = nullptr; mtmd_free(mctx); mctx = nullptr; @@ -768,17 +768,17 @@ private: llama_init = common_init_from_params(params_base); - model_main = llama_init->model(); - ctx_main = llama_init->context(); + model_tgt = llama_init->model(); + ctx_tgt = llama_init->context(); - if (model_main == nullptr) { + if (model_tgt == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - vocab = llama_model_get_vocab(model_main); + vocab = llama_model_get_vocab(model_tgt); - n_ctx = llama_n_ctx(ctx_main); + n_ctx = llama_n_ctx(ctx_tgt); add_bos_token = llama_vocab_get_add_bos(vocab); @@ -805,19 +805,19 @@ private: auto mparams_dft = common_model_params_to_llama(params_dft); - model_drft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); - if (model_drft == nullptr) { + model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); + if (model_dft == nullptr) { SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str()); return false; } auto cparams = common_context_params_to_llama(params_dft); - ctx_drft.reset(llama_init_from_model(model_drft.get(), cparams)); + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); - ctx_drft_seq_rm_type = common_context_can_seq_rm(ctx_drft.get()); + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); - params_base.speculative.draft.ctx_tgt = ctx_main; - params_base.speculative.draft.ctx_dft = ctx_drft.get(); + params_base.speculative.draft.ctx_tgt = ctx_tgt; + params_base.speculative.draft.ctx_dft = ctx_dft.get(); } std::string & mmproj_path = params_base.mmproj.path; @@ -837,7 +837,7 @@ private: mparams.image_max_tokens = params_base.image_max_tokens; mparams.media_marker = get_media_marker(); - mctx = mtmd_init_from_file(mmproj_path.c_str(), model_main, mparams); + mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams); if (mctx == nullptr) { SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); return false; @@ -855,7 +855,7 @@ private: } } - if (!llama_memory_can_shift(llama_get_memory(ctx_main))) { + if (!llama_memory_can_shift(llama_get_memory(ctx_tgt))) { if (params_base.ctx_shift) { params_base.ctx_shift = false; SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); @@ -867,14 +867,14 @@ private: } } - if (llama_model_n_swa(model_main) == 0) { + if (llama_model_n_swa(model_tgt) == 0) { if (params_base.swa_full) { params_base.swa_full = false; SRV_WRN("%s\n", "swa_full is not supported by this model, it will be disabled"); } } - n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_main); + n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_tgt); // Necessary similarity of prompt for slot selection slot_prompt_similarity = params_base.slot_prompt_similarity; @@ -882,9 +882,9 @@ private: // setup slots SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - const int n_ctx_train = llama_model_n_ctx_train(model_main); + const int n_ctx_train = llama_model_n_ctx_train(model_tgt); - int n_ctx_slot = llama_n_ctx_seq(ctx_main); + int n_ctx_slot = llama_n_ctx_seq(ctx_tgt); if (n_ctx_slot > n_ctx_train) { SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); n_ctx_slot = n_ctx_train; @@ -892,12 +892,12 @@ private: slots.clear(); - ctx_main_seq_rm_type = common_context_can_seq_rm(ctx_main); - if (ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + ctx_tgt_seq_rm_type = common_context_can_seq_rm(ctx_tgt); + if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } - if (ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { SRV_WRN("%s", "speculative decoding will use checkpoints\n"); } @@ -906,27 +906,27 @@ private: slots.emplace_back(); } - bool no_drft = false; + bool no_dft = false; for (int i = 0; i < params_base.n_parallel; i++) { server_slot & slot = slots[i]; slot.id = i; - slot.ctx_main = ctx_main; - slot.ctx_drft = ctx_drft.get(); + slot.ctx_tgt = ctx_tgt; + slot.ctx_dft = ctx_dft.get(); slot.n_ctx = n_ctx_slot; slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; // try speculative decoding - if (ctx_main_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { try { slot.spec.reset(common_speculative_init(params_base.speculative, slot.id)); } catch (const std::exception & e) { SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); - no_drft = true; + no_dft = true; } if (slot.spec) { @@ -943,13 +943,13 @@ private: slot.reset(); } - if (no_drft && ctx_drft) { + if (no_dft && ctx_dft) { SRV_WRN("%s", "destroying the draft model as it is not going to be used\n"); - ctx_drft.reset(); + ctx_dft.reset(); for (auto & slot : slots) { - slot.ctx_drft = nullptr; + slot.ctx_dft = nullptr; } } @@ -974,7 +974,7 @@ private: // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { - const int32_t n_batch = llama_n_batch(ctx_main); + const int32_t n_batch = llama_n_batch(ctx_tgt); batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } @@ -1018,8 +1018,8 @@ private: // unlike load_model(), this is only called once during initialization bool init() { - GGML_ASSERT(ctx_main != nullptr); - GGML_ASSERT(model_main != nullptr); + GGML_ASSERT(ctx_tgt != nullptr); + GGML_ASSERT(model_tgt != nullptr); GGML_ASSERT(!sleeping); @@ -1066,7 +1066,7 @@ private: common_chat_templates_ptr chat_templates; try { - chat_templates = common_chat_templates_init(model_main, params_base.chat_template); + chat_templates = common_chat_templates_init(model_tgt, params_base.chat_template); LOG_INF("%s: chat template, example_format: '%s'\n", __func__, common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); @@ -1329,7 +1329,7 @@ private: } } - if (!task.tokens.validate(ctx_main)) { + if (!task.tokens.validate(ctx_tgt)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } @@ -1339,7 +1339,7 @@ private: // initialize samplers if (task.need_sampling()) { try { - slot.smpl.reset(common_sampler_init(model_main, task.params.sampling)); + slot.smpl.reset(common_sampler_init(model_tgt, task.params.sampling)); } catch (std::exception & e) { std::string err_msg = std::string("Failed to initialize samplers: ") + e.what(); send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST); @@ -1360,9 +1360,9 @@ private: // TODO: tmp until backend sampling is fully implemented if (backend_sampling) { - llama_set_sampler(ctx_main, slot.id, common_sampler_get(slot.smpl.get())); + llama_set_sampler(ctx_tgt, slot.id, common_sampler_get(slot.smpl.get())); } else { - llama_set_sampler(ctx_main, slot.id, nullptr); + llama_set_sampler(ctx_tgt, slot.id, nullptr); } SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); @@ -1535,13 +1535,13 @@ private: for (size_t i = 0; i < n_probs; i++) { result.probs.push_back({ cur_p->data[i].id, - common_token_to_piece(ctx_main, cur_p->data[i].id, special), + common_token_to_piece(ctx_tgt, cur_p->data[i].id, special), cur_p->data[i].p }); } } else { // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx_main, idx); + std::vector cur = get_token_probabilities(ctx_tgt, idx); const size_t max_probs = cur.size(); const size_t n_probs = std::min(max_probs, n_probs_request); @@ -1559,7 +1559,7 @@ private: for (size_t i = 0; i < n_probs; i++) { result.probs.push_back({ cur[i].id, - common_token_to_piece(ctx_main, cur[i].id, special), + common_token_to_piece(ctx_tgt, cur[i].id, special), cur[i].p }); } @@ -1662,7 +1662,7 @@ private: res->tokens = std::move(slot.generated_tokens); } res->timings = slot.get_timings(); - res->prompt = slot.task->tokens.detokenize(ctx_main, true); + res->prompt = slot.task->tokens.detokenize(ctx_tgt, true); res->response_fields = std::move(slot.task->params.response_fields); res->truncated = slot.truncated; @@ -1685,7 +1685,7 @@ private: // populate res.probs_output if (slot.task->params.sampling.n_probs > 0) { if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { - const llama_tokens stop_word_toks = common_tokenize(ctx_main, slot.stopping_word, false); + const llama_tokens stop_word_toks = common_tokenize(ctx_tgt, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); res->probs_output = std::vector( @@ -1710,7 +1710,7 @@ private: res->n_tokens = slot.task->n_tokens(); res->res_type = slot.task->params.res_type; - const int n_embd_out = llama_model_n_embd_out(model_main); + const int n_embd_out = llama_model_n_embd_out(model_tgt); std::vector embd_res(n_embd_out, 0.0f); @@ -1720,10 +1720,10 @@ private: } const float * embd = nullptr; - if (llama_pooling_type(slot.ctx_main) == LLAMA_POOLING_TYPE_NONE) { - embd = llama_get_embeddings_ith(slot.ctx_main, i); + if (llama_pooling_type(slot.ctx_tgt) == LLAMA_POOLING_TYPE_NONE) { + embd = llama_get_embeddings_ith(slot.ctx_tgt, i); } else { - embd = llama_get_embeddings_seq(slot.ctx_main, batch.seq_id[i][0]); + embd = llama_get_embeddings_seq(slot.ctx_tgt, batch.seq_id[i][0]); } if (embd == nullptr) { @@ -1734,7 +1734,7 @@ private: } // normalize only when there is pooling - if (llama_pooling_type(slot.ctx_main) != LLAMA_POOLING_TYPE_NONE) { + if (llama_pooling_type(slot.ctx_tgt) != LLAMA_POOLING_TYPE_NONE) { common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize); res->embedding.push_back(embd_res); break; @@ -1759,9 +1759,9 @@ private: continue; } - const float * embd = llama_get_embeddings_seq(ctx_main, batch.seq_id[i][0]); + const float * embd = llama_get_embeddings_seq(ctx_tgt, batch.seq_id[i][0]); if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx_main, i); + embd = llama_get_embeddings_ith(ctx_tgt, i); } if (embd == NULL) { @@ -1875,8 +1875,8 @@ private: cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max); - cur.update_main(ctx_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - cur.update_drft(ctx_drft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", @@ -2036,7 +2036,7 @@ private: std::string filepath = task.slot_action.filepath; const llama_tokens & tokens = slot->prompt.tokens.get_tokens(); - const size_t nwrite = llama_state_seq_save_file(ctx_main, filepath.c_str(), slot->id, tokens.data(), token_count); + const size_t nwrite = llama_state_seq_save_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -2075,7 +2075,7 @@ private: llama_tokens tokens; tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx_main, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { slot->prompt.tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); @@ -2234,12 +2234,12 @@ private: SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); - if (ctx_drft) { - llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard); + if (ctx_dft) { + llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard); } // add generated tokens to cache @@ -2287,13 +2287,13 @@ private: } slot.update_batch(batch, - ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL, - ctx_drft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL, + ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); } // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx_main); - int32_t n_ubatch = llama_n_ubatch(ctx_main); + int32_t n_batch = llama_n_batch(ctx_tgt); + int32_t n_ubatch = llama_n_ubatch(ctx_tgt); float alora_scale = -1.0f; size_t alora_disabled_id = 0; @@ -2337,12 +2337,12 @@ private: /*if (1) { // first 16 tokens (avoid flooding logs) for (int i = 0; i < std::min(16, input_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_main, input_tokens[i]).c_str()); + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str()); } } else { // all for (int i = 0; i < (int) input_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_main, input_tokens[i]).c_str()); + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str()); } }*/ @@ -2361,7 +2361,7 @@ private: } // TODO: support memory-less logits computation - if (slot.task->need_logits() && !llama_get_memory(ctx_main)) { + if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) { send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); slot.release(); continue; @@ -2413,7 +2413,7 @@ private: const auto n_cache_reuse = slot.task->params.n_cache_reuse; const bool can_cache_reuse = - llama_memory_can_shift(llama_get_memory(ctx_main)) && + llama_memory_can_shift(llama_get_memory(ctx_tgt)) && !slot.prompt.tokens.has_mtmd; if (!can_cache_reuse && n_cache_reuse > 0) { @@ -2447,17 +2447,17 @@ private: if (n_match >= (size_t) n_cache_reuse) { SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); //for (size_t i = head_p; i < head_p + n_match; i++) { - // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_main, prompt_tokens[i]).c_str()); + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_tgt, prompt_tokens[i]).c_str()); //} const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, head_c, head_c + n_match, kv_shift); + llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift); - if (ctx_drft) { - llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, head_c, head_c + n_match, kv_shift); + if (ctx_dft) { + llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift); } for (size_t i = 0; i < n_match; i++) { @@ -2485,7 +2485,7 @@ private: const auto pos_min_thold = std::max(0, pos_next - n_swa); if (n_past > 0 && n_past < slot.prompt.n_tokens()) { - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_main), slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); if (pos_min == -1) { SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); @@ -2514,14 +2514,14 @@ private: { const auto token = slot.prompt.tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_main, token) : "[mtmd]"; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]"; ss0 << piece; st0 << std::setw(8) << token; } { const auto token = slot.task->tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_main, token) : "[mtmd]"; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]"; ss1 << piece; st1 << std::setw(8) << token; } @@ -2554,8 +2554,8 @@ private: if (!do_reset) { // restore the context checkpoint - it->load_main(ctx_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - it->load_drft(ctx_drft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); @@ -2616,7 +2616,7 @@ private: SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - if (!llama_memory_seq_rm(llama_get_memory(ctx_main), slot.id, p0, -1)) { + if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) { SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); slot.prompt_clear(true); @@ -2624,7 +2624,7 @@ private: // there is no common part left slot.n_prompt_tokens_cache = 0; } else { - if (ctx_drft && !llama_memory_seq_rm(llama_get_memory(ctx_drft.get()), slot.id, p0, -1)) { + if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) { GGML_ABORT("failed to truncate draft context\n"); } } @@ -2653,7 +2653,7 @@ private: // - the model does not support partial sequence removal // - the model uses SWA (and we are not using `swa_full`) do_checkpoint = do_checkpoint && ( - (ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || (n_swa > 0)); bool has_mtmd = false; @@ -2662,7 +2662,7 @@ private: while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { // process the image size_t n_tokens_out = 0; - int32_t res = input_tokens.process_chunk(ctx_main, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); @@ -2670,10 +2670,10 @@ private: continue; } - if (ctx_drft) { + if (ctx_dft) { // TODO: in the future, figure out how to infuse target embeddings to the images // for now, we skip this for simplicity - res = input_tokens.process_chunk(ctx_drft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); if (res != 0) { GGML_ABORT("failed to process multi-modal data on draft context\n"); } @@ -2780,8 +2780,8 @@ private: SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); } - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_main), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_main), slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); // no need for empty or small checkpoints do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64); @@ -2814,7 +2814,7 @@ private: if (slot_batched) { // apply lora, only need to do it once per batch - common_set_adapter_lora(ctx_main, slot_batched->lora); + common_set_adapter_lora(ctx_tgt, slot_batched->lora); // if the lora is temporarily disabled for an alora, re-enable it // for next time @@ -2823,7 +2823,7 @@ private: slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx_main, slot_batched->task->need_embd()); + llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd()); } if (batch.n_tokens == 0) { @@ -2852,7 +2852,7 @@ private: batch.logits + i, }; - const int ret = llama_decode(ctx_main, batch_view); + const int ret = llama_decode(ctx_tgt, batch_view); metrics.on_decoded(slots); @@ -2917,12 +2917,12 @@ private: // | Eagle3 | yes | // | DFlash | yes? | // - if (ctx_drft) { + if (ctx_dft) { // TODO: update as needed for MTP, Eagle3, etc. const bool need_tgt_embd = false; if (need_tgt_embd) { - llama_synchronize(ctx_main); + llama_synchronize(ctx_tgt); } // the logic here varies depending on the speculative decoding method @@ -2931,13 +2931,13 @@ private: // TODO: extract this in a function ? { // TODO: hook the embeddings from the last target batch here - if (llama_model_has_encoder(model_drft.get())) { - //llama_encode(ctx_drft, ...); + if (llama_model_has_encoder(model_dft.get())) { + //llama_encode(ctx_dft, ...); GGML_ABORT("not implemented yet\n"); } - const int ret = llama_decode(ctx_drft.get(), batch_view); + const int ret = llama_decode(ctx_dft.get(), batch_view); if (ret != 0) { SRV_ERR("failed to decode draft batch, ret = %d\n", ret); @@ -2952,7 +2952,7 @@ private: i_next = i + n_tokens; // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx_main); + n_batch = llama_n_batch(ctx_tgt); // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too for (auto & slot : slots) { @@ -3023,7 +3023,7 @@ private: const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_main, tok_idx); + llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); slot.i_batch = -1; @@ -3044,7 +3044,7 @@ private: completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(slot.ctx_main, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.task->params.sampling.n_probs > 0) { @@ -3075,23 +3075,23 @@ private: // verify and try to accept the draft { - const bool use_ckpt_main = ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; // only save the sampler sampler state if we use checkpoints common_sampler_ptr smpl_save; - if (use_ckpt_main) { + if (use_ckpt_tgt) { smpl_save.reset(common_sampler_clone(slot.smpl.get())); } GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); - auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_main, slot.spec_i_batch, slot.spec_draft); + auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); slot.spec_i_batch.clear(); GGML_ASSERT(accepted.size() >= 1); // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (use_ckpt_main) { + if (use_ckpt_tgt) { if (trace > 0) { SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); } @@ -3104,15 +3104,15 @@ private: SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); { - ckpt.load_main(slot.ctx_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, ckpt.pos_max + 1, -1); + llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1); } - if (slot.ctx_drft) { - ckpt.load_drft(slot.ctx_drft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + if (slot.ctx_dft) { + ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, ckpt.pos_max + 1, -1); + llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1); } slot.prompt.tokens.keep_first(ckpt.n_tokens); @@ -3148,16 +3148,16 @@ private: slot.sampled = ids.back(); // last accepted token SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, slot.prompt.tokens.pos_next(), -1); - if (slot.ctx_drft) { - llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, slot.prompt.tokens.pos_next(), -1); + llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1); + if (slot.ctx_dft) { + llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1); } for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; result.tok = ids[i]; - result.text_to_send = common_token_to_piece(slot.ctx_main, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // set later // TODO: set result.probs @@ -3209,7 +3209,7 @@ void server_context::terminate() { } llama_context * server_context::get_llama_context() const { - return impl->ctx_main; + return impl->ctx_tgt; } server_response_reader server_context::get_response_reader() { @@ -3219,8 +3219,8 @@ server_response_reader server_context::get_response_reader() { server_context_meta server_context::get_meta() const { auto bos_id = llama_vocab_bos(impl->vocab); auto eos_id = llama_vocab_eos(impl->vocab); - auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_main, bos_id, true) : ""; - auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_main, eos_id, true) : ""; + auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : ""; + auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : ""; return server_context_meta { /* build_info */ std::string(llama_build_info()), @@ -3233,7 +3233,7 @@ server_context_meta server_context::get_meta() const { /* has_inp_audio */ impl->chat_params.allow_audio, /* json_webui_settings */ impl->json_webui_settings, /* slot_n_ctx */ impl->get_slot_n_ctx(), - /* pooling_type */ llama_pooling_type(impl->ctx_main), + /* pooling_type */ llama_pooling_type(impl->ctx_tgt), /* chat_params */ impl->chat_params, /* chat_template_caps */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()), @@ -3251,10 +3251,10 @@ server_context_meta server_context::get_meta() const { /* model_vocab_type */ llama_vocab_type(impl->vocab), /* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab), - /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_main), - /* model_n_embd_inp */ llama_model_n_embd(impl->model_main), - /* model_n_params */ llama_model_n_params(impl->model_main), - /* model_size */ llama_model_size(impl->model_main), + /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_tgt), + /* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt), + /* model_n_params */ llama_model_n_params(impl->model_tgt), + /* model_size */ llama_model_size(impl->model_tgt), }; } @@ -4156,7 +4156,7 @@ void server_routes::init_routes() { std::vector tasks; tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_prompt_rerank(ctx_server.model_main, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + auto tmp = format_prompt_rerank(ctx_server.model_tgt, ctx_server.vocab, ctx_server.mctx, query, documents[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = rd.get_new_id(); task.tokens = std::move(tmp); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 809ae345e9..2865e23f31 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -1981,7 +1981,7 @@ size_t server_prompt_cache::n_tokens() const { return res; } -server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft) { +server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) { // first check if the current state is contained fully in the cache for (auto it = states.begin(); it != states.end(); ++it) { const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); @@ -2005,13 +2005,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t } } - std::vector state_data_main; - std::vector state_data_drft; + std::vector state_data_tgt; + std::vector state_data_dft; // check if we can allocate enough memory for the new state try { - state_data_main.resize(state_size_main); - state_data_drft.resize(state_size_drft); + state_data_tgt.resize(state_size_tgt); + state_data_dft.resize(state_size_dft); } catch (const std::bad_alloc & e) { SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); @@ -2027,8 +2027,8 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t states.push_back({ /*.tokens =*/ prompt.tokens.clone(), /*.data =*/ { - /*.main =*/ std::move(state_data_main), - /*.drft =*/ std::move(state_data_drft), + /*.main =*/ std::move(state_data_tgt), + /*.drft =*/ std::move(state_data_dft), }, /*.checkpoints =*/ prompt.checkpoints, }); @@ -2036,7 +2036,7 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t return &states.back(); } -bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot) { +bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) { const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins @@ -2073,7 +2073,7 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok auto & data = it_best->data.main; const size_t size = data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx_main, data.data(), size, id_slot, 0); + const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0); if (n != size) { SRV_WRN("failed to restore state with size %zu\n", size); @@ -2088,10 +2088,10 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok auto & data = it_best->data.drft; if (!data.empty()) { - GGML_ASSERT(ctx_drft); + GGML_ASSERT(ctx_dft); const size_t size = data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx_drft, data.data(), size, id_slot, 0); + const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0); if (n != size) { SRV_WRN("failed to restore state with size %zu\n", size);