|
|
|
|
@@ -54,8 +54,8 @@ enum server_state {
|
|
|
|
|
struct server_slot {
|
|
|
|
|
int id;
|
|
|
|
|
|
|
|
|
|
llama_context * ctx_main = nullptr;
|
|
|
|
|
llama_context * ctx_drft = nullptr;
|
|
|
|
|
llama_context * ctx_tgt = nullptr;
|
|
|
|
|
llama_context * ctx_dft = nullptr;
|
|
|
|
|
|
|
|
|
|
// multimodal
|
|
|
|
|
mtmd_context * mctx = nullptr;
|
|
|
|
|
@@ -108,27 +108,27 @@ struct server_slot {
|
|
|
|
|
void prompt_save(server_prompt_cache & prompt_cache) const {
|
|
|
|
|
GGML_ASSERT(prompt.data.size() == 0);
|
|
|
|
|
|
|
|
|
|
const size_t cur_size_main = llama_state_seq_get_size_ext(ctx_main, id, LLAMA_STATE_SEQ_FLAGS_NONE);
|
|
|
|
|
const size_t cur_size_drft = ctx_drft ? llama_state_seq_get_size_ext(ctx_drft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0;
|
|
|
|
|
const size_t cur_size_tgt = llama_state_seq_get_size_ext(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE);
|
|
|
|
|
const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0;
|
|
|
|
|
|
|
|
|
|
const size_t cur_size = cur_size_main + cur_size_drft;
|
|
|
|
|
const size_t cur_size = cur_size_tgt + cur_size_dft;
|
|
|
|
|
|
|
|
|
|
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n",
|
|
|
|
|
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_drft / (1024.0 * 1024.0));
|
|
|
|
|
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0));
|
|
|
|
|
|
|
|
|
|
auto * cur = prompt_cache.alloc(prompt, cur_size_main, cur_size_drft);
|
|
|
|
|
auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft);
|
|
|
|
|
if (cur == nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llama_state_seq_get_data_ext(ctx_main, cur->data.main.data(), cur_size_main, id, LLAMA_STATE_SEQ_FLAGS_NONE);
|
|
|
|
|
if (ctx_drft) {
|
|
|
|
|
llama_state_seq_get_data_ext(ctx_drft, cur->data.drft.data(), cur_size_drft, id, LLAMA_STATE_SEQ_FLAGS_NONE);
|
|
|
|
|
llama_state_seq_get_data_ext(ctx_tgt, cur->data.main.data(), cur_size_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE);
|
|
|
|
|
if (ctx_dft) {
|
|
|
|
|
llama_state_seq_get_data_ext(ctx_dft, cur->data.drft.data(), cur_size_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
|
|
|
|
bool res = prompt_cache.load(prompt, tokens, ctx_main, ctx_drft, id);
|
|
|
|
|
bool res = prompt_cache.load(prompt, tokens, ctx_tgt, ctx_dft, id);
|
|
|
|
|
if (!res) {
|
|
|
|
|
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
|
|
|
|
|
}
|
|
|
|
|
@@ -143,9 +143,9 @@ struct server_slot {
|
|
|
|
|
|
|
|
|
|
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
|
|
|
|
|
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx_main), id, -1, -1);
|
|
|
|
|
if (ctx_drft) {
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx_drft), id, -1, -1);
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1);
|
|
|
|
|
if (ctx_dft) {
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
prompt.tokens.clear();
|
|
|
|
|
@@ -205,7 +205,7 @@ struct server_slot {
|
|
|
|
|
task_prev = std::move(task);
|
|
|
|
|
task.reset();
|
|
|
|
|
|
|
|
|
|
llama_set_sampler(ctx_main, id, nullptr);
|
|
|
|
|
llama_set_sampler(ctx_tgt, id, nullptr);
|
|
|
|
|
|
|
|
|
|
// clear alora start
|
|
|
|
|
alora_invocation_start = -1;
|
|
|
|
|
@@ -242,7 +242,7 @@ struct server_slot {
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
|
!task->need_embd() ||
|
|
|
|
|
(llama_get_memory(ctx_main) && llama_pooling_type(ctx_main) == LLAMA_POOLING_TYPE_LAST);
|
|
|
|
|
(llama_get_memory(ctx_tgt) && llama_pooling_type(ctx_tgt) == LLAMA_POOLING_TYPE_LAST);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool can_batch_with(server_slot & other_slot) const {
|
|
|
|
|
@@ -316,7 +316,7 @@ struct server_slot {
|
|
|
|
|
return n_draft_max;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void update_batch(llama_batch & batch, bool use_ckpt_main, bool use_ckpt_drft) {
|
|
|
|
|
void update_batch(llama_batch & batch, bool use_ckpt_tgt, bool use_ckpt_dft) {
|
|
|
|
|
const int n_draft_max = get_n_draft_max();
|
|
|
|
|
if (n_draft_max > 0) {
|
|
|
|
|
GGML_ASSERT(can_speculate());
|
|
|
|
|
@@ -330,7 +330,7 @@ struct server_slot {
|
|
|
|
|
|
|
|
|
|
if (!spec_draft.empty()) {
|
|
|
|
|
// we have a previous (partial) draft to reuse
|
|
|
|
|
if (use_ckpt_main) {
|
|
|
|
|
if (use_ckpt_tgt) {
|
|
|
|
|
GGML_ASSERT(!spec_ckpt.empty());
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
@@ -338,11 +338,11 @@ struct server_slot {
|
|
|
|
|
|
|
|
|
|
spec_ckpt.update_pos(
|
|
|
|
|
prompt.n_tokens(),
|
|
|
|
|
llama_memory_seq_pos_min(llama_get_memory(ctx_main), id),
|
|
|
|
|
llama_memory_seq_pos_max(llama_get_memory(ctx_main), id));
|
|
|
|
|
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), id),
|
|
|
|
|
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), id));
|
|
|
|
|
|
|
|
|
|
if (use_ckpt_drft) {
|
|
|
|
|
spec_ckpt.update_drft(ctx_drft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
|
|
|
|
if (use_ckpt_dft) {
|
|
|
|
|
spec_ckpt.update_dft(ctx_dft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// generate a new draft
|
|
|
|
|
@@ -355,24 +355,24 @@ struct server_slot {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!spec_draft.empty()) {
|
|
|
|
|
if (use_ckpt_main) {
|
|
|
|
|
if (use_ckpt_tgt) {
|
|
|
|
|
//const int64_t t_start = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
spec_ckpt.update_main(ctx_main, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
|
|
|
|
spec_ckpt.update_tgt(ctx_tgt, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
|
|
|
|
|
|
|
|
|
//const int64_t t_total = ggml_time_us() - t_start;
|
|
|
|
|
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
|
|
|
|
|
|
|
|
|
|
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n",
|
|
|
|
|
spec_ckpt.pos_min, spec_ckpt.pos_max, prompt.n_tokens(), (float) spec_ckpt.size() / 1024 / 1024, (float) spec_ckpt.data_drft.size() / 1024 / 1024);
|
|
|
|
|
spec_ckpt.pos_min, spec_ckpt.pos_max, prompt.n_tokens(), (float) spec_ckpt.size() / 1024 / 1024, (float) spec_ckpt.data_dft.size() / 1024 / 1024);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
|
|
|
|
|
if (ctx_drft) {
|
|
|
|
|
spec_ckpt.load_drft(ctx_drft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
|
|
|
|
if (ctx_dft) {
|
|
|
|
|
spec_ckpt.load_dft(ctx_dft, this->id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
|
|
|
|
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx_drft), this->id, spec_ckpt.pos_max + 1, -1);
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx_dft), this->id, spec_ckpt.pos_max + 1, -1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -538,7 +538,7 @@ struct server_slot {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (!only_metrics) {
|
|
|
|
|
res["prompt"] = ptask->tokens.detokenize(ctx_main, true);
|
|
|
|
|
res["prompt"] = ptask->tokens.detokenize(ctx_tgt, true);
|
|
|
|
|
res["generated"] = generated_text.empty() ? debug_generated_text : generated_text;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -549,12 +549,12 @@ struct server_slot {
|
|
|
|
|
void copy_state_to(server_slot & other) const {
|
|
|
|
|
GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);
|
|
|
|
|
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx_main), other.id, -1, -1);
|
|
|
|
|
llama_memory_seq_cp(llama_get_memory(ctx_main), id, other.id, -1, -1);
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1);
|
|
|
|
|
llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1);
|
|
|
|
|
|
|
|
|
|
if (ctx_drft) {
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx_drft), other.id, -1, -1);
|
|
|
|
|
llama_memory_seq_cp(llama_get_memory(ctx_drft), id, other.id, -1, -1);
|
|
|
|
|
if (ctx_dft) {
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1);
|
|
|
|
|
llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
other.n_decoded = n_decoded;
|
|
|
|
|
@@ -646,7 +646,7 @@ public:
|
|
|
|
|
// only use these pointers outside of this class:
|
|
|
|
|
// - when not in sleeping state
|
|
|
|
|
// - and, with thread-safe APIs (e.g., tokenizer calls)
|
|
|
|
|
llama_model * model_main = nullptr;
|
|
|
|
|
llama_model * model_tgt = nullptr;
|
|
|
|
|
|
|
|
|
|
mtmd_context * mctx = nullptr;
|
|
|
|
|
const llama_vocab * vocab = nullptr;
|
|
|
|
|
@@ -674,15 +674,15 @@ private:
|
|
|
|
|
// note: keep these alive - they determine the lifetime of the model, context, etc.
|
|
|
|
|
common_init_result_ptr llama_init;
|
|
|
|
|
|
|
|
|
|
llama_context * ctx_main = nullptr;
|
|
|
|
|
llama_context * ctx_tgt = nullptr;
|
|
|
|
|
|
|
|
|
|
llama_batch batch {};
|
|
|
|
|
|
|
|
|
|
llama_model_ptr model_drft;
|
|
|
|
|
llama_context_ptr ctx_drft;
|
|
|
|
|
llama_model_ptr model_dft;
|
|
|
|
|
llama_context_ptr ctx_dft;
|
|
|
|
|
|
|
|
|
|
common_context_seq_rm_type ctx_main_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
|
|
|
|
common_context_seq_rm_type ctx_drft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
|
|
|
|
common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
|
|
|
|
common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
|
|
|
|
|
|
|
|
|
bool add_bos_token = true;
|
|
|
|
|
|
|
|
|
|
@@ -717,8 +717,8 @@ private:
|
|
|
|
|
void destroy() {
|
|
|
|
|
llama_init.reset();
|
|
|
|
|
|
|
|
|
|
ctx_main = nullptr;
|
|
|
|
|
model_main = nullptr;
|
|
|
|
|
ctx_tgt = nullptr;
|
|
|
|
|
model_tgt = nullptr;
|
|
|
|
|
|
|
|
|
|
mtmd_free(mctx);
|
|
|
|
|
mctx = nullptr;
|
|
|
|
|
@@ -768,17 +768,17 @@ private:
|
|
|
|
|
|
|
|
|
|
llama_init = common_init_from_params(params_base);
|
|
|
|
|
|
|
|
|
|
model_main = llama_init->model();
|
|
|
|
|
ctx_main = llama_init->context();
|
|
|
|
|
model_tgt = llama_init->model();
|
|
|
|
|
ctx_tgt = llama_init->context();
|
|
|
|
|
|
|
|
|
|
if (model_main == nullptr) {
|
|
|
|
|
if (model_tgt == nullptr) {
|
|
|
|
|
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
vocab = llama_model_get_vocab(model_main);
|
|
|
|
|
vocab = llama_model_get_vocab(model_tgt);
|
|
|
|
|
|
|
|
|
|
n_ctx = llama_n_ctx(ctx_main);
|
|
|
|
|
n_ctx = llama_n_ctx(ctx_tgt);
|
|
|
|
|
|
|
|
|
|
add_bos_token = llama_vocab_get_add_bos(vocab);
|
|
|
|
|
|
|
|
|
|
@@ -805,19 +805,19 @@ private:
|
|
|
|
|
|
|
|
|
|
auto mparams_dft = common_model_params_to_llama(params_dft);
|
|
|
|
|
|
|
|
|
|
model_drft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
|
|
|
|
|
if (model_drft == nullptr) {
|
|
|
|
|
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
|
|
|
|
|
if (model_dft == nullptr) {
|
|
|
|
|
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto cparams = common_context_params_to_llama(params_dft);
|
|
|
|
|
ctx_drft.reset(llama_init_from_model(model_drft.get(), cparams));
|
|
|
|
|
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
|
|
|
|
|
|
|
|
|
ctx_drft_seq_rm_type = common_context_can_seq_rm(ctx_drft.get());
|
|
|
|
|
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
|
|
|
|
|
|
|
|
|
params_base.speculative.draft.ctx_tgt = ctx_main;
|
|
|
|
|
params_base.speculative.draft.ctx_dft = ctx_drft.get();
|
|
|
|
|
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
|
|
|
|
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string & mmproj_path = params_base.mmproj.path;
|
|
|
|
|
@@ -837,7 +837,7 @@ private:
|
|
|
|
|
mparams.image_max_tokens = params_base.image_max_tokens;
|
|
|
|
|
mparams.media_marker = get_media_marker();
|
|
|
|
|
|
|
|
|
|
mctx = mtmd_init_from_file(mmproj_path.c_str(), model_main, mparams);
|
|
|
|
|
mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams);
|
|
|
|
|
if (mctx == nullptr) {
|
|
|
|
|
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
|
|
|
|
return false;
|
|
|
|
|
@@ -855,7 +855,7 @@ private:
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!llama_memory_can_shift(llama_get_memory(ctx_main))) {
|
|
|
|
|
if (!llama_memory_can_shift(llama_get_memory(ctx_tgt))) {
|
|
|
|
|
if (params_base.ctx_shift) {
|
|
|
|
|
params_base.ctx_shift = false;
|
|
|
|
|
SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled");
|
|
|
|
|
@@ -867,14 +867,14 @@ private:
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (llama_model_n_swa(model_main) == 0) {
|
|
|
|
|
if (llama_model_n_swa(model_tgt) == 0) {
|
|
|
|
|
if (params_base.swa_full) {
|
|
|
|
|
params_base.swa_full = false;
|
|
|
|
|
SRV_WRN("%s\n", "swa_full is not supported by this model, it will be disabled");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_main);
|
|
|
|
|
n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_tgt);
|
|
|
|
|
|
|
|
|
|
// Necessary similarity of prompt for slot selection
|
|
|
|
|
slot_prompt_similarity = params_base.slot_prompt_similarity;
|
|
|
|
|
@@ -882,9 +882,9 @@ private:
|
|
|
|
|
// setup slots
|
|
|
|
|
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
|
|
|
|
|
|
|
|
|
const int n_ctx_train = llama_model_n_ctx_train(model_main);
|
|
|
|
|
const int n_ctx_train = llama_model_n_ctx_train(model_tgt);
|
|
|
|
|
|
|
|
|
|
int n_ctx_slot = llama_n_ctx_seq(ctx_main);
|
|
|
|
|
int n_ctx_slot = llama_n_ctx_seq(ctx_tgt);
|
|
|
|
|
if (n_ctx_slot > n_ctx_train) {
|
|
|
|
|
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
|
|
|
|
|
n_ctx_slot = n_ctx_train;
|
|
|
|
|
@@ -892,12 +892,12 @@ private:
|
|
|
|
|
|
|
|
|
|
slots.clear();
|
|
|
|
|
|
|
|
|
|
ctx_main_seq_rm_type = common_context_can_seq_rm(ctx_main);
|
|
|
|
|
if (ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
|
|
|
|
ctx_tgt_seq_rm_type = common_context_can_seq_rm(ctx_tgt);
|
|
|
|
|
if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
|
|
|
|
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
|
|
|
|
if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
|
|
|
|
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -906,27 +906,27 @@ private:
|
|
|
|
|
slots.emplace_back();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool no_drft = false;
|
|
|
|
|
bool no_dft = false;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < params_base.n_parallel; i++) {
|
|
|
|
|
server_slot & slot = slots[i];
|
|
|
|
|
|
|
|
|
|
slot.id = i;
|
|
|
|
|
slot.ctx_main = ctx_main;
|
|
|
|
|
slot.ctx_drft = ctx_drft.get();
|
|
|
|
|
slot.ctx_tgt = ctx_tgt;
|
|
|
|
|
slot.ctx_dft = ctx_dft.get();
|
|
|
|
|
slot.n_ctx = n_ctx_slot;
|
|
|
|
|
|
|
|
|
|
slot.mctx = mctx;
|
|
|
|
|
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
|
|
|
|
|
|
|
|
|
// try speculative decoding
|
|
|
|
|
if (ctx_main_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
|
|
|
|
if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
|
|
|
|
try {
|
|
|
|
|
slot.spec.reset(common_speculative_init(params_base.speculative, slot.id));
|
|
|
|
|
} catch (const std::exception & e) {
|
|
|
|
|
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
|
|
|
|
|
|
|
|
|
|
no_drft = true;
|
|
|
|
|
no_dft = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (slot.spec) {
|
|
|
|
|
@@ -943,13 +943,13 @@ private:
|
|
|
|
|
slot.reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (no_drft && ctx_drft) {
|
|
|
|
|
if (no_dft && ctx_dft) {
|
|
|
|
|
SRV_WRN("%s", "destroying the draft model as it is not going to be used\n");
|
|
|
|
|
|
|
|
|
|
ctx_drft.reset();
|
|
|
|
|
ctx_dft.reset();
|
|
|
|
|
|
|
|
|
|
for (auto & slot : slots) {
|
|
|
|
|
slot.ctx_drft = nullptr;
|
|
|
|
|
slot.ctx_dft = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -974,7 +974,7 @@ private:
|
|
|
|
|
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
|
|
|
|
|
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
|
|
|
|
{
|
|
|
|
|
const int32_t n_batch = llama_n_batch(ctx_main);
|
|
|
|
|
const int32_t n_batch = llama_n_batch(ctx_tgt);
|
|
|
|
|
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1018,8 +1018,8 @@ private:
|
|
|
|
|
|
|
|
|
|
// unlike load_model(), this is only called once during initialization
|
|
|
|
|
bool init() {
|
|
|
|
|
GGML_ASSERT(ctx_main != nullptr);
|
|
|
|
|
GGML_ASSERT(model_main != nullptr);
|
|
|
|
|
GGML_ASSERT(ctx_tgt != nullptr);
|
|
|
|
|
GGML_ASSERT(model_tgt != nullptr);
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(!sleeping);
|
|
|
|
|
|
|
|
|
|
@@ -1066,7 +1066,7 @@ private:
|
|
|
|
|
common_chat_templates_ptr chat_templates;
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
chat_templates = common_chat_templates_init(model_main, params_base.chat_template);
|
|
|
|
|
chat_templates = common_chat_templates_init(model_tgt, params_base.chat_template);
|
|
|
|
|
|
|
|
|
|
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
|
|
|
|
|
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
|
|
|
|
|
@@ -1329,7 +1329,7 @@ private:
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!task.tokens.validate(ctx_main)) {
|
|
|
|
|
if (!task.tokens.validate(ctx_tgt)) {
|
|
|
|
|
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
@@ -1339,7 +1339,7 @@ private:
|
|
|
|
|
// initialize samplers
|
|
|
|
|
if (task.need_sampling()) {
|
|
|
|
|
try {
|
|
|
|
|
slot.smpl.reset(common_sampler_init(model_main, task.params.sampling));
|
|
|
|
|
slot.smpl.reset(common_sampler_init(model_tgt, task.params.sampling));
|
|
|
|
|
} catch (std::exception & e) {
|
|
|
|
|
std::string err_msg = std::string("Failed to initialize samplers: ") + e.what();
|
|
|
|
|
send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST);
|
|
|
|
|
@@ -1360,9 +1360,9 @@ private:
|
|
|
|
|
|
|
|
|
|
// TODO: tmp until backend sampling is fully implemented
|
|
|
|
|
if (backend_sampling) {
|
|
|
|
|
llama_set_sampler(ctx_main, slot.id, common_sampler_get(slot.smpl.get()));
|
|
|
|
|
llama_set_sampler(ctx_tgt, slot.id, common_sampler_get(slot.smpl.get()));
|
|
|
|
|
} else {
|
|
|
|
|
llama_set_sampler(ctx_main, slot.id, nullptr);
|
|
|
|
|
llama_set_sampler(ctx_tgt, slot.id, nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
|
|
|
|
@@ -1535,13 +1535,13 @@ private:
|
|
|
|
|
for (size_t i = 0; i < n_probs; i++) {
|
|
|
|
|
result.probs.push_back({
|
|
|
|
|
cur_p->data[i].id,
|
|
|
|
|
common_token_to_piece(ctx_main, cur_p->data[i].id, special),
|
|
|
|
|
common_token_to_piece(ctx_tgt, cur_p->data[i].id, special),
|
|
|
|
|
cur_p->data[i].p
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// TODO: optimize this with min-p optimization
|
|
|
|
|
std::vector<llama_token_data> cur = get_token_probabilities(ctx_main, idx);
|
|
|
|
|
std::vector<llama_token_data> cur = get_token_probabilities(ctx_tgt, idx);
|
|
|
|
|
const size_t max_probs = cur.size();
|
|
|
|
|
const size_t n_probs = std::min(max_probs, n_probs_request);
|
|
|
|
|
|
|
|
|
|
@@ -1559,7 +1559,7 @@ private:
|
|
|
|
|
for (size_t i = 0; i < n_probs; i++) {
|
|
|
|
|
result.probs.push_back({
|
|
|
|
|
cur[i].id,
|
|
|
|
|
common_token_to_piece(ctx_main, cur[i].id, special),
|
|
|
|
|
common_token_to_piece(ctx_tgt, cur[i].id, special),
|
|
|
|
|
cur[i].p
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
@@ -1662,7 +1662,7 @@ private:
|
|
|
|
|
res->tokens = std::move(slot.generated_tokens);
|
|
|
|
|
}
|
|
|
|
|
res->timings = slot.get_timings();
|
|
|
|
|
res->prompt = slot.task->tokens.detokenize(ctx_main, true);
|
|
|
|
|
res->prompt = slot.task->tokens.detokenize(ctx_tgt, true);
|
|
|
|
|
res->response_fields = std::move(slot.task->params.response_fields);
|
|
|
|
|
|
|
|
|
|
res->truncated = slot.truncated;
|
|
|
|
|
@@ -1685,7 +1685,7 @@ private:
|
|
|
|
|
// populate res.probs_output
|
|
|
|
|
if (slot.task->params.sampling.n_probs > 0) {
|
|
|
|
|
if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) {
|
|
|
|
|
const llama_tokens stop_word_toks = common_tokenize(ctx_main, slot.stopping_word, false);
|
|
|
|
|
const llama_tokens stop_word_toks = common_tokenize(ctx_tgt, slot.stopping_word, false);
|
|
|
|
|
|
|
|
|
|
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
|
|
|
|
res->probs_output = std::vector<completion_token_output>(
|
|
|
|
|
@@ -1710,7 +1710,7 @@ private:
|
|
|
|
|
res->n_tokens = slot.task->n_tokens();
|
|
|
|
|
res->res_type = slot.task->params.res_type;
|
|
|
|
|
|
|
|
|
|
const int n_embd_out = llama_model_n_embd_out(model_main);
|
|
|
|
|
const int n_embd_out = llama_model_n_embd_out(model_tgt);
|
|
|
|
|
|
|
|
|
|
std::vector<float> embd_res(n_embd_out, 0.0f);
|
|
|
|
|
|
|
|
|
|
@@ -1720,10 +1720,10 @@ private:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const float * embd = nullptr;
|
|
|
|
|
if (llama_pooling_type(slot.ctx_main) == LLAMA_POOLING_TYPE_NONE) {
|
|
|
|
|
embd = llama_get_embeddings_ith(slot.ctx_main, i);
|
|
|
|
|
if (llama_pooling_type(slot.ctx_tgt) == LLAMA_POOLING_TYPE_NONE) {
|
|
|
|
|
embd = llama_get_embeddings_ith(slot.ctx_tgt, i);
|
|
|
|
|
} else {
|
|
|
|
|
embd = llama_get_embeddings_seq(slot.ctx_main, batch.seq_id[i][0]);
|
|
|
|
|
embd = llama_get_embeddings_seq(slot.ctx_tgt, batch.seq_id[i][0]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (embd == nullptr) {
|
|
|
|
|
@@ -1734,7 +1734,7 @@ private:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// normalize only when there is pooling
|
|
|
|
|
if (llama_pooling_type(slot.ctx_main) != LLAMA_POOLING_TYPE_NONE) {
|
|
|
|
|
if (llama_pooling_type(slot.ctx_tgt) != LLAMA_POOLING_TYPE_NONE) {
|
|
|
|
|
common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize);
|
|
|
|
|
res->embedding.push_back(embd_res);
|
|
|
|
|
break;
|
|
|
|
|
@@ -1759,9 +1759,9 @@ private:
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const float * embd = llama_get_embeddings_seq(ctx_main, batch.seq_id[i][0]);
|
|
|
|
|
const float * embd = llama_get_embeddings_seq(ctx_tgt, batch.seq_id[i][0]);
|
|
|
|
|
if (embd == NULL) {
|
|
|
|
|
embd = llama_get_embeddings_ith(ctx_main, i);
|
|
|
|
|
embd = llama_get_embeddings_ith(ctx_tgt, i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (embd == NULL) {
|
|
|
|
|
@@ -1875,8 +1875,8 @@ private:
|
|
|
|
|
|
|
|
|
|
cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
|
|
|
|
|
|
|
|
|
|
cur.update_main(ctx_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
|
cur.update_drft(ctx_drft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
|
cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
|
cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
|
|
|
|
|
|
SLT_WRN(slot,
|
|
|
|
|
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
|
|
|
|
@@ -2036,7 +2036,7 @@ private:
|
|
|
|
|
std::string filepath = task.slot_action.filepath;
|
|
|
|
|
|
|
|
|
|
const llama_tokens & tokens = slot->prompt.tokens.get_tokens();
|
|
|
|
|
const size_t nwrite = llama_state_seq_save_file(ctx_main, filepath.c_str(), slot->id, tokens.data(), token_count);
|
|
|
|
|
const size_t nwrite = llama_state_seq_save_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), token_count);
|
|
|
|
|
|
|
|
|
|
const int64_t t_end = ggml_time_us();
|
|
|
|
|
const double t_save_ms = (t_end - t_start) / 1000.0;
|
|
|
|
|
@@ -2075,7 +2075,7 @@ private:
|
|
|
|
|
llama_tokens tokens;
|
|
|
|
|
tokens.resize(slot->n_ctx);
|
|
|
|
|
size_t token_count = 0;
|
|
|
|
|
size_t nread = llama_state_seq_load_file(ctx_main, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
|
|
|
|
|
size_t nread = llama_state_seq_load_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
|
|
|
|
|
if (nread == 0) {
|
|
|
|
|
slot->prompt.tokens.clear(); // KV may already been invalidated?
|
|
|
|
|
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
|
|
|
|
@@ -2234,12 +2234,12 @@ private:
|
|
|
|
|
|
|
|
|
|
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
|
|
|
|
|
|
|
|
|
llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, n_keep , n_keep + n_discard);
|
|
|
|
|
llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
|
|
|
|
|
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard);
|
|
|
|
|
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
|
|
|
|
|
|
|
|
|
|
if (ctx_drft) {
|
|
|
|
|
llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, n_keep , n_keep + n_discard);
|
|
|
|
|
llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
|
|
|
|
|
if (ctx_dft) {
|
|
|
|
|
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard);
|
|
|
|
|
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// add generated tokens to cache
|
|
|
|
|
@@ -2287,13 +2287,13 @@ private:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
slot.update_batch(batch,
|
|
|
|
|
ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL,
|
|
|
|
|
ctx_drft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
|
|
|
|
|
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL,
|
|
|
|
|
ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// process in chunks of params.n_batch
|
|
|
|
|
int32_t n_batch = llama_n_batch(ctx_main);
|
|
|
|
|
int32_t n_ubatch = llama_n_ubatch(ctx_main);
|
|
|
|
|
int32_t n_batch = llama_n_batch(ctx_tgt);
|
|
|
|
|
int32_t n_ubatch = llama_n_ubatch(ctx_tgt);
|
|
|
|
|
|
|
|
|
|
float alora_scale = -1.0f;
|
|
|
|
|
size_t alora_disabled_id = 0;
|
|
|
|
|
@@ -2337,12 +2337,12 @@ private:
|
|
|
|
|
/*if (1) {
|
|
|
|
|
// first 16 tokens (avoid flooding logs)
|
|
|
|
|
for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
|
|
|
|
|
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_main, input_tokens[i]).c_str());
|
|
|
|
|
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str());
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// all
|
|
|
|
|
for (int i = 0; i < (int) input_tokens.size(); i++) {
|
|
|
|
|
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_main, input_tokens[i]).c_str());
|
|
|
|
|
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str());
|
|
|
|
|
}
|
|
|
|
|
}*/
|
|
|
|
|
|
|
|
|
|
@@ -2361,7 +2361,7 @@ private:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: support memory-less logits computation
|
|
|
|
|
if (slot.task->need_logits() && !llama_get_memory(ctx_main)) {
|
|
|
|
|
if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) {
|
|
|
|
|
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
|
|
|
|
|
slot.release();
|
|
|
|
|
continue;
|
|
|
|
|
@@ -2413,7 +2413,7 @@ private:
|
|
|
|
|
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
|
|
|
|
|
|
|
|
|
|
const bool can_cache_reuse =
|
|
|
|
|
llama_memory_can_shift(llama_get_memory(ctx_main)) &&
|
|
|
|
|
llama_memory_can_shift(llama_get_memory(ctx_tgt)) &&
|
|
|
|
|
!slot.prompt.tokens.has_mtmd;
|
|
|
|
|
|
|
|
|
|
if (!can_cache_reuse && n_cache_reuse > 0) {
|
|
|
|
|
@@ -2447,17 +2447,17 @@ private:
|
|
|
|
|
if (n_match >= (size_t) n_cache_reuse) {
|
|
|
|
|
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
|
|
|
|
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
|
|
|
|
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_main, prompt_tokens[i]).c_str());
|
|
|
|
|
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_tgt, prompt_tokens[i]).c_str());
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
|
|
|
|
|
|
|
|
|
llama_memory_seq_rm (llama_get_memory(ctx_main), slot.id, head_p, head_c);
|
|
|
|
|
llama_memory_seq_add(llama_get_memory(ctx_main), slot.id, head_c, head_c + n_match, kv_shift);
|
|
|
|
|
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c);
|
|
|
|
|
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift);
|
|
|
|
|
|
|
|
|
|
if (ctx_drft) {
|
|
|
|
|
llama_memory_seq_rm (llama_get_memory(ctx_drft.get()), slot.id, head_p, head_c);
|
|
|
|
|
llama_memory_seq_add(llama_get_memory(ctx_drft.get()), slot.id, head_c, head_c + n_match, kv_shift);
|
|
|
|
|
if (ctx_dft) {
|
|
|
|
|
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c);
|
|
|
|
|
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < n_match; i++) {
|
|
|
|
|
@@ -2485,7 +2485,7 @@ private:
|
|
|
|
|
const auto pos_min_thold = std::max(0, pos_next - n_swa);
|
|
|
|
|
|
|
|
|
|
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
|
|
|
|
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_main), slot.id);
|
|
|
|
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
|
|
|
|
|
if (pos_min == -1) {
|
|
|
|
|
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
|
|
|
|
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
|
|
|
|
@@ -2514,14 +2514,14 @@ private:
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
const auto token = slot.prompt.tokens[i];
|
|
|
|
|
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_main, token) : "[mtmd]";
|
|
|
|
|
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]";
|
|
|
|
|
ss0 << piece;
|
|
|
|
|
st0 << std::setw(8) << token;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
const auto token = slot.task->tokens[i];
|
|
|
|
|
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_main, token) : "[mtmd]";
|
|
|
|
|
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]";
|
|
|
|
|
ss1 << piece;
|
|
|
|
|
st1 << std::setw(8) << token;
|
|
|
|
|
}
|
|
|
|
|
@@ -2554,8 +2554,8 @@ private:
|
|
|
|
|
if (!do_reset) {
|
|
|
|
|
// restore the context checkpoint
|
|
|
|
|
|
|
|
|
|
it->load_main(ctx_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
|
it->load_drft(ctx_drft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
|
it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
|
it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
|
|
|
|
|
|
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
|
|
|
|
|
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
|
|
|
|
|
@@ -2616,7 +2616,7 @@ private:
|
|
|
|
|
|
|
|
|
|
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
|
|
|
|
|
|
|
|
|
|
if (!llama_memory_seq_rm(llama_get_memory(ctx_main), slot.id, p0, -1)) {
|
|
|
|
|
if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) {
|
|
|
|
|
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
|
|
|
|
|
|
|
|
|
slot.prompt_clear(true);
|
|
|
|
|
@@ -2624,7 +2624,7 @@ private:
|
|
|
|
|
// there is no common part left
|
|
|
|
|
slot.n_prompt_tokens_cache = 0;
|
|
|
|
|
} else {
|
|
|
|
|
if (ctx_drft && !llama_memory_seq_rm(llama_get_memory(ctx_drft.get()), slot.id, p0, -1)) {
|
|
|
|
|
if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) {
|
|
|
|
|
GGML_ABORT("failed to truncate draft context\n");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -2653,7 +2653,7 @@ private:
|
|
|
|
|
// - the model does not support partial sequence removal
|
|
|
|
|
// - the model uses SWA (and we are not using `swa_full`)
|
|
|
|
|
do_checkpoint = do_checkpoint && (
|
|
|
|
|
(ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
|
|
|
|
|
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
|
|
|
|
|
(n_swa > 0));
|
|
|
|
|
|
|
|
|
|
bool has_mtmd = false;
|
|
|
|
|
@@ -2662,7 +2662,7 @@ private:
|
|
|
|
|
while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
|
|
|
|
// process the image
|
|
|
|
|
size_t n_tokens_out = 0;
|
|
|
|
|
int32_t res = input_tokens.process_chunk(ctx_main, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
|
|
|
|
|
int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
|
|
|
|
|
if (res != 0) {
|
|
|
|
|
SLT_ERR(slot, "failed to process image, res = %d\n", res);
|
|
|
|
|
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
|
|
|
|
|
@@ -2670,10 +2670,10 @@ private:
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx_drft) {
|
|
|
|
|
if (ctx_dft) {
|
|
|
|
|
// TODO: in the future, figure out how to infuse target embeddings to the images
|
|
|
|
|
// for now, we skip this for simplicity
|
|
|
|
|
res = input_tokens.process_chunk(ctx_drft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
|
|
|
|
|
res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
|
|
|
|
|
if (res != 0) {
|
|
|
|
|
GGML_ABORT("failed to process multi-modal data on draft context\n");
|
|
|
|
|
}
|
|
|
|
|
@@ -2780,8 +2780,8 @@ private:
|
|
|
|
|
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_main), slot.id);
|
|
|
|
|
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_main), slot.id);
|
|
|
|
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
|
|
|
|
|
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
|
|
|
|
|
|
|
|
|
|
// no need for empty or small checkpoints
|
|
|
|
|
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64);
|
|
|
|
|
@@ -2814,7 +2814,7 @@ private:
|
|
|
|
|
|
|
|
|
|
if (slot_batched) {
|
|
|
|
|
// apply lora, only need to do it once per batch
|
|
|
|
|
common_set_adapter_lora(ctx_main, slot_batched->lora);
|
|
|
|
|
common_set_adapter_lora(ctx_tgt, slot_batched->lora);
|
|
|
|
|
|
|
|
|
|
// if the lora is temporarily disabled for an alora, re-enable it
|
|
|
|
|
// for next time
|
|
|
|
|
@@ -2823,7 +2823,7 @@ private:
|
|
|
|
|
slot_batched->lora[alora_disabled_id].scale = alora_scale;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llama_set_embeddings(ctx_main, slot_batched->task->need_embd());
|
|
|
|
|
llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (batch.n_tokens == 0) {
|
|
|
|
|
@@ -2852,7 +2852,7 @@ private:
|
|
|
|
|
batch.logits + i,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
const int ret = llama_decode(ctx_main, batch_view);
|
|
|
|
|
const int ret = llama_decode(ctx_tgt, batch_view);
|
|
|
|
|
|
|
|
|
|
metrics.on_decoded(slots);
|
|
|
|
|
|
|
|
|
|
@@ -2917,12 +2917,12 @@ private:
|
|
|
|
|
// | Eagle3 | yes |
|
|
|
|
|
// | DFlash | yes? |
|
|
|
|
|
//
|
|
|
|
|
if (ctx_drft) {
|
|
|
|
|
if (ctx_dft) {
|
|
|
|
|
// TODO: update as needed for MTP, Eagle3, etc.
|
|
|
|
|
const bool need_tgt_embd = false;
|
|
|
|
|
|
|
|
|
|
if (need_tgt_embd) {
|
|
|
|
|
llama_synchronize(ctx_main);
|
|
|
|
|
llama_synchronize(ctx_tgt);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// the logic here varies depending on the speculative decoding method
|
|
|
|
|
@@ -2931,13 +2931,13 @@ private:
|
|
|
|
|
// TODO: extract this in a function ?
|
|
|
|
|
{
|
|
|
|
|
// TODO: hook the embeddings from the last target batch here
|
|
|
|
|
if (llama_model_has_encoder(model_drft.get())) {
|
|
|
|
|
//llama_encode(ctx_drft, ...);
|
|
|
|
|
if (llama_model_has_encoder(model_dft.get())) {
|
|
|
|
|
//llama_encode(ctx_dft, ...);
|
|
|
|
|
|
|
|
|
|
GGML_ABORT("not implemented yet\n");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int ret = llama_decode(ctx_drft.get(), batch_view);
|
|
|
|
|
const int ret = llama_decode(ctx_dft.get(), batch_view);
|
|
|
|
|
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
SRV_ERR("failed to decode draft batch, ret = %d\n", ret);
|
|
|
|
|
@@ -2952,7 +2952,7 @@ private:
|
|
|
|
|
i_next = i + n_tokens;
|
|
|
|
|
|
|
|
|
|
// on successful decode, restore the original batch size
|
|
|
|
|
n_batch = llama_n_batch(ctx_main);
|
|
|
|
|
n_batch = llama_n_batch(ctx_tgt);
|
|
|
|
|
|
|
|
|
|
// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
|
|
|
|
|
for (auto & slot : slots) {
|
|
|
|
|
@@ -3023,7 +3023,7 @@ private:
|
|
|
|
|
|
|
|
|
|
const int tok_idx = slot.i_batch - i;
|
|
|
|
|
|
|
|
|
|
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_main, tok_idx);
|
|
|
|
|
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
|
|
|
|
|
|
|
|
|
|
slot.i_batch = -1;
|
|
|
|
|
|
|
|
|
|
@@ -3044,7 +3044,7 @@ private:
|
|
|
|
|
|
|
|
|
|
completion_token_output result;
|
|
|
|
|
result.tok = id;
|
|
|
|
|
result.text_to_send = common_token_to_piece(slot.ctx_main, result.tok, accept_special_token(slot, result.tok));
|
|
|
|
|
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
|
|
|
|
|
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
|
|
|
|
|
|
|
|
|
if (slot.task->params.sampling.n_probs > 0) {
|
|
|
|
|
@@ -3075,23 +3075,23 @@ private:
|
|
|
|
|
|
|
|
|
|
// verify and try to accept the draft
|
|
|
|
|
{
|
|
|
|
|
const bool use_ckpt_main = ctx_main_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
|
|
|
|
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
|
|
|
|
|
|
|
|
|
// only save the sampler sampler state if we use checkpoints
|
|
|
|
|
common_sampler_ptr smpl_save;
|
|
|
|
|
if (use_ckpt_main) {
|
|
|
|
|
if (use_ckpt_tgt) {
|
|
|
|
|
smpl_save.reset(common_sampler_clone(slot.smpl.get()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
|
|
|
|
|
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_main, slot.spec_i_batch, slot.spec_draft);
|
|
|
|
|
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft);
|
|
|
|
|
slot.spec_i_batch.clear();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(accepted.size() >= 1);
|
|
|
|
|
|
|
|
|
|
// check for partial draft acceptance
|
|
|
|
|
if (accepted.size() < slot.spec_draft.size() + 1) {
|
|
|
|
|
if (use_ckpt_main) {
|
|
|
|
|
if (use_ckpt_tgt) {
|
|
|
|
|
if (trace > 0) {
|
|
|
|
|
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
|
|
|
|
|
}
|
|
|
|
|
@@ -3104,15 +3104,15 @@ private:
|
|
|
|
|
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
ckpt.load_main(slot.ctx_main, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
|
|
|
|
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
|
|
|
|
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, ckpt.pos_max + 1, -1);
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (slot.ctx_drft) {
|
|
|
|
|
ckpt.load_drft(slot.ctx_drft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
|
|
|
|
if (slot.ctx_dft) {
|
|
|
|
|
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
|
|
|
|
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, ckpt.pos_max + 1, -1);
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
slot.prompt.tokens.keep_first(ckpt.n_tokens);
|
|
|
|
|
@@ -3148,16 +3148,16 @@ private:
|
|
|
|
|
slot.sampled = ids.back(); // last accepted token
|
|
|
|
|
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
|
|
|
|
|
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(slot.ctx_main), slot.id, slot.prompt.tokens.pos_next(), -1);
|
|
|
|
|
if (slot.ctx_drft) {
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(slot.ctx_drft), slot.id, slot.prompt.tokens.pos_next(), -1);
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1);
|
|
|
|
|
if (slot.ctx_dft) {
|
|
|
|
|
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < ids.size(); ++i) {
|
|
|
|
|
completion_token_output result;
|
|
|
|
|
|
|
|
|
|
result.tok = ids[i];
|
|
|
|
|
result.text_to_send = common_token_to_piece(slot.ctx_main, result.tok, accept_special_token(slot, result.tok));
|
|
|
|
|
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
|
|
|
|
|
result.prob = 1.0f; // set later
|
|
|
|
|
|
|
|
|
|
// TODO: set result.probs
|
|
|
|
|
@@ -3209,7 +3209,7 @@ void server_context::terminate() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llama_context * server_context::get_llama_context() const {
|
|
|
|
|
return impl->ctx_main;
|
|
|
|
|
return impl->ctx_tgt;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
server_response_reader server_context::get_response_reader() {
|
|
|
|
|
@@ -3219,8 +3219,8 @@ server_response_reader server_context::get_response_reader() {
|
|
|
|
|
server_context_meta server_context::get_meta() const {
|
|
|
|
|
auto bos_id = llama_vocab_bos(impl->vocab);
|
|
|
|
|
auto eos_id = llama_vocab_eos(impl->vocab);
|
|
|
|
|
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_main, bos_id, true) : "";
|
|
|
|
|
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_main, eos_id, true) : "";
|
|
|
|
|
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : "";
|
|
|
|
|
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : "";
|
|
|
|
|
|
|
|
|
|
return server_context_meta {
|
|
|
|
|
/* build_info */ std::string(llama_build_info()),
|
|
|
|
|
@@ -3233,7 +3233,7 @@ server_context_meta server_context::get_meta() const {
|
|
|
|
|
/* has_inp_audio */ impl->chat_params.allow_audio,
|
|
|
|
|
/* json_webui_settings */ impl->json_webui_settings,
|
|
|
|
|
/* slot_n_ctx */ impl->get_slot_n_ctx(),
|
|
|
|
|
/* pooling_type */ llama_pooling_type(impl->ctx_main),
|
|
|
|
|
/* pooling_type */ llama_pooling_type(impl->ctx_tgt),
|
|
|
|
|
|
|
|
|
|
/* chat_params */ impl->chat_params,
|
|
|
|
|
/* chat_template_caps */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()),
|
|
|
|
|
@@ -3251,10 +3251,10 @@ server_context_meta server_context::get_meta() const {
|
|
|
|
|
|
|
|
|
|
/* model_vocab_type */ llama_vocab_type(impl->vocab),
|
|
|
|
|
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
|
|
|
|
|
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_main),
|
|
|
|
|
/* model_n_embd_inp */ llama_model_n_embd(impl->model_main),
|
|
|
|
|
/* model_n_params */ llama_model_n_params(impl->model_main),
|
|
|
|
|
/* model_size */ llama_model_size(impl->model_main),
|
|
|
|
|
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_tgt),
|
|
|
|
|
/* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt),
|
|
|
|
|
/* model_n_params */ llama_model_n_params(impl->model_tgt),
|
|
|
|
|
/* model_size */ llama_model_size(impl->model_tgt),
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -4156,7 +4156,7 @@ void server_routes::init_routes() {
|
|
|
|
|
std::vector<server_task> tasks;
|
|
|
|
|
tasks.reserve(documents.size());
|
|
|
|
|
for (size_t i = 0; i < documents.size(); i++) {
|
|
|
|
|
auto tmp = format_prompt_rerank(ctx_server.model_main, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
|
|
|
|
|
auto tmp = format_prompt_rerank(ctx_server.model_tgt, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
|
|
|
|
|
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
|
|
|
|
task.id = rd.get_new_id();
|
|
|
|
|
task.tokens = std::move(tmp);
|
|
|
|
|
|