spec : drop support for incompatible vocabs

[no ci]
This commit is contained in:
Georgi Gerganov
2026-05-07 09:54:09 +03:00
parent cc36b7a242
commit ab4718f199
4 changed files with 10 additions and 89 deletions

View File

@@ -622,10 +622,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
string_process_escapes(seq_breaker);
}
for (auto & pair : params.speculative.draft.replacements) {
string_process_escapes(pair.first);
string_process_escapes(pair.second);
}
}
if (!params.kv_overrides.empty()) {
@@ -3553,13 +3549,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.draft.mparams.path = value;
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
add_opt(common_arg(
{"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT",
"translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
[](common_params & params, const std::string & tgt, const std::string & dft) {
params.speculative.draft.replacements.push_back({ tgt, dft });
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",

View File

@@ -322,7 +322,6 @@ struct common_params_speculative_draft {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
};

View File

@@ -181,14 +181,10 @@ struct common_speculative_state_draft : public common_speculative_state {
llama_batch batch;
llama_tokens prompt_dft;
bool vocab_cmpt = true; // whether retokenization is needed
std::unordered_map<std::string, std::string> vocab_map;
common_speculative_state_draft(
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_dft,
const std::vector<std::pair<std::string, std::string>> & replacements,
bool use_ckpt)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
@@ -225,15 +221,13 @@ struct common_speculative_state_draft : public common_speculative_state {
smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt);
const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt);
if (!vocab_cmpt) {
LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n");
LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__);
for (const auto & pair : replacements) {
vocab_map[pair.first] = pair.second;
}
throw std::runtime_error("draft model vocab type must match target model to use speculation");
}
}
@@ -300,33 +294,7 @@ struct common_speculative_state_draft : public common_speculative_state {
const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max;
llama_tokens prompt_cnv;
if (!spec->vocab_cmpt) {
std::string text;
text = common_detokenize(ctx_tgt, prompt_tgt, true);
text = replace_to_dft(text);
LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
prompt_cnv = common_tokenize(ctx_dft, text, false, true);
// convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
const auto * model_tgt = llama_get_model(ctx_tgt);
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
text.resize(-n_chars);
llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
text = replace_to_dft(text);
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
id_last = common_tokenize(ctx_dft, text, false, true)[0];
}
const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv;
const llama_tokens & prompt_cur = prompt_tgt;
const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
@@ -505,16 +473,6 @@ struct common_speculative_state_draft : public common_speculative_state {
prompt_dft.push_back(id);
}
if (!spec->vocab_cmpt) {
std::string detokenized = common_detokenize(ctx_dft, result, true);
detokenized = replace_to_tgt(detokenized);
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
result = common_tokenize(ctx_tgt, detokenized, false, true);
if (result.size() > (size_t) sparams.n_max) {
result.resize(sparams.n_max);
}
}
if (result.size() < (size_t) sparams.n_min) {
result.clear();
}
@@ -532,34 +490,6 @@ struct common_speculative_state_draft : public common_speculative_state {
int32_t n_min(const common_params_speculative & params) const override {
return params.draft.n_min;
}
std::string replace_to_dft(const std::string & input) const {
std::string result = input;
for (const auto & pair : this->vocab_map) {
size_t pos = result.find(pair.first);
while (pos != std::string::npos) {
result.replace(pos, pair.first.length(), pair.second);
pos = result.find(pair.first, pos + pair.second.length());
}
}
return result;
}
std::string replace_to_tgt(const std::string & input) const {
std::string result = input;
for (const auto & pair : this->vocab_map) {
size_t pos = result.find(pair.second);
while (pos != std::string::npos) {
result.replace(pos, pair.second.length(), pair.first);
pos = result.find(pair.second, pos + pair.first.length());
}
}
return result;
}
};
struct common_speculative_state_eagle3 : public common_speculative_state {
@@ -1036,7 +966,6 @@ common_speculative * common_speculative_init(
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.draft.replacements,
/* .use_ckpt = */ params.draft.use_ckpt
));
break;

View File

@@ -910,7 +910,11 @@ private:
// try speculative decoding
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
try {
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
} catch (const std::exception & e) {
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
}
if (slot.spec) {
SLT_INF(slot, "%s", "speculative decoding context initialized\n");