mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 20:44:09 +00:00
cont : simplify
This commit is contained in:
@@ -391,22 +391,14 @@ struct common_speculative_state_mtp : public common_speculative_impl {
|
||||
// The last h-row of one process() call needs the first token of the NEXT
|
||||
// call to pair with, so it's stashed here until that next call fires.
|
||||
std::vector<std::vector<float>> pending_h; // [n_seq][n_embd]
|
||||
std::vector<llama_pos> pending_pos; // [n_seq]
|
||||
|
||||
std::vector<uint16_t> last_n_drafted;
|
||||
std::vector<int32_t> last_n_accepted;
|
||||
|
||||
// Number of trunk output rows produced by the most recent process() call.
|
||||
// Used by draft() for the first AR step (when last_n_accepted is -1) to
|
||||
// pick the last prefill row out of ctx_tgt's pre-norm buffer.
|
||||
std::vector<int32_t> last_trunk_n_outputs;
|
||||
std::vector<int32_t> i_batch_beg;
|
||||
std::vector<int32_t> i_batch_end;
|
||||
|
||||
common_speculative_state_mtp(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_MTP, n_seq)
|
||||
, params(params.draft)
|
||||
{
|
||||
GGML_ASSERT(n_seq == 1 && "MTP currently supports only single-sequence speculation");
|
||||
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
|
||||
@@ -423,7 +415,7 @@ struct common_speculative_state_mtp : public common_speculative_impl {
|
||||
for (auto & s : smpls) {
|
||||
common_params_sampling sparams;
|
||||
sparams.no_perf = false;
|
||||
sparams.top_k = 1;
|
||||
sparams.top_k = 10;
|
||||
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
|
||||
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
|
||||
}
|
||||
@@ -432,11 +424,9 @@ struct common_speculative_state_mtp : public common_speculative_impl {
|
||||
llama_set_embeddings_pre_norm(ctx_dft, true);
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
pending_pos.assign(n_seq, -1);
|
||||
|
||||
last_n_drafted.assign(n_seq, 0);
|
||||
last_n_accepted.assign(n_seq, -1);
|
||||
last_trunk_n_outputs.assign(n_seq, 0);
|
||||
i_batch_beg.assign(n_seq, -1);
|
||||
i_batch_end.assign(n_seq, -1);
|
||||
}
|
||||
|
||||
~common_speculative_state_mtp() override {
|
||||
@@ -448,12 +438,6 @@ struct common_speculative_state_mtp : public common_speculative_impl {
|
||||
}
|
||||
|
||||
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < pending_pos.size());
|
||||
|
||||
last_n_accepted[seq_id] = -1;
|
||||
last_n_drafted [seq_id] = 0;
|
||||
pending_pos [seq_id] = -1;
|
||||
|
||||
const int32_t N = (int32_t) prompt.size();
|
||||
if (N <= 0) {
|
||||
return;
|
||||
@@ -474,231 +458,207 @@ struct common_speculative_state_mtp : public common_speculative_impl {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Single-seq for now (asserted in ctor). Future: bucket by seq_id.
|
||||
const llama_seq_id seq_id = 0;
|
||||
|
||||
// TODO: how to make it work with vision tokens?
|
||||
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
|
||||
pending_pos[seq_id] = -1;
|
||||
return true;
|
||||
}
|
||||
|
||||
const int32_t n_tokens = batch_in.n_tokens;
|
||||
|
||||
// remember the frist and last batch index for each sequence
|
||||
std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1);
|
||||
std::fill(i_batch_end.begin(), i_batch_end.end(), -1);
|
||||
|
||||
for (int k = 0; k < n_tokens; ++k) {
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
GGML_ASSERT(batch_in.n_seq_id[k] == 1);
|
||||
|
||||
if (batch_in.seq_id[k][0] == seq_id) {
|
||||
i_batch_end[seq_id] = k;
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
i_batch_beg[seq_id] = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
|
||||
const int32_t n_rows = batch_in.n_tokens;
|
||||
const llama_pos pos_start = batch_in.pos[0];
|
||||
|
||||
const llama_pos pos_max_dft = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
if (pos_start <= pos_max_dft) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Stale pending: discard if the new batch doesn't start one past it.
|
||||
const bool pending_continues =
|
||||
pending_pos[seq_id] >= 0 && pending_pos[seq_id] + 1 == pos_start;
|
||||
if (pending_pos[seq_id] >= 0 && !pending_continues) {
|
||||
pending_pos[seq_id] = -1;
|
||||
}
|
||||
|
||||
// Build a paired hook batch:
|
||||
// row 0 = (pending_h, batch_in.token[0]) at pos_start if pending_continues
|
||||
// rows 1..n_rows-1 = (h_k from this batch, batch_in.token[k+1]) at pos[k+1]
|
||||
// The last h-row (h_{n_rows-1}) is stashed as the new pending and is *not*
|
||||
// decoded this call — it waits for the next batch's first token to pair.
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
|
||||
common_batch_clear(batch);
|
||||
int out_idx = 0;
|
||||
|
||||
auto add_pair = [&](const float * h_row, llama_token tok, llama_pos pos) {
|
||||
std::memcpy(batch.embd + (size_t) out_idx * n_embd, h_row, row_bytes);
|
||||
batch.token [out_idx] = tok;
|
||||
batch.pos [out_idx] = pos;
|
||||
batch.n_seq_id[out_idx] = 1;
|
||||
batch.seq_id [out_idx][0] = seq_id;
|
||||
batch.logits [out_idx] = 0;
|
||||
++out_idx;
|
||||
for (int k = 0; k < n_tokens; ++k) {
|
||||
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
|
||||
}
|
||||
|
||||
// shift the tgt embeddings to the right by one position
|
||||
// assumes that the tokens in the batch are sequential for each sequence
|
||||
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
|
||||
// ^--- this is a problem
|
||||
// TODO:this is generally true, but would be nice to assert it
|
||||
{
|
||||
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
|
||||
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
|
||||
|
||||
//{
|
||||
// // string with seq_ids in the batch
|
||||
// std::stringstream ss;
|
||||
// for (int i = 0; i < n_tokens; ++i) {
|
||||
// ss << batch_in.seq_id[i][0] << ",";
|
||||
// }
|
||||
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
|
||||
//}
|
||||
}
|
||||
|
||||
// fill the pending embeddings from a previous run
|
||||
auto set_h = [&](int idx, const float * h_row) {
|
||||
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
|
||||
};
|
||||
|
||||
if (pending_continues) {
|
||||
add_pair(pending_h[seq_id].data(), batch_in.token[0], pos_start);
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
|
||||
}
|
||||
|
||||
// TODO: is there is a fast way to build this batch?
|
||||
for (int k = 0; k + 1 < n_rows; ++k) {
|
||||
if (batch_in.logits[k] == 0) {
|
||||
LOG_WRN("%s: batch_in.logits[%d] == 0 (need_embd / logits=1 missing on prefill); stopping hook at this row\n",
|
||||
__func__, k);
|
||||
break;
|
||||
}
|
||||
const float * h_k = llama_get_embeddings_pre_norm_ith(ctx_tgt, k);
|
||||
if (h_k == nullptr) {
|
||||
LOG_WRN("%s: ctx_tgt has no pre-norm row at i=%d; stopping hook\n", __func__, k);
|
||||
break;
|
||||
}
|
||||
add_pair(h_k, batch_in.token[k + 1], batch_in.pos[k + 1]);
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (out_idx > 0) {
|
||||
batch.n_tokens = out_idx;
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d, n=%d)\n",
|
||||
__func__, (int) rc, (int) pos_start, out_idx);
|
||||
return false;
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_end[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// last_n_accepted < 0) can find the last pre-norm row of this batch.
|
||||
// We assume every batch position has logits=1 (server sets need_embd
|
||||
// for MTP slots) → n_outputs == n_tokens.
|
||||
last_trunk_n_outputs[seq_id] = n_rows;
|
||||
|
||||
// Stash the last h-row (h_{n_rows-1}) as the new pending for the next
|
||||
// process() call's first token to pair with.
|
||||
if (batch_in.logits[n_rows - 1] != 0) {
|
||||
const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, n_rows - 1);
|
||||
if (h_last != nullptr) {
|
||||
std::memcpy(pending_h[seq_id].data(), h_last, row_bytes);
|
||||
pending_pos[seq_id] = batch_in.pos[n_rows - 1];
|
||||
} else {
|
||||
pending_pos[seq_id] = -1;
|
||||
}
|
||||
} else {
|
||||
// No trunk output at the tail — can't carry over.
|
||||
pending_pos[seq_id] = -1;
|
||||
const float * h_last = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_end[seq_id]);
|
||||
std::memcpy(pending_h[seq_id].data(), h_last, row_bytes);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void draft(common_speculative_draft_params_vec & dparams) override {
|
||||
// Single-seq for now (asserted in ctor). Future: iterate over dparams.
|
||||
const llama_seq_id seq_id = 0;
|
||||
if ((size_t) seq_id >= dparams.size()) {
|
||||
return;
|
||||
}
|
||||
auto & dp = dparams[seq_id];
|
||||
if (!dp.drafting) {
|
||||
return;
|
||||
}
|
||||
auto & ctx_dft = params.ctx_dft;
|
||||
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
auto * smpl = smpls[seq_id].get();
|
||||
common_batch_clear(batch);
|
||||
|
||||
GGML_ASSERT(dp.result != nullptr);
|
||||
auto & draft_tokens = *dp.result;
|
||||
draft_tokens.clear();
|
||||
|
||||
if (last_n_drafted[seq_id] > 0) {
|
||||
const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1;
|
||||
if (n_to_drop > 0) {
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
if (pos_max >= 0) {
|
||||
const llama_pos drop_from = pos_max - n_to_drop + 1;
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
|
||||
}
|
||||
}
|
||||
last_n_drafted[seq_id] = 0;
|
||||
last_n_accepted[seq_id] = 0;
|
||||
}
|
||||
|
||||
// Effective draft length: min(global cap, per-sequence override).
|
||||
int32_t n_max = std::max(1, params.n_max);
|
||||
if (dp.n_max > 0) {
|
||||
n_max = std::min(n_max, dp.n_max);
|
||||
}
|
||||
// keep track of which sequences are still drafting
|
||||
int n_drafting = 0;
|
||||
std::vector<bool> drafting(n_seq);
|
||||
|
||||
const float * h_row = nullptr;
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
|
||||
common_sampler_reset(smpl);
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
auto & dp = dparams[seq_id];
|
||||
|
||||
llama_token cond_tok = dp.id_last;
|
||||
llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id) + 1;
|
||||
|
||||
for (int32_t k = 0; k < n_max; ++k) {
|
||||
const float * h_row = nullptr;
|
||||
|
||||
if (k == 0) {
|
||||
// Condition on the trunk's pre-norm row.
|
||||
int32_t row_idx;
|
||||
if (last_n_accepted[seq_id] < 0) {
|
||||
// First draft after begin(): use the last prefill row.
|
||||
row_idx = last_trunk_n_outputs[seq_id] - 1;
|
||||
} else {
|
||||
// After accept(n_accepted): row of the next conditioning
|
||||
// position in the trunk's verify batch.
|
||||
row_idx = last_n_accepted[seq_id];
|
||||
}
|
||||
if (row_idx < 0) {
|
||||
LOG_WRN("%s: no trunk pre-norm row available (row_idx=%d); stopping chain\n",
|
||||
__func__, row_idx);
|
||||
break;
|
||||
}
|
||||
h_row = llama_get_embeddings_pre_norm_ith(ctx_tgt, row_idx);
|
||||
} else {
|
||||
// AR step: condition on the MTP head's own pre-norm row from
|
||||
// the just-completed single-token decode. n_outputs=1 there,
|
||||
// so the row is at batch position 0.
|
||||
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, 0);
|
||||
if (!dp.drafting) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (h_row == nullptr) {
|
||||
LOG_WRN("%s: missing pre-norm row at k=%d; stopping chain\n", __func__, k);
|
||||
break;
|
||||
}
|
||||
n_drafting++;
|
||||
drafting[seq_id] = true;
|
||||
common_sampler_reset(smpls[seq_id].get());
|
||||
|
||||
// 1-token batch carrying both (token, h_pre_norm).
|
||||
common_batch_clear(batch);
|
||||
std::memcpy(batch.embd, h_row, row_bytes);
|
||||
batch.token [0] = cond_tok;
|
||||
batch.pos [0] = pos;
|
||||
batch.n_seq_id[0] = 1;
|
||||
batch.seq_id [0][0] = seq_id;
|
||||
batch.logits [0] = 1; // need logits for sampling
|
||||
batch.n_tokens = 1;
|
||||
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_WRN("%s: llama_decode(ctx_dft) failed rc=%d at k=%d; stopping chain\n",
|
||||
__func__, rc, k);
|
||||
break;
|
||||
}
|
||||
|
||||
const llama_token best = common_sampler_sample(smpl, ctx_dft, 0);
|
||||
common_sampler_accept(smpl, best, /*is_generated=*/ false);
|
||||
draft_tokens.push_back(best);
|
||||
cond_tok = best;
|
||||
++pos;
|
||||
h_row = pending_h[seq_id].data();
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
}
|
||||
|
||||
last_n_drafted[seq_id] = (uint16_t) draft_tokens.size();
|
||||
}
|
||||
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size());
|
||||
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
const int32_t n_drafted_last = (int32_t) last_n_drafted[seq_id];
|
||||
|
||||
const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1);
|
||||
|
||||
if (pos_max < 0) {
|
||||
last_n_accepted[seq_id] = (int32_t) n_accepted;
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
return;
|
||||
}
|
||||
|
||||
if (n_to_drop > 0) {
|
||||
const llama_pos drop_from = pos_max - n_to_drop + 1;
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
|
||||
int i = 0;
|
||||
|
||||
while (n_drafting > 0) {
|
||||
int i_batch = 0;
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (!drafting[seq_id]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto * smpl = smpls[seq_id].get();
|
||||
|
||||
common_sampler_sample(smpl, ctx_dft, i_batch, true);
|
||||
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch);
|
||||
++i_batch;
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
|
||||
// add drafted token for each sequence
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[0].p < params.p_min) {
|
||||
drafting[seq_id] = false;
|
||||
n_drafting--;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
auto & dp = dparams.at(seq_id);
|
||||
auto & result = *dp.result;
|
||||
|
||||
result.push_back(id);
|
||||
|
||||
if ((params.n_max <= (int) result.size()) ||
|
||||
(dp.n_max > 0 && dp.n_max <= (int) result.size())) {
|
||||
drafting[seq_id] = false;
|
||||
n_drafting--;
|
||||
continue;
|
||||
}
|
||||
|
||||
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
++i;
|
||||
}
|
||||
|
||||
last_n_drafted [seq_id] = 0;
|
||||
last_n_accepted[seq_id] = (int32_t) n_accepted;
|
||||
for (auto & dp : dparams) {
|
||||
if (!dp.drafting) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (dp.result->size() < (size_t) params.n_min) {
|
||||
dp.result->clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -778,11 +778,6 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (params_base.n_parallel > 1) {
|
||||
SRV_ERR("MTP currently supports only n_parallel=1; got %d\n", params_base.n_parallel);
|
||||
return false;
|
||||
}
|
||||
|
||||
SRV_INF("loading MTP head from '%s' (override_arch=%s)\n",
|
||||
params_base.model.path.c_str(), mtp_arch);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user