From c8f8e2364c90d5e5fd97cc8e20bb4802ba59834a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 11 May 2026 09:41:00 +0300 Subject: [PATCH] cont : simplify --- common/speculative.cpp | 366 ++++++++++++++------------------ tools/server/server-context.cpp | 5 - 2 files changed, 163 insertions(+), 208 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index ef13edd34e..3b36e04bda 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -391,22 +391,14 @@ struct common_speculative_state_mtp : public common_speculative_impl { // The last h-row of one process() call needs the first token of the NEXT // call to pair with, so it's stashed here until that next call fires. std::vector> pending_h; // [n_seq][n_embd] - std::vector pending_pos; // [n_seq] - std::vector last_n_drafted; - std::vector last_n_accepted; - - // Number of trunk output rows produced by the most recent process() call. - // Used by draft() for the first AR step (when last_n_accepted is -1) to - // pick the last prefill row out of ctx_tgt's pre-norm buffer. - std::vector last_trunk_n_outputs; + std::vector i_batch_beg; + std::vector i_batch_end; common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq) , params(params.draft) { - GGML_ASSERT(n_seq == 1 && "MTP currently supports only single-sequence speculation"); - auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); @@ -423,7 +415,7 @@ struct common_speculative_state_mtp : public common_speculative_impl { for (auto & s : smpls) { common_params_sampling sparams; sparams.no_perf = false; - sparams.top_k = 1; + sparams.top_k = 10; sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } @@ -432,11 +424,9 @@ struct common_speculative_state_mtp : public common_speculative_impl { llama_set_embeddings_pre_norm(ctx_dft, true); pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); - pending_pos.assign(n_seq, -1); - last_n_drafted.assign(n_seq, 0); - last_n_accepted.assign(n_seq, -1); - last_trunk_n_outputs.assign(n_seq, 0); + i_batch_beg.assign(n_seq, -1); + i_batch_end.assign(n_seq, -1); } ~common_speculative_state_mtp() override { @@ -448,12 +438,6 @@ struct common_speculative_state_mtp : public common_speculative_impl { } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { - GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < pending_pos.size()); - - last_n_accepted[seq_id] = -1; - last_n_drafted [seq_id] = 0; - pending_pos [seq_id] = -1; - const int32_t N = (int32_t) prompt.size(); if (N <= 0) { return; @@ -474,231 +458,207 @@ struct common_speculative_state_mtp : public common_speculative_impl { return true; } - // Single-seq for now (asserted in ctor). Future: bucket by seq_id. - const llama_seq_id seq_id = 0; - // TODO: how to make it work with vision tokens? if (batch_in.token == nullptr || batch_in.embd != nullptr) { - pending_pos[seq_id] = -1; return true; } + const int32_t n_tokens = batch_in.n_tokens; + + // remember the frist and last batch index for each sequence + std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1); + std::fill(i_batch_end.begin(), i_batch_end.end(), -1); + + for (int k = 0; k < n_tokens; ++k) { + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + GGML_ASSERT(batch_in.n_seq_id[k] == 1); + + if (batch_in.seq_id[k][0] == seq_id) { + i_batch_end[seq_id] = k; + if (i_batch_beg[seq_id] < 0) { + i_batch_beg[seq_id] = k; + } + } + } + } + auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; - const int32_t n_rows = batch_in.n_tokens; - const llama_pos pos_start = batch_in.pos[0]; - - const llama_pos pos_max_dft = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - if (pos_start <= pos_max_dft) { - return true; - } - - // Stale pending: discard if the new batch doesn't start one past it. - const bool pending_continues = - pending_pos[seq_id] >= 0 && pending_pos[seq_id] + 1 == pos_start; - if (pending_pos[seq_id] >= 0 && !pending_continues) { - pending_pos[seq_id] = -1; - } - - // Build a paired hook batch: - // row 0 = (pending_h, batch_in.token[0]) at pos_start if pending_continues - // rows 1..n_rows-1 = (h_k from this batch, batch_in.token[k+1]) at pos[k+1] - // The last h-row (h_{n_rows-1}) is stashed as the new pending and is *not* - // decoded this call — it waits for the next batch's first token to pair. const size_t row_bytes = (size_t) n_embd * sizeof(float); common_batch_clear(batch); - int out_idx = 0; - auto add_pair = [&](const float * h_row, llama_token tok, llama_pos pos) { - std::memcpy(batch.embd + (size_t) out_idx * n_embd, h_row, row_bytes); - batch.token [out_idx] = tok; - batch.pos [out_idx] = pos; - batch.n_seq_id[out_idx] = 1; - batch.seq_id [out_idx][0] = seq_id; - batch.logits [out_idx] = 0; - ++out_idx; + for (int k = 0; k < n_tokens; ++k) { + common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0); + } + + // shift the tgt embeddings to the right by one position + // assumes that the tokens in the batch are sequential for each sequence + // i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1] + // ^--- this is a problem + // TODO:this is generally true, but would be nice to assert it + { + const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt); + std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1)); + + //{ + // // string with seq_ids in the batch + // std::stringstream ss; + // for (int i = 0; i < n_tokens; ++i) { + // ss << batch_in.seq_id[i][0] << ","; + // } + // LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str()); + //} + } + + // fill the pending embeddings from a previous run + auto set_h = [&](int idx, const float * h_row) { + std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes); }; - if (pending_continues) { - add_pair(pending_h[seq_id].data(), batch_in.token[0], pos_start); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_beg[seq_id] < 0) { + continue; + } + + set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); } - // TODO: is there is a fast way to build this batch? - for (int k = 0; k + 1 < n_rows; ++k) { - if (batch_in.logits[k] == 0) { - LOG_WRN("%s: batch_in.logits[%d] == 0 (need_embd / logits=1 missing on prefill); stopping hook at this row\n", - __func__, k); - break; - } - const float * h_k = llama_get_embeddings_pre_norm_ith(ctx_tgt, k); - if (h_k == nullptr) { - LOG_WRN("%s: ctx_tgt has no pre-norm row at i=%d; stopping hook\n", __func__, k); - break; - } - add_pair(h_k, batch_in.token[k + 1], batch_in.pos[k + 1]); + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); + return false; } - if (out_idx > 0) { - batch.n_tokens = out_idx; - const int32_t rc = llama_decode(ctx_dft, batch); - if (rc != 0) { - LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d, n=%d)\n", - __func__, (int) rc, (int) pos_start, out_idx); - return false; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_end[seq_id] < 0) { + continue; } - } - // last_n_accepted < 0) can find the last pre-norm row of this batch. - // We assume every batch position has logits=1 (server sets need_embd - // for MTP slots) → n_outputs == n_tokens. - last_trunk_n_outputs[seq_id] = n_rows; - - // Stash the last h-row (h_{n_rows-1}) as the new pending for the next - // process() call's first token to pair with. - if (batch_in.logits[n_rows - 1] != 0) { - const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, n_rows - 1); - if (h_last != nullptr) { - std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); - pending_pos[seq_id] = batch_in.pos[n_rows - 1]; - } else { - pending_pos[seq_id] = -1; - } - } else { - // No trunk output at the tail — can't carry over. - pending_pos[seq_id] = -1; + const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_end[seq_id]); + std::memcpy(pending_h[seq_id].data(), h_last, row_bytes); } return true; } void draft(common_speculative_draft_params_vec & dparams) override { - // Single-seq for now (asserted in ctor). Future: iterate over dparams. - const llama_seq_id seq_id = 0; - if ((size_t) seq_id >= dparams.size()) { - return; - } - auto & dp = dparams[seq_id]; - if (!dp.drafting) { - return; - } + auto & ctx_dft = params.ctx_dft; - auto * ctx_tgt = this->params.ctx_tgt; - auto * ctx_dft = this->params.ctx_dft; - auto * smpl = smpls[seq_id].get(); + common_batch_clear(batch); - GGML_ASSERT(dp.result != nullptr); - auto & draft_tokens = *dp.result; - draft_tokens.clear(); - - if (last_n_drafted[seq_id] > 0) { - const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1; - if (n_to_drop > 0) { - const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - if (pos_max >= 0) { - const llama_pos drop_from = pos_max - n_to_drop + 1; - llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); - } - } - last_n_drafted[seq_id] = 0; - last_n_accepted[seq_id] = 0; - } - - // Effective draft length: min(global cap, per-sequence override). - int32_t n_max = std::max(1, params.n_max); - if (dp.n_max > 0) { - n_max = std::min(n_max, dp.n_max); - } + // keep track of which sequences are still drafting + int n_drafting = 0; + std::vector drafting(n_seq); + const float * h_row = nullptr; const size_t row_bytes = (size_t) n_embd * sizeof(float); - common_sampler_reset(smpl); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; - llama_token cond_tok = dp.id_last; - llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id) + 1; - - for (int32_t k = 0; k < n_max; ++k) { - const float * h_row = nullptr; - - if (k == 0) { - // Condition on the trunk's pre-norm row. - int32_t row_idx; - if (last_n_accepted[seq_id] < 0) { - // First draft after begin(): use the last prefill row. - row_idx = last_trunk_n_outputs[seq_id] - 1; - } else { - // After accept(n_accepted): row of the next conditioning - // position in the trunk's verify batch. - row_idx = last_n_accepted[seq_id]; - } - if (row_idx < 0) { - LOG_WRN("%s: no trunk pre-norm row available (row_idx=%d); stopping chain\n", - __func__, row_idx); - break; - } - h_row = llama_get_embeddings_pre_norm_ith(ctx_tgt, row_idx); - } else { - // AR step: condition on the MTP head's own pre-norm row from - // the just-completed single-token decode. n_outputs=1 there, - // so the row is at batch position 0. - h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, 0); + if (!dp.drafting) { + continue; } - if (h_row == nullptr) { - LOG_WRN("%s: missing pre-norm row at k=%d; stopping chain\n", __func__, k); - break; - } + n_drafting++; + drafting[seq_id] = true; + common_sampler_reset(smpls[seq_id].get()); - // 1-token batch carrying both (token, h_pre_norm). - common_batch_clear(batch); - std::memcpy(batch.embd, h_row, row_bytes); - batch.token [0] = cond_tok; - batch.pos [0] = pos; - batch.n_seq_id[0] = 1; - batch.seq_id [0][0] = seq_id; - batch.logits [0] = 1; // need logits for sampling - batch.n_tokens = 1; + common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); - const int32_t rc = llama_decode(ctx_dft, batch); - if (rc != 0) { - LOG_WRN("%s: llama_decode(ctx_dft) failed rc=%d at k=%d; stopping chain\n", - __func__, rc, k); - break; - } - - const llama_token best = common_sampler_sample(smpl, ctx_dft, 0); - common_sampler_accept(smpl, best, /*is_generated=*/ false); - draft_tokens.push_back(best); - cond_tok = best; - ++pos; + h_row = pending_h[seq_id].data(); + std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); } - last_n_drafted[seq_id] = (uint16_t) draft_tokens.size(); - } - - void accept(llama_seq_id seq_id, uint16_t n_accepted) override { - GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size()); - - auto * ctx_dft = this->params.ctx_dft; - - const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - const int32_t n_drafted_last = (int32_t) last_n_drafted[seq_id]; - - const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1); - - if (pos_max < 0) { - last_n_accepted[seq_id] = (int32_t) n_accepted; + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); return; } - if (n_to_drop > 0) { - const llama_pos drop_from = pos_max - n_to_drop + 1; - llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1); + int i = 0; + + while (n_drafting > 0) { + int i_batch = 0; + + common_batch_clear(batch); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (!drafting[seq_id]) { + continue; + } + + auto * smpl = smpls[seq_id].get(); + + common_sampler_sample(smpl, ctx_dft, i_batch, true); + h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch); + ++i_batch; + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, + common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + drafting[seq_id] = false; + n_drafting--; + + continue; + } + + common_sampler_accept(smpl, id, true); + + auto & dp = dparams.at(seq_id); + auto & result = *dp.result; + + result.push_back(id); + + if ((params.n_max <= (int) result.size()) || + (dp.n_max > 0 && dp.n_max <= (int) result.size())) { + drafting[seq_id] = false; + n_drafting--; + continue; + } + + common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); + std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); + } + + if (batch.n_tokens == 0) { + break; + } + + // evaluate the drafted tokens on the draft model + ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); + break; + } + + ++i; } - last_n_drafted [seq_id] = 0; - last_n_accepted[seq_id] = (int32_t) n_accepted; + for (auto & dp : dparams) { + if (!dp.drafting) { + continue; + } + + if (dp.result->size() < (size_t) params.n_min) { + dp.result->clear(); + } + } + } + + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { } }; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 8561c83faf..f7219e687f 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -778,11 +778,6 @@ private: return false; } - if (params_base.n_parallel > 1) { - SRV_ERR("MTP currently supports only n_parallel=1; got %d\n", params_base.n_parallel); - return false; - } - SRV_INF("loading MTP head from '%s' (override_arch=%s)\n", params_base.model.path.c_str(), mtp_arch);