diff --git a/common/arg.cpp b/common/arg.cpp index 36f0200a87..27048e4114 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -622,10 +622,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context for (auto & seq_breaker : params.sampling.dry_sequence_breakers) { string_process_escapes(seq_breaker); } - for (auto & pair : params.speculative.draft.replacements) { - string_process_escapes(pair.first); - string_process_escapes(pair.second); - } } if (!params.kv_overrides.empty()) { @@ -3553,13 +3549,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.draft.mparams.path = value; } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL")); - add_opt(common_arg( - {"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT", - "translate the string in TARGET into DRAFT if the draft model and main model are not compatible", - [](common_params & params, const std::string & tgt, const std::string & dft) { - params.speculative.draft.replacements.push_back({ tgt, dft }); - } - ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); add_opt(common_arg( {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]", string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n", diff --git a/common/common.h b/common/common.h index 14da0ef020..a8db51cd1a 100644 --- a/common/common.h +++ b/common/common.h @@ -322,7 +322,6 @@ struct common_params_speculative_draft { std::vector devices; // devices to use for offloading - std::vector> replacements; // main to speculative model replacements std::vector tensor_buft_overrides; }; diff --git a/common/speculative.cpp b/common/speculative.cpp index 77d466a672..36519f0ce9 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -181,14 +181,10 @@ struct common_speculative_state_draft : public common_speculative_state { llama_batch batch; llama_tokens prompt_dft; - bool vocab_cmpt = true; // whether retokenization is needed - std::unordered_map vocab_map; - common_speculative_state_draft( enum common_speculative_type type, llama_context * ctx_tgt, llama_context * ctx_dft, - const std::vector> & replacements, bool use_ckpt) : common_speculative_state(type) , ctx_tgt(ctx_tgt) @@ -225,15 +221,13 @@ struct common_speculative_state_draft : public common_speculative_state { smpl = common_sampler_init(llama_get_model(ctx_dft), params); } - vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); - LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); + const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); + LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt); if (!vocab_cmpt) { - LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); + LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__); - for (const auto & pair : replacements) { - vocab_map[pair.first] = pair.second; - } + throw std::runtime_error("draft model vocab type must match target model to use speculation"); } } @@ -300,33 +294,7 @@ struct common_speculative_state_draft : public common_speculative_state { const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max; - llama_tokens prompt_cnv; - if (!spec->vocab_cmpt) { - std::string text; - - text = common_detokenize(ctx_tgt, prompt_tgt, true); - text = replace_to_dft(text); - - LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); - - prompt_cnv = common_tokenize(ctx_dft, text, false, true); - - // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation - const auto * model_tgt = llama_get_model(ctx_tgt); - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); - - int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); - GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); - - text.resize(-n_chars); - llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); - text = replace_to_dft(text); - - LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); - id_last = common_tokenize(ctx_dft, text, false, true)[0]; - } - - const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; + const llama_tokens & prompt_cur = prompt_tgt; const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); @@ -505,16 +473,6 @@ struct common_speculative_state_draft : public common_speculative_state { prompt_dft.push_back(id); } - if (!spec->vocab_cmpt) { - std::string detokenized = common_detokenize(ctx_dft, result, true); - detokenized = replace_to_tgt(detokenized); - LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); - result = common_tokenize(ctx_tgt, detokenized, false, true); - if (result.size() > (size_t) sparams.n_max) { - result.resize(sparams.n_max); - } - } - if (result.size() < (size_t) sparams.n_min) { result.clear(); } @@ -532,34 +490,6 @@ struct common_speculative_state_draft : public common_speculative_state { int32_t n_min(const common_params_speculative & params) const override { return params.draft.n_min; } - - std::string replace_to_dft(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.first); - while (pos != std::string::npos) { - result.replace(pos, pair.first.length(), pair.second); - pos = result.find(pair.first, pos + pair.second.length()); - } - } - - return result; - } - - std::string replace_to_tgt(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.second); - while (pos != std::string::npos) { - result.replace(pos, pair.second.length(), pair.first); - pos = result.find(pair.second, pos + pair.first.length()); - } - } - - return result; - } }; struct common_speculative_state_eagle3 : public common_speculative_state { @@ -1036,7 +966,6 @@ common_speculative * common_speculative_init( impls.push_back(std::make_unique(config.type, /* .ctx_tgt = */ ctx_tgt, /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.draft.replacements, /* .use_ckpt = */ params.draft.use_ckpt )); break; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 4fcb77e9b0..e922d6667a 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -910,7 +910,11 @@ private: // try speculative decoding if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { - slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); + try { + slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); + } catch (const std::exception & e) { + SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); + } if (slot.spec) { SLT_INF(slot, "%s", "speculative decoding context initialized\n");