naming : improve consistency

This commit is contained in:
Georgi Gerganov
2026-05-08 12:24:57 +03:00
parent 778f9e247e
commit efa2f8e5a7
5 changed files with 201 additions and 201 deletions

View File

@@ -1962,11 +1962,11 @@ bool common_prompt_batch_decode(
}
size_t common_prompt_checkpoint::size() const {
return data_main.size() + data_drft.size();
return data_tgt.size() + data_dft.size();
}
bool common_prompt_checkpoint::empty() const {
return data_main.empty();
return data_tgt.empty();
}
void common_prompt_checkpoint::clear() {
@@ -1975,8 +1975,8 @@ void common_prompt_checkpoint::clear() {
pos_min = 0;
pos_max = 0;
data_main.clear();
data_drft.clear();
data_tgt.clear();
data_dft.clear();
}
void common_prompt_checkpoint::update_pos(
@@ -1988,7 +1988,7 @@ void common_prompt_checkpoint::update_pos(
this->pos_max = pos_max;
}
void common_prompt_checkpoint::update_main(
void common_prompt_checkpoint::update_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) {
@@ -1998,15 +1998,15 @@ void common_prompt_checkpoint::update_main(
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
data_main.resize(ckpt_size);
data_tgt.resize(ckpt_size);
const size_t n = llama_state_seq_get_data_ext(ctx, data_main.data(), ckpt_size, seq_id, flags);
const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}
void common_prompt_checkpoint::update_drft(
void common_prompt_checkpoint::update_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) {
@@ -2016,15 +2016,15 @@ void common_prompt_checkpoint::update_drft(
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags);
data_drft.resize(ckpt_size);
data_dft.resize(ckpt_size);
const size_t n = llama_state_seq_get_data_ext(ctx, data_drft.data(), ckpt_size, seq_id, flags);
const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags);
if (n != ckpt_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n);
}
}
void common_prompt_checkpoint::load_main(
void common_prompt_checkpoint::load_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const {
@@ -2032,17 +2032,17 @@ void common_prompt_checkpoint::load_main(
return;
}
if (data_main.empty()) {
if (data_tgt.empty()) {
return;
}
const size_t n = llama_state_seq_set_data_ext(ctx, data_main.data(), data_main.size(), seq_id, flags);
if (n != data_main.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_main.size(), n);
const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags);
if (n != data_tgt.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n);
}
}
void common_prompt_checkpoint::load_drft(
void common_prompt_checkpoint::load_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const {
@@ -2050,12 +2050,12 @@ void common_prompt_checkpoint::load_drft(
return;
}
if (data_drft.empty()) {
if (data_dft.empty()) {
return;
}
const size_t n = llama_state_seq_set_data_ext(ctx, data_drft.data(), data_drft.size(), seq_id, flags);
if (n != data_drft.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_drft.size(), n);
const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags);
if (n != data_dft.size()) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n);
}
}

View File

