diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 79f0e1df75..dfe42ca64c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -82,8 +82,6 @@ struct server_slot { llama_context * ctx = nullptr; - common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; - // multimodal mtmd_context * mctx = nullptr; @@ -333,7 +331,7 @@ struct server_slot { return n_draft_max; } - void update_batch(llama_batch & batch) { + void update_batch(llama_batch & batch, bool use_ckpt) { const int n_draft_max = get_n_draft_max(); if (n_draft_max > 0) { GGML_ASSERT(can_speculate()); @@ -347,7 +345,7 @@ struct server_slot { if (!spec_draft.empty()) { // we have a previous (partial) draft to reuse - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + if (use_ckpt) { GGML_ASSERT(!spec_ckpt.empty()); } } else { @@ -362,7 +360,7 @@ struct server_slot { spec_draft.resize(n_draft_max); } - if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + if (!spec_draft.empty() && use_ckpt) { const auto n_tokens = prompt.tokens.size(); //const int64_t t_start = ggml_time_us(); @@ -676,6 +674,9 @@ private: llama_model_ptr model_dft; llama_context_ptr ctx_dft; + common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + bool add_bos_token = true; int32_t n_ctx; // total context for all clients / slots @@ -806,9 +807,11 @@ private: auto cparams = common_context_params_to_llama(params_dft); ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + params_base.speculative.draft.ctx_tgt = ctx; params_base.speculative.draft.ctx_dft = ctx_dft.get(); - params_base.speculative.draft.use_ckpt = common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + params_base.speculative.draft.use_ckpt = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; } std::string & mmproj_path = params_base.mmproj.path; @@ -883,7 +886,7 @@ private: slots.clear(); - const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx); + ctx_seq_rm_type = common_context_can_seq_rm(ctx); if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } @@ -904,8 +907,6 @@ private: slot.ctx = ctx; slot.n_ctx = n_ctx_slot; - slot.ctx_seq_rm_type = ctx_seq_rm_type; - slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; @@ -2254,7 +2255,7 @@ private: continue; } - slot.update_batch(batch); + slot.update_batch(batch, ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); } // process in chunks of params.n_batch @@ -2615,7 +2616,7 @@ private: // - the model does not support partial sequence removal // - the model uses SWA (and we are not using `swa_full`) do_checkpoint = do_checkpoint && ( - (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || + (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || (n_swa > 0)); bool has_mtmd = false; @@ -2985,7 +2986,7 @@ private: // verify and try to accept the draft { - const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt = ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; // only save the sampler sampler state if we use checkpoints common_sampler_ptr smpl_save;