spec : disacard last drafted token with low prob (#22506)

This commit is contained in:
Georgi Gerganov
2026-04-29 17:00:00 +03:00
committed by GitHub
parent b1d5f5b449
commit 683c5acb90
2 changed files with 7 additions and 7 deletions

View File

@@ -467,7 +467,7 @@ struct common_speculative_state_draft : public common_speculative_state {
prompt_dft.push_back(id_last);
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
int ret = llama_decode(ctx_dft, batch);
if (ret != 0 && ret != 1) {
@@ -495,14 +495,14 @@ struct common_speculative_state_draft : public common_speculative_state {
common_sampler_accept(smpl, id, true);
result.push_back(id);
if (sparams.n_max <= (int) result.size()) {
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < sparams.p_min) {
break;
}
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < sparams.p_min) {
result.push_back(id);
if (sparams.n_max <= (int) result.size()) {
break;
}

View File

@@ -354,6 +354,7 @@ struct server_slot {
// generate a new draft
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
n_draft_total += spec_draft.size();
if (spec_draft.size() > (size_t) n_draft_max) {
SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max);
@@ -3019,7 +3020,6 @@ private:
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
slot.n_draft_total += n_draft;
// add accepted tokens to the prompt
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);