cont : simplify

This commit is contained in:
Georgi Gerganov
2026-05-11 09:41:00 +03:00
parent c417ddfc74
commit c8f8e2364c
2 changed files with 163 additions and 208 deletions

View File

@@ -391,22 +391,14 @@ struct common_speculative_state_mtp : public common_speculative_impl {
// The last h-row of one process() call needs the first token of the NEXT
// call to pair with, so it's stashed here until that next call fires.
std::vector<std::vector<float>> pending_h; // [n_seq][n_embd]
std::vector<llama_pos> pending_pos; // [n_seq]
std::vector<uint16_t> last_n_drafted;
std::vector<int32_t> last_n_accepted;
// Number of trunk output rows produced by the most recent process() call.
// Used by draft() for the first AR step (when last_n_accepted is -1) to
// pick the last prefill row out of ctx_tgt's pre-norm buffer.
std::vector<int32_t> last_trunk_n_outputs;
std::vector<int32_t> i_batch_beg;
std::vector<int32_t> i_batch_end;
common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq)
, params(params.draft)
{
GGML_ASSERT(n_seq == 1 && "MTP currently supports only single-sequence speculation");
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
@@ -423,7 +415,7 @@ struct common_speculative_state_mtp : public common_speculative_impl {
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 1;
sparams.top_k = 10;
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}
@@ -432,11 +424,9 @@ struct common_speculative_state_mtp : public common_speculative_impl {
llama_set_embeddings_pre_norm(ctx_dft, true);
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
pending_pos.assign(n_seq, -1);
last_n_drafted.assign(n_seq, 0);
last_n_accepted.assign(n_seq, -1);
last_trunk_n_outputs.assign(n_seq, 0);
i_batch_beg.assign(n_seq, -1);
i_batch_end.assign(n_seq, -1);
}
~common_speculative_state_mtp() override {
@@ -448,12 +438,6 @@ struct common_speculative_state_mtp : public common_speculative_impl {
}
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < pending_pos.size());
last_n_accepted[seq_id] = -1;
last_n_drafted [seq_id] = 0;
pending_pos [seq_id] = -1;
const int32_t N = (int32_t) prompt.size();
if (N <= 0) {
return;
@@ -474,231 +458,207 @@ struct common_speculative_state_mtp : public common_speculative_impl {
return true;
}
// Single-seq for now (asserted in ctor). Future: bucket by seq_id.
const llama_seq_id seq_id = 0;
// TODO: how to make it work with vision tokens?
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
pending_pos[seq_id] = -1;
return true;
}
const int32_t n_tokens = batch_in.n_tokens;
// remember the frist and last batch index for each sequence
std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1);
std::fill(i_batch_end.begin(), i_batch_end.end(), -1);
for (int k = 0; k < n_tokens; ++k) {
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
GGML_ASSERT(batch_in.n_seq_id[k] == 1);
if (batch_in.seq_id[k][0] == seq_id) {
i_batch_end[seq_id] = k;
if (i_batch_beg[seq_id] < 0) {
i_batch_beg[seq_id] = k;
}
}
}
}
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
const int32_t n_rows = batch_in.n_tokens;
const llama_pos pos_start = batch_in.pos[0];
const llama_pos pos_max_dft = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_start <= pos_max_dft) {
return true;
}
// Stale pending: discard if the new batch doesn't start one past it.
const bool pending_continues =
pending_pos[seq_id] >= 0 && pending_pos[seq_id] + 1 == pos_start;
if (pending_pos[seq_id] >= 0 && !pending_continues) {
pending_pos[seq_id] = -1;
}
// Build a paired hook batch:
// row 0 = (pending_h, batch_in.token[0]) at pos_start if pending_continues
// rows 1..n_rows-1 = (h_k from this batch, batch_in.token[k+1]) at pos[k+1]
// The last h-row (h_{n_rows-1}) is stashed as the new pending and is *not*
// decoded this call — it waits for the next batch's first token to pair.
const size_t row_bytes = (size_t) n_embd * sizeof(float);
common_batch_clear(batch);
int out_idx = 0;
auto add_pair = [&](const float * h_row, llama_token tok, llama_pos pos) {
std::memcpy(batch.embd + (size_t) out_idx * n_embd, h_row, row_bytes);
batch.token [out_idx] = tok;
batch.pos [out_idx] = pos;
batch.n_seq_id[out_idx] = 1;
batch.seq_id [out_idx][0] = seq_id;
batch.logits [out_idx] = 0;
++out_idx;
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
//{
// // string with seq_ids in the batch
// std::stringstream ss;
// for (int i = 0; i < n_tokens; ++i) {
// ss << batch_in.seq_id[i][0] << ",";
// }
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
//}
}
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
if (pending_continues) {
add_pair(pending_h[seq_id].data(), batch_in.token[0], pos_start);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
// TODO: is there is a fast way to build this batch?
for (int k = 0; k + 1 < n_rows; ++k) {
if (batch_in.logits[k] == 0) {
LOG_WRN("%s: batch_in.logits[%d] == 0 (need_embd / logits=1 missing on prefill); stopping hook at this row\n",
__func__, k);
break;
}
const float * h_k = llama_get_embeddings_pre_norm_ith(ctx_tgt, k);
if (h_k == nullptr) {
LOG_WRN("%s: ctx_tgt has no pre-norm row at i=%d; stopping hook\n", __func__, k);
break;
}
add_pair(h_k, batch_in.token[k + 1], batch_in.pos[k + 1]);
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
}
if (out_idx > 0) {
batch.n_tokens = out_idx;
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d, n=%d)\n",
__func__, (int) rc, (int) pos_start, out_idx);
return false;
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_end[seq_id] < 0) {
continue;
}
}
// last_n_accepted < 0) can find the last pre-norm row of this batch.
// We assume every batch position has logits=1 (server sets need_embd
// for MTP slots) → n_outputs == n_tokens.
last_trunk_n_outputs[seq_id] = n_rows;
// Stash the last h-row (h_{n_rows-1}) as the new pending for the next
// process() call's first token to pair with.
if (batch_in.logits[n_rows - 1] != 0) {
const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, n_rows - 1);
if (h_last != nullptr) {
std::memcpy(pending_h[seq_id].data(), h_last, row_bytes);
pending_pos[seq_id] = batch_in.pos[n_rows - 1];
} else {
pending_pos[seq_id] = -1;
}
} else {
// No trunk output at the tail — can't carry over.
pending_pos[seq_id] = -1;
const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_end[seq_id]);
std::memcpy(pending_h[seq_id].data(), h_last, row_bytes);
}
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
// Single-seq for now (asserted in ctor). Future: iterate over dparams.
const llama_seq_id seq_id = 0;
if ((size_t) seq_id >= dparams.size()) {
return;
}
auto & dp = dparams[seq_id];
if (!dp.drafting) {
return;
}
auto & ctx_dft = params.ctx_dft;
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
auto * smpl = smpls[seq_id].get();
common_batch_clear(batch);
GGML_ASSERT(dp.result != nullptr);
auto & draft_tokens = *dp.result;
draft_tokens.clear();
if (last_n_drafted[seq_id] > 0) {
const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1;
if (n_to_drop > 0) {
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max >= 0) {
const llama_pos drop_from = pos_max - n_to_drop + 1;
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
}
}
last_n_drafted[seq_id] = 0;
last_n_accepted[seq_id] = 0;
}
// Effective draft length: min(global cap, per-sequence override).
int32_t n_max = std::max(1, params.n_max);
if (dp.n_max > 0) {
n_max = std::min(n_max, dp.n_max);
}
// keep track of which sequences are still drafting
int n_drafting = 0;
std::vector<bool> drafting(n_seq);
const float * h_row = nullptr;
const size_t row_bytes = (size_t) n_embd * sizeof(float);
common_sampler_reset(smpl);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
llama_token cond_tok = dp.id_last;
llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id) + 1;
for (int32_t k = 0; k < n_max; ++k) {
const float * h_row = nullptr;
if (k == 0) {
// Condition on the trunk's pre-norm row.
int32_t row_idx;
if (last_n_accepted[seq_id] < 0) {
// First draft after begin(): use the last prefill row.
row_idx = last_trunk_n_outputs[seq_id] - 1;
} else {
// After accept(n_accepted): row of the next conditioning
// position in the trunk's verify batch.
row_idx = last_n_accepted[seq_id];
}
if (row_idx < 0) {
LOG_WRN("%s: no trunk pre-norm row available (row_idx=%d); stopping chain\n",
__func__, row_idx);
break;
}
h_row = llama_get_embeddings_pre_norm_ith(ctx_tgt, row_idx);
} else {
// AR step: condition on the MTP head's own pre-norm row from
// the just-completed single-token decode. n_outputs=1 there,
// so the row is at batch position 0.
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, 0);
if (!dp.drafting) {
continue;
}
if (h_row == nullptr) {
LOG_WRN("%s: missing pre-norm row at k=%d; stopping chain\n", __func__, k);
break;
}
n_drafting++;
drafting[seq_id] = true;
common_sampler_reset(smpls[seq_id].get());
// 1-token batch carrying both (token, h_pre_norm).
common_batch_clear(batch);
std::memcpy(batch.embd, h_row, row_bytes);
batch.token [0] = cond_tok;
batch.pos [0] = pos;
batch.n_seq_id[0] = 1;
batch.seq_id [0][0] = seq_id;
batch.logits [0] = 1; // need logits for sampling
batch.n_tokens = 1;
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_WRN("%s: llama_decode(ctx_dft) failed rc=%d at k=%d; stopping chain\n",
__func__, rc, k);
break;
}
const llama_token best = common_sampler_sample(smpl, ctx_dft, 0);
common_sampler_accept(smpl, best, /*is_generated=*/ false);
draft_tokens.push_back(best);
cond_tok = best;
++pos;
h_row = pending_h[seq_id].data();
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}
last_n_drafted[seq_id] = (uint16_t) draft_tokens.size();
}
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size());
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
const int32_t n_drafted_last = (int32_t) last_n_drafted[seq_id];
const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1);
if (pos_max < 0) {
last_n_accepted[seq_id] = (int32_t) n_accepted;
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
}
if (n_to_drop > 0) {
const llama_pos drop_from = pos_max - n_to_drop + 1;
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
int i = 0;
while (n_drafting > 0) {
int i_batch = 0;
common_batch_clear(batch);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (!drafting[seq_id]) {
continue;
}
auto * smpl = smpls[seq_id].get();
common_sampler_sample(smpl, ctx_dft, i_batch, true);
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch);
++i_batch;
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
drafting[seq_id] = false;
n_drafting--;
continue;
}
common_sampler_accept(smpl, id, true);
auto & dp = dparams.at(seq_id);
auto & result = *dp.result;
result.push_back(id);
if ((params.n_max <= (int) result.size()) ||
(dp.n_max > 0 && dp.n_max <= (int) result.size())) {
drafting[seq_id] = false;
n_drafting--;
continue;
}
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}
if (batch.n_tokens == 0) {
break;
}
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
++i;
}
last_n_drafted [seq_id] = 0;
last_n_accepted[seq_id] = (int32_t) n_accepted;
for (auto & dp : dparams) {
if (!dp.drafting) {
continue;
}
if (dp.result->size() < (size_t) params.n_min) {
dp.result->clear();
}
}
}
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
}
};

View File

@@ -778,11 +778,6 @@ private:
return false;
}
if (params_base.n_parallel > 1) {
SRV_ERR("MTP currently supports only n_parallel=1; got %d\n", params_base.n_parallel);
return false;
}
SRV_INF("loading MTP head from '%s' (override_arch=%s)\n",
params_base.model.path.c_str(), mtp_arch);