@@ -1034,8 +1034,8 @@ struct common_prompt_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
std::vector<uint8_t> data_main;
std::vector<uint8_t> data_drft;
std::vector<uint8_t> data_tgt;
std::vector<uint8_t> data_dft;
size_t size() const;
@@ -1047,22 +1047,22 @@ struct common_prompt_checkpoint {
llama_pos pos_min,
llama_pos pos_max);
void update_main(
void update_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);
void update_drft(
void update_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);
void load_main(
void load_tgt(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;
void load_drft(
void load_dft(
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;

View File

@@ -175,7 +175,7 @@ int main(int argc, char ** argv) {
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id));
if (use_ckpt_dft) {
ckpt.update_drft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
// generate a new draft
@@ -188,12 +188,12 @@ int main(int argc, char ** argv) {
// this allows us to restore the state if partial draft acceptance occurs
if (!draft.empty()) {
if (use_ckpt_tgt) {
ckpt.update_main(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
}
{
ckpt.load_drft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
}
@@ -253,13 +253,13 @@ int main(int argc, char ** argv) {
draft = std::move(ids);
{
ckpt.load_main(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, ckpt.pos_max + 1, -1);
}
{
ckpt.load_drft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
}

View File

@@ -54,8 +54,8 @@ enum server_state {
struct server_slot {
int id;
llama_context * ctx_main = nullptr;
llama_context * ctx_drft = nullptr;
llama_context * ctx_tgt = nullptr;
llama_context * ctx_dft = nullptr;
// multimodal
mtmd_context * mctx = nullptr;
@@ -108,27 +108,27 @@ struct server_slot {
void prompt_save(server_prompt_cache & prompt_cache) const {
GGML_ASSERT(prompt.data.size() == 0);
const size_t cur_size_main = llama_state_seq_get_size_ext(ctx_main, id, LLAMA_STATE_SEQ_FLAGS_NONE);
const size_t cur_size_drft = ctx_drft ? llama_state_seq_get_size_ext(ctx_drft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0;
const size_t cur_size_tgt = llama_state_seq_get_size_ext(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE);
const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0;
const size_t cur_size = cur_size_main + cur_size_drft;
const size_t cur_size = cur_size_tgt + cur_size_dft;
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n",
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_drft / (1024.0 * 1024.0));
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0));
auto * cur = prompt_cache.alloc(prompt, cur_size_main, cur_size_drft);
auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft);
if (cur == nullptr) {
return;
}
llama_state_seq_get_data_ext(ctx_main, cur->data.main.data(), cur_size_main, id, LLAMA_STATE_SEQ_FLAGS_NONE);
if (ctx_drft) {
llama_state_seq_get_data_ext(ctx_drft, cur->data.drft.data(), cur_size_drft, id, LLAMA_STATE_SEQ_FLAGS_NONE);
llama_state_seq_get_data_ext(ctx_tgt, cur->data.main.data(), cur_size_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE);
if (ctx_dft) {
llama_state_seq_get_data_ext(ctx_dft, cur->data.drft.data(), cur_size_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE);
}
}
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
bool res = prompt_cache.load(prompt, tokens, ctx_main, ctx_drft, id);
bool res = prompt_cache.load(prompt, tokens, ctx_tgt, ctx_dft, id);
if (!res) {
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
}
@@ -143,9 +143,9 @@ struct server_slot {
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx_main), id, -1, -1);
if (ctx_drft) {
llama_memory_seq_rm(llama_get_memory(ctx_drft), id, -1, -1);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1);
if (ctx_dft) {
llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1);
}
prompt.tokens.clear();
@@ -205,7 +205,7 @@ struct server_slot {
task_prev = std::move(task);
task.reset();
llama_set_sampler(ctx_main, id, nullptr);
llama_set_sampler(ctx_tgt, id, nullptr);
// clear alora start
alora_invocation_start = -1;
@@ -242,7 +242,7 @@ struct server_slot {
return
!task->need_embd() ||
(llama_get_memory(ctx_main) && llama_pooling_type(ctx_main) == LLAMA_POOLING_TYPE_LAST);
(llama_get_memory(ctx_tgt) && llama_pooling_type(ctx_tgt) == LLAMA_POOLING_TYPE_LAST);
}
bool can_batch_with(server_slot & other_slot) const {
@@ -316,7 +316,7 @@ struct server_slot {
return n_draft_max;
}
void update_batch(llama_batch & batch, bool use_ckpt_main, bool use_ckpt_drft) {
void update_batch(llama_batch & batch, bool use_ckpt_tgt, bool use_ckpt_dft) {
const int n_draft_max = get_n_draft_max();
if (n_draft_max > 0) {
GGML_ASSERT(can_speculate());
@@ -330,7 +330,7 @@ struct server_slot {
if (!spec_draft.empty()) {
// we have a previous (partial) draft to reuse
if (use_ckpt_main) {
if (use_ckpt_tgt) {
GGML_ASSERT(!spec_ckpt.empty());
}
} else {
@@ -338,11 +338,11 @@ struct server_slot {
spec_ckpt.update_pos(
prompt.n_tokens(),
llama_memory_seq_pos_min(llama_get_memory(ctx_main), id),
llama_memory_seq_pos_max(llama_get_memory(ctx_main), id));
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), id),
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), id));
if (use_ckpt_drft) {
spec_ckpt.update_drft(ctx_drft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (use_ckpt_dft) {
spec_ckpt.update_dft(ctx_dft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
// generate a new draft
@@ -355,24 +355,24 @@ struct server_slot {
}
if (!spec_draft.empty()) {
if (use_ckpt_main) {
if (use_ckpt_tgt) {
//const int64_t t_start = ggml_time_us();
spec_ckpt.update_main(ctx_main, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
spec_ckpt.update_tgt(ctx_tgt, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
//const int64_t t_total = ggml_time_us() - t_start;
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n",
spec_ckpt.pos_min, spec_ckpt.pos_max, prompt.n_tokens(), (float) spec_ckpt.size() / 1024 / 1024, (float) spec_ckpt.data_drft.size() / 1024 / 1024);
spec_ckpt.pos_min, spec_ckpt.pos_max, prompt.n_tokens(), (float) spec_ckpt.size() / 1024 / 1024, (float) spec_ckpt.data_dft.size() / 1024 / 1024);
}
}
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
if (ctx_drft) {
spec_ckpt.load_drft(ctx_drft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (ctx_dft) {
spec_ckpt.load_dft(ctx_dft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(ctx_drft), this->id, spec_ckpt.pos_max + 1, -1);
llama_memory_seq_rm(llama_get_memory(ctx_dft), this->id, spec_ckpt.pos_max + 1, -1);
}
}
@@ -538,7 +538,7 @@ struct server_slot {
};
if (!only_metrics) {
res["prompt"] = ptask->tokens.detokenize(ctx_main, true);
res["prompt"] = ptask->tokens.detokenize(ctx_tgt, true);
res["generated"] = generated_text.empty() ? debug_generated_text : generated_text;
}
}
@@ -549,12 +549,12 @@ struct server_slot {
void copy_state_to(server_slot & other) const {
GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);
llama_memory_seq_rm(llama_get_memory(ctx_main), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx_main), id, other.id, -1, -1);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1);
if (ctx_drft) {
llama_memory_seq_rm(llama_get_memory(ctx_drft), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx_drft), id, other.id, -1, -1);
if (ctx_dft) {
llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1);
}
other.n_decoded = n_decoded;
@@ -646,7 +646,7 @@ public:
// only use these pointers outside of this class:
// - when not in sleeping state
// - and, with thread-safe APIs (e.g., tokenizer calls)
llama_model * model_main = nullptr;
llama_model * model_tgt = nullptr;
mtmd_context * mctx = nullptr;
const llama_vocab * vocab = nullptr;
@@ -674,15 +674,15 @@ private:
// note: keep these alive - they determine the lifetime of the model, context, etc.
common_init_result_ptr llama_init;
llama_context * ctx_main = nullptr;
llama_context * ctx_tgt = nullptr;
llama_batch batch {};
llama_model_ptr model_drft;
llama_context_ptr ctx_drft;
llama_model_ptr model_dft;
llama_context_ptr ctx_dft;
common_context_seq_rm_type ctx_main_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
common_context_seq_rm_type ctx_drft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
bool add_bos_token = true;
@@ -717,8 +717,8 @@ private:
void destroy() {
llama_init.reset();
ctx_main = nullptr;
model_main = nullptr;
ctx_tgt = nullptr;
model_tgt = nullptr;
mtmd_free(mctx);
mctx = nullptr;
@@ -768,17 +768,17 @@ private:
llama_init = common_init_from_params(params_base);
model_main = llama_init->model();
ctx_main = llama_init->context();
model_tgt = llama_init->model();
ctx_tgt = llama_init->context();
if (model_main == nullptr) {
if (model_tgt == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
return false;
}
vocab = llama_model_get_vocab(model_main);
vocab = llama_model_get_vocab(model_tgt);
n_ctx = llama_n_ctx(ctx_main);
n_ctx = llama_n_ctx(ctx_tgt);
add_bos_token = llama_vocab_get_add_bos(vocab);
@@ -805,19 +805,19 @@ private:
auto mparams_dft = common_model_params_to_llama(params_dft);
model_drft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
if (model_drft == nullptr) {
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
return false;
}
auto cparams = common_context_params_to_llama(params_dft);
ctx_drft.reset(llama_init_from_model(model_drft.get(), cparams));
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
ctx_drft_seq_rm_type = common_context_can_seq_rm(ctx_drft.get());
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_main;
params_base.speculative.draft.ctx_dft = ctx_drft.get();
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
}
std::string & mmproj_path = params_base.mmproj.path;
@@ -837,7 +837,7 @@ private:
mparams.image_max_tokens = params_base.image_max_tokens;
mparams.media_marker = get_media_marker();
mctx = mtmd_init_from_file(mmproj_path.c_str(), model_main, mparams);
mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams);
if (mctx == nullptr) {
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
return false;
@@ -855,7 +855,7 @@ private:
}
}
if (!llama_memory_can_shift(llama_get_memory(ctx_main))) {
if (!llama_memory_can_shift(llama_get_memory(ctx_tgt))) {
if (params_base.ctx_shift) {
params_base.ctx_shift = false;
SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled");
@@ -867,14 +867,14 @@ private:
}
}
if (llama_model_n_swa(model_main) == 0) {
if (llama_model_n_swa(model_tgt) == 0) {
if (params_base.swa_full) {
params_base.swa_full = false;
SRV_WRN("%s\n", "swa_full is not supported by this model, it will be disabled");
}
}
n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_main);
n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_tgt);
// Necessary similarity of prompt for slot selection
slot_prompt_similarity = params_base.slot_prompt_similarity;
@@ -882,9 +882,9 @@ private:
// setup slots
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
const int n_ctx_train = llama_model_n_ctx_train(model_main);
const int n_ctx_train = llama_model_n_ctx_train(model_tgt);
int n_ctx_slot = llama_n_ctx_seq(ctx_main);
int n_ctx_slot = llama_n_ctx_seq(ctx_tgt);
if (n_ctx_slot > n_ctx_train) {
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
n_ctx_slot = n_ctx_train;
@@ -892,12 +892,12 @@ private:
slots.clear();
ctx_main_seq_rm_type = common_context_can_seq_rm(ctx_main);
if (ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
ctx_tgt_seq_rm_type = common_context_can_seq_rm(ctx_tgt);
if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
if (ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
}
@@ -906,27 +906,27 @@ private:
slots.emplace_back();
}
bool no_drft = false;
bool no_dft = false;
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot & slot = slots[i];
slot.id = i;
slot.ctx_main = ctx_main;
slot.ctx_drft = ctx_drft.get();
slot.ctx_tgt = ctx_tgt;
slot.ctx_dft = ctx_dft.get();
slot.n_ctx = n_ctx_slot;
slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr;
// try speculative decoding
if (ctx_main_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
try {
slot.spec.reset(common_speculative_init(params_base.speculative, slot.id));
} catch (const std::exception & e) {
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
no_drft = true;
no_dft = true;
}
if (slot.spec) {
@@ -943,13 +943,13 @@ private:
slot.reset();
}
if (no_drft && ctx_drft) {
if (no_dft && ctx_dft) {
SRV_WRN("%s", "destroying the draft model as it is not going to be used\n");
ctx_drft.reset();
ctx_dft.reset();
for (auto & slot : slots) {
slot.ctx_drft = nullptr;
slot.ctx_dft = nullptr;
}
}
@@ -974,7 +974,7 @@ private:
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
{
const int32_t n_batch = llama_n_batch(ctx_main);
const int32_t n_batch = llama_n_batch(ctx_tgt);
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
@@ -1018,8 +1018,8 @@ private:
// unlike load_model(), this is only called once during initialization
bool init() {
GGML_ASSERT(ctx_main != nullptr);
GGML_ASSERT(model_main != nullptr);
GGML_ASSERT(ctx_tgt != nullptr);
GGML_ASSERT(model_tgt != nullptr);
GGML_ASSERT(!sleeping);
@@ -1066,7 +1066,7 @@ private:
common_chat_templates_ptr chat_templates;
try {
chat_templates = common_chat_templates_init(model_main, params_base.chat_template);
chat_templates = common_chat_templates_init(model_tgt, params_base.chat_template);
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
@@ -1329,7 +1329,7 @@ private:
}
}
if (!task.tokens.validate(ctx_main)) {
if (!task.tokens.validate(ctx_tgt)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
}
@@ -1339,7 +1339,7 @@ private:
// initialize samplers
if (task.need_sampling()) {
try {
slot.smpl.reset(common_sampler_init(model_main, task.params.sampling));
slot.smpl.reset(common_sampler_init(model_tgt, task.params.sampling));
} catch (std::exception & e) {
std::string err_msg = std::string("Failed to initialize samplers: ") + e.what();
send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST);
@@ -1360,9 +1360,9 @@ private:
// TODO: tmp until backend sampling is fully implemented
if (backend_sampling) {
llama_set_sampler(ctx_main, slot.id, common_sampler_get(slot.smpl.get()));
llama_set_sampler(ctx_tgt, slot.id, common_sampler_get(slot.smpl.get()));
} else {
llama_set_sampler(ctx_main, slot.id, nullptr);
llama_set_sampler(ctx_tgt, slot.id, nullptr);
}
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
@@ -1535,13 +1535,13 @@ private:
for (size_t i = 0; i < n_probs; i++) {
result.probs.push_back({
cur_p->data[i].id,
common_token_to_piece(ctx_main, cur_p->data[i].id, special),
common_token_to_piece(ctx_tgt, cur_p->data[i].id, special),
cur_p->data[i].p
});
}
} else {
// TODO: optimize this with min-p optimization
std::vector<llama_token_data> cur = get_token_probabilities(ctx_main, idx);
std::vector<llama_token_data> cur = get_token_probabilities(ctx_tgt, idx);
const size_t max_probs = cur.size();
const size_t n_probs = std::min(max_probs, n_probs_request);
@@ -1559,7 +1559,7 @@ private:
for (size_t i = 0; i < n_probs; i++) {
result.probs.push_back({
cur[i].id,
common_token_to_piece(ctx_main, cur[i].id, special),
common_token_to_piece(ctx_tgt, cur[i].id, special),
cur[i].p
});
}
@@ -1662,7 +1662,7 @@ private:
res->tokens = std::move(slot.generated_tokens);
}
res->timings = slot.get_timings();
res->prompt = slot.task->tokens.detokenize(ctx_main, true);
res->prompt = slot.task->tokens.detokenize(ctx_tgt, true);
res->response_fields = std::move(slot.task->params.response_fields);
res->truncated = slot.truncated;
@@ -1685,7 +1685,7 @@ private:
// populate res.probs_output
if (slot.task->params.sampling.n_probs > 0) {
if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) {
const llama_tokens stop_word_toks = common_tokenize(ctx_main, slot.stopping_word, false);
const llama_tokens stop_word_toks = common_tokenize(ctx_tgt, slot.stopping_word, false);
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
res->probs_output = std::vector<completion_token_output>(
@@ -1710,7 +1710,7 @@ private:
res->n_tokens = slot.task->n_tokens();
res->res_type = slot.task->params.res_type;
const int n_embd_out = llama_model_n_embd_out(model_main);
const int n_embd_out = llama_model_n_embd_out(model_tgt);
std::vector<float> embd_res(n_embd_out, 0.0f);
@@ -1720,10 +1720,10 @@ private:
}
const float * embd = nullptr;
if (llama_pooling_type(slot.ctx_main) == LLAMA_POOLING_TYPE_NONE) {
embd = llama_get_embeddings_ith(slot.ctx_main, i);
if (llama_pooling_type(slot.ctx_tgt) == LLAMA_POOLING_TYPE_NONE) {
embd = llama_get_embeddings_ith(slot.ctx_tgt, i);
} else {
embd = llama_get_embeddings_seq(slot.ctx_main, batch.seq_id[i][0]);
embd = llama_get_embeddings_seq(slot.ctx_tgt, batch.seq_id[i][0]);
}
if (embd == nullptr) {
@@ -1734,7 +1734,7 @@ private:
}
// normalize only when there is pooling
if (llama_pooling_type(slot.ctx_main) != LLAMA_POOLING_TYPE_NONE) {
if (llama_pooling_type(slot.ctx_tgt) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize);
res->embedding.push_back(embd_res);
break;
@@ -1759,9 +1759,9 @@ private:
continue;
}
const float * embd = llama_get_embeddings_seq(ctx_main, batch.seq_id[i][0]);
const float * embd = llama_get_embeddings_seq(ctx_tgt, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx_main, i);
embd = llama_get_embeddings_ith(ctx_tgt, i);
}
if (embd == NULL) {
@@ -1875,8 +1875,8 @@ private:
cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
cur.update_main(ctx_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
cur.update_drft(ctx_drft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
@@ -2036,7 +2036,7 @@ private:
std::string filepath = task.slot_action.filepath;
const llama_tokens & tokens = slot->prompt.tokens.get_tokens();
const size_t nwrite = llama_state_seq_save_file(ctx_main, filepath.c_str(), slot->id, tokens.data(), token_count);
const size_t nwrite = llama_state_seq_save_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), token_count);
const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0;
@@ -2075,7 +2075,7 @@ private:
llama_tokens tokens;
tokens.resize(slot->n_ctx);
size_t token_count = 0;
size_t nread = llama_state_seq_load_file(ctx_main, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
size_t nread = llama_state_seq_load_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
if (nread == 0) {
slot->prompt.tokens.clear(); // KV may already been invalidated?
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
@@ -2234,12 +2234,12 @@ private:
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
if (ctx_drft) {
llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
if (ctx_dft) {
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
}
// add generated tokens to cache
@@ -2287,13 +2287,13 @@ private:
}
slot.update_batch(batch,
ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL,
ctx_drft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL,
ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
}
// process in chunks of params.n_batch
int32_t n_batch = llama_n_batch(ctx_main);
int32_t n_ubatch = llama_n_ubatch(ctx_main);
int32_t n_batch = llama_n_batch(ctx_tgt);
int32_t n_ubatch = llama_n_ubatch(ctx_tgt);
float alora_scale = -1.0f;
size_t alora_disabled_id = 0;
@@ -2337,12 +2337,12 @@ private:
/*if (1) {
// first 16 tokens (avoid flooding logs)
for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_main, input_tokens[i]).c_str());
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str());
}
} else {
// all
for (int i = 0; i < (int) input_tokens.size(); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_main, input_tokens[i]).c_str());
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str());
}
}*/
@@ -2361,7 +2361,7 @@ private:
}
// TODO: support memory-less logits computation
if (slot.task->need_logits() && !llama_get_memory(ctx_main)) {
if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) {
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
slot.release();
continue;
@@ -2413,7 +2413,7 @@ private:
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
const bool can_cache_reuse =
llama_memory_can_shift(llama_get_memory(ctx_main)) &&
llama_memory_can_shift(llama_get_memory(ctx_tgt)) &&
!slot.prompt.tokens.has_mtmd;
if (!can_cache_reuse && n_cache_reuse > 0) {
@@ -2447,17 +2447,17 @@ private:
if (n_match >= (size_t) n_cache_reuse) {
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
//for (size_t i = head_p; i < head_p + n_match; i++) {
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_main, prompt_tokens[i]).c_str());
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_tgt, prompt_tokens[i]).c_str());
//}
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, head_c, head_c + n_match, kv_shift);
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift);
if (ctx_drft) {
llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, head_c, head_c + n_match, kv_shift);
if (ctx_dft) {
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift);
}
for (size_t i = 0; i < n_match; i++) {
@@ -2485,7 +2485,7 @@ private:
const auto pos_min_thold = std::max(0, pos_next - n_swa);
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_main), slot.id);
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
@@ -2514,14 +2514,14 @@ private:
{
const auto token = slot.prompt.tokens[i];
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_main, token) : "[mtmd]";
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]";
ss0 << piece;
st0 << std::setw(8) << token;
}
{
const auto token = slot.task->tokens[i];
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_main, token) : "[mtmd]";
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]";
ss1 << piece;
st1 << std::setw(8) << token;
}
@@ -2554,8 +2554,8 @@ private:
if (!do_reset) {
// restore the context checkpoint
it->load_main(ctx_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
it->load_drft(ctx_drft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
@@ -2616,7 +2616,7 @@ private:
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
if (!llama_memory_seq_rm(llama_get_memory(ctx_main), slot.id, p0, -1)) {
if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
slot.prompt_clear(true);
@@ -2624,7 +2624,7 @@ private:
// there is no common part left
slot.n_prompt_tokens_cache = 0;
} else {
if (ctx_drft && !llama_memory_seq_rm(llama_get_memory(ctx_drft.get()), slot.id, p0, -1)) {
if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) {
GGML_ABORT("failed to truncate draft context\n");
}
}
@@ -2653,7 +2653,7 @@ private:
// - the model does not support partial sequence removal
// - the model uses SWA (and we are not using `swa_full`)
do_checkpoint = do_checkpoint && (
(ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
(n_swa > 0));
bool has_mtmd = false;
@@ -2662,7 +2662,7 @@ private:
while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image
size_t n_tokens_out = 0;
int32_t res = input_tokens.process_chunk(ctx_main, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
if (res != 0) {
SLT_ERR(slot, "failed to process image, res = %d\n", res);
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
@@ -2670,10 +2670,10 @@ private:
continue;
}
if (ctx_drft) {
if (ctx_dft) {
// TODO: in the future, figure out how to infuse target embeddings to the images
// for now, we skip this for simplicity
res = input_tokens.process_chunk(ctx_drft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
if (res != 0) {
GGML_ABORT("failed to process multi-modal data on draft context\n");
}
@@ -2780,8 +2780,8 @@ private:
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
}
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_main), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_main), slot.id);
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64);
@@ -2814,7 +2814,7 @@ private:
if (slot_batched) {
// apply lora, only need to do it once per batch
common_set_adapter_lora(ctx_main, slot_batched->lora);
common_set_adapter_lora(ctx_tgt, slot_batched->lora);
// if the lora is temporarily disabled for an alora, re-enable it
// for next time
@@ -2823,7 +2823,7 @@ private:
slot_batched->lora[alora_disabled_id].scale = alora_scale;
}
llama_set_embeddings(ctx_main, slot_batched->task->need_embd());
llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd());
}
if (batch.n_tokens == 0) {
@@ -2852,7 +2852,7 @@ private:
batch.logits + i,
};
const int ret = llama_decode(ctx_main, batch_view);
const int ret = llama_decode(ctx_tgt, batch_view);
metrics.on_decoded(slots);
@@ -2917,12 +2917,12 @@ private:
// | Eagle3 | yes |
// | DFlash | yes? |
//
if (ctx_drft) {
if (ctx_dft) {
// TODO: update as needed for MTP, Eagle3, etc.
const bool need_tgt_embd = false;
if (need_tgt_embd) {
llama_synchronize(ctx_main);
llama_synchronize(ctx_tgt);
}
// the logic here varies depending on the speculative decoding method
@@ -2931,13 +2931,13 @@ private:
// TODO: extract this in a function ?
{
// TODO: hook the embeddings from the last target batch here
if (llama_model_has_encoder(model_drft.get())) {
//llama_encode(ctx_drft, ...);
if (llama_model_has_encoder(model_dft.get())) {
//llama_encode(ctx_dft, ...);
GGML_ABORT("not implemented yet\n");
}
const int ret = llama_decode(ctx_drft.get(), batch_view);
const int ret = llama_decode(ctx_dft.get(), batch_view);
if (ret != 0) {
SRV_ERR("failed to decode draft batch, ret = %d\n", ret);
@@ -2952,7 +2952,7 @@ private:
i_next = i + n_tokens;
// on successful decode, restore the original batch size
n_batch = llama_n_batch(ctx_main);
n_batch = llama_n_batch(ctx_tgt);
// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
for (auto & slot : slots) {
@@ -3023,7 +3023,7 @@ private:
const int tok_idx = slot.i_batch - i;
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_main, tok_idx);
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
slot.i_batch = -1;
@@ -3044,7 +3044,7 @@ private:
completion_token_output result;
result.tok = id;
result.text_to_send = common_token_to_piece(slot.ctx_main, result.tok, accept_special_token(slot, result.tok));
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
if (slot.task->params.sampling.n_probs > 0) {
@@ -3075,23 +3075,23 @@ private:
// verify and try to accept the draft
{
const bool use_ckpt_main = ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
// only save the sampler sampler state if we use checkpoints
common_sampler_ptr smpl_save;
if (use_ckpt_main) {
if (use_ckpt_tgt) {
smpl_save.reset(common_sampler_clone(slot.smpl.get()));
}
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_main, slot.spec_i_batch, slot.spec_draft);
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft);
slot.spec_i_batch.clear();
GGML_ASSERT(accepted.size() >= 1);
// check for partial draft acceptance
if (accepted.size() < slot.spec_draft.size() + 1) {
if (use_ckpt_main) {
if (use_ckpt_tgt) {
if (trace > 0) {
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
}
@@ -3104,15 +3104,15 @@ private:
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
{
ckpt.load_main(slot.ctx_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, ckpt.pos_max + 1, -1);
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1);
}
if (slot.ctx_drft) {
ckpt.load_drft(slot.ctx_drft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (slot.ctx_dft) {
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, ckpt.pos_max + 1, -1);
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1);
}
slot.prompt.tokens.keep_first(ckpt.n_tokens);
@@ -3148,16 +3148,16 @@ private:
slot.sampled = ids.back(); // last accepted token
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, slot.prompt.tokens.pos_next(), -1);
if (slot.ctx_drft) {
llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, slot.prompt.tokens.pos_next(), -1);
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1);
if (slot.ctx_dft) {
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1);
}
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
result.text_to_send = common_token_to_piece(slot.ctx_main, result.tok, accept_special_token(slot, result.tok));
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // set later
// TODO: set result.probs
@@ -3209,7 +3209,7 @@ void server_context::terminate() {
}
llama_context * server_context::get_llama_context() const {
return impl->ctx_main;
return impl->ctx_tgt;
}
server_response_reader server_context::get_response_reader() {
@@ -3219,8 +3219,8 @@ server_response_reader server_context::get_response_reader() {
server_context_meta server_context::get_meta() const {
auto bos_id = llama_vocab_bos(impl->vocab);
auto eos_id = llama_vocab_eos(impl->vocab);
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_main, bos_id, true) : "";
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_main, eos_id, true) : "";
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : "";
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : "";
return server_context_meta {
/* build_info */ std::string(llama_build_info()),
@@ -3233,7 +3233,7 @@ server_context_meta server_context::get_meta() const {
/* has_inp_audio */ impl->chat_params.allow_audio,
/* json_webui_settings */ impl->json_webui_settings,
/* slot_n_ctx */ impl->get_slot_n_ctx(),
/* pooling_type */ llama_pooling_type(impl->ctx_main),
/* pooling_type */ llama_pooling_type(impl->ctx_tgt),
/* chat_params */ impl->chat_params,
/* chat_template_caps */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()),
@@ -3251,10 +3251,10 @@ server_context_meta server_context::get_meta() const {
/* model_vocab_type */ llama_vocab_type(impl->vocab),
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_main),
/* model_n_embd_inp */ llama_model_n_embd(impl->model_main),
/* model_n_params */ llama_model_n_params(impl->model_main),
/* model_size */ llama_model_size(impl->model_main),
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_tgt),
/* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt),
/* model_n_params */ llama_model_n_params(impl->model_tgt),
/* model_size */ llama_model_size(impl->model_tgt),
};
}
@@ -4156,7 +4156,7 @@ void server_routes::init_routes() {
std::vector<server_task> tasks;
tasks.reserve(documents.size());
for (size_t i = 0; i < documents.size(); i++) {
auto tmp = format_prompt_rerank(ctx_server.model_main, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
auto tmp = format_prompt_rerank(ctx_server.model_tgt, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = rd.get_new_id();
task.tokens = std::move(tmp);

View File

@@ -1981,7 +1981,7 @@ size_t server_prompt_cache::n_tokens() const {
return res;
}
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft) {
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) {
// first check if the current state is contained fully in the cache
for (auto it = states.begin(); it != states.end(); ++it) {
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
@@ -2005,13 +2005,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
}
}
std::vector<uint8_t> state_data_main;
std::vector<uint8_t> state_data_drft;
std::vector<uint8_t> state_data_tgt;
std::vector<uint8_t> state_data_dft;
// check if we can allocate enough memory for the new state
try {
state_data_main.resize(state_size_main);
state_data_drft.resize(state_size_drft);
state_data_tgt.resize(state_size_tgt);
state_data_dft.resize(state_size_dft);
} catch (const std::bad_alloc & e) {
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
@@ -2027,8 +2027,8 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
states.push_back({
/*.tokens =*/ prompt.tokens.clone(),
/*.data =*/ {
/*.main =*/ std::move(state_data_main),
/*.drft =*/ std::move(state_data_drft),
/*.main =*/ std::move(state_data_tgt),
/*.drft =*/ std::move(state_data_dft),
},
/*.checkpoints =*/ prompt.checkpoints,
});
@@ -2036,7 +2036,7 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
return &states.back();
}
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot) {
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) {
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
@@ -2073,7 +2073,7 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
auto & data = it_best->data.main;
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_main, data.data(), size, id_slot, 0);
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
@@ -2088,10 +2088,10 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
auto & data = it_best->data.drft;
if (!data.empty()) {
GGML_ASSERT(ctx_drft);
GGML_ASSERT(ctx_dft);
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_drft, data.data(), size, id_slot, 0);
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);