From 456268fa7fcafe71e78d699e5664b559167b473d Mon Sep 17 00:00:00 2001 From: Sascha Rogmann Date: Wed, 14 Jan 2026 23:44:23 +0100 Subject: [PATCH] common: ngram map, config self-speculative decoding --- common/arg.cpp | 42 ++++- common/common.h | 3 +- common/ngram-map.cpp | 296 ++++++++++++++++++++++++++++++++ common/ngram-map.h | 74 ++++++++ common/speculative.cpp | 109 ++++++++++-- common/speculative.h | 29 ++-- tools/server/server-context.cpp | 17 +- tools/server/server-task.cpp | 9 +- 8 files changed, 538 insertions(+), 41 deletions(-) create mode 100644 common/ngram-map.cpp create mode 100644 common/ngram-map.h diff --git a/common/arg.cpp b/common/arg.cpp index 91743ebaa6..8fc15261a5 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3393,10 +3393,46 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); add_opt(common_arg( - {"--spec-self"}, "<0|1>", - "use self-speculation without a draft model (default: 0, no self speculation without draft model)", + {"--spec-self"}, "N", + "mode of self-speculation without a draft model: disabled(0), fixed(1), keys-only(2), key-values(3) (default: %d)\n", [](common_params & params, int value) { - params.speculative.use_self = value; + if (value < 0 || value > 3) { + throw std::invalid_argument("invalid value"); + } + params.speculative.self_mode = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--spec-self-config"}, "N0,N1,N2,...", + "speculative self decoding config: ngram size (key), mgram size (value), check rate, min hits (default: %d,%d,%d,%d)", + [](common_params & params, const std::string & value) { + std::string arg_next = value; + + // split string by , and / + const std::regex regex{ R"([,/]+)" }; + std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; + std::vector split_arg{ it, {} }; + if (split_arg.size() > 4) { + throw std::invalid_argument( + string_format("got %d input configs, but self-speculative decoding config require at most 4 values", (int)split_arg.size()) + ); + } + for (size_t i = 0; i < split_arg.size(); ++i) { + int val = std::stoi(split_arg[i]); + if (i == 0 && (val < 1 || val > 255)) { + throw std::invalid_argument("ngram size must be between 1 and 255"); + } + if (i == 1 && (val < 1 || val > 255)) { + throw std::invalid_argument("mgram size must be between 1 and 255"); + } + if (i == 2 && val == 0) { + throw std::invalid_argument("check rate must be greater than 0"); + } + if (i == 3 && (val < 1 || val > 255)) { + throw std::invalid_argument("min hits must be between 1 and 255"); + } + params.speculative.self_cfg[i] = (uint16_t) val; + } } ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( diff --git a/common/common.h b/common/common.h index 7204020d28..be38ef1185 100644 --- a/common/common.h +++ b/common/common.h @@ -251,7 +251,8 @@ struct common_params_speculative { int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) - int32_t use_self = 0; // use self-speculative decoding without draft model (default: 0 = off) + int32_t self_mode = 0; // mode of self-speculative decoding without draft model (default: 0 = off) + std::vector self_cfg = {12, 48, 2, 1}; // self-speculative decoding config (n-gram size, m-gram size, check rate, min hits) std::vector> replacements; // main to speculative model replacements std::vector tensor_buft_overrides; diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp new file mode 100644 index 0000000000..487f48362f --- /dev/null +++ b/common/ngram-map.cpp @@ -0,0 +1,296 @@ +#include "ngram-map.h" +#include "common.h" +#include "log.h" + +#include +#include +#include + +// maximum number of counted values of a ngram map value. +#define COMMON_NGRAM_MAX_VALUE_COUNT 16380 + +std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length); + +void common_ngram_map_draft(common_ngram_map & map, + const llama_tokens & inp, llama_token sampled, + llama_tokens & draft) { + // reset last key and value. + map.last_draft_created = false; + map.last_draft_key_idx = 0; + map.last_draft_value_idx = 0; + + const size_t cur_len = inp.size(); + const uint16_t n = map.size_key; + const uint16_t m = map.size_value; + if (cur_len < static_cast(2 * n + m)) { + return; + } + + // Only check every check_rate tokens to save compute + // i.e., perform check if (cur_len - idx_last_check) >= check_rate + if (map.idx_last_check + map.check_rate > cur_len) { + return; + } + map.idx_last_check = cur_len; + + // search pattern, the key n-gram + std::vector key_tokens; + key_tokens.reserve(n); + for (size_t j = cur_len - n + 1; j < cur_len; ++j) { + key_tokens.push_back(inp[j]); + } + key_tokens.push_back(sampled); + + // search for the key in the map + size_t match_pos = 0; + for (size_t j = cur_len - n - m - 1; j > 0; --j) { + bool match = true; + for (size_t k = 0; k < n; ++k) { + if (inp[j + k] != key_tokens[k]) { + match = false; + break; + } + } + if (match) { + match_pos = j; + break; + } + } + if (match_pos > 0) { + LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__, + cur_len, n, m, key_tokens.size(), sampled, match_pos); + } + + if (match_pos == 0) { + return; + } + + // We have a match, now we look for the statistics of the key. + size_t key_offset = map.keys.size(); // offset in the map + // We iterate through the std::vector map->keys. + for (size_t i = 0; i < map.keys.size(); ++i) { + bool match = true; + for (size_t j = 0; j < n; ++j) { + if (inp[map.keys[i].key_idx + j] != key_tokens[j]) { + match = false; + break; + } + } + if (match) { + key_offset = i; + break; + } + } + if (key_offset == map.keys.size()) { + // We create a new key-entry, it will get offset key_offset. + common_ngram_map_key new_key; + new_key.key_idx = match_pos; + new_key.stat_idx = 0; + new_key.key_num = 0; + for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) { + new_key.values[i].value_num = 0; + new_key.values[i].n_accepted = m; + } + map.keys.push_back(new_key); + } + + // our key n-gram: + common_ngram_map_key & curr_key = map.keys[key_offset]; + + // update number of key hits + curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1, + (int) COMMON_NGRAM_MAX_VALUE_COUNT); + + if (map.key_only) { + // simple mode: + // Fill in the draft with the m tokens following the key. + // We work with value values[0] only. + int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted); + + for (int i = 0; i < n_draft_tokens; ++i) { + draft.push_back(inp[match_pos + n + i]); + } + + LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__, + key_offset, curr_key.key_num, draft.size()); + + map.last_draft_created = false; + map.last_draft_key_idx = key_offset; + map.last_draft_value_idx = 0; // value 0 is used for simple mode + map.drafts_generated_tokens += draft.size(); + return; + } + + if (curr_key.key_num < map.min_hits) { + // not enough hits to consider this a good draft + LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__, + key_offset, curr_key.key_num, map.min_hits); + return; + } + + // complex mode: examine the different m-grams after this key n-gram. + // + + // determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram. + for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) { + // begins the key n-gram at index i? + bool match_key = true; + for (size_t k = 0; k < n; ++k) { + if (inp[i + k] != key_tokens[k]) { + match_key = false; + break; + } + } + if (!match_key) { + continue; + } + + // Do we haven a existing value m-gram or a new one after the key at index i? + size_t idx_begin_value_key = i + n; + int idx_value = -1; + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + size_t idx_begin_value_v = curr_key.values[v].value_idx; + if (idx_begin_value_v == 0) { + // We found an empty value slot => we found a new value m-gram after the key n-gram. + curr_key.values[v].value_idx = idx_begin_value_key; + curr_key.values[v].value_num = 0; + curr_key.values[v].n_accepted = m; + idx_value = v; + break; + } + bool match = true; + for (size_t j = 0; j < m; ++j) { + if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) { + match = false; + break; + } + } + if (match) { + // We found an existing value m-gram after the key n-gram. + idx_value = v; + break; + } + } + if (idx_value >= 0) { + // We found a value m-gram of the key n-gram. + curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1, + (int) COMMON_NGRAM_MAX_VALUE_COUNT); + } + } + // the statistics are updated up to match_pos. + curr_key.stat_idx = match_pos; + + // Do we have a value we could use for the draft? + uint16_t max_occur = 0; + int slot_max = 0; + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + uint16_t curr_occur = curr_key.values[v].value_num; + if (curr_occur > max_occur) { + max_occur = curr_occur; + slot_max = v; + } + } + // What is sum of the other occurences? + uint32_t sum_occur = 0; + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + if (v == slot_max) { + continue; + } + uint16_t curr_occur = curr_key.values[v].value_num; + sum_occur += curr_occur; + } + + LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__, + key_offset, + max_occur, sum_occur, slot_max, + curr_key.values[0].value_idx, curr_key.values[0].value_num, + curr_key.values[1].value_idx, curr_key.values[1].value_num, + curr_key.values[2].value_idx, curr_key.values[2].value_num, + curr_key.values[3].value_idx, curr_key.values[3].value_num + ); + // Print the tokens of the four values (if idx != 0), use LOG_INF + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + if (curr_key.values[v].value_idx != 0) { + LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str()); + } + } + + if (sum_occur > 0 && max_occur < 3 * sum_occur) { + // The most frequent value is not much more frequent than the other values. + // We do not use the draft. + return; + } + + // We use the most frequent value values[slot_max] for the draft. + // Fill in the draft with the m tokens following the key. + int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted); + + for (int i = 0; i < n_draft_tokens; ++i) { + draft.push_back(inp[match_pos + n + i]); + } + + LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__, + key_offset, slot_max, + curr_key.key_num, draft.size()); + + map.last_draft_created = true; + map.last_draft_key_idx = key_offset; + map.last_draft_value_idx = slot_max; // value used for draft generation. + map.drafts_generated_tokens += draft.size(); +} + +void common_ngram_map_send_accepted(common_ngram_map & map, uint16_t n_accepted) { + if (!map.last_draft_created) { + return; + } + + // find the key and its chosen value. + const size_t key_idx = map.last_draft_key_idx; + const size_t val_idx = map.last_draft_value_idx; + + // find key corresponding to key_idx. + common_ngram_map_key & curr_key = map.keys[key_idx]; + // find value corresponding to val_idx. + struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation. + + // update the value statistics + LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", + n_accepted, curr_value.n_accepted); + curr_value.n_accepted = n_accepted; + + // draft statistics update + if (n_accepted > 0) { + map.drafts_accepted_count++; + } else { + map.drafts_rejected_count++; + } + map.drafts_accepted_tokens += n_accepted; +} + +// Display statistics of the ngram map. +void common_ngram_map_print_stats(const common_ngram_map & map) { + LOG_INF("ngram map: size_key = %d, size_value = %d, key_only = %s, min_hits = %d\n", + map.size_key, map.size_value, + map.key_only ? "true" : "false", + map.min_hits); + LOG_INF("drafts_accepted_count = %zu, drafts_rejected_count = %zu, drafts_generated_tokens = %zu, drafts_accepted_tokens = %zu\n", + map.drafts_accepted_count, map.drafts_rejected_count, + map.drafts_generated_tokens, map.drafts_accepted_tokens); +} + +// Helper functions. +// + +// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...]. +std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) { + std::string result = "["; + for (size_t i = 0; i < length; ++i) { + if (i > 0) { + result += ", "; + } + result += std::to_string(inp[start + i]); + } + result += "]"; + return result; +} + diff --git a/common/ngram-map.h b/common/ngram-map.h new file mode 100644 index 0000000000..478cc6d54b --- /dev/null +++ b/common/ngram-map.h @@ -0,0 +1,74 @@ +#pragma once +// +// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams +// +// These structures are used to do a lookup of n-grams followed by m-grams in token history. + +#include "llama.h" + +#include +#include + +// maximum number of m-gram values stored for each key n-gram. +#define COMMON_NGRAM_MAX_VALUES 4 + +// statistics of a m-gram after a known n-gram +struct common_ngram_map_value { + size_t value_idx; // index of value m-gram in token-history (0 if unused) + uint16_t value_num; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot) + int16_t n_accepted; // number of accepted tokens at last draft (-1 if unused) +}; + +// statistics of a n-gram +struct common_ngram_map_key { + size_t key_idx; // index of key n-gram in token-history + size_t stat_idx; // index of last token of stastistics computation (key_num, values) + + uint16_t key_num; // number of occurences of this key n-gram in token-history + common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key +}; + +// map from n-grams to following m-grams in token-history +struct common_ngram_map { + uint16_t size_key; // size of key n-grams + uint16_t size_value; // size of value m-grams + + bool key_only; // true if only key n-grams are used, no values. + + // first draft: vector only, no map. + std::vector keys; // key n-grams which occur several times in token-history + uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token + uint16_t min_hits; // minimum number of key hits to consider a draft + + common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys, + uint16_t check_rate, uint16_t min_hits) + : size_key(sz_key), size_value(sz_value), key_only(only_keys), keys(std::vector{}), + check_rate(check_rate), min_hits(min_hits) {} + + size_t drafts_accepted_count = 0; // number of drafts accepted by the target model. + size_t drafts_rejected_count = 0; // number of drafts rejected by the target model. + size_t drafts_generated_tokens = 0; // number of tokens generated by this ngram map. + size_t drafts_accepted_tokens = 0; // number of tokens accepted by the target model. + + bool last_draft_created = false; // true if a draft was created at last call. + size_t last_draft_key_idx = 0; // index of last key used for draft generation. + uint16_t last_draft_value_idx = 0; // index of last value used for draft generation. + + size_t idx_last_check = 0; // index of last check in context history +}; + +// Searches for the n-gram in the history and checks whether a draft sequence should be generated. +// map: the ngram map to search in. +// inp: the tokens generated so far. +// sampled: the token that was just sampled. +// draft: vector to store the draft tokens, initially empty. +void common_ngram_map_draft( + common_ngram_map & map, + const llama_tokens & inp, llama_token sampled, + llama_tokens & draft); + +// Update the statistics of a value after a draft was accepted. +void common_ngram_map_send_accepted(common_ngram_map & map, uint16_t n_accepted); + +// Display statistics of the ngram map. +void common_ngram_map_print_stats(const common_ngram_map & map); diff --git a/common/speculative.cpp b/common/speculative.cpp index b4c7f6ed62..6f35edf6be 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -5,6 +5,7 @@ #include "log.h" #include "common.h" #include "sampling.h" +#include "ngram-map.cpp" #include #include @@ -13,6 +14,13 @@ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 +struct common_speculative_self { + uint16_t size_ngram = 12; // size of n-grams to lookup in self-mode + uint16_t size_mgram = 48; // size of m-grams to draft in self-mode + const uint16_t check_rate = 3; // check for speculative decoding without draft model for each check_rate token + size_t idx_last_check = 0; // index of last check in context history +}; + struct common_speculative { struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft struct llama_context * ctx_dft; @@ -22,20 +30,44 @@ struct common_speculative { llama_tokens prompt_dft; bool vocab_dft_compatible = true; // whether retokenization is needed std::map tgt_dft_replacements = {}; + + const uint16_t self_mode = 0; // 0: off, 1: self speculative, 2: n-grams (keys) only, 3: n-grams/m-grams (key-values) + common_ngram_map map; // draft ngram map for speculative decoding without draft model + common_speculative_self self_state; // state of self-speculation (simple implementation, not ngram-map) }; struct common_speculative * common_speculative_init( struct llama_context * ctx_tgt, - struct llama_context * ctx_dft) { + struct llama_context * ctx_dft, + uint16_t self_mode, // 0: off, 1: self speculative, 2: n-grams (keys) only, 3: n-grams/m-grams (key-values) + const std::vector self_cfg // ngram size, mgram size, keys only (0|1), min hits + ) { + uint16_t ngram_size_key = self_cfg.size() >= 1 ? self_cfg[0] : 12; + uint16_t mgram_size_value = self_cfg.size() >= 2 ? self_cfg[1] : 48; + uint16_t check_rate = self_cfg.size() >= 3 ? self_cfg[2] : 3; + bool key_only = (self_mode != 3); + uint16_t min_hits = self_cfg.size() >= 4 ? self_cfg[3] : 1; + common_ngram_map ngram_map = common_ngram_map(ngram_size_key, mgram_size_value, key_only, check_rate, min_hits); + common_speculative_self self_state = common_speculative_self{ + /* .size_ngram = */ ngram_size_key, + /* .size_mgram = */ mgram_size_value, + /* .check_rate = */ check_rate, + /* .idx_last_check = */ 0, + }; auto * result = new common_speculative { /* .ctx_tgt = */ ctx_tgt, /* .ctx_dft = */ ctx_dft, /* .smpl = */ nullptr, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .batch = */ llama_batch_init(ctx_dft ? llama_n_batch(ctx_dft) : 64, 0, 1), /* .prompt_dft = */ {}, /* .vocab_dft_compatible = */ false, + /* .tgt_dft_replacements = */ {}, + /* .self_mode = */ self_mode, + /* .map = */ ngram_map, + /* .self_state = */ self_state }; + LOG_INF("common_speculative_init: created speculative decoder, map.n = %d\n", result->map.size_key); // TODO: optimize or pass from outside? #if 0 { @@ -64,7 +96,9 @@ struct common_speculative * common_speculative_init( COMMON_SAMPLER_TYPE_TOP_K, }; - result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + if (ctx_dft) { + result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } } #endif @@ -89,6 +123,9 @@ void common_speculative_free(struct common_speculative * spec) { bool common_speculative_are_compatible( const struct llama_context * ctx_tgt, const struct llama_context * ctx_dft) { + if (ctx_tgt == nullptr && ctx_dft == nullptr) { + return true; + } const struct llama_model * model_tgt = llama_get_model(ctx_tgt); const struct llama_model * model_dft = llama_get_model(ctx_dft); @@ -181,22 +218,25 @@ static std::string replace_to_tgt( return result; } +llama_tokens common_speculative_gen_self_draft( + common_speculative * spec, + const llama_tokens & tokens, llama_token sampled); llama_tokens common_speculative_gen_draft( struct common_speculative * spec, struct common_speculative_params params, const llama_tokens & prompt_tgt_main_model, // specified in target model vocab llama_token id_last) { - if (params.self_mode == 1) { + if (spec->self_mode) { // Look in the current context for a n-gram and return the following tokens as the draft. - llama_tokens draft_self = common_speculative_gen_self_draft(prompt_tgt_main_model, id_last, - params.self_ngram_size, params.n_draft); + llama_tokens draft_self = common_speculative_gen_self_draft(spec, + prompt_tgt_main_model, id_last); if (!draft_self.empty()) { return draft_self; } } - if (spec == nullptr) { - return {}; + if (spec == nullptr || spec->ctx_dft == nullptr) { + return {}; // no draft model, return } auto & batch = spec->batch; @@ -372,14 +412,54 @@ llama_tokens common_speculative_gen_draft( return result; } -llama_tokens common_speculative_gen_self_draft(const llama_tokens & tokens, llama_token sampled, - size_t n_draft_min, size_t n_draft_max) { +void common_speculative_send_accepted(struct common_speculative * spec, const uint16_t n_accepted) { + // use new function to update the ngram map statistics. + common_ngram_map_send_accepted(spec->map, n_accepted); +} + +// self-speculative decoding +// + +/** + * Perform speculative generation using the model's own token history. + * Searches for a matching pattern in the token history and returns draft tokens. + * + * @param spec configuration of speculative drafts + * @param tokens Token history to search in + * @param sampled Last sampled token + * @return Vector of draft tokens, empty if no matching pattern is found + */ +llama_tokens common_speculative_gen_self_draft( + common_speculative * spec, + const llama_tokens & tokens, llama_token sampled) { + + common_ngram_map & map = spec->map; + if (spec->self_mode != 1) { + // Use common_ngram_map_draft to generate a draft from the current context. + llama_tokens draft_tokens; + common_ngram_map_draft(map, tokens, sampled, draft_tokens); + return draft_tokens; + } + + // Simple implementation of self-speculative decoding without draft model, without ngram-map. + // + common_speculative_self & self_state = spec->self_state; const size_t cur_len = tokens.size(); + // Only check every check_rate tokens to save compute + // i.e., perform check if (cur_len - idx_last_check) >= check_rate + if (self_state.idx_last_check + self_state.check_rate > cur_len) { + llama_tokens draft_tokens; + return draft_tokens; + } + + size_t n_draft_min = self_state.size_ngram; // size of n-gram to lookup in token history + size_t n_draft_max = self_state.size_mgram; // the m-gram following the found n-gram is used for draft // vector for tokens we want to verify. // return empty vector if there is no match. llama_tokens draft_tokens; + // We need at least n_draft_min + n_draft_max + 1 tokens. if (cur_len <= static_cast(n_draft_min + n_draft_max + 1)) { return draft_tokens; } @@ -392,6 +472,9 @@ llama_tokens common_speculative_gen_self_draft(const llama_tokens & tokens, llam } pattern.push_back(sampled); // add the last token to the pattern + // We do a search in the token history. + self_state.idx_last_check = tokens.size(); + size_t match_pos = 0; // we ignore position 0, position 0 == no match // search backwards, but skip the current match (we are currently there) for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) { @@ -428,3 +511,9 @@ llama_tokens common_speculative_gen_self_draft(const llama_tokens & tokens, llam } return draft_tokens; } + +void common_speculative_print_stats(const struct common_speculative * spec) { + if (spec->map.drafts_generated_tokens > 0) { // only print if we have some stats + common_ngram_map_print_stats(spec->map); + } +} diff --git a/common/speculative.h b/common/speculative.h index 6407563c6f..94d4538876 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -10,14 +10,13 @@ struct common_speculative_params { int n_reuse = 256; float p_min = 0.75f; // min probability required to accept a token in the draft - - int self_mode = 0; // 0: off, 1: self speculative lookup - int self_ngram_size = 12; // length of pattern to search for in self mode }; struct common_speculative * common_speculative_init( struct llama_context * ctx_tgt, - struct llama_context * ctx_dft + struct llama_context * ctx_dft, + const uint16_t self_mode = 0, // 0: off, 1: self speculative, 2: n-grams (keys) only, 3: n-grams/m-grams (key-values) + const std::vector self_cfg = { 12, 48, 3, 1 } // ngram size, mgram size, check rate, min hits ); void common_speculative_free(struct common_speculative * spec); @@ -37,18 +36,10 @@ llama_tokens common_speculative_gen_draft( const llama_tokens & prompt, llama_token id_last); -/** - * Perform speculative generation using the model's own token history. - * Searches for a matching pattern in the token history and returns draft tokens. - * - * @param tokens Token history to search in - * @param sampled Last sampled token - * @param n_draft_min Minimum number of draft tokens required - * @param n_draft_max Maximum number of draft tokens to generate - * @return Vector of draft tokens, empty if no matching pattern is found - */ -llama_tokens common_speculative_gen_self_draft( - const llama_tokens & tokens, - llama_token sampled, - size_t n_draft_min, - size_t n_draft_max); +// informs the speculative decoder that n_accepted tokens were accepted by the target model +void common_speculative_send_accepted( + struct common_speculative * spec, + const uint16_t n_accepted); + +// print statistics about the speculative decoding +void common_speculative_print_stats(const struct common_speculative * spec); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 099c34056d..ed4cd44ca6 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -260,7 +260,7 @@ struct server_slot { // Checks if a draft model is active or self-speculation using context-tokens bool can_speculate() const { - return ctx_dft || task->params.speculative.use_self; + return ctx_dft || task->params.speculative.self_mode; } void add_token(const completion_token_output & token) { @@ -397,6 +397,7 @@ struct server_slot { "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", draft_ratio, n_draft_accepted, n_draft_total ); + common_speculative_print_stats(spec); } } @@ -774,7 +775,9 @@ private: return false; } - slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft); + slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft, + params_base.speculative.self_mode, + params_base.speculative.self_cfg); if (slot.spec == nullptr) { SRV_ERR("%s", "failed to create speculator\n"); return false; @@ -782,6 +785,11 @@ private: for (auto & pair : params_base.speculative.replacements) { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } + } else if (params_base.speculative.self_mode) { + SLT_INF(slot, "init spec for self-speculative decoding, slot %d\n", i); + slot.spec = common_speculative_init(nullptr, nullptr, + params_base.speculative.self_mode, + params_base.speculative.self_cfg); } SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); @@ -2071,8 +2079,6 @@ private: params_spec.n_draft = n_draft_max; params_spec.n_reuse = slot.ctx_dft ? (llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max) : 0; params_spec.p_min = slot.task->params.speculative.p_min; - params_spec.self_mode = slot.task->params.speculative.use_self; - params_spec.self_ngram_size = std::max(5, slot.task->params.speculative.n_min); const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); @@ -2816,6 +2822,9 @@ private: // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; + // inform the speculative decoding about the accepted tokens + common_speculative_send_accepted(slot.spec, ids.size() - 1); + // rollback to the state before sampling the draft tokens slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 79426dbc84..0a721c3142 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -234,10 +234,11 @@ task_params server_task::params_from_json_cmpl( params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); - params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); - params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); - params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); - params.speculative.use_self = json_value(data, "speculative.use_self", defaults.speculative.use_self); + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + params.speculative.self_mode = json_value(data, "speculative.self_mode", defaults.speculative.self_mode); + params.speculative.self_cfg = json_value(data, "speculative.self_cfg", defaults.speculative.self_cfg); params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); params.speculative.n_min = std::max(params.speculative.n_min, 0);