cont : dedup ctx_seq_rm_type

[no ci]
This commit is contained in:
Georgi Gerganov
2026-05-07 10:22:20 +03:00
parent 0791b0d95b
commit ae1f10b110

View File

@@ -82,8 +82,6 @@ struct server_slot {
llama_context * ctx = nullptr;
common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
// multimodal
mtmd_context * mctx = nullptr;
@@ -333,7 +331,7 @@ struct server_slot {
return n_draft_max;
}
void update_batch(llama_batch & batch) {
void update_batch(llama_batch & batch, bool use_ckpt) {
const int n_draft_max = get_n_draft_max();
if (n_draft_max > 0) {
GGML_ASSERT(can_speculate());
@@ -347,7 +345,7 @@ struct server_slot {
if (!spec_draft.empty()) {
// we have a previous (partial) draft to reuse
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
if (use_ckpt) {
GGML_ASSERT(!spec_ckpt.empty());
}
} else {
@@ -362,7 +360,7 @@ struct server_slot {
spec_draft.resize(n_draft_max);
}
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
if (!spec_draft.empty() && use_ckpt) {
const auto n_tokens = prompt.tokens.size();
//const int64_t t_start = ggml_time_us();
@@ -676,6 +674,9 @@ private:
llama_model_ptr model_dft;
llama_context_ptr ctx_dft;
common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
bool add_bos_token = true;
int32_t n_ctx; // total context for all clients / slots
@@ -806,9 +807,11 @@ private:
auto cparams = common_context_params_to_llama(params_dft);
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
params_base.speculative.draft.use_ckpt = common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
params_base.speculative.draft.use_ckpt = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
}
std::string & mmproj_path = params_base.mmproj.path;
@@ -883,7 +886,7 @@ private:
slots.clear();
const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx);
ctx_seq_rm_type = common_context_can_seq_rm(ctx);
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
@@ -904,8 +907,6 @@ private:
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.ctx_seq_rm_type = ctx_seq_rm_type;
slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr;
@@ -2254,7 +2255,7 @@ private:
continue;
}
slot.update_batch(batch);
slot.update_batch(batch, ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
}
// process in chunks of params.n_batch
@@ -2615,7 +2616,7 @@ private:
// - the model does not support partial sequence removal
// - the model uses SWA (and we are not using `swa_full`)
do_checkpoint = do_checkpoint && (
(slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
(ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
(n_swa > 0));
bool has_mtmd = false;
@@ -2985,7 +2986,7 @@ private:
// verify and try to accept the draft
{
const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
const bool use_ckpt = ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
// only save the sampler sampler state if we use checkpoints
common_sampler_ptr smpl_save;