mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-08 18:14:07 +00:00
Compare commits
24 Commits
b9072
...
gg/spec-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
efa2f8e5a7 | ||
|
|
778f9e247e | ||
|
|
1dbc054da5 | ||
|
|
161eae0adf | ||
|
|
e5b1401318 | ||
|
|
3b1a8df8fd | ||
|
|
233d1aee69 | ||
|
|
12c7cfbe83 | ||
|
|
6a4b05a030 | ||
|
|
8be14e40de | ||
|
|
7e118cdce0 | ||
|
|
ae6703fa89 | ||
|
|
0239f4c611 | ||
|
|
c7facb0fe1 | ||
|
|
08c8012bde | ||
|
|
de35b1255c | ||
|
|
1afee5b262 | ||
|
|
11fd5e7272 | ||
|
|
c97dc3605e | ||
|
|
8a50f6f0b9 | ||
|
|
77269ad8a7 | ||
|
|
4550f0f08b | ||
|
|
befc7ef635 | ||
|
|
2c9a40849f |
@@ -622,10 +622,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
|
||||
string_process_escapes(seq_breaker);
|
||||
}
|
||||
for (auto & pair : params.speculative.draft.replacements) {
|
||||
string_process_escapes(pair.first);
|
||||
string_process_escapes(pair.second);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.kv_overrides.empty()) {
|
||||
@@ -3518,13 +3514,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.speculative.draft.p_min = std::stof(value);
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draft-ctx-size", "-cd", "--ctx-size-draft"}, "N",
|
||||
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.draft.n_ctx),
|
||||
[](common_params & params, int value) {
|
||||
params.speculative.draft.n_ctx = value;
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_CTX_SIZE"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draft-device", "-devd", "--device-draft"}, "<dev1,dev2,..>",
|
||||
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
|
||||
@@ -3560,13 +3549,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.speculative.draft.mparams.path = value;
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT",
|
||||
"translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
|
||||
[](common_params & params, const std::string & tgt, const std::string & dft) {
|
||||
params.speculative.draft.replacements.push_back({ tgt, dft });
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
|
||||
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
|
||||
|
||||
@@ -1422,7 +1422,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
|
||||
|
||||
// try to remove the last tokens
|
||||
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
|
||||
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
|
||||
LOG_WRN("%s: the context does not support partial sequence removal\n", __func__);
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
goto done;
|
||||
}
|
||||
@@ -1960,3 +1960,102 @@ bool common_prompt_batch_decode(
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t common_prompt_checkpoint::size() const {
|
||||
return data_tgt.size() + data_dft.size();
|
||||
}
|
||||
|
||||
bool common_prompt_checkpoint::empty() const {
|
||||
return data_tgt.empty();
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::clear() {
|
||||
n_tokens = 0;
|
||||
|
||||
pos_min = 0;
|
||||
pos_max = 0;
|
||||
|
||||
data_tgt.clear();
|
||||
data_dft.clear();
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::update_pos(
|
||||
int64_t n_tokens,
|
||||
llama_pos pos_min,
|
||||
llama_pos pos_max) {
|
||||
this->n_tokens = n_tokens;
|
||||
this->pos_min = pos_min;
|
||||
this->pos_max = pos_max;
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::update_tgt(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
|
||||
|
||||
data_tgt.resize(ckpt_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags);
|
||||
if (n != ckpt_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
|
||||
}
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::update_dft(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
|
||||
|
||||
data_dft.resize(ckpt_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags);
|
||||
if (n != ckpt_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
|
||||
}
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::load_tgt(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) const {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (data_tgt.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags);
|
||||
if (n != data_tgt.size()) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n);
|
||||
}
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::load_dft(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) const {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (data_dft.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags);
|
||||
if (n != data_dft.size()) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,11 +307,9 @@ struct common_params_speculative_draft {
|
||||
|
||||
common_params_model mparams;
|
||||
|
||||
llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts
|
||||
llama_context * ctx_tgt = nullptr;
|
||||
llama_context * ctx_dft = nullptr;
|
||||
|
||||
llama_context_params cparams; // these are the parameters for the draft llama_context
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
@@ -322,7 +320,6 @@ struct common_params_speculative_draft {
|
||||
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
};
|
||||
|
||||
@@ -1026,3 +1023,47 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
|
||||
|
||||
// "adamw" or "sgd" (case insensitive)
|
||||
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
|
||||
|
||||
//
|
||||
// prompt utils
|
||||
//
|
||||
|
||||
struct common_prompt_checkpoint {
|
||||
int64_t n_tokens;
|
||||
|
||||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
|
||||
std::vector<uint8_t> data_tgt;
|
||||
std::vector<uint8_t> data_dft;
|
||||
|
||||
size_t size() const;
|
||||
|
||||
bool empty() const;
|
||||
void clear();
|
||||
|
||||
void update_pos(
|
||||
int64_t n_tokens,
|
||||
llama_pos pos_min,
|
||||
llama_pos pos_max);
|
||||
|
||||
void update_tgt(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags);
|
||||
|
||||
void update_dft(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags);
|
||||
|
||||
void load_tgt(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) const;
|
||||
|
||||
void load_dft(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) const;
|
||||
};
|
||||
|
||||
@@ -147,6 +147,7 @@ struct common_speculative_state {
|
||||
virtual void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_pos n_past,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) = 0;
|
||||
|
||||
@@ -156,44 +157,25 @@ struct common_speculative_state {
|
||||
virtual int32_t n_min(const common_params_speculative & params) const = 0;
|
||||
};
|
||||
|
||||
struct common_speculative_checkpoint {
|
||||
llama_pos pos_min = 0;
|
||||
llama_pos pos_max = 0;
|
||||
|
||||
int64_t n_tokens = 0;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||
llama_context * ctx_dft;
|
||||
|
||||
bool use_ckpt = false;
|
||||
common_speculative_checkpoint ckpt;
|
||||
llama_seq_id seq_id;
|
||||
|
||||
common_sampler * smpl;
|
||||
|
||||
llama_batch batch;
|
||||
llama_tokens prompt_dft;
|
||||
|
||||
bool vocab_cmpt = true; // whether retokenization is needed
|
||||
std::unordered_map<std::string, std::string> vocab_map;
|
||||
llama_batch batch;
|
||||
|
||||
common_speculative_state_draft(
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt,
|
||||
llama_context * ctx_dft,
|
||||
const std::vector<std::pair<std::string, std::string>> & replacements,
|
||||
bool use_ckpt)
|
||||
llama_seq_id seq_id)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
, ctx_dft(ctx_dft)
|
||||
, use_ckpt(use_ckpt)
|
||||
, seq_id(seq_id)
|
||||
{
|
||||
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||
smpl = nullptr;
|
||||
@@ -225,23 +207,17 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
|
||||
vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
|
||||
LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt);
|
||||
const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
|
||||
LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt);
|
||||
|
||||
if (!vocab_cmpt) {
|
||||
LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n");
|
||||
LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__);
|
||||
|
||||
for (const auto & pair : replacements) {
|
||||
vocab_map[pair.first] = pair.second;
|
||||
}
|
||||
throw std::runtime_error("draft model vocab type must match target model to use speculation");
|
||||
}
|
||||
}
|
||||
|
||||
~common_speculative_state_draft() override {
|
||||
llama_perf_context_print(ctx_dft);
|
||||
|
||||
llama_free(ctx_dft);
|
||||
|
||||
common_sampler_free(smpl);
|
||||
|
||||
llama_batch_free(batch);
|
||||
@@ -250,220 +226,29 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
void begin(const llama_tokens & /*prompt*/) override {
|
||||
}
|
||||
|
||||
size_t create_checkpoint(int n_tokens_prompt) {
|
||||
int slot_id = 0;
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.n_tokens = n_tokens_prompt;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
}
|
||||
|
||||
LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__,
|
||||
ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024);
|
||||
return n;
|
||||
}
|
||||
|
||||
size_t restore_checkpoint() {
|
||||
int slot_id = 0;
|
||||
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != ckpt.size()) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
}
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_pos n_past,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
const auto & sparams = params.draft;
|
||||
|
||||
auto * spec = this;
|
||||
|
||||
auto & batch = spec->batch;
|
||||
auto & ctx_tgt = spec->ctx_tgt;
|
||||
auto & ctx_dft = spec->ctx_dft;
|
||||
auto & smpl = spec->smpl;
|
||||
auto & prompt_dft = spec->prompt_dft;
|
||||
auto & batch = spec->batch;
|
||||
auto & ctx_dft = spec->ctx_dft;
|
||||
auto & smpl = spec->smpl;
|
||||
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
|
||||
int reuse_i = 0; // index of part to be reused in prompt_dft
|
||||
int reuse_n = 0; // length of part to be reused in prompt_dft
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max;
|
||||
|
||||
llama_tokens prompt_cnv;
|
||||
if (!spec->vocab_cmpt) {
|
||||
std::string text;
|
||||
|
||||
text = common_detokenize(ctx_tgt, prompt_tgt, true);
|
||||
text = replace_to_dft(text);
|
||||
|
||||
LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
|
||||
|
||||
prompt_cnv = common_tokenize(ctx_dft, text, false, true);
|
||||
|
||||
// convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
|
||||
const auto * model_tgt = llama_get_model(ctx_tgt);
|
||||
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||
|
||||
int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
|
||||
GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
|
||||
|
||||
text.resize(-n_chars);
|
||||
llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
|
||||
text = replace_to_dft(text);
|
||||
|
||||
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
|
||||
id_last = common_tokenize(ctx_dft, text, false, true)[0];
|
||||
}
|
||||
|
||||
const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv;
|
||||
|
||||
const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
|
||||
|
||||
if (use_ckpt && i_start > 0) {
|
||||
LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
// reuse as much as possible from the old draft context
|
||||
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
||||
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
|
||||
int cur = 0;
|
||||
while (i_start + cur < (int) prompt_cur.size() &&
|
||||
i + cur < (int) prompt_dft.size() &&
|
||||
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
|
||||
cur++;
|
||||
}
|
||||
|
||||
if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) {
|
||||
reuse_i = i;
|
||||
reuse_n = cur;
|
||||
}
|
||||
|
||||
if (use_ckpt) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
|
||||
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
|
||||
if (use_ckpt && ckpt.n_tokens > reuse_n) {
|
||||
LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens);
|
||||
|
||||
reuse_i = 0;
|
||||
reuse_n = 0;
|
||||
|
||||
ckpt = {};
|
||||
}
|
||||
|
||||
result.clear();
|
||||
result.reserve(sparams.n_max);
|
||||
|
||||
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
|
||||
llama_memory_clear(mem_dft, false);
|
||||
prompt_dft.clear();
|
||||
} else {
|
||||
// this happens when a previous draft has been discarded (for example, due to being too small), but the
|
||||
// target model agreed with it. in this case, we simply pass back the previous results to save compute
|
||||
if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
|
||||
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
|
||||
result.push_back(prompt_dft[i]);
|
||||
|
||||
if (sparams.n_max <= (int) result.size()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (reuse_i > 0) {
|
||||
GGML_ASSERT(!use_ckpt);
|
||||
|
||||
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
|
||||
if (!is_removed) {
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
|
||||
return;
|
||||
}
|
||||
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
|
||||
|
||||
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
|
||||
}
|
||||
|
||||
if (reuse_n < (int) prompt_dft.size()) {
|
||||
if (use_ckpt) {
|
||||
if (ckpt.n_tokens > 0) {
|
||||
LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
|
||||
restore_checkpoint();
|
||||
reuse_n = ckpt.n_tokens;
|
||||
prompt_dft.resize(reuse_n);
|
||||
}
|
||||
} else {
|
||||
const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1);
|
||||
if (!is_removed) {
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size());
|
||||
return;
|
||||
}
|
||||
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prepare a batch to evaluate any new tokens in the prompt
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) {
|
||||
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]);
|
||||
common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false);
|
||||
|
||||
prompt_dft.push_back(prompt_cur[i]);
|
||||
}
|
||||
|
||||
// we should rarely end-up here during normal decoding
|
||||
if (batch.n_tokens > 0) {
|
||||
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||
LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens);
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0 && ret != 1) {
|
||||
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
|
||||
__func__, ret, prompt_cur.size());
|
||||
}
|
||||
|
||||
if (use_ckpt) {
|
||||
create_checkpoint(prompt_dft.size());
|
||||
}
|
||||
}
|
||||
|
||||
const llama_pos n_past = prompt_dft.size();
|
||||
|
||||
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
|
||||
GGML_ASSERT(n_past >= (llama_pos) prompt_tgt.size());
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add (batch, id_last, n_past, { 0 }, true);
|
||||
|
||||
prompt_dft.push_back(id_last);
|
||||
|
||||
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
|
||||
common_batch_add (batch, id_last, n_past, { seq_id }, true);
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0 && ret != 1) {
|
||||
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
|
||||
__func__, ret, prompt_cur.size(), prompt_dft.size());
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
return;
|
||||
}
|
||||
|
||||
common_sampler_reset(smpl);
|
||||
@@ -497,25 +282,13 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
break;
|
||||
}
|
||||
|
||||
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
||||
common_batch_add(batch, id, n_past + i + 1, { seq_id }, true);
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
|
||||
__func__, i, ret, prompt_cur.size(), prompt_dft.size());
|
||||
}
|
||||
|
||||
prompt_dft.push_back(id);
|
||||
}
|
||||
|
||||
if (!spec->vocab_cmpt) {
|
||||
std::string detokenized = common_detokenize(ctx_dft, result, true);
|
||||
detokenized = replace_to_tgt(detokenized);
|
||||
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
|
||||
result = common_tokenize(ctx_tgt, detokenized, false, true);
|
||||
if (result.size() > (size_t) sparams.n_max) {
|
||||
result.resize(sparams.n_max);
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -536,34 +309,6 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
int32_t n_min(const common_params_speculative & params) const override {
|
||||
return params.draft.n_min;
|
||||
}
|
||||
|
||||
std::string replace_to_dft(const std::string & input) const {
|
||||
std::string result = input;
|
||||
|
||||
for (const auto & pair : this->vocab_map) {
|
||||
size_t pos = result.find(pair.first);
|
||||
while (pos != std::string::npos) {
|
||||
result.replace(pos, pair.first.length(), pair.second);
|
||||
pos = result.find(pair.first, pos + pair.second.length());
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string replace_to_tgt(const std::string & input) const {
|
||||
std::string result = input;
|
||||
|
||||
for (const auto & pair : this->vocab_map) {
|
||||
size_t pos = result.find(pair.second);
|
||||
while (pos != std::string::npos) {
|
||||
result.replace(pos, pair.second.length(), pair.first);
|
||||
pos = result.find(pair.second, pos + pair.first.length());
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_eagle3 : public common_speculative_state {
|
||||
@@ -576,11 +321,13 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_pos n_past,
|
||||
llama_token id_last,
|
||||
llama_tokens & draft_tokens) override {
|
||||
// TODO: implement
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(prompt_tgt);
|
||||
GGML_UNUSED(n_past);
|
||||
GGML_UNUSED(id_last);
|
||||
GGML_UNUSED(draft_tokens);
|
||||
}
|
||||
@@ -615,11 +362,13 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_pos n_past,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(n_past);
|
||||
|
||||
result = common_ngram_simple_draft(config, prompt_tgt, id_last);
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
@@ -652,10 +401,13 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state {
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_pos n_past,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
common_ngram_map_draft(config, prompt_tgt, id_last, result);
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(n_past);
|
||||
|
||||
common_ngram_map_draft(config, prompt_tgt, id_last, result);
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
@@ -722,8 +474,11 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_pos n_past,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
GGML_UNUSED(n_past);
|
||||
|
||||
const auto & sparams = params.ngram_mod;
|
||||
|
||||
n_draft_last = 0;
|
||||
@@ -853,9 +608,11 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_pos n_past,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(n_past);
|
||||
|
||||
if (cache_size < prompt_tgt.size() + 1) {
|
||||
llama_tokens tokens_new;
|
||||
@@ -971,18 +728,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
||||
|
||||
// initialization of the speculative decoding system
|
||||
//
|
||||
common_speculative * common_speculative_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt) {
|
||||
llama_context * ctx_dft = nullptr;
|
||||
if (params.draft.model) {
|
||||
ctx_dft = llama_init_from_model(params.draft.model, params.draft.cparams);
|
||||
if (ctx_dft == nullptr) {
|
||||
LOG_ERR("%s", "failed to create draft context\n");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id) {
|
||||
// Compute the implementations to use based on the config and their order of preference
|
||||
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
|
||||
{
|
||||
@@ -1044,13 +790,10 @@ common_speculative * common_speculative_init(
|
||||
case COMMON_SPECULATIVE_TYPE_NONE:
|
||||
break;
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: {
|
||||
const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
|
||||
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .ctx_dft = */ ctx_dft,
|
||||
/* .replacements = */ params.draft.replacements,
|
||||
/* .use_ckpt = */ use_ckpt
|
||||
/* .ctx_tgt = */ params.draft.ctx_tgt,
|
||||
/* .ctx_dft = */ params.draft.ctx_dft,
|
||||
/* .seq_id = */ seq_id
|
||||
));
|
||||
break;
|
||||
}
|
||||
@@ -1135,6 +878,7 @@ llama_tokens common_speculative_draft(
|
||||
common_speculative * spec,
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt, // specified in target model vocab
|
||||
llama_pos n_past,
|
||||
llama_token id_last) {
|
||||
llama_tokens result;
|
||||
|
||||
@@ -1143,7 +887,7 @@ llama_tokens common_speculative_draft(
|
||||
for (auto & impl : spec->impls) {
|
||||
{
|
||||
common_time_meas tm(impl->t_draft_us, !impl->gen_perf);
|
||||
impl->draft(params, prompt_tgt, id_last, result);
|
||||
impl->draft(params, prompt_tgt, n_past, id_last, result);
|
||||
impl->n_call_draft++;
|
||||
}
|
||||
|
||||
|
||||
@@ -14,9 +14,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
||||
// convert type to string
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||
|
||||
common_speculative * common_speculative_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt);
|
||||
common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id);
|
||||
|
||||
void common_speculative_free(common_speculative * spec);
|
||||
|
||||
@@ -28,6 +26,7 @@ llama_tokens common_speculative_draft(
|
||||
common_speculative * spec,
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt,
|
||||
llama_pos n_past,
|
||||
llama_token id_last);
|
||||
|
||||
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
||||
|
||||
@@ -13,20 +13,6 @@
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
struct spec_checkpoint {
|
||||
int64_t n_tokens = 0;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return data.empty();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
@@ -43,11 +29,6 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.speculative.draft.mparams.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
@@ -62,18 +43,11 @@ int main(int argc, char ** argv) {
|
||||
model_tgt = llama_init_tgt->model();
|
||||
ctx_tgt = llama_init_tgt->context();
|
||||
|
||||
// check if the context supports partial sequence removal
|
||||
const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt);
|
||||
const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
|
||||
|
||||
if (use_ckpt) {
|
||||
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
|
||||
|
||||
// load the draft model
|
||||
llama_model_ptr model_dft;
|
||||
llama_context_ptr ctx_dft;
|
||||
|
||||
// TODO: simplify this logic
|
||||
{
|
||||
@@ -81,9 +55,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
auto params_dft = params;
|
||||
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.n_ctx = params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.mparams;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
@@ -103,8 +74,19 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.speculative.draft.model = model_dft.get();
|
||||
params.speculative.draft.cparams = common_context_params_to_llama(params_dft);
|
||||
auto cparams = common_context_params_to_llama(params_dft);
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
|
||||
params.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
}
|
||||
|
||||
// check if the context supports partial sequence removal
|
||||
const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
|
||||
const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
|
||||
|
||||
if (use_ckpt_tgt) {
|
||||
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
|
||||
}
|
||||
|
||||
// Tokenize the prompt
|
||||
@@ -136,6 +118,8 @@ int main(int argc, char ** argv) {
|
||||
// used to determine end of generation
|
||||
bool has_eos = false;
|
||||
|
||||
llama_seq_id seq_id = 0;
|
||||
|
||||
// ================================================
|
||||
// everything until here is standard initialization
|
||||
// the relevant stuff for speculative decoding starts here
|
||||
@@ -146,7 +130,8 @@ int main(int argc, char ** argv) {
|
||||
common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling));
|
||||
|
||||
// eval the prompt
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
|
||||
llama_decode(ctx_dft.get(), llama_batch_get_one(inp.data(), inp.size() - 1));
|
||||
|
||||
// note: keep the last token separate!
|
||||
llama_token id_last = inp.back();
|
||||
@@ -160,7 +145,7 @@ int main(int argc, char ** argv) {
|
||||
// init the speculator
|
||||
const auto & params_spec = params.speculative;
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative, seq_id);
|
||||
|
||||
common_speculative_begin(spec, prompt_tgt);
|
||||
|
||||
@@ -169,7 +154,7 @@ int main(int argc, char ** argv) {
|
||||
size_t n_draft = 0;
|
||||
|
||||
llama_tokens draft;
|
||||
spec_checkpoint spec_ckpt;
|
||||
common_prompt_checkpoint ckpt;
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
@@ -184,40 +169,49 @@ int main(int argc, char ** argv) {
|
||||
// from a cache or lookup tables.
|
||||
//
|
||||
if (draft.empty()) {
|
||||
ckpt.update_pos(
|
||||
prompt_tgt.size(),
|
||||
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), seq_id),
|
||||
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id));
|
||||
|
||||
if (use_ckpt_dft) {
|
||||
ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
}
|
||||
|
||||
// generate a new draft
|
||||
draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
|
||||
draft = common_speculative_draft(spec, params_spec, prompt_tgt, prompt_tgt.size(), id_last);
|
||||
|
||||
// save the original draft size
|
||||
n_draft = draft.size();
|
||||
|
||||
// save a checkpoint of the target context before evaluating the draft
|
||||
// this allows us to restore the state if partial draft acceptance occurs
|
||||
if (!draft.empty() && use_ckpt) {
|
||||
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
spec_ckpt.data.resize(ckpt_size);
|
||||
if (!draft.empty()) {
|
||||
if (use_ckpt_tgt) {
|
||||
ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
GGML_ASSERT(n == ckpt_size);
|
||||
{
|
||||
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
spec_ckpt.n_tokens = (int64_t) prompt_tgt.size();
|
||||
LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
} else {
|
||||
// we have a previous (partial) draft to reuse from checkpoint restoration
|
||||
if (use_ckpt) {
|
||||
GGML_ASSERT(!spec_ckpt.empty());
|
||||
if (use_ckpt_tgt) {
|
||||
GGML_ASSERT(!ckpt.empty());
|
||||
}
|
||||
}
|
||||
|
||||
// always have a token to evaluate from before - id_last
|
||||
common_batch_clear(batch_tgt);
|
||||
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
|
||||
common_batch_add (batch_tgt, id_last, n_past++, { seq_id }, true);
|
||||
|
||||
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
||||
{
|
||||
for (size_t i = 0; i < draft.size(); ++i) {
|
||||
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
|
||||
common_batch_add(batch_tgt, draft[i], n_past + i, { seq_id }, true);
|
||||
}
|
||||
|
||||
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
|
||||
@@ -225,9 +219,15 @@ int main(int argc, char ** argv) {
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
}
|
||||
|
||||
// evaluate the same batch with the draft model
|
||||
{
|
||||
// TODO: extend to support MTP, Eagle, etc. See server code for reference
|
||||
llama_decode(ctx_dft.get(), batch_tgt);
|
||||
}
|
||||
|
||||
// only save the sampler sampler state if we use checkpoints
|
||||
common_sampler_ptr smpl_save;
|
||||
if (use_ckpt) {
|
||||
if (use_ckpt_tgt) {
|
||||
smpl_save.reset(common_sampler_clone(smpl.get()));
|
||||
}
|
||||
|
||||
@@ -247,17 +247,24 @@ int main(int argc, char ** argv) {
|
||||
// check for partial draft acceptance:
|
||||
// if the context doesn't support partial sequence removal, restore the checkpoint
|
||||
// and make the accepted tokens the new partial draft for the next iteration
|
||||
if (use_ckpt && ids.size() - 1 < draft.size()) {
|
||||
if (use_ckpt_tgt && ids.size() - 1 < draft.size()) {
|
||||
LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size());
|
||||
|
||||
draft = std::move(ids);
|
||||
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
GGML_ASSERT(n == spec_ckpt.size());
|
||||
{
|
||||
ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
prompt_tgt.resize(spec_ckpt.n_tokens);
|
||||
{
|
||||
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
prompt_tgt.resize(ckpt.n_tokens);
|
||||
smpl = std::move(smpl_save);
|
||||
|
||||
n_past = (int) prompt_tgt.size();
|
||||
@@ -305,7 +312,8 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, n_past, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, n_past, -1);
|
||||
}
|
||||
|
||||
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
||||
|
||||
@@ -858,6 +858,8 @@ extern "C" {
|
||||
size_t n_token_capacity,
|
||||
size_t * n_token_count_out);
|
||||
|
||||
#define LLAMA_STATE_SEQ_FLAGS_NONE 0
|
||||
|
||||
// for backwards-compat
|
||||
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
|
||||
|
||||
|
||||
@@ -195,11 +195,9 @@
|
||||
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
|
||||
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
|
||||
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
|
||||
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
|
||||
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
|
||||
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
|
||||
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
|
||||
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
|
||||
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
|
||||
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
|
||||
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |
|
||||
|
||||
@@ -244,11 +244,9 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
|
||||
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
|
||||
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
|
||||
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
|
||||
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
|
||||
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
|
||||
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
|
||||
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
|
||||
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
|
||||
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
|
||||
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |
|
||||
|
||||
@@ -36,32 +36,6 @@ using json = nlohmann::ordered_json;
|
||||
|
||||
constexpr int HTTP_POLLING_SECONDS = 1;
|
||||
|
||||
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, bool on_device, llama_pos pos_min = -1, llama_pos pos_max = -1) {
|
||||
if (pos_min == -1) {
|
||||
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
|
||||
}
|
||||
if (pos_max == -1) {
|
||||
pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id);
|
||||
}
|
||||
|
||||
auto flags = LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY;
|
||||
if (on_device) {
|
||||
flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE;
|
||||
}
|
||||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, flags);
|
||||
|
||||
ckpt.pos_min = pos_min;
|
||||
ckpt.pos_max = pos_max;
|
||||
ckpt.n_tokens = n_tokens;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, flags);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
}
|
||||
}
|
||||
|
||||
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
||||
enum slot_state {
|
||||
SLOT_STATE_IDLE,
|
||||
@@ -80,9 +54,8 @@ enum server_state {
|
||||
struct server_slot {
|
||||
int id;
|
||||
|
||||
llama_context * ctx = nullptr;
|
||||
|
||||
common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
llama_context * ctx_tgt = nullptr;
|
||||
llama_context * ctx_dft = nullptr;
|
||||
|
||||
// multimodal
|
||||
mtmd_context * mctx = nullptr;
|
||||
@@ -90,7 +63,7 @@ struct server_slot {
|
||||
// speculative decoding
|
||||
llama_tokens spec_draft;
|
||||
std::vector<int32_t> spec_i_batch;
|
||||
server_prompt_checkpoint spec_ckpt;
|
||||
common_prompt_checkpoint spec_ckpt;
|
||||
common_speculative_ptr spec;
|
||||
|
||||
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
||||
@@ -135,21 +108,27 @@ struct server_slot {
|
||||
void prompt_save(server_prompt_cache & prompt_cache) const {
|
||||
GGML_ASSERT(prompt.data.size() == 0);
|
||||
|
||||
const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
|
||||
const size_t cur_size_tgt = llama_state_seq_get_size_ext(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE);
|
||||
const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0;
|
||||
|
||||
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
|
||||
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
|
||||
const size_t cur_size = cur_size_tgt + cur_size_dft;
|
||||
|
||||
auto * cur = prompt_cache.alloc(prompt, cur_size);
|
||||
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n",
|
||||
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0));
|
||||
|
||||
auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft);
|
||||
if (cur == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
|
||||
llama_state_seq_get_data_ext(ctx_tgt, cur->data.main.data(), cur_size_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE);
|
||||
if (ctx_dft) {
|
||||
llama_state_seq_get_data_ext(ctx_dft, cur->data.drft.data(), cur_size_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE);
|
||||
}
|
||||
}
|
||||
|
||||
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
||||
bool res = prompt_cache.load(prompt, tokens, ctx, id);
|
||||
bool res = prompt_cache.load(prompt, tokens, ctx_tgt, ctx_dft, id);
|
||||
if (!res) {
|
||||
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
|
||||
}
|
||||
@@ -164,7 +143,11 @@ struct server_slot {
|
||||
|
||||
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1);
|
||||
if (ctx_dft) {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1);
|
||||
}
|
||||
|
||||
prompt.tokens.clear();
|
||||
}
|
||||
|
||||
@@ -222,7 +205,7 @@ struct server_slot {
|
||||
task_prev = std::move(task);
|
||||
task.reset();
|
||||
|
||||
llama_set_sampler(ctx, id, nullptr);
|
||||
llama_set_sampler(ctx_tgt, id, nullptr);
|
||||
|
||||
// clear alora start
|
||||
alora_invocation_start = -1;
|
||||
@@ -259,7 +242,7 @@ struct server_slot {
|
||||
|
||||
return
|
||||
!task->need_embd() ||
|
||||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
|
||||
(llama_get_memory(ctx_tgt) && llama_pooling_type(ctx_tgt) == LLAMA_POOLING_TYPE_LAST);
|
||||
}
|
||||
|
||||
bool can_batch_with(server_slot & other_slot) const {
|
||||
@@ -333,7 +316,7 @@ struct server_slot {
|
||||
return n_draft_max;
|
||||
}
|
||||
|
||||
void update_batch(llama_batch & batch) {
|
||||
void update_batch(llama_batch & batch, bool use_ckpt_tgt, bool use_ckpt_dft) {
|
||||
const int n_draft_max = get_n_draft_max();
|
||||
if (n_draft_max > 0) {
|
||||
GGML_ASSERT(can_speculate());
|
||||
@@ -341,20 +324,29 @@ struct server_slot {
|
||||
// generate draft tokens in speculative decoding mode
|
||||
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
||||
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||
const llama_tokens & tokens = prompt.tokens.get_text_tokens();
|
||||
const llama_tokens & tokens_text = prompt.tokens.get_text_tokens();
|
||||
|
||||
const auto & params_spec = task->params.speculative;
|
||||
|
||||
if (!spec_draft.empty()) {
|
||||
// we have a previous (partial) draft to reuse
|
||||
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
if (use_ckpt_tgt) {
|
||||
GGML_ASSERT(!spec_ckpt.empty());
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(spec_i_batch.empty());
|
||||
|
||||
spec_ckpt.update_pos(
|
||||
prompt.n_tokens(),
|
||||
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), id),
|
||||
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), id));
|
||||
|
||||
if (use_ckpt_dft) {
|
||||
spec_ckpt.update_dft(ctx_dft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
}
|
||||
|
||||
// generate a new draft
|
||||
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
|
||||
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens_text, prompt.n_tokens(), sampled);
|
||||
n_draft_total += spec_draft.size();
|
||||
|
||||
if (spec_draft.size() > (size_t) n_draft_max) {
|
||||
@@ -362,18 +354,25 @@ struct server_slot {
|
||||
spec_draft.resize(n_draft_max);
|
||||
}
|
||||
|
||||
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
const auto n_tokens = prompt.tokens.size();
|
||||
if (!spec_draft.empty()) {
|
||||
if (use_ckpt_tgt) {
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
spec_ckpt.update_tgt(ctx_tgt, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens, true);
|
||||
//const int64_t t_total = ggml_time_us() - t_start;
|
||||
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
|
||||
|
||||
//const int64_t t_total = ggml_time_us() - t_start;
|
||||
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
|
||||
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n",
|
||||
spec_ckpt.pos_min, spec_ckpt.pos_max, prompt.n_tokens(), (float) spec_ckpt.size() / 1024 / 1024, (float) spec_ckpt.data_dft.size() / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
|
||||
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
|
||||
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
|
||||
if (ctx_dft) {
|
||||
spec_ckpt.load_dft(ctx_dft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), this->id, spec_ckpt.pos_max + 1, -1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -539,7 +538,7 @@ struct server_slot {
|
||||
};
|
||||
|
||||
if (!only_metrics) {
|
||||
res["prompt"] = ptask->tokens.detokenize(ctx, true);
|
||||
res["prompt"] = ptask->tokens.detokenize(ctx_tgt, true);
|
||||
res["generated"] = generated_text.empty() ? debug_generated_text : generated_text;
|
||||
}
|
||||
}
|
||||
@@ -550,8 +549,13 @@ struct server_slot {
|
||||
void copy_state_to(server_slot & other) const {
|
||||
GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), other.id, -1, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1);
|
||||
|
||||
if (ctx_dft) {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1);
|
||||
}
|
||||
|
||||
other.n_decoded = n_decoded;
|
||||
other.n_remaining = n_remaining;
|
||||
@@ -642,7 +646,8 @@ public:
|
||||
// only use these pointers outside of this class:
|
||||
// - when not in sleeping state
|
||||
// - and, with thread-safe APIs (e.g., tokenizer calls)
|
||||
llama_model * model = nullptr;
|
||||
llama_model * model_tgt = nullptr;
|
||||
|
||||
mtmd_context * mctx = nullptr;
|
||||
const llama_vocab * vocab = nullptr;
|
||||
|
||||
@@ -669,11 +674,15 @@ private:
|
||||
// note: keep these alive - they determine the lifetime of the model, context, etc.
|
||||
common_init_result_ptr llama_init;
|
||||
|
||||
llama_context * ctx = nullptr;
|
||||
llama_context * ctx_tgt = nullptr;
|
||||
|
||||
llama_batch batch {};
|
||||
|
||||
llama_model_ptr model_dft;
|
||||
llama_context_ptr ctx_dft;
|
||||
|
||||
common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
|
||||
bool add_bos_token = true;
|
||||
|
||||
@@ -708,8 +717,8 @@ private:
|
||||
void destroy() {
|
||||
llama_init.reset();
|
||||
|
||||
ctx = nullptr;
|
||||
model = nullptr;
|
||||
ctx_tgt = nullptr;
|
||||
model_tgt = nullptr;
|
||||
|
||||
mtmd_free(mctx);
|
||||
mctx = nullptr;
|
||||
@@ -759,17 +768,17 @@ private:
|
||||
|
||||
llama_init = common_init_from_params(params_base);
|
||||
|
||||
model = llama_init->model();
|
||||
ctx = llama_init->context();
|
||||
model_tgt = llama_init->model();
|
||||
ctx_tgt = llama_init->context();
|
||||
|
||||
if (model == nullptr) {
|
||||
if (model_tgt == nullptr) {
|
||||
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
vocab = llama_model_get_vocab(model);
|
||||
vocab = llama_model_get_vocab(model_tgt);
|
||||
|
||||
n_ctx = llama_n_ctx(ctx);
|
||||
n_ctx = llama_n_ctx(ctx_tgt);
|
||||
|
||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||
|
||||
@@ -781,9 +790,6 @@ private:
|
||||
|
||||
auto params_dft = params_base;
|
||||
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx);
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.mparams;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
@@ -805,8 +811,13 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
params_base.speculative.draft.model = model_dft.get();
|
||||
params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft);
|
||||
auto cparams = common_context_params_to_llama(params_dft);
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
|
||||
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
}
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
@@ -826,7 +837,7 @@ private:
|
||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||
mparams.media_marker = get_media_marker();
|
||||
|
||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams);
|
||||
if (mctx == nullptr) {
|
||||
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
||||
return false;
|
||||
@@ -844,7 +855,7 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
if (!llama_memory_can_shift(llama_get_memory(ctx))) {
|
||||
if (!llama_memory_can_shift(llama_get_memory(ctx_tgt))) {
|
||||
if (params_base.ctx_shift) {
|
||||
params_base.ctx_shift = false;
|
||||
SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled");
|
||||
@@ -856,14 +867,14 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
if (llama_model_n_swa(model) == 0) {
|
||||
if (llama_model_n_swa(model_tgt) == 0) {
|
||||
if (params_base.swa_full) {
|
||||
params_base.swa_full = false;
|
||||
SRV_WRN("%s\n", "swa_full is not supported by this model, it will be disabled");
|
||||
}
|
||||
}
|
||||
|
||||
n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model);
|
||||
n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_tgt);
|
||||
|
||||
// Necessary similarity of prompt for slot selection
|
||||
slot_prompt_similarity = params_base.slot_prompt_similarity;
|
||||
@@ -871,9 +882,9 @@ private:
|
||||
// setup slots
|
||||
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
||||
|
||||
const int n_ctx_train = llama_model_n_ctx_train(model);
|
||||
const int n_ctx_train = llama_model_n_ctx_train(model_tgt);
|
||||
|
||||
int n_ctx_slot = llama_n_ctx_seq(ctx);
|
||||
int n_ctx_slot = llama_n_ctx_seq(ctx_tgt);
|
||||
if (n_ctx_slot > n_ctx_train) {
|
||||
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
|
||||
n_ctx_slot = n_ctx_train;
|
||||
@@ -881,12 +892,12 @@ private:
|
||||
|
||||
slots.clear();
|
||||
|
||||
const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx);
|
||||
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
ctx_tgt_seq_rm_type = common_context_can_seq_rm(ctx_tgt);
|
||||
if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
||||
}
|
||||
|
||||
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
|
||||
}
|
||||
|
||||
@@ -895,21 +906,28 @@ private:
|
||||
slots.emplace_back();
|
||||
}
|
||||
|
||||
bool no_dft = false;
|
||||
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot & slot = slots[i];
|
||||
|
||||
slot.id = i;
|
||||
slot.ctx = ctx;
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
|
||||
slot.ctx_seq_rm_type = ctx_seq_rm_type;
|
||||
slot.id = i;
|
||||
slot.ctx_tgt = ctx_tgt;
|
||||
slot.ctx_dft = ctx_dft.get();
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
|
||||
slot.mctx = mctx;
|
||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||
|
||||
// try speculative decoding
|
||||
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
|
||||
if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
try {
|
||||
slot.spec.reset(common_speculative_init(params_base.speculative, slot.id));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
|
||||
|
||||
no_dft = true;
|
||||
}
|
||||
|
||||
if (slot.spec) {
|
||||
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
||||
@@ -925,6 +943,16 @@ private:
|
||||
slot.reset();
|
||||
}
|
||||
|
||||
if (no_dft && ctx_dft) {
|
||||
SRV_WRN("%s", "destroying the draft model as it is not going to be used\n");
|
||||
|
||||
ctx_dft.reset();
|
||||
|
||||
for (auto & slot : slots) {
|
||||
slot.ctx_dft = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const char * LLAMA_TRACE = getenv("LLAMA_TRACE");
|
||||
trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0;
|
||||
@@ -946,7 +974,7 @@ private:
|
||||
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
|
||||
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
||||
{
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
const int32_t n_batch = llama_n_batch(ctx_tgt);
|
||||
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
||||
}
|
||||
|
||||
@@ -990,8 +1018,9 @@ private:
|
||||
|
||||
// unlike load_model(), this is only called once during initialization
|
||||
bool init() {
|
||||
GGML_ASSERT(ctx != nullptr);
|
||||
GGML_ASSERT(model != nullptr);
|
||||
GGML_ASSERT(ctx_tgt != nullptr);
|
||||
GGML_ASSERT(model_tgt != nullptr);
|
||||
|
||||
GGML_ASSERT(!sleeping);
|
||||
|
||||
// wiring up server queues
|
||||
@@ -1037,7 +1066,7 @@ private:
|
||||
common_chat_templates_ptr chat_templates;
|
||||
|
||||
try {
|
||||
chat_templates = common_chat_templates_init(model, params_base.chat_template);
|
||||
chat_templates = common_chat_templates_init(model_tgt, params_base.chat_template);
|
||||
|
||||
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
|
||||
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
|
||||
@@ -1300,7 +1329,7 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
if (!task.tokens.validate(ctx)) {
|
||||
if (!task.tokens.validate(ctx_tgt)) {
|
||||
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
@@ -1310,7 +1339,7 @@ private:
|
||||
// initialize samplers
|
||||
if (task.need_sampling()) {
|
||||
try {
|
||||
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
|
||||
slot.smpl.reset(common_sampler_init(model_tgt, task.params.sampling));
|
||||
} catch (std::exception & e) {
|
||||
std::string err_msg = std::string("Failed to initialize samplers: ") + e.what();
|
||||
send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST);
|
||||
@@ -1331,9 +1360,9 @@ private:
|
||||
|
||||
// TODO: tmp until backend sampling is fully implemented
|
||||
if (backend_sampling) {
|
||||
llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get()));
|
||||
llama_set_sampler(ctx_tgt, slot.id, common_sampler_get(slot.smpl.get()));
|
||||
} else {
|
||||
llama_set_sampler(ctx, slot.id, nullptr);
|
||||
llama_set_sampler(ctx_tgt, slot.id, nullptr);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
||||
@@ -1506,13 +1535,13 @@ private:
|
||||
for (size_t i = 0; i < n_probs; i++) {
|
||||
result.probs.push_back({
|
||||
cur_p->data[i].id,
|
||||
common_token_to_piece(ctx, cur_p->data[i].id, special),
|
||||
common_token_to_piece(ctx_tgt, cur_p->data[i].id, special),
|
||||
cur_p->data[i].p
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// TODO: optimize this with min-p optimization
|
||||
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
||||
std::vector<llama_token_data> cur = get_token_probabilities(ctx_tgt, idx);
|
||||
const size_t max_probs = cur.size();
|
||||
const size_t n_probs = std::min(max_probs, n_probs_request);
|
||||
|
||||
@@ -1530,7 +1559,7 @@ private:
|
||||
for (size_t i = 0; i < n_probs; i++) {
|
||||
result.probs.push_back({
|
||||
cur[i].id,
|
||||
common_token_to_piece(ctx, cur[i].id, special),
|
||||
common_token_to_piece(ctx_tgt, cur[i].id, special),
|
||||
cur[i].p
|
||||
});
|
||||
}
|
||||
@@ -1633,7 +1662,7 @@ private:
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
}
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = slot.task->tokens.detokenize(ctx, true);
|
||||
res->prompt = slot.task->tokens.detokenize(ctx_tgt, true);
|
||||
res->response_fields = std::move(slot.task->params.response_fields);
|
||||
|
||||
res->truncated = slot.truncated;
|
||||
@@ -1656,7 +1685,7 @@ private:
|
||||
// populate res.probs_output
|
||||
if (slot.task->params.sampling.n_probs > 0) {
|
||||
if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) {
|
||||
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
||||
const llama_tokens stop_word_toks = common_tokenize(ctx_tgt, slot.stopping_word, false);
|
||||
|
||||
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
||||
res->probs_output = std::vector<completion_token_output>(
|
||||
@@ -1681,7 +1710,7 @@ private:
|
||||
res->n_tokens = slot.task->n_tokens();
|
||||
res->res_type = slot.task->params.res_type;
|
||||
|
||||
const int n_embd_out = llama_model_n_embd_out(model);
|
||||
const int n_embd_out = llama_model_n_embd_out(model_tgt);
|
||||
|
||||
std::vector<float> embd_res(n_embd_out, 0.0f);
|
||||
|
||||
@@ -1691,10 +1720,10 @@ private:
|
||||
}
|
||||
|
||||
const float * embd = nullptr;
|
||||
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
if (llama_pooling_type(slot.ctx_tgt) == LLAMA_POOLING_TYPE_NONE) {
|
||||
embd = llama_get_embeddings_ith(slot.ctx_tgt, i);
|
||||
} else {
|
||||
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
embd = llama_get_embeddings_seq(slot.ctx_tgt, batch.seq_id[i][0]);
|
||||
}
|
||||
|
||||
if (embd == nullptr) {
|
||||
@@ -1705,7 +1734,7 @@ private:
|
||||
}
|
||||
|
||||
// normalize only when there is pooling
|
||||
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||
if (llama_pooling_type(slot.ctx_tgt) != LLAMA_POOLING_TYPE_NONE) {
|
||||
common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize);
|
||||
res->embedding.push_back(embd_res);
|
||||
break;
|
||||
@@ -1730,9 +1759,9 @@ private:
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
const float * embd = llama_get_embeddings_seq(ctx_tgt, batch.seq_id[i][0]);
|
||||
if (embd == NULL) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
embd = llama_get_embeddings_ith(ctx_tgt, i);
|
||||
}
|
||||
|
||||
if (embd == NULL) {
|
||||
@@ -1837,18 +1866,22 @@ private:
|
||||
const auto & cur = slot.prompt.checkpoints.front();
|
||||
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
|
||||
|
||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back();
|
||||
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max);
|
||||
|
||||
cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
|
||||
|
||||
cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot,
|
||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
|
||||
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
void process_single_task(server_task && task) {
|
||||
@@ -2003,7 +2036,7 @@ private:
|
||||
std::string filepath = task.slot_action.filepath;
|
||||
|
||||
const llama_tokens & tokens = slot->prompt.tokens.get_tokens();
|
||||
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
|
||||
const size_t nwrite = llama_state_seq_save_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), token_count);
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
const double t_save_ms = (t_end - t_start) / 1000.0;
|
||||
@@ -2042,7 +2075,7 @@ private:
|
||||
llama_tokens tokens;
|
||||
tokens.resize(slot->n_ctx);
|
||||
size_t token_count = 0;
|
||||
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
|
||||
size_t nread = llama_state_seq_load_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
|
||||
if (nread == 0) {
|
||||
slot->prompt.tokens.clear(); // KV may already been invalidated?
|
||||
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
||||
@@ -2201,8 +2234,13 @@ private:
|
||||
|
||||
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
||||
|
||||
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
|
||||
|
||||
if (ctx_dft) {
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
|
||||
}
|
||||
|
||||
// add generated tokens to cache
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
|
||||
@@ -2248,12 +2286,14 @@ private:
|
||||
continue;
|
||||
}
|
||||
|
||||
slot.update_batch(batch);
|
||||
slot.update_batch(batch,
|
||||
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL,
|
||||
ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
|
||||
}
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
int32_t n_batch = llama_n_batch(ctx_tgt);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx_tgt);
|
||||
|
||||
float alora_scale = -1.0f;
|
||||
size_t alora_disabled_id = 0;
|
||||
@@ -2297,12 +2337,12 @@ private:
|
||||
/*if (1) {
|
||||
// first 16 tokens (avoid flooding logs)
|
||||
for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str());
|
||||
}
|
||||
} else {
|
||||
// all
|
||||
for (int i = 0; i < (int) input_tokens.size(); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str());
|
||||
}
|
||||
}*/
|
||||
|
||||
@@ -2321,7 +2361,7 @@ private:
|
||||
}
|
||||
|
||||
// TODO: support memory-less logits computation
|
||||
if (slot.task->need_logits() && !llama_get_memory(ctx)) {
|
||||
if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) {
|
||||
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
@@ -2373,7 +2413,7 @@ private:
|
||||
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
|
||||
|
||||
const bool can_cache_reuse =
|
||||
llama_memory_can_shift(llama_get_memory(ctx)) &&
|
||||
llama_memory_can_shift(llama_get_memory(ctx_tgt)) &&
|
||||
!slot.prompt.tokens.has_mtmd;
|
||||
|
||||
if (!can_cache_reuse && n_cache_reuse > 0) {
|
||||
@@ -2407,13 +2447,18 @@ private:
|
||||
if (n_match >= (size_t) n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_tgt, prompt_tokens[i]).c_str());
|
||||
//}
|
||||
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
|
||||
llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
|
||||
if (ctx_dft) {
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
|
||||
@@ -2440,7 +2485,7 @@ private:
|
||||
const auto pos_min_thold = std::max(0, pos_next - n_swa);
|
||||
|
||||
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||
@@ -2469,14 +2514,14 @@ private:
|
||||
|
||||
{
|
||||
const auto token = slot.prompt.tokens[i];
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]";
|
||||
ss0 << piece;
|
||||
st0 << std::setw(8) << token;
|
||||
}
|
||||
|
||||
{
|
||||
const auto token = slot.task->tokens[i];
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]";
|
||||
ss1 << piece;
|
||||
st1 << std::setw(8) << token;
|
||||
}
|
||||
@@ -2508,18 +2553,13 @@ private:
|
||||
|
||||
if (!do_reset) {
|
||||
// restore the context checkpoint
|
||||
const size_t checkpoint_size = it->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
if (n != checkpoint_size) {
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
|
||||
do_reset = true;
|
||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||
} else {
|
||||
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
|
||||
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) checkpoint_size / 1024 / 1024);
|
||||
}
|
||||
it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
|
||||
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
if (do_reset) {
|
||||
@@ -2536,7 +2576,7 @@ private:
|
||||
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
|
||||
const auto & cur = *it;
|
||||
if (cur.pos_max > pos_next) {
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.data.size() / 1024 / 1024);
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024);
|
||||
it = slot.prompt.checkpoints.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
@@ -2576,14 +2616,18 @@ private:
|
||||
|
||||
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
|
||||
|
||||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||
if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) {
|
||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||
|
||||
slot.prompt_clear(true);
|
||||
|
||||
// there is no common part left
|
||||
slot.n_prompt_tokens_cache = 0;
|
||||
}
|
||||
} else {
|
||||
if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) {
|
||||
GGML_ABORT("failed to truncate draft context\n");
|
||||
}
|
||||
}
|
||||
|
||||
// If using an alora, there may be uncached tokens that come
|
||||
// before the invocation sequence. When this happens, the
|
||||
@@ -2609,7 +2653,7 @@ private:
|
||||
// - the model does not support partial sequence removal
|
||||
// - the model uses SWA (and we are not using `swa_full`)
|
||||
do_checkpoint = do_checkpoint && (
|
||||
(slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
|
||||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
|
||||
(n_swa > 0));
|
||||
|
||||
bool has_mtmd = false;
|
||||
@@ -2618,7 +2662,7 @@ private:
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
||||
// process the image
|
||||
size_t n_tokens_out = 0;
|
||||
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
|
||||
int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
|
||||
if (res != 0) {
|
||||
SLT_ERR(slot, "failed to process image, res = %d\n", res);
|
||||
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
|
||||
@@ -2626,6 +2670,15 @@ private:
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ctx_dft) {
|
||||
// TODO: in the future, figure out how to infuse target embeddings to the images
|
||||
// for now, we skip this for simplicity
|
||||
res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
|
||||
if (res != 0) {
|
||||
GGML_ABORT("failed to process multi-modal data on draft context\n");
|
||||
}
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed += n_tokens_out;
|
||||
|
||||
// add the image chunk to cache
|
||||
@@ -2727,8 +2780,8 @@ private:
|
||||
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
|
||||
}
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64);
|
||||
@@ -2761,7 +2814,7 @@ private:
|
||||
|
||||
if (slot_batched) {
|
||||
// apply lora, only need to do it once per batch
|
||||
common_set_adapter_lora(ctx, slot_batched->lora);
|
||||
common_set_adapter_lora(ctx_tgt, slot_batched->lora);
|
||||
|
||||
// if the lora is temporarily disabled for an alora, re-enable it
|
||||
// for next time
|
||||
@@ -2770,7 +2823,7 @@ private:
|
||||
slot_batched->lora[alora_disabled_id].scale = alora_scale;
|
||||
}
|
||||
|
||||
llama_set_embeddings(ctx, slot_batched->task->need_embd());
|
||||
llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd());
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
@@ -2799,7 +2852,7 @@ private:
|
||||
batch.logits + i,
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
const int ret = llama_decode(ctx_tgt, batch_view);
|
||||
|
||||
metrics.on_decoded(slots);
|
||||
|
||||
@@ -2852,11 +2905,54 @@ private:
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
|
||||
// for now, always re-evaluate for simplicity
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384
|
||||
//
|
||||
// | spec type | need re-eval |
|
||||
// | --- | --- |
|
||||
// | draft model | no | because the draft model does not use embeddings from the target
|
||||
// | MTP (std) | yes |
|
||||
// | MTP Gemma4 | no | because the KV cache is shared
|
||||
// | Eagle3 | yes |
|
||||
// | DFlash | yes? |
|
||||
//
|
||||
if (ctx_dft) {
|
||||
// TODO: update as needed for MTP, Eagle3, etc.
|
||||
const bool need_tgt_embd = false;
|
||||
|
||||
if (need_tgt_embd) {
|
||||
llama_synchronize(ctx_tgt);
|
||||
}
|
||||
|
||||
// the logic here varies depending on the speculative decoding method
|
||||
// - some draft contexts require embeddings from the target context, others don't
|
||||
// - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings
|
||||
// TODO: extract this in a function ?
|
||||
{
|
||||
// TODO: hook the embeddings from the last target batch here
|
||||
if (llama_model_has_encoder(model_dft.get())) {
|
||||
//llama_encode(ctx_dft, ...);
|
||||
|
||||
GGML_ABORT("not implemented yet\n");
|
||||
}
|
||||
|
||||
const int ret = llama_decode(ctx_dft.get(), batch_view);
|
||||
|
||||
if (ret != 0) {
|
||||
SRV_ERR("failed to decode draft batch, ret = %d\n", ret);
|
||||
|
||||
// TODO: handle error
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// move the head of the batch forward with the number of tokens we just processed
|
||||
i_next = i + n_tokens;
|
||||
|
||||
// on successful decode, restore the original batch size
|
||||
n_batch = llama_n_batch(ctx);
|
||||
n_batch = llama_n_batch(ctx_tgt);
|
||||
|
||||
// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
|
||||
for (auto & slot : slots) {
|
||||
@@ -2927,7 +3023,7 @@ private:
|
||||
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx);
|
||||
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
@@ -2948,7 +3044,7 @@ private:
|
||||
|
||||
completion_token_output result;
|
||||
result.tok = id;
|
||||
result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||
|
||||
if (slot.task->params.sampling.n_probs > 0) {
|
||||
@@ -2979,23 +3075,23 @@ private:
|
||||
|
||||
// verify and try to accept the draft
|
||||
{
|
||||
const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
|
||||
// only save the sampler sampler state if we use checkpoints
|
||||
common_sampler_ptr smpl_save;
|
||||
if (use_ckpt) {
|
||||
if (use_ckpt_tgt) {
|
||||
smpl_save.reset(common_sampler_clone(slot.smpl.get()));
|
||||
}
|
||||
|
||||
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
|
||||
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft);
|
||||
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft);
|
||||
slot.spec_i_batch.clear();
|
||||
|
||||
GGML_ASSERT(accepted.size() >= 1);
|
||||
|
||||
// check for partial draft acceptance
|
||||
if (accepted.size() < slot.spec_draft.size() + 1) {
|
||||
if (use_ckpt) {
|
||||
if (use_ckpt_tgt) {
|
||||
if (trace > 0) {
|
||||
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
|
||||
}
|
||||
@@ -3005,16 +3101,19 @@ private:
|
||||
|
||||
const auto & ckpt = slot.spec_ckpt;
|
||||
|
||||
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n",
|
||||
ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
|
||||
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != ckpt.size()) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
|
||||
{
|
||||
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1);
|
||||
if (slot.ctx_dft) {
|
||||
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
slot.prompt.tokens.keep_first(ckpt.n_tokens);
|
||||
slot.smpl = std::move(smpl_save);
|
||||
@@ -3049,13 +3148,16 @@ private:
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
if (slot.ctx_dft) {
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
completion_token_output result;
|
||||
|
||||
result.tok = ids[i];
|
||||
result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // set later
|
||||
|
||||
// TODO: set result.probs
|
||||
@@ -3107,7 +3209,7 @@ void server_context::terminate() {
|
||||
}
|
||||
|
||||
llama_context * server_context::get_llama_context() const {
|
||||
return impl->ctx;
|
||||
return impl->ctx_tgt;
|
||||
}
|
||||
|
||||
server_response_reader server_context::get_response_reader() {
|
||||
@@ -3117,8 +3219,8 @@ server_response_reader server_context::get_response_reader() {
|
||||
server_context_meta server_context::get_meta() const {
|
||||
auto bos_id = llama_vocab_bos(impl->vocab);
|
||||
auto eos_id = llama_vocab_eos(impl->vocab);
|
||||
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : "";
|
||||
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, eos_id, true) : "";
|
||||
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : "";
|
||||
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : "";
|
||||
|
||||
return server_context_meta {
|
||||
/* build_info */ std::string(llama_build_info()),
|
||||
@@ -3131,7 +3233,7 @@ server_context_meta server_context::get_meta() const {
|
||||
/* has_inp_audio */ impl->chat_params.allow_audio,
|
||||
/* json_webui_settings */ impl->json_webui_settings,
|
||||
/* slot_n_ctx */ impl->get_slot_n_ctx(),
|
||||
/* pooling_type */ llama_pooling_type(impl->ctx),
|
||||
/* pooling_type */ llama_pooling_type(impl->ctx_tgt),
|
||||
|
||||
/* chat_params */ impl->chat_params,
|
||||
/* chat_template_caps */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()),
|
||||
@@ -3149,10 +3251,10 @@ server_context_meta server_context::get_meta() const {
|
||||
|
||||
/* model_vocab_type */ llama_vocab_type(impl->vocab),
|
||||
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
|
||||
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model),
|
||||
/* model_n_embd_inp */ llama_model_n_embd(impl->model),
|
||||
/* model_n_params */ llama_model_n_params(impl->model),
|
||||
/* model_size */ llama_model_size(impl->model),
|
||||
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_tgt),
|
||||
/* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt),
|
||||
/* model_n_params */ llama_model_n_params(impl->model_tgt),
|
||||
/* model_size */ llama_model_size(impl->model_tgt),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -4054,7 +4156,7 @@ void server_routes::init_routes() {
|
||||
std::vector<server_task> tasks;
|
||||
tasks.reserve(documents.size());
|
||||
for (size_t i = 0; i < documents.size(); i++) {
|
||||
auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
|
||||
auto tmp = format_prompt_rerank(ctx_server.model_tgt, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
|
||||
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||
task.id = rd.get_new_id();
|
||||
task.tokens = std::move(tmp);
|
||||
|
||||
@@ -1981,7 +1981,7 @@ size_t server_prompt_cache::n_tokens() const {
|
||||
return res;
|
||||
}
|
||||
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) {
|
||||
// first check if the current state is contained fully in the cache
|
||||
for (auto it = states.begin(); it != states.end(); ++it) {
|
||||
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
|
||||
@@ -2005,11 +2005,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> state_data;
|
||||
std::vector<uint8_t> state_data_tgt;
|
||||
std::vector<uint8_t> state_data_dft;
|
||||
|
||||
// check if we can allocate enough memory for the new state
|
||||
try {
|
||||
state_data.resize(state_size);
|
||||
state_data_tgt.resize(state_size_tgt);
|
||||
state_data_dft.resize(state_size_dft);
|
||||
} catch (const std::bad_alloc & e) {
|
||||
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
|
||||
|
||||
@@ -2022,17 +2024,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto & cur = states.emplace_back();
|
||||
cur = {
|
||||
states.push_back({
|
||||
/*.tokens =*/ prompt.tokens.clone(),
|
||||
/*.data =*/ std::move(state_data),
|
||||
/*.data =*/ {
|
||||
/*.main =*/ std::move(state_data_tgt),
|
||||
/*.drft =*/ std::move(state_data_dft),
|
||||
},
|
||||
/*.checkpoints =*/ prompt.checkpoints,
|
||||
};
|
||||
});
|
||||
|
||||
return &cur;
|
||||
return &states.back();
|
||||
}
|
||||
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) {
|
||||
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
|
||||
|
||||
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
|
||||
@@ -2065,16 +2069,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
|
||||
if (it_best != states.end()) {
|
||||
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
|
||||
const size_t size = it_best->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
{
|
||||
auto & data = it_best->data.main;
|
||||
|
||||
return false;
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
|
||||
it_best->data.clear();
|
||||
it_best->data.shrink_to_fit();
|
||||
{
|
||||
auto & data = it_best->data.drft;
|
||||
|
||||
if (!data.empty()) {
|
||||
GGML_ASSERT(ctx_dft);
|
||||
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
prompt = std::move(*it_best);
|
||||
|
||||
|
||||
@@ -565,42 +565,29 @@ struct server_task_result_apply_lora : server_task_result {
|
||||
virtual json to_json() override;
|
||||
};
|
||||
|
||||
struct server_prompt_checkpoint {
|
||||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
|
||||
int64_t n_tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
struct server_prompt_data {
|
||||
std::vector<uint8_t> main;
|
||||
std::vector<uint8_t> drft;
|
||||
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return data.empty();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
pos_min = 0;
|
||||
pos_max = 0;
|
||||
n_tokens = 0;
|
||||
data.clear();
|
||||
return main.size() + drft.size();
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt {
|
||||
server_tokens tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
server_prompt_data data;
|
||||
|
||||
std::list<server_prompt_checkpoint> checkpoints;
|
||||
std::list<common_prompt_checkpoint> checkpoints;
|
||||
|
||||
size_t size() const {
|
||||
size_t res = data.size();
|
||||
size_t res = 0;
|
||||
|
||||
for (const auto & checkpoint : checkpoints) {
|
||||
res += checkpoint.size();
|
||||
res += data.size();
|
||||
|
||||
for (const auto & ckpt : checkpoints) {
|
||||
res += ckpt.size();
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -614,7 +601,7 @@ struct server_prompt {
|
||||
return server_prompt {
|
||||
tokens.clone(),
|
||||
data,
|
||||
checkpoints
|
||||
checkpoints,
|
||||
};
|
||||
}
|
||||
};
|
||||
@@ -637,9 +624,9 @@ struct server_prompt_cache {
|
||||
|
||||
size_t n_tokens() const;
|
||||
|
||||
server_prompt * alloc(const server_prompt & prompt, size_t state_size);
|
||||
server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft);
|
||||
|
||||
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
|
||||
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot);
|
||||
|
||||
void update();
|
||||
};
|
||||
|
||||
@@ -5,7 +5,7 @@ from utils import *
|
||||
|
||||
server = ServerPreset.stories15m_moe()
|
||||
|
||||
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
|
||||
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf"
|
||||
|
||||
def create_server():
|
||||
global server
|
||||
|
||||
Reference in New Issue
Block a user