mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-10 19:14:07 +00:00
This commit removes the `--backend-dist` option and instead uses the
configured --samplers chain to determine which samplers run on the
backend.
Backend sampling is still enabled using With `--backend_sampling`, and
the sampler chain, either explictly specified using `--samplers` or the
default, is automatically analyzed to determine which samplers can run
on the backend. The system finds the longest contiguous chain of
backend supported samplers from the start of the sampler sequence.
For example:
* If the chain is `top-k -> temperature -> top-p`, and both `top-k` and
`temperature` are backend-supported but `top-p` is not, then `top-k`
and `temperature` will run on the backend, while `top-p` and
subsequent samplers run on the CPU.
* If all configured samplers are supported, the final distribution
sampling will also happen on the backend, transferring only the
sampled token IDs back to the host.
* If the sampler chain starts with an unsupported sampler (e.g.,
`penalties`), all sampling runs on the CPU. Note that this is
currently the case with the default sampler so to use backend sampling
it is required to specify a sampler chain. See below for an example.
The following shows how llama-cli can be run with backend sampling:
```console
$ llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
--prompt 'What is the capital of Sweden?' \
-n 20 \
-no-cnv \
--verbose-prompt \
-ngl 40 \
--backend-sampling \
--samplers 'top_k;temperature'
```
In this case the all sampling will happen on the backend since both
`top_k` and `temperature` are supported backend samplers.
To enable a partial backend sampling (hybrid sampling), for example
running `top_k` and `temperature` on the backend and `typ_p` on the CPU
the following sampler chain could be specified:
```console
$ llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
--prompt 'What is the capital of Sweden?' \
-n 20 \
-no-cnv \
--verbose-prompt \
-ngl 40 \
--backend-sampling \
--samplers 'top_k;temperature;top_p'
```
If this looks good then I'll follow up with updates the llama-cli and
llama-server documentation to reflect these changes.
3864 lines
158 KiB
C++
3864 lines
158 KiB
C++
#include "server-common.h"
|
|
#include "server-http.h"
|
|
#include "server-task.h"
|
|
#include "server-queue.h"
|
|
|
|
#include "arg.h"
|
|
#include "common.h"
|
|
#include "llama.h"
|
|
#include "log.h"
|
|
#include "sampling.h"
|
|
#include "speculative.h"
|
|
#include "mtmd.h"
|
|
#include "mtmd-helper.h"
|
|
|
|
#include <atomic>
|
|
#include <cstddef>
|
|
#include <cinttypes>
|
|
#include <memory>
|
|
#include <signal.h>
|
|
#include <thread>
|
|
#include <unordered_set>
|
|
|
|
// fix problem with std::min and std::max
|
|
#if defined(_WIN32)
|
|
#define WIN32_LEAN_AND_MEAN
|
|
#ifndef NOMINMAX
|
|
# define NOMINMAX
|
|
#endif
|
|
#include <windows.h>
|
|
#endif
|
|
|
|
using json = nlohmann::ordered_json;
|
|
|
|
constexpr int HTTP_POLLING_SECONDS = 1;
|
|
|
|
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
|
enum slot_state {
|
|
SLOT_STATE_IDLE,
|
|
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
|
SLOT_STATE_PROCESSING_PROMPT,
|
|
SLOT_STATE_DONE_PROMPT,
|
|
SLOT_STATE_GENERATING,
|
|
};
|
|
|
|
enum server_state {
|
|
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
|
|
SERVER_STATE_READY, // Server is ready and model is loaded
|
|
};
|
|
|
|
static bool server_task_type_need_embd(server_task_type task_type) {
|
|
switch (task_type) {
|
|
case SERVER_TASK_TYPE_EMBEDDING:
|
|
case SERVER_TASK_TYPE_RERANK:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
static bool server_task_type_need_logits(server_task_type task_type) {
|
|
switch (task_type) {
|
|
case SERVER_TASK_TYPE_COMPLETION:
|
|
case SERVER_TASK_TYPE_INFILL:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
|
|
struct server_slot {
|
|
int id;
|
|
|
|
llama_batch batch_spec = {};
|
|
|
|
// TODO: change to unique_ptrs for consistency:
|
|
llama_context * ctx = nullptr;
|
|
llama_context * ctx_dft = nullptr;
|
|
|
|
// multimodal
|
|
mtmd_context * mctx = nullptr;
|
|
|
|
common_speculative * spec = nullptr;
|
|
|
|
std::unique_ptr<const server_task> task;
|
|
std::unique_ptr<const server_task> task_prev; // used for debugging
|
|
|
|
// used to determine the slot that has been used the longest
|
|
int64_t t_last_used = -1;
|
|
|
|
// generation props
|
|
int32_t n_ctx = 0; // context size per slot
|
|
int32_t n_keep = 0;
|
|
int32_t n_decoded = 0;
|
|
int32_t n_remaining = -1;
|
|
int32_t i_batch = -1;
|
|
|
|
int32_t n_prompt_tokens_cache = 0;
|
|
int32_t n_prompt_tokens_processed = 0;
|
|
|
|
size_t last_nl_pos = 0;
|
|
|
|
std::string generated_text;
|
|
llama_tokens generated_tokens;
|
|
|
|
common_chat_msg chat_msg;
|
|
|
|
std::vector<completion_token_output> generated_token_probs;
|
|
|
|
bool has_next_token = true;
|
|
bool has_new_line = false;
|
|
bool truncated = false;
|
|
|
|
stop_type stop;
|
|
|
|
std::string stopping_word;
|
|
|
|
// state
|
|
slot_state state = SLOT_STATE_IDLE;
|
|
|
|
server_prompt prompt;
|
|
|
|
void prompt_save(server_prompt_cache & prompt_cache) const {
|
|
GGML_ASSERT(prompt.data.size() == 0);
|
|
|
|
const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
|
|
|
|
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
|
|
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
|
|
|
|
auto * cur = prompt_cache.alloc(prompt, cur_size);
|
|
if (cur == nullptr) {
|
|
return;
|
|
}
|
|
|
|
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
|
|
}
|
|
|
|
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
|
bool res = prompt_cache.load(prompt, tokens, ctx, id);
|
|
if (!res) {
|
|
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
std::vector<common_adapter_lora_info> lora;
|
|
int32_t alora_invocation_start = -1;
|
|
|
|
// sampling
|
|
json json_schema;
|
|
|
|
struct common_sampler * smpl = nullptr;
|
|
llama_sampler * backend_sampler = nullptr;
|
|
|
|
llama_token sampled;
|
|
|
|
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
std::vector<std::string> generated_tool_call_ids;
|
|
|
|
// stats
|
|
size_t n_sent_text = 0; // number of sent text character
|
|
|
|
int64_t t_start_process_prompt;
|
|
int64_t t_start_generation;
|
|
|
|
double t_prompt_processing; // ms
|
|
double t_token_generation; // ms
|
|
|
|
std::function<void(int)> callback_on_release;
|
|
|
|
// Speculative decoding stats
|
|
int32_t n_draft_total = 0; // Total draft tokens generated
|
|
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
|
|
|
|
void reset() {
|
|
SLT_DBG(*this, "%s", "\n");
|
|
|
|
n_prompt_tokens_cache = 0;
|
|
|
|
last_nl_pos = 0;
|
|
generated_text = "";
|
|
has_new_line = false;
|
|
truncated = false;
|
|
stop = STOP_TYPE_NONE;
|
|
stopping_word = "";
|
|
n_sent_text = 0;
|
|
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
|
|
generated_tokens.clear();
|
|
generated_token_probs.clear();
|
|
chat_msg = {};
|
|
json_schema = json();
|
|
generated_tool_call_ids.clear();
|
|
|
|
// clear speculative decoding stats
|
|
n_draft_total = 0;
|
|
n_draft_accepted = 0;
|
|
|
|
if (backend_sampler != nullptr) {
|
|
if (ctx != nullptr) {
|
|
llama_set_backend_sampler(ctx, id, nullptr);
|
|
}
|
|
backend_sampler = nullptr;
|
|
}
|
|
|
|
task.reset();
|
|
task_prev.reset();
|
|
|
|
// clear alora start
|
|
alora_invocation_start = -1;
|
|
}
|
|
|
|
bool need_embd() const {
|
|
GGML_ASSERT(task);
|
|
|
|
return server_task_type_need_embd(task->type);
|
|
}
|
|
|
|
bool need_logits() const {
|
|
GGML_ASSERT(task);
|
|
|
|
return server_task_type_need_logits(task->type);
|
|
}
|
|
|
|
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
|
|
// also we cannot split if the pooling would require any past tokens
|
|
bool can_split() const {
|
|
return
|
|
!need_embd() ||
|
|
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
|
|
}
|
|
|
|
bool can_batch_with(server_slot & other_slot) const {
|
|
GGML_ASSERT(task);
|
|
|
|
return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora);
|
|
}
|
|
|
|
bool has_budget(const common_params & global_params) {
|
|
GGML_ASSERT(task);
|
|
|
|
if (task->params.n_predict == -1 && global_params.n_predict == -1) {
|
|
return true; // limitless
|
|
}
|
|
|
|
n_remaining = -1;
|
|
|
|
if (task->params.n_predict != -1) {
|
|
n_remaining = task->params.n_predict - n_decoded;
|
|
} else if (global_params.n_predict != -1) {
|
|
n_remaining = global_params.n_predict - n_decoded;
|
|
}
|
|
|
|
return n_remaining > 0; // no budget
|
|
}
|
|
|
|
bool is_processing() const {
|
|
return state != SLOT_STATE_IDLE;
|
|
}
|
|
|
|
bool can_speculate() const {
|
|
return ctx_dft;
|
|
}
|
|
|
|
void add_token(const completion_token_output & token) {
|
|
if (!is_processing()) {
|
|
SLT_WRN(*this, "%s", "slot is not processing\n");
|
|
return;
|
|
}
|
|
generated_token_probs.push_back(token);
|
|
}
|
|
|
|
void release() {
|
|
if (is_processing()) {
|
|
GGML_ASSERT(task);
|
|
|
|
SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
|
|
|
|
t_last_used = ggml_time_us();
|
|
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
|
state = SLOT_STATE_IDLE;
|
|
|
|
task_prev = std::move(task);
|
|
task.reset();
|
|
|
|
callback_on_release(id);
|
|
}
|
|
}
|
|
|
|
result_timings get_timings() const {
|
|
result_timings timings;
|
|
timings.cache_n = n_prompt_tokens_cache;
|
|
|
|
timings.prompt_n = n_prompt_tokens_processed;
|
|
timings.prompt_ms = t_prompt_processing;
|
|
timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
|
|
timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
|
|
timings.predicted_n = n_decoded;
|
|
timings.predicted_ms = t_token_generation;
|
|
timings.predicted_per_token_ms = t_token_generation / n_decoded;
|
|
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
|
|
|
|
// Add speculative metrics
|
|
if (n_draft_total > 0) {
|
|
timings.draft_n = n_draft_total;
|
|
timings.draft_n_accepted = n_draft_accepted;
|
|
}
|
|
|
|
return timings;
|
|
}
|
|
|
|
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
|
|
GGML_ASSERT(task);
|
|
|
|
auto previous_msg = chat_msg;
|
|
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
|
auto new_msg = common_chat_parse(
|
|
generated_text,
|
|
/* is_partial= */ stop != STOP_TYPE_EOS,
|
|
task->params.oaicompat_chat_syntax);
|
|
if (!new_msg.empty()) {
|
|
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
|
|
chat_msg = new_msg;
|
|
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
|
|
}
|
|
return chat_msg;
|
|
}
|
|
|
|
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
|
GGML_ASSERT(task);
|
|
|
|
size_t stop_pos = std::string::npos;
|
|
|
|
for (const std::string & word : task->params.antiprompt) {
|
|
size_t pos;
|
|
|
|
if (is_full_stop) {
|
|
const size_t tmp = word.size() + last_token_size;
|
|
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
|
|
|
pos = text.find(word, from_pos);
|
|
} else {
|
|
// otherwise, partial stop
|
|
pos = string_find_partial_stop(text, word);
|
|
}
|
|
|
|
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
|
if (is_full_stop) {
|
|
stop = STOP_TYPE_WORD;
|
|
stopping_word = word;
|
|
has_next_token = false;
|
|
}
|
|
stop_pos = pos;
|
|
}
|
|
}
|
|
|
|
return stop_pos;
|
|
}
|
|
|
|
void print_timings() const {
|
|
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
|
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
|
|
const double t_gen = t_token_generation / n_decoded;
|
|
const double n_gen_second = 1e3 / t_token_generation * n_decoded;
|
|
|
|
SLT_INF(*this,
|
|
"\n"
|
|
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
" total time = %10.2f ms / %5d tokens\n",
|
|
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
|
|
t_token_generation, n_decoded, t_gen, n_gen_second,
|
|
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
|
|
|
|
if (n_draft_total > 0) {
|
|
const float draft_ratio = (float) n_draft_accepted / n_draft_total;
|
|
SLT_INF(*this,
|
|
"\n"
|
|
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
|
|
draft_ratio, n_draft_accepted, n_draft_total
|
|
);
|
|
}
|
|
}
|
|
|
|
json to_json(bool only_metrics = false) const {
|
|
json res;
|
|
|
|
res = {
|
|
{"id", id},
|
|
{"n_ctx", n_ctx},
|
|
{"speculative", can_speculate()},
|
|
{"is_processing", is_processing()},
|
|
};
|
|
|
|
const auto & ptask = task ? task : task_prev;
|
|
|
|
if (ptask) {
|
|
res["id_task"] = ptask->id;
|
|
res["params"] = ptask->params.to_json(only_metrics);
|
|
res["next_token"] = {
|
|
{
|
|
{"has_next_token", has_next_token},
|
|
{"has_new_line", has_new_line},
|
|
{"n_remain", n_remaining},
|
|
{"n_decoded", n_decoded},
|
|
}
|
|
};
|
|
|
|
if (!only_metrics) {
|
|
res["prompt"] = ptask->tokens.detokenize(ctx, true);
|
|
res["generated"] = generated_text;
|
|
}
|
|
}
|
|
|
|
return res;
|
|
}
|
|
};
|
|
|
|
struct server_metrics {
|
|
int64_t t_start = 0;
|
|
|
|
uint64_t n_prompt_tokens_processed_total = 0;
|
|
uint64_t t_prompt_processing_total = 0;
|
|
uint64_t n_tokens_predicted_total = 0;
|
|
uint64_t t_tokens_generation_total = 0;
|
|
|
|
uint64_t n_tokens_max = 0;
|
|
|
|
uint64_t n_prompt_tokens_processed = 0;
|
|
uint64_t t_prompt_processing = 0;
|
|
|
|
uint64_t n_tokens_predicted = 0;
|
|
uint64_t t_tokens_generation = 0;
|
|
|
|
uint64_t n_decode_total = 0;
|
|
uint64_t n_busy_slots_total = 0;
|
|
|
|
void init() {
|
|
t_start = ggml_time_us();
|
|
}
|
|
|
|
void on_prompt_eval(const server_slot & slot) {
|
|
n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
|
|
n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
|
|
t_prompt_processing += slot.t_prompt_processing;
|
|
t_prompt_processing_total += slot.t_prompt_processing;
|
|
|
|
n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
|
|
}
|
|
|
|
void on_prediction(const server_slot & slot) {
|
|
n_tokens_predicted_total += slot.n_decoded;
|
|
n_tokens_predicted += slot.n_decoded;
|
|
t_tokens_generation += slot.t_token_generation;
|
|
t_tokens_generation_total += slot.t_token_generation;
|
|
}
|
|
|
|
void on_decoded(const std::vector<server_slot> & slots) {
|
|
n_decode_total++;
|
|
for (const auto & slot : slots) {
|
|
if (slot.is_processing()) {
|
|
n_busy_slots_total++;
|
|
}
|
|
n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
|
|
}
|
|
}
|
|
|
|
void reset_bucket() {
|
|
n_prompt_tokens_processed = 0;
|
|
t_prompt_processing = 0;
|
|
n_tokens_predicted = 0;
|
|
t_tokens_generation = 0;
|
|
}
|
|
};
|
|
|
|
struct server_context {
|
|
common_params params_base;
|
|
|
|
// note: keep these alive - they determine the lifetime of the model, context, etc.
|
|
common_init_result llama_init;
|
|
common_init_result llama_init_dft;
|
|
|
|
llama_model * model = nullptr;
|
|
llama_context * ctx = nullptr;
|
|
|
|
// multimodal
|
|
mtmd_context * mctx = nullptr;
|
|
|
|
const llama_vocab * vocab = nullptr;
|
|
bool vocab_dft_compatible = true;
|
|
|
|
llama_model * model_dft = nullptr;
|
|
|
|
llama_context_params cparams_dft;
|
|
|
|
llama_batch batch {};
|
|
|
|
bool add_bos_token = true;
|
|
|
|
int32_t n_ctx; // total context for all clients / slots
|
|
|
|
// slots / clients
|
|
std::vector<server_slot> slots;
|
|
|
|
int slots_debug = 0;
|
|
|
|
server_queue queue_tasks;
|
|
server_response queue_results;
|
|
|
|
std::unique_ptr<server_prompt_cache> prompt_cache;
|
|
|
|
server_metrics metrics;
|
|
|
|
// Necessary similarity of prompt for slot selection
|
|
float slot_prompt_similarity = 0.0f;
|
|
|
|
common_chat_templates_ptr chat_templates;
|
|
oaicompat_parser_options oai_parser_opt;
|
|
|
|
~server_context() {
|
|
mtmd_free(mctx);
|
|
|
|
// Clear any sampling context
|
|
for (server_slot & slot : slots) {
|
|
common_sampler_free(slot.smpl);
|
|
slot.smpl = nullptr;
|
|
|
|
if (slot.backend_sampler != nullptr) {
|
|
if (ctx != nullptr) {
|
|
llama_set_backend_sampler(ctx, slot.id, nullptr);
|
|
}
|
|
slot.backend_sampler = nullptr;
|
|
}
|
|
|
|
llama_free(slot.ctx_dft);
|
|
slot.ctx_dft = nullptr;
|
|
|
|
common_speculative_free(slot.spec);
|
|
slot.spec = nullptr;
|
|
|
|
llama_batch_free(slot.batch_spec);
|
|
}
|
|
|
|
llama_batch_free(batch);
|
|
}
|
|
|
|
// load the model and initialize llama_context
|
|
bool load_model(const common_params & params) {
|
|
SRV_INF("loading model '%s'\n", params.model.path.c_str());
|
|
|
|
params_base = params;
|
|
|
|
llama_init = common_init_from_params(params_base);
|
|
|
|
model = llama_init.model.get();
|
|
ctx = llama_init.context.get();
|
|
|
|
if (model == nullptr) {
|
|
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
|
|
return false;
|
|
}
|
|
|
|
vocab = llama_model_get_vocab(model);
|
|
|
|
n_ctx = llama_n_ctx(ctx);
|
|
|
|
add_bos_token = llama_vocab_get_add_bos(vocab);
|
|
|
|
if (params_base.has_speculative()) {
|
|
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
|
|
|
|
auto params_dft = params_base;
|
|
|
|
params_dft.devices = params_base.speculative.devices;
|
|
params_dft.model = params_base.speculative.model;
|
|
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
|
|
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
|
params_dft.n_parallel = 1;
|
|
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
|
params_dft.cache_type_v = params_base.speculative.cache_type_v;
|
|
|
|
params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads;
|
|
params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads;
|
|
params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
|
|
|
|
llama_init_dft = common_init_from_params(params_dft);
|
|
|
|
model_dft = llama_init_dft.model.get();
|
|
|
|
if (model_dft == nullptr) {
|
|
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
|
|
return false;
|
|
}
|
|
|
|
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get());
|
|
if (!vocab_dft_compatible) {
|
|
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
|
|
}
|
|
|
|
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
|
|
|
|
cparams_dft = common_context_params_to_llama(params_dft);
|
|
cparams_dft.n_batch = n_ctx_dft;
|
|
|
|
// the context is not needed - we will create one for each slot
|
|
llama_init_dft.context.reset();
|
|
}
|
|
|
|
chat_templates = common_chat_templates_init(model, params_base.chat_template);
|
|
try {
|
|
common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs);
|
|
} catch (const std::exception & e) {
|
|
SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what());
|
|
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
|
chat_templates = common_chat_templates_init(model, "chatml");
|
|
}
|
|
|
|
std::string & mmproj_path = params_base.mmproj.path;
|
|
if (!mmproj_path.empty()) {
|
|
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
|
|
|
mtmd_context_params mparams = mtmd_context_params_default();
|
|
mparams.use_gpu = params_base.mmproj_use_gpu;
|
|
mparams.print_timings = false;
|
|
mparams.n_threads = params_base.cpuparams.n_threads;
|
|
mparams.flash_attn_type = params_base.flash_attn_type;
|
|
mparams.image_min_tokens = params_base.image_min_tokens;
|
|
mparams.image_max_tokens = params_base.image_max_tokens;
|
|
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
|
if (mctx == nullptr) {
|
|
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
|
return false;
|
|
}
|
|
SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str());
|
|
|
|
if (params_base.ctx_shift) {
|
|
params_base.ctx_shift = false;
|
|
SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
|
|
}
|
|
|
|
if (params_base.n_cache_reuse) {
|
|
params_base.n_cache_reuse = 0;
|
|
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
|
}
|
|
|
|
if (params_base.has_speculative()) {
|
|
SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (!llama_memory_can_shift(llama_get_memory(ctx))) {
|
|
if (params_base.ctx_shift) {
|
|
params_base.ctx_shift = false;
|
|
SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled");
|
|
}
|
|
|
|
if (params_base.n_cache_reuse) {
|
|
params_base.n_cache_reuse = 0;
|
|
SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// initialize slots and server-related data
|
|
void init() {
|
|
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
|
|
|
const int n_ctx_train = llama_model_n_ctx_train(model);
|
|
|
|
int n_ctx_slot = llama_n_ctx_seq(ctx);
|
|
if (n_ctx_slot > n_ctx_train) {
|
|
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
|
|
n_ctx_slot = n_ctx_train;
|
|
}
|
|
|
|
for (int i = 0; i < params_base.n_parallel; i++) {
|
|
server_slot slot;
|
|
|
|
slot.id = i;
|
|
slot.ctx = ctx;
|
|
slot.n_ctx = n_ctx_slot;
|
|
slot.mctx = mctx;
|
|
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
|
|
|
if (model_dft) {
|
|
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
|
|
|
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
|
|
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
|
if (slot.ctx_dft == nullptr) {
|
|
SRV_ERR("%s", "failed to create draft context\n");
|
|
return;
|
|
}
|
|
|
|
slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
|
|
if (slot.spec == nullptr) {
|
|
SRV_ERR("%s", "failed to create speculator\n");
|
|
return;
|
|
}
|
|
for (auto & pair : params_base.speculative.replacements) {
|
|
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
|
|
}
|
|
}
|
|
|
|
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
|
|
|
slot.callback_on_release = [this](int) {
|
|
queue_tasks.pop_deferred_task();
|
|
};
|
|
|
|
slot.reset();
|
|
|
|
slots.push_back(std::move(slot));
|
|
}
|
|
|
|
{
|
|
const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG");
|
|
slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0;
|
|
|
|
if (slots_debug) {
|
|
SRV_WRN("slots debug = %d\n", slots_debug);
|
|
}
|
|
}
|
|
|
|
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
|
|
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
|
{
|
|
const int32_t n_batch = llama_n_batch(ctx);
|
|
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
|
}
|
|
|
|
metrics.init();
|
|
|
|
if (params_base.cache_ram_mib != 0) {
|
|
if (params_base.cache_ram_mib < 0) {
|
|
SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit");
|
|
} else {
|
|
SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
|
|
}
|
|
SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n");
|
|
|
|
prompt_cache = std::make_unique<server_prompt_cache>(params_base.cache_ram_mib, n_ctx);
|
|
} else {
|
|
SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
|
|
}
|
|
SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
|
|
|
|
// thinking is enabled if:
|
|
// 1. It's not explicitly disabled (reasoning_budget == 0)
|
|
// 2. The chat template supports it
|
|
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
|
|
SRV_INF("thinking = %d\n", enable_thinking);
|
|
|
|
oai_parser_opt = {
|
|
/* use_jinja */ params_base.use_jinja,
|
|
/* prefill_assistant */ params_base.prefill_assistant,
|
|
/* reasoning_format */ params_base.reasoning_format,
|
|
/* chat_template_kwargs */ params_base.default_template_kwargs,
|
|
/* common_chat_templates */ chat_templates.get(),
|
|
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
|
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
|
|
/* enable_thinking */ enable_thinking,
|
|
};
|
|
|
|
// print sample chat example to make it clear which template is used
|
|
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
|
common_chat_templates_source(chat_templates.get()),
|
|
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
|
|
}
|
|
|
|
server_slot * get_slot_by_id(int id) {
|
|
for (server_slot & slot : slots) {
|
|
if (slot.id == id) {
|
|
return &slot;
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
server_slot * get_available_slot(const server_task & task) {
|
|
server_slot * ret = nullptr;
|
|
|
|
bool update_cache = false;
|
|
|
|
// find the slot that has at least n% prompt similarity
|
|
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
|
|
float sim_best = 0;
|
|
|
|
for (server_slot & slot : slots) {
|
|
// skip the slot if it is not available
|
|
if (slot.is_processing()) {
|
|
continue;
|
|
}
|
|
|
|
const auto & tokens = slot.prompt.tokens;
|
|
|
|
// skip the slot if it does not contains cached tokens
|
|
if (tokens.empty()) {
|
|
continue;
|
|
}
|
|
|
|
// fraction of the Longest Common Prefix length with respect to the input prompt length
|
|
const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size();
|
|
|
|
// select the current slot if the criteria match
|
|
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
|
|
sim_best = sim_cur;
|
|
|
|
ret = &slot;
|
|
}
|
|
}
|
|
|
|
if (ret != nullptr) {
|
|
const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size();
|
|
|
|
SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n",
|
|
sim_best, slot_prompt_similarity, f_keep);
|
|
|
|
// if we are about to lose a large portion of the existing context - save it in the prompt cache
|
|
if (f_keep < 0.5f) {
|
|
update_cache = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// find the slot that has been least recently used
|
|
if (ret == nullptr) {
|
|
int64_t t_last = -1;
|
|
|
|
for (server_slot & slot : slots) {
|
|
// skip the slot if it is not available
|
|
if (slot.is_processing()) {
|
|
continue;
|
|
}
|
|
|
|
// select the current slot if the criteria match
|
|
if (!ret || slot.t_last_used <= t_last) {
|
|
t_last = slot.t_last_used;
|
|
ret = &slot;
|
|
}
|
|
}
|
|
|
|
if (ret != nullptr) {
|
|
SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last);
|
|
|
|
update_cache = true;
|
|
}
|
|
}
|
|
|
|
if (ret) {
|
|
const auto & tokens = ret->prompt.tokens;
|
|
|
|
update_cache = update_cache && prompt_cache;
|
|
|
|
// cache prompts only for completion tasks
|
|
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
|
|
|
|
// don't update the cache if the slot's context is empty
|
|
update_cache = update_cache && tokens.size() > 0;
|
|
|
|
// TODO: mtmd does not support prompt cache
|
|
update_cache = update_cache && (ret->mctx == nullptr);
|
|
|
|
if (update_cache) {
|
|
SRV_WRN("%s", "updating prompt cache\n");
|
|
|
|
const int64_t t_start = ggml_time_us();
|
|
|
|
ret->prompt_save(*prompt_cache);
|
|
|
|
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
|
|
clear_slot(*ret);
|
|
}
|
|
|
|
prompt_cache->update();
|
|
|
|
SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
|
}
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
void clear_slot(server_slot & slot) const {
|
|
GGML_ASSERT(!slot.is_processing());
|
|
|
|
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
|
slot.prompt.tokens.clear();
|
|
}
|
|
|
|
// return true if at least one slot has been cleared
|
|
// TODO: improve logic
|
|
// - smarter decision which slot to clear (LRU or longest prompt?)
|
|
// - move slot to level 2 cache instead of removing?
|
|
// - instead of purging, try to store and resume later?
|
|
bool try_clear_idle_slots() {
|
|
bool res = false;
|
|
|
|
if (!params_base.kv_unified) {
|
|
return res;
|
|
}
|
|
|
|
for (auto & slot : slots) {
|
|
if (slot.is_processing()) {
|
|
continue;
|
|
}
|
|
|
|
if (slot.prompt.n_tokens() > 0) {
|
|
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
|
|
|
clear_slot(slot);
|
|
|
|
res = true;
|
|
|
|
// clear slots one by one
|
|
break;
|
|
}
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
|
slot.reset();
|
|
|
|
if (!are_lora_equal(task.params.lora, slot.lora)) {
|
|
// if lora has changed, check to see if the cache should be cleared
|
|
if (lora_should_clear_cache(slot.lora, task.params.lora)) {
|
|
SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size());
|
|
slot.prompt.tokens.clear();
|
|
} else {
|
|
SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size());
|
|
}
|
|
slot.lora = task.params.lora;
|
|
}
|
|
|
|
// if using alora, make sure it's only a single one requested and active
|
|
size_t alora_invocation_start = task.tokens.size();
|
|
if (lora_all_alora(slot.lora)) {
|
|
const auto & enabled_ids = lora_get_enabled_ids(slot.lora);
|
|
// TODO: This will error out if a user requests two aloras, but only
|
|
// provides the activation string for one. We could, instead search
|
|
// for all requested alora activation strings and then either keep
|
|
// only the last one, or reject if multiple are found.
|
|
if (enabled_ids.size() != 1) {
|
|
send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
const auto & lora = slot.lora[enabled_ids[0]].ptr;
|
|
|
|
// get the pointer and count for the invocation tokens
|
|
const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora);
|
|
const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora);
|
|
|
|
// scan backwards through the prompt tokens to find the last
|
|
// occurrence of the invocation sequence
|
|
int match_idx = static_cast<int>(n_invocation_tokens) - 1;
|
|
for (int i = task.tokens.size() - 1; i >= 0; --i) {
|
|
// the token in this position matches the next token to find in
|
|
// the invocation sequence
|
|
if (task.tokens[i] == invocation_tokens[match_idx]) {
|
|
// if it's a full match, we've found the start
|
|
if (match_idx == 0) {
|
|
alora_invocation_start = i;
|
|
break;
|
|
}
|
|
// otherwise, check the next token in the sequence
|
|
--match_idx;
|
|
} else {
|
|
// no match in this position, so start looking over again
|
|
match_idx = static_cast<int>(n_invocation_tokens) - 1;
|
|
}
|
|
}
|
|
|
|
// if the activation string is not found, disable the alora
|
|
if (alora_invocation_start == task.tokens.size()) {
|
|
SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
|
|
slot.lora[enabled_ids[0]].scale = 0.0f;
|
|
} else {
|
|
SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start);
|
|
slot.alora_invocation_start = alora_invocation_start;
|
|
}
|
|
}
|
|
|
|
if (!task.tokens.validate(ctx)) {
|
|
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
|
|
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
|
|
|
// initialize samplers
|
|
{
|
|
if (slot.smpl != nullptr) {
|
|
common_sampler_free(slot.smpl);
|
|
}
|
|
|
|
slot.smpl = common_sampler_init(model, task.params.sampling);
|
|
if (slot.smpl == nullptr) {
|
|
// for now, the only error that may happen here is invalid grammar
|
|
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
|
|
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str());
|
|
}
|
|
|
|
if (!configure_slot_backend_sampler(slot, task.params.sampling)) {
|
|
send_error(task, "Failed to configure backend samplers", ERROR_TYPE_SERVER);
|
|
return false;
|
|
}
|
|
|
|
// initialize draft batch
|
|
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
|
|
if (slot.ctx_dft) {
|
|
llama_batch_free(slot.batch_spec);
|
|
|
|
slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1);
|
|
}
|
|
|
|
slot.task = std::make_unique<const server_task>(std::move(task));
|
|
|
|
slot.state = SLOT_STATE_STARTED;
|
|
|
|
SLT_INF(slot, "%s", "processing task\n");
|
|
|
|
return true;
|
|
}
|
|
|
|
bool configure_slot_backend_sampler(server_slot & slot, const common_params_sampling & sampling) {
|
|
if (!sampling.backend_sampling) {
|
|
if (slot.backend_sampler != nullptr) {
|
|
llama_set_backend_sampler(ctx, slot.id, nullptr);
|
|
slot.backend_sampler = nullptr;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
llama_sampler * backend_chain = common_sampler_backend_init(model, sampling);
|
|
// The sampler types configured with --samplers might not be supported
|
|
// by backend samplers in which case we disable backend sampling and
|
|
// fallback to CPU only sampling.
|
|
if (backend_chain == nullptr) {
|
|
if (slot.backend_sampler != nullptr) {
|
|
llama_set_backend_sampler(ctx, slot.id, nullptr);
|
|
slot.backend_sampler = nullptr;
|
|
}
|
|
SLT_INF(slot, "%s", "no backend samplers configured (sampler chain doesn't start with backend-supported samplers)\n");
|
|
return true;
|
|
}
|
|
|
|
if (slot.backend_sampler != nullptr) {
|
|
llama_set_backend_sampler(ctx, slot.id, nullptr);
|
|
slot.backend_sampler = nullptr;
|
|
}
|
|
|
|
slot.backend_sampler = backend_chain;
|
|
llama_set_backend_sampler(ctx, slot.id, backend_chain);
|
|
SLT_INF(slot, "%s", "configured backend samplers\n");
|
|
return true;
|
|
}
|
|
|
|
bool process_token(completion_token_output & result, server_slot & slot) {
|
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
|
const std::string token_str = result.text_to_send;
|
|
slot.sampled = result.tok;
|
|
|
|
slot.generated_text += token_str;
|
|
if (slot.task->params.return_tokens) {
|
|
slot.generated_tokens.push_back(result.tok);
|
|
}
|
|
slot.has_next_token = true;
|
|
|
|
// check if there is incomplete UTF-8 character at the end
|
|
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
|
|
|
|
// search stop word and delete it
|
|
if (!incomplete) {
|
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
|
|
const std::string str_test = slot.generated_text.substr(pos);
|
|
bool send_text = true;
|
|
|
|
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
|
|
if (stop_pos != std::string::npos) {
|
|
slot.generated_text.erase(
|
|
slot.generated_text.begin() + pos + stop_pos,
|
|
slot.generated_text.end());
|
|
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
} else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) {
|
|
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
|
|
send_text = stop_pos == std::string::npos;
|
|
}
|
|
|
|
// check if there is any token to predict
|
|
if (send_text) {
|
|
// no send the stop word in the response
|
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
|
slot.n_sent_text += result.text_to_send.size();
|
|
// add the token to slot queue and cache
|
|
} else {
|
|
result.text_to_send = "";
|
|
}
|
|
|
|
slot.add_token(result);
|
|
if (slot.task->params.stream) {
|
|
send_partial_response(slot, result, false);
|
|
}
|
|
}
|
|
|
|
if (incomplete) {
|
|
slot.has_next_token = true;
|
|
}
|
|
|
|
// if context shifting is disabled, make sure that we don't run out of context
|
|
if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
|
|
slot.truncated = true;
|
|
slot.stop = STOP_TYPE_LIMIT;
|
|
slot.has_next_token = false;
|
|
|
|
SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n",
|
|
slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx);
|
|
}
|
|
|
|
// check the limits
|
|
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
|
|
slot.stop = STOP_TYPE_LIMIT;
|
|
slot.has_next_token = false;
|
|
|
|
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict);
|
|
}
|
|
|
|
if (slot.has_new_line) {
|
|
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
|
|
if (slot.task->params.n_indent > 0) {
|
|
// check the current indentation
|
|
// TODO: improve by not doing it more than once for each new line
|
|
if (slot.last_nl_pos > 0) {
|
|
size_t pos = slot.last_nl_pos;
|
|
|
|
int n_indent = 0;
|
|
while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
|
|
n_indent++;
|
|
pos++;
|
|
}
|
|
|
|
if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) {
|
|
slot.stop = STOP_TYPE_LIMIT;
|
|
slot.has_next_token = false;
|
|
|
|
// cut the last line
|
|
slot.generated_text.erase(pos, std::string::npos);
|
|
|
|
SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
|
|
}
|
|
}
|
|
|
|
// find the next new line
|
|
{
|
|
const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
|
|
|
|
if (pos != std::string::npos) {
|
|
slot.last_nl_pos = pos + 1;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// check if there is a new line in the generated text
|
|
if (result.text_to_send.find('\n') != std::string::npos) {
|
|
slot.has_new_line = true;
|
|
|
|
// if we have seen a new line, we stop after a certain time limit, but only upon another new line
|
|
if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) {
|
|
slot.stop = STOP_TYPE_LIMIT;
|
|
slot.has_next_token = false;
|
|
|
|
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms);
|
|
}
|
|
}
|
|
|
|
if (llama_vocab_is_eog(vocab, result.tok)) {
|
|
slot.stop = STOP_TYPE_EOS;
|
|
slot.has_next_token = false;
|
|
|
|
SLT_DBG(slot, "%s", "stopped by EOS\n");
|
|
}
|
|
|
|
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
|
|
|
|
return slot.has_next_token; // continue
|
|
}
|
|
|
|
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
|
|
size_t n_probs = slot.task->params.sampling.n_probs;
|
|
size_t n_vocab = llama_vocab_n_tokens(vocab);
|
|
|
|
if (post_sampling) {
|
|
const auto * cur_p = common_sampler_get_candidates(slot.smpl, true);
|
|
const size_t max_probs = cur_p->size;
|
|
|
|
// set probability for sampled token
|
|
for (size_t i = 0; i < max_probs; i++) {
|
|
if (cur_p->data[i].id == result.tok) {
|
|
result.prob = cur_p->data[i].p;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// set probability for top n_probs tokens
|
|
result.probs.reserve(max_probs);
|
|
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
|
result.probs.push_back({
|
|
cur_p->data[i].id,
|
|
common_token_to_piece(ctx, cur_p->data[i].id, special),
|
|
cur_p->data[i].p
|
|
});
|
|
}
|
|
} else {
|
|
// TODO: optimize this with min-p optimization
|
|
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
|
|
|
// set probability for sampled token
|
|
for (size_t i = 0; i < n_vocab; i++) {
|
|
// set probability for sampled token
|
|
if (cur[i].id == result.tok) {
|
|
result.prob = cur[i].p;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// set probability for top n_probs tokens
|
|
result.probs.reserve(n_probs);
|
|
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
|
result.probs.push_back({
|
|
cur[i].id,
|
|
common_token_to_piece(ctx, cur[i].id, special),
|
|
cur[i].p
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
send_error(task.id, error, type);
|
|
}
|
|
|
|
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx);
|
|
}
|
|
|
|
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
|
|
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
|
|
|
|
if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
|
|
GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0);
|
|
}
|
|
|
|
auto res = std::make_unique<server_task_result_error>();
|
|
res->id = id_task;
|
|
res->err_type = type;
|
|
res->err_msg = error;
|
|
res->n_prompt_tokens = n_prompt_tokens;
|
|
res->n_ctx = n_ctx;
|
|
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
// if multimodal is enabled, send an error and return false
|
|
bool check_no_mtmd(const int id_task) {
|
|
if (mctx) {
|
|
send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) {
|
|
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
|
|
|
res->id = slot.task->id;
|
|
res->index = slot.task->index;
|
|
|
|
if (is_progress) {
|
|
res->is_progress = true;
|
|
res->progress.total = slot.task->n_tokens();
|
|
res->progress.cache = slot.n_prompt_tokens_cache;
|
|
res->progress.processed = slot.prompt.tokens.size();
|
|
res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000;
|
|
} else {
|
|
res->content = tkn.text_to_send;
|
|
res->tokens = { tkn.tok };
|
|
|
|
slot.update_chat_msg(res->oaicompat_msg_diffs);
|
|
}
|
|
|
|
res->n_decoded = slot.n_decoded;
|
|
res->n_prompt_tokens = slot.task->n_tokens();
|
|
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
|
|
|
res->verbose = slot.task->params.verbose;
|
|
res->oaicompat = slot.task->params.oaicompat;
|
|
res->oaicompat_model = slot.task->params.oaicompat_model;
|
|
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
|
|
|
// populate res.probs_output
|
|
if (slot.task->params.sampling.n_probs > 0) {
|
|
res->prob_output = tkn; // copy the token probs
|
|
}
|
|
|
|
// populate timings if this is final response or timings_per_token is enabled
|
|
if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) {
|
|
res->timings = slot.get_timings();
|
|
}
|
|
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
void send_final_response(server_slot & slot) {
|
|
auto res = std::make_unique<server_task_result_cmpl_final>();
|
|
|
|
res->id = slot.task->id;
|
|
res->id_slot = slot.id;
|
|
|
|
res->index = slot.task->index;
|
|
res->content = slot.generated_text;
|
|
res->tokens = std::move(slot.generated_tokens);
|
|
res->timings = slot.get_timings();
|
|
res->prompt = slot.task->tokens.detokenize(ctx, true);
|
|
res->response_fields = std::move(slot.task->params.response_fields);
|
|
|
|
res->truncated = slot.truncated;
|
|
res->n_decoded = slot.n_decoded;
|
|
res->n_prompt_tokens = slot.task->n_tokens();
|
|
res->n_tokens_cached = slot.prompt.n_tokens();
|
|
res->has_new_line = slot.has_new_line;
|
|
res->stopping_word = slot.stopping_word;
|
|
res->stop = slot.stop;
|
|
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
|
|
|
res->verbose = slot.task->params.verbose;
|
|
res->stream = slot.task->params.stream;
|
|
res->include_usage = slot.task->params.include_usage;
|
|
res->oaicompat = slot.task->params.oaicompat;
|
|
res->oaicompat_model = slot.task->params.oaicompat_model;
|
|
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
|
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
|
|
|
// populate res.probs_output
|
|
if (slot.task->params.sampling.n_probs > 0) {
|
|
if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) {
|
|
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
|
|
|
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
|
res->probs_output = std::vector<completion_token_output>(
|
|
slot.generated_token_probs.begin(),
|
|
slot.generated_token_probs.end() - safe_offset);
|
|
} else {
|
|
res->probs_output = std::vector<completion_token_output>(
|
|
slot.generated_token_probs.begin(),
|
|
slot.generated_token_probs.end());
|
|
}
|
|
}
|
|
|
|
res->generation_params = slot.task->params; // copy the parameters
|
|
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
|
auto res = std::make_unique<server_task_result_embd>();
|
|
res->id = slot.task->id;
|
|
res->index = slot.task->index;
|
|
res->n_tokens = slot.task->n_tokens();
|
|
res->oaicompat = slot.task->params.oaicompat;
|
|
|
|
const int n_embd = llama_model_n_embd(model);
|
|
|
|
std::vector<float> embd_res(n_embd, 0.0f);
|
|
|
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
|
continue;
|
|
}
|
|
|
|
const float * embd = nullptr;
|
|
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
|
embd = llama_get_embeddings_ith(ctx, i);
|
|
} else {
|
|
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
|
}
|
|
|
|
if (embd == nullptr) {
|
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
|
|
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
|
continue;
|
|
}
|
|
|
|
// normalize only when there is pooling
|
|
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
|
common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
|
|
res->embedding.push_back(embd_res);
|
|
break;
|
|
}
|
|
|
|
res->embedding.emplace_back(embd, embd + n_embd);
|
|
}
|
|
|
|
SLT_DBG(slot, "%s", "sending embeddings\n");
|
|
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
|
auto res = std::make_unique<server_task_result_rerank>();
|
|
res->id = slot.task->id;
|
|
res->index = slot.task->index;
|
|
res->n_tokens = slot.task->n_tokens();
|
|
|
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
|
continue;
|
|
}
|
|
|
|
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
|
if (embd == NULL) {
|
|
embd = llama_get_embeddings_ith(ctx, i);
|
|
}
|
|
|
|
if (embd == NULL) {
|
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
|
|
res->score = -1e6;
|
|
continue;
|
|
}
|
|
|
|
res->score = embd[0];
|
|
}
|
|
|
|
SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score);
|
|
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
//
|
|
// Functions to process the task
|
|
//
|
|
|
|
void process_single_task(server_task && task) {
|
|
switch (task.type) {
|
|
case SERVER_TASK_TYPE_COMPLETION:
|
|
case SERVER_TASK_TYPE_INFILL:
|
|
case SERVER_TASK_TYPE_EMBEDDING:
|
|
case SERVER_TASK_TYPE_RERANK:
|
|
{
|
|
const int id_slot = task.id_slot;
|
|
|
|
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
|
|
|
|
if (slot == nullptr) {
|
|
// if no slot is available, we defer this task for processing later
|
|
SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
|
|
if (slot->is_processing()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
|
|
if (!launch_slot_with_task(*slot, std::move(task))) {
|
|
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
|
break;
|
|
}
|
|
} break;
|
|
case SERVER_TASK_TYPE_CANCEL:
|
|
{
|
|
// release slot linked with the task id
|
|
for (auto & slot : slots) {
|
|
if (slot.task && slot.task->id == task.id_target) {
|
|
slot.release();
|
|
break;
|
|
}
|
|
}
|
|
} break;
|
|
case SERVER_TASK_TYPE_NEXT_RESPONSE:
|
|
{
|
|
// do nothing
|
|
} break;
|
|
case SERVER_TASK_TYPE_METRICS:
|
|
{
|
|
json slots_data = json::array();
|
|
|
|
int n_idle_slots = 0;
|
|
int n_processing_slots = 0;
|
|
|
|
for (server_slot & slot : slots) {
|
|
json slot_data = slot.to_json(slots_debug == 0);
|
|
|
|
if (slot.is_processing()) {
|
|
n_processing_slots++;
|
|
} else {
|
|
n_idle_slots++;
|
|
}
|
|
|
|
slots_data.push_back(slot_data);
|
|
}
|
|
SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
|
|
|
|
auto res = std::make_unique<server_task_result_metrics>();
|
|
res->id = task.id;
|
|
res->slots_data = std::move(slots_data);
|
|
res->n_idle_slots = n_idle_slots;
|
|
res->n_processing_slots = n_processing_slots;
|
|
res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size();
|
|
res->t_start = metrics.t_start;
|
|
|
|
res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
|
|
res->t_prompt_processing_total = metrics.t_prompt_processing_total;
|
|
res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
|
|
res->t_tokens_generation_total = metrics.t_tokens_generation_total;
|
|
|
|
res->n_tokens_max = metrics.n_tokens_max;
|
|
|
|
res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
|
|
res->t_prompt_processing = metrics.t_prompt_processing;
|
|
res->n_tokens_predicted = metrics.n_tokens_predicted;
|
|
res->t_tokens_generation = metrics.t_tokens_generation;
|
|
|
|
res->n_decode_total = metrics.n_decode_total;
|
|
res->n_busy_slots_total = metrics.n_busy_slots_total;
|
|
|
|
if (task.metrics_reset_bucket) {
|
|
metrics.reset_bucket();
|
|
}
|
|
queue_results.send(std::move(res));
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_SAVE:
|
|
{
|
|
if (!check_no_mtmd(task.id)) {
|
|
break;
|
|
}
|
|
|
|
int id_slot = task.slot_action.slot_id;
|
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (slot->is_processing()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
|
|
const size_t token_count = slot->prompt.tokens.size();
|
|
const int64_t t_start = ggml_time_us();
|
|
|
|
std::string filename = task.slot_action.filename;
|
|
std::string filepath = task.slot_action.filepath;
|
|
|
|
const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens();
|
|
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
|
|
|
|
const int64_t t_end = ggml_time_us();
|
|
const double t_save_ms = (t_end - t_start) / 1000.0;
|
|
|
|
auto res = std::make_unique<server_task_result_slot_save_load>();
|
|
res->id = task.id;
|
|
res->id_slot = id_slot;
|
|
res->filename = filename;
|
|
res->is_save = true;
|
|
res->n_tokens = token_count;
|
|
res->n_bytes = nwrite;
|
|
res->t_ms = t_save_ms;
|
|
queue_results.send(std::move(res));
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
|
{
|
|
if (!check_no_mtmd(task.id)) break;
|
|
int id_slot = task.slot_action.slot_id;
|
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (slot->is_processing()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
|
|
const int64_t t_start = ggml_time_us();
|
|
|
|
std::string filename = task.slot_action.filename;
|
|
std::string filepath = task.slot_action.filepath;
|
|
|
|
llama_tokens tokens;
|
|
tokens.resize(slot->n_ctx);
|
|
size_t token_count = 0;
|
|
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
|
|
if (nread == 0) {
|
|
slot->prompt.tokens.clear(); // KV may already been invalidated?
|
|
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
tokens.resize(token_count);
|
|
slot->prompt.tokens.clear();
|
|
slot->prompt.tokens.insert(tokens);
|
|
|
|
const int64_t t_end = ggml_time_us();
|
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
|
|
|
auto res = std::make_unique<server_task_result_slot_save_load>();
|
|
res->id = task.id;
|
|
res->id_slot = id_slot;
|
|
res->filename = filename;
|
|
res->is_save = false;
|
|
res->n_tokens = token_count;
|
|
res->n_bytes = nread;
|
|
res->t_ms = t_restore_ms;
|
|
queue_results.send(std::move(res));
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_ERASE:
|
|
{
|
|
if (!check_no_mtmd(task.id)) {
|
|
break;
|
|
}
|
|
int id_slot = task.slot_action.slot_id;
|
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (slot->is_processing()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
|
|
// Erase token cache
|
|
const size_t n_erased = slot->prompt.tokens.size();
|
|
|
|
clear_slot(*slot);
|
|
|
|
auto res = std::make_unique<server_task_result_slot_erase>();
|
|
res->id = task.id;
|
|
res->id_slot = id_slot;
|
|
res->n_erased = n_erased;
|
|
queue_results.send(std::move(res));
|
|
} break;
|
|
case SERVER_TASK_TYPE_SET_LORA:
|
|
{
|
|
params_base.lora_adapters = std::move(task.set_lora);
|
|
auto res = std::make_unique<server_task_result_apply_lora>();
|
|
res->id = task.id;
|
|
queue_results.send(std::move(res));
|
|
} break;
|
|
|
|
}
|
|
}
|
|
|
|
void update_slots() {
|
|
// check if all slots are idle
|
|
{
|
|
bool all_idle = true;
|
|
|
|
for (auto & slot : slots) {
|
|
if (slot.is_processing()) {
|
|
all_idle = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (all_idle) {
|
|
SRV_INF("%s", "all slots are idle\n");
|
|
|
|
return;
|
|
}
|
|
}
|
|
|
|
{
|
|
SRV_DBG("%s", "posting NEXT_RESPONSE\n");
|
|
|
|
server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
|
|
task.id = queue_tasks.get_new_id();
|
|
queue_tasks.post(std::move(task));
|
|
}
|
|
|
|
// apply context-shift if needed
|
|
// TODO: simplify and improve
|
|
for (server_slot & slot : slots) {
|
|
if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
|
|
if (!params_base.ctx_shift) {
|
|
// this check is redundant (for good)
|
|
// we should never get here, because generation should already stopped in process_token()
|
|
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
|
|
slot.release();
|
|
continue;
|
|
}
|
|
|
|
if (mctx) {
|
|
// we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded
|
|
// we don't support ctx_shift because an image chunk may contains multiple tokens
|
|
GGML_ABORT("not supported by multimodal");
|
|
}
|
|
|
|
// Shift context
|
|
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
|
|
|
|
if (add_bos_token) {
|
|
n_keep += 1;
|
|
}
|
|
|
|
n_keep = std::min(slot.n_ctx - 4, n_keep);
|
|
|
|
const int n_left = slot.prompt.n_tokens() - n_keep;
|
|
const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
|
|
|
|
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
|
|
|
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard);
|
|
llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
|
|
|
|
// add generated tokens to cache
|
|
// ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
|
|
{
|
|
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
|
|
|
llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
|
|
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
|
new_tokens[i - n_discard] = new_tokens[i];
|
|
}
|
|
|
|
new_tokens.resize(slot.prompt.tokens.size() - n_discard);
|
|
|
|
slot.prompt.tokens.clear();
|
|
slot.prompt.tokens.insert(new_tokens);
|
|
}
|
|
|
|
slot.truncated = true;
|
|
}
|
|
}
|
|
|
|
// start populating the batch for this iteration
|
|
common_batch_clear(batch);
|
|
|
|
// track if given slot can be batched with slots already in the batch
|
|
server_slot * slot_batched = nullptr;
|
|
|
|
auto accept_special_token = [&](server_slot & slot, llama_token token) {
|
|
return params_base.special ||
|
|
slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
|
|
};
|
|
|
|
// first, add sampled tokens from any ongoing sequences
|
|
for (auto & slot : slots) {
|
|
if (slot.state != SLOT_STATE_GENERATING) {
|
|
continue;
|
|
}
|
|
|
|
// check if we can batch this slot with the previous one
|
|
if (!slot_batched) {
|
|
slot_batched = &slot;
|
|
} else if (!slot_batched->can_batch_with(slot)) {
|
|
continue;
|
|
}
|
|
|
|
slot.i_batch = batch.n_tokens;
|
|
|
|
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
|
|
slot.prompt.tokens.push_back(slot.sampled);
|
|
|
|
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
|
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
|
|
}
|
|
|
|
// process in chunks of params.n_batch
|
|
int32_t n_batch = llama_n_batch(ctx);
|
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
|
|
|
float alora_scale = -1.0f;
|
|
size_t alora_disabled_id = 0;
|
|
|
|
// next, batch any pending prompts without exceeding n_batch
|
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
|
for (auto & slot : slots) {
|
|
if (!slot.is_processing()) {
|
|
continue;
|
|
}
|
|
|
|
// check if we can batch this slot with the previous one
|
|
if (slot_batched && !slot_batched->can_batch_with(slot)) {
|
|
continue;
|
|
}
|
|
|
|
// this slot still has a prompt to be processed
|
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
|
const auto & input_tokens = slot.task->tokens;
|
|
|
|
// TODO: maybe move branch to outside of this loop in the future
|
|
if (slot.state == SLOT_STATE_STARTED) {
|
|
slot.t_start_process_prompt = ggml_time_us();
|
|
slot.t_start_generation = 0;
|
|
|
|
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
|
|
|
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
|
|
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
|
|
|
|
// print prompt tokens (for debugging)
|
|
/*if (1) {
|
|
// first 16 tokens (avoid flooding logs)
|
|
for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
|
|
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
|
}
|
|
} else {
|
|
// all
|
|
for (int i = 0; i < (int) input_tokens.size(); i++) {
|
|
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
|
}
|
|
}*/
|
|
|
|
// keep track how many tokens we can reuse from the previous state
|
|
int n_past = 0;
|
|
|
|
// empty prompt passed -> release the slot and send empty response
|
|
if (input_tokens.empty()) {
|
|
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
|
|
|
slot.print_timings();
|
|
send_final_response(slot);
|
|
slot.release();
|
|
|
|
continue;
|
|
}
|
|
|
|
// TODO: support memory-less logits computation
|
|
if (slot.need_logits() && !llama_get_memory(ctx)) {
|
|
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
|
|
slot.release();
|
|
continue;
|
|
}
|
|
|
|
if (!slot.can_split()) {
|
|
if (slot.task->n_tokens() > n_ubatch) {
|
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
|
slot.release();
|
|
continue;
|
|
}
|
|
|
|
if (slot.task->n_tokens() > slot.n_ctx) {
|
|
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
|
slot.release();
|
|
continue;
|
|
}
|
|
} else {
|
|
if (slot.task->n_tokens() >= slot.n_ctx) {
|
|
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
|
slot.release();
|
|
continue;
|
|
}
|
|
|
|
if (slot.task->params.cache_prompt) {
|
|
// reuse any previously computed tokens that are common with the new prompt
|
|
n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
|
|
|
|
// if there is an alora invoked, don't cache after the invocation start
|
|
if (slot.alora_invocation_start > 0) {
|
|
SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
|
|
n_past = std::min(n_past, slot.alora_invocation_start - 1);
|
|
}
|
|
|
|
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
|
if (params_base.n_cache_reuse > 0) {
|
|
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
|
|
|
size_t head_c = n_past; // cache
|
|
size_t head_p = n_past; // current prompt
|
|
|
|
if (mctx) {
|
|
// we should never reach this
|
|
GGML_ABORT("not supported by multimodal");
|
|
}
|
|
|
|
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
|
|
|
|
while (head_c < slot.prompt.tokens.size() &&
|
|
head_p < input_tokens.size()) {
|
|
|
|
size_t n_match = 0;
|
|
while (head_c + n_match < slot.prompt.tokens.size() &&
|
|
head_p + n_match < input_tokens.size() &&
|
|
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
|
|
|
|
n_match++;
|
|
}
|
|
|
|
if (n_match >= (size_t) params_base.n_cache_reuse) {
|
|
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
|
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
|
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
|
//}
|
|
|
|
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
|
|
|
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
|
|
llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
|
|
|
|
for (size_t i = 0; i < n_match; i++) {
|
|
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
|
|
n_past++;
|
|
}
|
|
|
|
head_c += n_match;
|
|
head_p += n_match;
|
|
} else {
|
|
head_c += 1;
|
|
}
|
|
}
|
|
|
|
SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
|
|
}
|
|
} else {
|
|
// if we don't cache the prompt, we have to remove all previous tokens
|
|
n_past = 0;
|
|
}
|
|
|
|
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
|
|
const auto n_swa = std::max(1, llama_model_n_swa(model));
|
|
|
|
// the largest pos_min required for a checkpoint to be useful
|
|
const auto pos_min_thold = std::max(0, n_past - n_swa);
|
|
|
|
// note: disallow with mtmd contexts for now
|
|
// https://github.com/ggml-org/llama.cpp/issues/17043
|
|
if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
|
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
|
if (pos_min == -1) {
|
|
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
|
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
|
}
|
|
|
|
// when the prompt prefix does not match, print the tokens around the mismatch
|
|
// this is useful for debugging prompt caching
|
|
if (slots_debug) {
|
|
const int np0 = std::max<int>(n_past - 4, 0);
|
|
const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
|
|
|
|
std::stringstream ss0;
|
|
std::stringstream ss1;
|
|
|
|
std::stringstream st0;
|
|
std::stringstream st1;
|
|
|
|
ss0 << "old: ... ";
|
|
ss1 << "new: ... ";
|
|
|
|
for (int i = np0; i < np1; i++) {
|
|
if (i == n_past) {
|
|
ss0 << " | ";
|
|
ss1 << " | ";
|
|
}
|
|
|
|
{
|
|
const auto token = slot.prompt.tokens[i];
|
|
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
|
ss0 << piece;
|
|
st0 << std::setw(8) << token;
|
|
}
|
|
|
|
{
|
|
const auto token = slot.task->tokens[i];
|
|
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
|
ss1 << piece;
|
|
st1 << std::setw(8) << token;
|
|
}
|
|
}
|
|
|
|
SLT_WRN(slot, "%s\n", ss0.str().c_str());
|
|
SLT_WRN(slot, "%s\n", ss1.str().c_str());
|
|
|
|
SLT_WRN(slot, "%s\n", st0.str().c_str());
|
|
SLT_WRN(slot, "%s\n", st1.str().c_str());
|
|
}
|
|
|
|
if (pos_min > pos_min_thold) {
|
|
// TODO: support can be added in the future when corresponding vision models get released
|
|
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
|
|
|
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
|
|
|
// search for a context checkpoint
|
|
const auto it = std::find_if(
|
|
slot.prompt.checkpoints.rbegin(),
|
|
slot.prompt.checkpoints.rend(),
|
|
[&](const auto & cur) {
|
|
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
|
return cur.pos_min < pos_min_thold;
|
|
}
|
|
);
|
|
|
|
bool do_reset = it == slot.prompt.checkpoints.rend();
|
|
|
|
if (!do_reset) {
|
|
// restore the context checkpoint
|
|
const size_t checkpoint_size = it->data.size();
|
|
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
if (n != checkpoint_size) {
|
|
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
|
do_reset = true;
|
|
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
|
} else {
|
|
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
|
|
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
|
}
|
|
}
|
|
|
|
if (do_reset) {
|
|
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
|
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
|
n_past = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
// erase any checkpoints with pos_min > pos_min_thold
|
|
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
|
|
const auto & cur = *it;
|
|
if (cur.pos_min > pos_min_thold) {
|
|
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
|
|
it = slot.prompt.checkpoints.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// [TAG_PROMPT_LOGITS]
|
|
if (n_past == slot.task->n_tokens() && n_past > 0) {
|
|
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
|
|
n_past--;
|
|
SLT_WRN(slot, "n_past was set to %d\n", n_past);
|
|
}
|
|
|
|
slot.n_prompt_tokens_cache = n_past;
|
|
slot.n_prompt_tokens_processed = 0;
|
|
|
|
slot.prompt.tokens.keep_first(n_past);
|
|
}
|
|
|
|
if (!slot.can_split()) {
|
|
// cannot fit the prompt in the current batch - will try next iter
|
|
if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// truncate any tokens that are beyond n_past for this slot
|
|
const llama_pos p0 = slot.prompt.tokens.pos_next();
|
|
|
|
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
|
|
|
|
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
|
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
|
|
|
clear_slot(slot);
|
|
|
|
// there is no common part left
|
|
slot.n_prompt_tokens_cache = 0;
|
|
}
|
|
|
|
// check if we should process the image
|
|
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
|
// process the image
|
|
size_t n_tokens_out = 0;
|
|
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
|
|
if (res != 0) {
|
|
SLT_ERR(slot, "failed to process image, res = %d\n", res);
|
|
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
|
|
slot.release();
|
|
continue;
|
|
}
|
|
|
|
slot.n_prompt_tokens_processed += n_tokens_out;
|
|
|
|
// add the image chunk to cache
|
|
{
|
|
const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens());
|
|
slot.prompt.tokens.push_back(chunk.get()); // copy
|
|
}
|
|
}
|
|
|
|
// If using an alora, there may be uncached tokens that come
|
|
// before the invocation sequence. When this happens, the
|
|
// tokens before the invocation sequence need to be
|
|
// processed without the adapter in a separate batch, then
|
|
// the adapter needs to be enabled for the remaining tokens.
|
|
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
|
|
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
|
|
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
|
|
GGML_ASSERT(enabled_loras.size() == 1);
|
|
alora_scale = slot.lora[enabled_loras[0]].scale;
|
|
slot.lora[enabled_loras[0]].scale = 0.0f;
|
|
alora_disabled_id = enabled_loras[0];
|
|
}
|
|
|
|
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
|
|
|
|
// make checkpoints only for completion tasks
|
|
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
|
|
|
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
|
// checkpoints are created only if:
|
|
// - the model uses SWA and we are not using `swa_full`
|
|
// - the model architecture is marked as recurrent or hybrid
|
|
//
|
|
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
|
do_checkpoint = do_checkpoint && (
|
|
llama_model_is_recurrent(model) ||
|
|
llama_model_is_hybrid(model) ||
|
|
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
|
|
);
|
|
|
|
// add prompt tokens for processing in the current batch
|
|
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
|
|
// get next token to process
|
|
llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
|
|
if (cur_tok == LLAMA_TOKEN_NULL) {
|
|
break; // end of text chunk
|
|
}
|
|
|
|
// if this is an alora request with pre-invocation
|
|
// tokens that are not cached, we need to stop filling
|
|
// this batch at those pre-invocation tokens.
|
|
if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) {
|
|
SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
|
|
break;
|
|
}
|
|
|
|
// embedding requires all tokens in the batch to be output
|
|
common_batch_add(batch,
|
|
cur_tok,
|
|
slot.prompt.tokens.pos_next(),
|
|
{ slot.id },
|
|
slot.need_embd());
|
|
slot.prompt.tokens.push_back(cur_tok);
|
|
|
|
slot.n_prompt_tokens_processed++;
|
|
|
|
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
|
|
if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
|
|
|
|
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
|
|
|
|
// entire prompt has been processed
|
|
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
|
|
slot.state = SLOT_STATE_DONE_PROMPT;
|
|
|
|
GGML_ASSERT(batch.n_tokens > 0);
|
|
|
|
common_sampler_reset(slot.smpl);
|
|
|
|
// Process all prompt tokens through sampler system
|
|
for (int i = 0; i < slot.task->n_tokens(); ++i) {
|
|
llama_token id = input_tokens[i];
|
|
if (id != LLAMA_TOKEN_NULL) {
|
|
common_sampler_accept(slot.smpl, id, false);
|
|
}
|
|
}
|
|
|
|
// extract the logits only for the last token
|
|
batch.logits[batch.n_tokens - 1] = true;
|
|
|
|
slot.n_decoded = 0;
|
|
slot.i_batch = batch.n_tokens - 1;
|
|
|
|
SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
|
|
|
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
|
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
|
|
|
// no need for empty or small checkpoints
|
|
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
|
|
|
|
// no need to create checkpoints that are too close together
|
|
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
|
|
|
|
if (do_checkpoint) {
|
|
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
|
// make room for the new checkpoint, if needed
|
|
const auto & cur = slot.prompt.checkpoints.front();
|
|
|
|
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
|
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
|
|
|
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
|
}
|
|
|
|
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
|
/*.pos_min = */ pos_min,
|
|
/*.pos_max = */ pos_max,
|
|
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
|
});
|
|
|
|
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
|
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!slot_batched) {
|
|
slot_batched = &slot;
|
|
}
|
|
|
|
if (batch.n_tokens >= n_batch) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (batch.n_tokens == 0) {
|
|
SRV_WRN("%s", "no tokens to decode\n");
|
|
return;
|
|
}
|
|
|
|
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
|
|
|
if (slot_batched) {
|
|
// apply lora, only need to do it once per batch
|
|
common_set_adapter_lora(ctx, slot_batched->lora);
|
|
|
|
// if the lora is temporarily disabled for an alora, re-enable it
|
|
// for next time
|
|
if (alora_scale > 0.0f) {
|
|
SRV_DBG("re-enabling alora with scale %f\n", alora_scale);
|
|
slot_batched->lora[alora_disabled_id].scale = alora_scale;
|
|
}
|
|
|
|
llama_set_embeddings(ctx, slot_batched->need_embd());
|
|
}
|
|
|
|
int32_t i_next = 0;
|
|
|
|
// process the created batch of tokens
|
|
for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
|
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
|
|
|
llama_batch batch_view = {
|
|
n_tokens,
|
|
batch.token + i,
|
|
nullptr,
|
|
batch.pos + i,
|
|
batch.n_seq_id + i,
|
|
batch.seq_id + i,
|
|
batch.logits + i,
|
|
};
|
|
|
|
const int ret = llama_decode(ctx, batch_view);
|
|
|
|
metrics.on_decoded(slots);
|
|
|
|
if (ret != 0) {
|
|
{
|
|
std::string err;
|
|
|
|
if (n_batch == 1 && ret == 1) {
|
|
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
|
|
// need to remove the tokens from the current batch too
|
|
err = "Context size has been exceeded.";
|
|
}
|
|
|
|
if (ret == -1) {
|
|
err = "Invalid input batch.";
|
|
}
|
|
|
|
if (ret < -1) {
|
|
// TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
|
|
err = "Compute error.";
|
|
}
|
|
|
|
// TODO: handle ret == 2 (abort) when we start aborting
|
|
|
|
if (!err.empty()) {
|
|
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
|
|
|
for (auto & slot : slots) {
|
|
if (slot.is_processing()) {
|
|
send_error(slot, err);
|
|
slot.release();
|
|
|
|
// note: it's complicated to keep track of how much of the current batch has been
|
|
// processed before the error occurred, so we simply clear the entire context
|
|
clear_slot(slot);
|
|
}
|
|
}
|
|
|
|
break;
|
|
}
|
|
}
|
|
|
|
// retry with half the batch size to try to find a free slot in the KV cache
|
|
if (!try_clear_idle_slots()) {
|
|
n_batch /= 2;
|
|
}
|
|
|
|
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
|
|
|
continue; // continue loop of n_batch
|
|
}
|
|
|
|
// move the head of the batch forward with the number of tokens we just processed
|
|
i_next = i + n_tokens;
|
|
|
|
// on successful decode, restore the original batch size
|
|
n_batch = llama_n_batch(ctx);
|
|
|
|
for (auto & slot : slots) {
|
|
// optionally send prompt processing progress
|
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
|
if (slot.task->params.stream && slot.task->params.return_progress) {
|
|
send_partial_response(slot, {}, true);
|
|
}
|
|
}
|
|
|
|
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
|
|
continue; // continue loop of slots
|
|
}
|
|
|
|
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
|
if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) {
|
|
// prompt evaluated for embedding
|
|
send_embedding(slot, batch_view);
|
|
slot.release();
|
|
slot.i_batch = -1;
|
|
continue; // continue loop of slots
|
|
}
|
|
|
|
if (slot.task->type == SERVER_TASK_TYPE_RERANK) {
|
|
send_rerank(slot, batch_view);
|
|
slot.release();
|
|
slot.i_batch = -1;
|
|
continue; // continue loop of slots
|
|
}
|
|
|
|
// prompt evaluated for next-token prediction
|
|
slot.state = SLOT_STATE_GENERATING;
|
|
} else if (slot.state != SLOT_STATE_GENERATING) {
|
|
continue; // continue loop of slots
|
|
}
|
|
|
|
const int tok_idx = slot.i_batch - i;
|
|
|
|
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
|
|
|
slot.i_batch = -1;
|
|
|
|
common_sampler_accept(slot.smpl, id, true);
|
|
|
|
slot.n_decoded += 1;
|
|
|
|
const int64_t t_current = ggml_time_us();
|
|
|
|
if (slot.n_decoded == 1) {
|
|
slot.t_start_generation = t_current;
|
|
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
|
metrics.on_prompt_eval(slot);
|
|
}
|
|
|
|
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
|
|
|
completion_token_output result;
|
|
result.tok = id;
|
|
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
|
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
|
|
|
if (slot.task->params.sampling.n_probs > 0) {
|
|
populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
|
|
}
|
|
|
|
if (!process_token(result, slot)) {
|
|
// release slot because of stop condition
|
|
slot.print_timings();
|
|
send_final_response(slot);
|
|
metrics.on_prediction(slot);
|
|
slot.release();
|
|
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// do speculative decoding
|
|
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
|
// perform the speculative drafting for all sequences at the same time in a single batch
|
|
for (auto & slot : slots) {
|
|
if (!slot.is_processing() || !slot.can_speculate()) {
|
|
continue;
|
|
}
|
|
|
|
if (slot.state != SLOT_STATE_GENERATING) {
|
|
continue;
|
|
}
|
|
|
|
if (mctx) {
|
|
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
|
GGML_ABORT("not supported by multimodal");
|
|
}
|
|
|
|
// determine the max draft that fits the current slot state
|
|
int n_draft_max = slot.task->params.speculative.n_max;
|
|
|
|
// note: slot.prompt is not yet expanded with the `id` token sampled above
|
|
// also, need to leave space for 1 extra token to allow context shifts
|
|
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);
|
|
|
|
if (slot.n_remaining > 0) {
|
|
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
|
|
}
|
|
|
|
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
|
|
|
|
if (n_draft_max < slot.task->params.speculative.n_min) {
|
|
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);
|
|
|
|
continue;
|
|
}
|
|
|
|
llama_token id = slot.sampled;
|
|
|
|
struct common_speculative_params params_spec;
|
|
params_spec.n_draft = n_draft_max;
|
|
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
|
|
params_spec.p_min = slot.task->params.speculative.p_min;
|
|
|
|
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
|
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
|
|
|
|
// ignore small drafts
|
|
if (slot.task->params.speculative.n_min > (int) draft.size()) {
|
|
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
|
|
|
|
continue;
|
|
}
|
|
|
|
// keep track of total number of drafted tokens tested
|
|
slot.n_draft_total += draft.size();
|
|
|
|
// construct the speculation batch
|
|
common_batch_clear(slot.batch_spec);
|
|
common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
|
|
for (size_t i = 0; i < draft.size(); ++i) {
|
|
common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
|
|
}
|
|
|
|
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
|
|
|
llama_decode(ctx, slot.batch_spec);
|
|
|
|
// the accepted tokens from the speculation
|
|
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
|
|
|
slot.n_decoded += ids.size();
|
|
|
|
// update how many tokens out of those tested were accepted
|
|
slot.n_draft_accepted += ids.size() - 1;
|
|
|
|
slot.prompt.tokens.push_back(id);
|
|
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
|
|
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
|
|
|
|
for (size_t i = 0; i < ids.size(); ++i) {
|
|
completion_token_output result;
|
|
|
|
result.tok = ids[i];
|
|
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
|
result.prob = 1.0f; // set later
|
|
|
|
// TODO: set result.probs
|
|
|
|
if (!process_token(result, slot)) {
|
|
slot.print_timings();
|
|
send_final_response(slot);
|
|
metrics.on_prediction(slot);
|
|
slot.release();
|
|
|
|
break;
|
|
}
|
|
}
|
|
|
|
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
|
|
}
|
|
}
|
|
|
|
SRV_DBG("%s", "run slots completed\n");
|
|
}
|
|
|
|
json model_meta() const {
|
|
return json {
|
|
{"vocab_type", llama_vocab_type (vocab)},
|
|
{"n_vocab", llama_vocab_n_tokens (vocab)},
|
|
{"n_ctx_train", llama_model_n_ctx_train(model)},
|
|
{"n_embd", llama_model_n_embd (model)},
|
|
{"n_params", llama_model_n_params (model)},
|
|
{"size", llama_model_size (model)},
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
// generator-like API for server responses, support pooling connection state and aggregating results
|
|
struct server_response_reader {
|
|
std::unordered_set<int> id_tasks;
|
|
server_context & ctx_server;
|
|
size_t received_count = 0;
|
|
bool cancelled = false;
|
|
|
|
server_response_reader(server_context & ctx_server) : ctx_server(ctx_server) {}
|
|
~server_response_reader() {
|
|
stop();
|
|
}
|
|
|
|
void post_tasks(std::vector<server_task> && tasks) {
|
|
id_tasks = server_task::get_list_id(tasks);
|
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
ctx_server.queue_tasks.post(std::move(tasks));
|
|
}
|
|
|
|
bool has_next() const {
|
|
return !cancelled && received_count < id_tasks.size();
|
|
}
|
|
|
|
// return nullptr if should_stop() is true before receiving a result
|
|
// note: if one error is received, it will stop further processing and return error result
|
|
server_task_result_ptr next(const std::function<bool()> & should_stop) {
|
|
while (true) {
|
|
server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
|
|
if (result == nullptr) {
|
|
// timeout, check stop condition
|
|
if (should_stop()) {
|
|
SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
|
|
return nullptr;
|
|
}
|
|
} else {
|
|
if (result->is_error()) {
|
|
stop(); // cancel remaining tasks
|
|
SRV_DBG("%s", "received error result, stopping further processing\n");
|
|
return result;
|
|
}
|
|
if (result->is_stop()) {
|
|
received_count++;
|
|
}
|
|
return result;
|
|
}
|
|
}
|
|
|
|
// should not reach here
|
|
}
|
|
|
|
struct batch_response {
|
|
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
|
|
std::vector<server_task_result_ptr> results;
|
|
server_task_result_ptr error; // nullptr if no error
|
|
};
|
|
|
|
batch_response wait_for_all(const std::function<bool()> & should_stop) {
|
|
batch_response batch_res;
|
|
batch_res.results.resize(id_tasks.size());
|
|
while (has_next()) {
|
|
auto res = next(should_stop);
|
|
if (res == nullptr) {
|
|
batch_res.is_terminated = true;
|
|
return batch_res;
|
|
}
|
|
if (res->is_error()) {
|
|
batch_res.error = std::move(res);
|
|
return batch_res;
|
|
}
|
|
const size_t idx = res->get_index();
|
|
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
|
|
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
|
|
batch_res.results[idx] = std::move(res);
|
|
}
|
|
return batch_res;
|
|
}
|
|
|
|
void stop() {
|
|
ctx_server.queue_results.remove_waiting_task_ids(id_tasks);
|
|
if (has_next() && !cancelled) {
|
|
// if tasks is not finished yet, cancel them
|
|
cancelled = true;
|
|
std::vector<server_task> cancel_tasks;
|
|
cancel_tasks.reserve(id_tasks.size());
|
|
for (const auto & id_task : id_tasks) {
|
|
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
|
server_task task(SERVER_TASK_TYPE_CANCEL);
|
|
task.id_target = id_task;
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
cancel_tasks.push_back(std::move(task));
|
|
}
|
|
// push to beginning of the queue, so it has highest priority
|
|
ctx_server.queue_tasks.post(std::move(cancel_tasks), true);
|
|
} else {
|
|
SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
|
|
}
|
|
}
|
|
};
|
|
|
|
// generator-like API for HTTP response generation
|
|
struct server_res_generator : server_http_res {
|
|
server_response_reader rd;
|
|
server_res_generator(server_context & ctx_server_) : rd(ctx_server_) {}
|
|
void ok(const json & response_data) {
|
|
status = 200;
|
|
data = safe_json_to_str(response_data);
|
|
}
|
|
void error(const json & error_data) {
|
|
status = json_value(error_data, "code", 500);
|
|
data = safe_json_to_str({{ "error", error_data }});
|
|
}
|
|
};
|
|
|
|
struct server_routes {
|
|
const common_params & params;
|
|
server_context & ctx_server;
|
|
server_http_context & ctx_http; // for reading is_ready
|
|
server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http)
|
|
: params(params), ctx_server(ctx_server), ctx_http(ctx_http) {}
|
|
|
|
public:
|
|
// handlers using lambda function, so that they can capture `this` without `std::bind`
|
|
|
|
server_http_context::handler_t get_health = [this](const server_http_req &) {
|
|
// error and loading states are handled by middleware
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
res->ok({{"status", "ok"}});
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t get_metrics = [this](const server_http_req &) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
if (!params.endpoint_metrics) {
|
|
res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
|
|
return res;
|
|
}
|
|
|
|
// request slots data using task queue
|
|
// TODO: use server_response_reader
|
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
|
{
|
|
server_task task(SERVER_TASK_TYPE_METRICS);
|
|
task.id = task_id;
|
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
|
ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
|
|
}
|
|
|
|
// get the result
|
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
|
|
|
if (result->is_error()) {
|
|
res->error(result->to_json());
|
|
return res;
|
|
}
|
|
|
|
// TODO: get rid of this dynamic_cast
|
|
auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
|
|
GGML_ASSERT(res_task != nullptr);
|
|
|
|
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
|
|
json all_metrics_def = json {
|
|
{"counter", {{
|
|
{"name", "prompt_tokens_total"},
|
|
{"help", "Number of prompt tokens processed."},
|
|
{"value", (uint64_t) res_task->n_prompt_tokens_processed_total}
|
|
}, {
|
|
{"name", "prompt_seconds_total"},
|
|
{"help", "Prompt process time"},
|
|
{"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3}
|
|
}, {
|
|
{"name", "tokens_predicted_total"},
|
|
{"help", "Number of generation tokens processed."},
|
|
{"value", (uint64_t) res_task->n_tokens_predicted_total}
|
|
}, {
|
|
{"name", "tokens_predicted_seconds_total"},
|
|
{"help", "Predict process time"},
|
|
{"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3}
|
|
}, {
|
|
{"name", "n_decode_total"},
|
|
{"help", "Total number of llama_decode() calls"},
|
|
{"value", res_task->n_decode_total}
|
|
}, {
|
|
{"name", "n_tokens_max"},
|
|
{"help", "Largest observed n_tokens."},
|
|
{"value", res_task->n_tokens_max}
|
|
}, {
|
|
{"name", "n_busy_slots_per_decode"},
|
|
{"help", "Average number of busy slots per llama_decode() call"},
|
|
{"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)}
|
|
}}},
|
|
{"gauge", {{
|
|
{"name", "prompt_tokens_seconds"},
|
|
{"help", "Average prompt throughput in tokens/s."},
|
|
{"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.}
|
|
},{
|
|
{"name", "predicted_tokens_seconds"},
|
|
{"help", "Average generation throughput in tokens/s."},
|
|
{"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.}
|
|
},{
|
|
{"name", "requests_processing"},
|
|
{"help", "Number of requests processing."},
|
|
{"value", (uint64_t) res_task->n_processing_slots}
|
|
},{
|
|
{"name", "requests_deferred"},
|
|
{"help", "Number of requests deferred."},
|
|
{"value", (uint64_t) res_task->n_tasks_deferred}
|
|
}}}
|
|
};
|
|
|
|
std::stringstream prometheus;
|
|
|
|
for (const auto & el : all_metrics_def.items()) {
|
|
const auto & type = el.key();
|
|
const auto & metrics_def = el.value();
|
|
|
|
for (const auto & metric_def : metrics_def) {
|
|
const std::string name = metric_def.at("name");
|
|
const std::string help = metric_def.at("help");
|
|
|
|
auto value = json_value(metric_def, "value", 0.);
|
|
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
|
|
<< "# TYPE llamacpp:" << name << " " << type << "\n"
|
|
<< "llamacpp:" << name << " " << value << "\n";
|
|
}
|
|
}
|
|
|
|
res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start);
|
|
res->content_type = "text/plain; version=0.0.4";
|
|
res->ok(prometheus.str());
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t get_slots = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
if (!params.endpoint_slots) {
|
|
res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
|
|
return res;
|
|
}
|
|
|
|
// request slots data using task queue
|
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
|
{
|
|
server_task task(SERVER_TASK_TYPE_METRICS);
|
|
task.id = task_id;
|
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
|
ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
|
|
}
|
|
|
|
// get the result
|
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
|
|
|
if (result->is_error()) {
|
|
res->error(result->to_json());
|
|
return res;
|
|
}
|
|
|
|
// TODO: get rid of this dynamic_cast
|
|
auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
|
|
GGML_ASSERT(res_task != nullptr);
|
|
|
|
// optionally return "fail_on_no_slot" error
|
|
if (!req.get_param("fail_on_no_slot").empty()) {
|
|
if (res_task->n_idle_slots == 0) {
|
|
res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
|
return res;
|
|
}
|
|
}
|
|
|
|
res->ok(res_task->slots_data);
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t post_slots = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
if (params.slot_save_path.empty()) {
|
|
res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
|
|
return res;
|
|
}
|
|
|
|
std::string id_slot_str = req.get_param("id_slot");
|
|
int id_slot;
|
|
|
|
try {
|
|
id_slot = std::stoi(id_slot_str);
|
|
} catch (const std::exception &) {
|
|
res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
|
|
std::string action = req.get_param("action");
|
|
|
|
if (action == "save") {
|
|
return handle_slots_save(req, id_slot);
|
|
} else if (action == "restore") {
|
|
return handle_slots_restore(req, id_slot);
|
|
} else if (action == "erase") {
|
|
return handle_slots_erase(req, id_slot);
|
|
} else {
|
|
res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
};
|
|
|
|
server_http_context::handler_t get_props = [this](const server_http_req &) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
json default_generation_settings_for_props;
|
|
|
|
{
|
|
task_params params;
|
|
|
|
params.sampling = ctx_server.params_base.sampling;
|
|
|
|
default_generation_settings_for_props = json {
|
|
{"params", params.to_json(true)},
|
|
{"n_ctx", ctx_server.slots[0].n_ctx},
|
|
};
|
|
}
|
|
|
|
// this endpoint is publicly available, please only return what is safe to be exposed
|
|
json data = {
|
|
{ "default_generation_settings", default_generation_settings_for_props },
|
|
{ "total_slots", ctx_server.params_base.n_parallel },
|
|
{ "model_alias", ctx_server.params_base.model_alias },
|
|
{ "model_path", ctx_server.params_base.model.path },
|
|
{ "modalities", json {
|
|
{"vision", ctx_server.oai_parser_opt.allow_image},
|
|
{"audio", ctx_server.oai_parser_opt.allow_audio},
|
|
} },
|
|
{ "endpoint_slots", params.endpoint_slots },
|
|
{ "endpoint_props", params.endpoint_props },
|
|
{ "endpoint_metrics", params.endpoint_metrics },
|
|
{ "webui", params.webui },
|
|
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
|
|
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
|
|
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
|
|
{ "build_info", build_info },
|
|
};
|
|
if (ctx_server.params_base.use_jinja) {
|
|
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
|
|
data["chat_template_tool_use"] = tool_use_src;
|
|
}
|
|
}
|
|
|
|
res->ok(data);
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t post_props = [this](const server_http_req &) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
if (!params.endpoint_props) {
|
|
res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
|
|
return res;
|
|
}
|
|
// update any props here
|
|
|
|
res->ok({{ "success", true }});
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t get_api_show = [this](const server_http_req &) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
bool has_mtmd = ctx_server.mctx != nullptr;
|
|
json data = {
|
|
{
|
|
"template", common_chat_templates_source(ctx_server.chat_templates.get()),
|
|
},
|
|
{
|
|
"model_info", {
|
|
{ "llama.context_length", ctx_server.slots.back().n_ctx, },
|
|
}
|
|
},
|
|
{"modelfile", ""},
|
|
{"parameters", ""},
|
|
{"template", common_chat_templates_source(ctx_server.chat_templates.get())},
|
|
{"details", {
|
|
{"parent_model", ""},
|
|
{"format", "gguf"},
|
|
{"family", ""},
|
|
{"families", {""}},
|
|
{"parameter_size", ""},
|
|
{"quantization_level", ""}
|
|
}},
|
|
{"model_info", ""},
|
|
{"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}
|
|
};
|
|
|
|
res->ok(data);
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t post_infill = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
// check model compatibility
|
|
std::string err;
|
|
if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
|
|
err += "prefix token is missing. ";
|
|
}
|
|
if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
|
|
err += "suffix token is missing. ";
|
|
}
|
|
if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
|
|
err += "middle token is missing. ";
|
|
}
|
|
if (!err.empty()) {
|
|
res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
|
return res;
|
|
}
|
|
|
|
// validate input
|
|
json data = json::parse(req.body);
|
|
if (data.contains("prompt") && !data.at("prompt").is_string()) {
|
|
// prompt is optional
|
|
res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
|
}
|
|
|
|
if (!data.contains("input_prefix")) {
|
|
res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
|
}
|
|
|
|
if (!data.contains("input_suffix")) {
|
|
res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
|
}
|
|
|
|
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
|
// input_extra is optional
|
|
res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
|
|
json input_extra = json_value(data, "input_extra", json::array());
|
|
for (const auto & chunk : input_extra) {
|
|
// { "text": string, "filename": string }
|
|
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
|
res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
// filename is optional
|
|
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
|
|
res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
}
|
|
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
|
|
|
std::string prompt = json_value(data, "prompt", std::string());
|
|
std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true);
|
|
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
|
data["prompt"] = format_prompt_infill(
|
|
ctx_server.vocab,
|
|
data.at("input_prefix"),
|
|
data.at("input_suffix"),
|
|
data.at("input_extra"),
|
|
ctx_server.params_base.n_batch,
|
|
ctx_server.params_base.n_predict,
|
|
ctx_server.slots[0].n_ctx, // TODO: there should be a better way
|
|
ctx_server.params_base.spm_infill,
|
|
tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
|
|
);
|
|
|
|
std::vector<raw_buffer> files; // dummy
|
|
return handle_completions_impl(
|
|
SERVER_TASK_TYPE_INFILL,
|
|
data,
|
|
files,
|
|
req.should_stop,
|
|
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
|
};
|
|
|
|
server_http_context::handler_t post_completions = [this](const server_http_req & req) {
|
|
std::vector<raw_buffer> files; // dummy
|
|
const json body = json::parse(req.body);
|
|
return handle_completions_impl(
|
|
SERVER_TASK_TYPE_COMPLETION,
|
|
body,
|
|
files,
|
|
req.should_stop,
|
|
OAICOMPAT_TYPE_NONE);
|
|
};
|
|
|
|
server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) {
|
|
std::vector<raw_buffer> files; // dummy
|
|
const json body = json::parse(req.body);
|
|
return handle_completions_impl(
|
|
SERVER_TASK_TYPE_COMPLETION,
|
|
body,
|
|
files,
|
|
req.should_stop,
|
|
OAICOMPAT_TYPE_COMPLETION);
|
|
};
|
|
|
|
server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) {
|
|
std::vector<raw_buffer> files;
|
|
json body = json::parse(req.body);
|
|
json body_parsed = oaicompat_chat_params_parse(
|
|
body,
|
|
ctx_server.oai_parser_opt,
|
|
files);
|
|
return handle_completions_impl(
|
|
SERVER_TASK_TYPE_COMPLETION,
|
|
body_parsed,
|
|
files,
|
|
req.should_stop,
|
|
OAICOMPAT_TYPE_CHAT);
|
|
};
|
|
|
|
// same with handle_chat_completions, but without inference part
|
|
server_http_context::handler_t post_apply_template = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
std::vector<raw_buffer> files; // dummy, unused
|
|
json body = json::parse(req.body);
|
|
json data = oaicompat_chat_params_parse(
|
|
body,
|
|
ctx_server.oai_parser_opt,
|
|
files);
|
|
res->ok({{ "prompt", std::move(data.at("prompt")) }});
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t get_models = [this](const server_http_req &) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
bool is_model_ready = ctx_http.is_ready.load();
|
|
json model_meta = nullptr;
|
|
if (is_model_ready) {
|
|
model_meta = ctx_server.model_meta();
|
|
}
|
|
bool has_mtmd = ctx_server.mctx != nullptr;
|
|
json models = {
|
|
{"models", {
|
|
{
|
|
{"name", params.model_alias.empty() ? params.model.path : params.model_alias},
|
|
{"model", params.model_alias.empty() ? params.model.path : params.model_alias},
|
|
{"modified_at", ""},
|
|
{"size", ""},
|
|
{"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash
|
|
{"type", "model"},
|
|
{"description", ""},
|
|
{"tags", {""}},
|
|
{"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})},
|
|
{"parameters", ""},
|
|
{"details", {
|
|
{"parent_model", ""},
|
|
{"format", "gguf"},
|
|
{"family", ""},
|
|
{"families", {""}},
|
|
{"parameter_size", ""},
|
|
{"quantization_level", ""}
|
|
}}
|
|
}
|
|
}},
|
|
{"object", "list"},
|
|
{"data", {
|
|
{
|
|
{"id", params.model_alias.empty() ? params.model.path : params.model_alias},
|
|
{"object", "model"},
|
|
{"created", std::time(0)},
|
|
{"owned_by", "llamacpp"},
|
|
{"meta", model_meta},
|
|
},
|
|
}}
|
|
};
|
|
|
|
res->ok(models);
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t post_tokenize = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
const json body = json::parse(req.body);
|
|
json tokens_response = json::array();
|
|
if (body.count("content") != 0) {
|
|
const bool add_special = json_value(body, "add_special", false);
|
|
const bool parse_special = json_value(body, "parse_special", true);
|
|
const bool with_pieces = json_value(body, "with_pieces", false);
|
|
|
|
llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special);
|
|
|
|
if (with_pieces) {
|
|
for (const auto& token : tokens) {
|
|
std::string piece = common_token_to_piece(ctx_server.ctx, token);
|
|
json piece_json;
|
|
|
|
// Check if the piece is valid UTF-8
|
|
if (is_valid_utf8(piece)) {
|
|
piece_json = piece;
|
|
} else {
|
|
// If not valid UTF-8, store as array of byte values
|
|
piece_json = json::array();
|
|
for (unsigned char c : piece) {
|
|
piece_json.push_back(static_cast<int>(c));
|
|
}
|
|
}
|
|
|
|
tokens_response.push_back({
|
|
{"id", token},
|
|
{"piece", piece_json}
|
|
});
|
|
}
|
|
} else {
|
|
tokens_response = tokens;
|
|
}
|
|
}
|
|
|
|
res->ok(json{{"tokens", std::move(tokens_response)}});
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t post_detokenize = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
const json body = json::parse(req.body);
|
|
|
|
std::string content;
|
|
if (body.count("tokens") != 0) {
|
|
const llama_tokens tokens = body.at("tokens");
|
|
content = tokens_to_str(ctx_server.ctx, tokens);
|
|
}
|
|
|
|
res->ok(json{{"content", std::move(content)}});
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t post_embeddings = [this](const server_http_req & req) {
|
|
return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE);
|
|
};
|
|
|
|
server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) {
|
|
return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING);
|
|
};
|
|
|
|
server_http_context::handler_t post_rerank = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
|
|
res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
|
return res;
|
|
}
|
|
|
|
const json body = json::parse(req.body);
|
|
|
|
// if true, use TEI API format, otherwise use Jina API format
|
|
// Jina: https://jina.ai/reranker/
|
|
// TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
|
|
bool is_tei_format = body.contains("texts");
|
|
|
|
json query;
|
|
if (body.count("query") == 1) {
|
|
query = body.at("query");
|
|
if (!query.is_string()) {
|
|
res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
} else {
|
|
res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
|
|
std::vector<std::string> documents = json_value(body, "documents",
|
|
json_value(body, "texts", std::vector<std::string>()));
|
|
if (documents.empty()) {
|
|
res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
|
|
int top_n = json_value(body, "top_n", (int)documents.size());
|
|
|
|
// create and queue the task
|
|
json responses = json::array();
|
|
server_response_reader rd(ctx_server);
|
|
{
|
|
std::vector<server_task> tasks;
|
|
tasks.reserve(documents.size());
|
|
for (size_t i = 0; i < documents.size(); i++) {
|
|
auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
|
|
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
task.index = i;
|
|
task.tokens = std::move(tmp);
|
|
tasks.push_back(std::move(task));
|
|
}
|
|
rd.post_tasks(std::move(tasks));
|
|
}
|
|
|
|
// wait for the results
|
|
auto all_results = rd.wait_for_all(req.should_stop);
|
|
|
|
// collect results
|
|
if (all_results.is_terminated) {
|
|
return res; // connection is closed
|
|
} else if (all_results.error) {
|
|
res->error(all_results.error->to_json());
|
|
return res;
|
|
} else {
|
|
for (auto & res : all_results.results) {
|
|
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
|
responses.push_back(res->to_json());
|
|
}
|
|
}
|
|
|
|
// write JSON response
|
|
json root = format_response_rerank(
|
|
body,
|
|
responses,
|
|
is_tei_format,
|
|
documents,
|
|
top_n);
|
|
|
|
res->ok(root);
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t get_lora_adapters = [this](const server_http_req &) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
json result = json::array();
|
|
const auto & loras = ctx_server.params_base.lora_adapters;
|
|
for (size_t i = 0; i < loras.size(); ++i) {
|
|
auto & lora = loras[i];
|
|
json entry = {
|
|
{"id", i},
|
|
{"path", lora.path},
|
|
{"scale", lora.scale},
|
|
{"task_name", lora.task_name},
|
|
{"prompt_prefix", lora.prompt_prefix},
|
|
};
|
|
std::string alora_invocation_string = "";
|
|
const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr);
|
|
std::vector<llama_token> alora_invocation_tokens;
|
|
if (n_alora_tokens) {
|
|
const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr);
|
|
for (uint64_t i = 0; i < n_alora_tokens; ++i) {
|
|
alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]);
|
|
alora_invocation_tokens.push_back(alora_tokens[i]);
|
|
}
|
|
entry["alora_invocation_string"] = alora_invocation_string;
|
|
entry["alora_invocation_tokens"] = alora_invocation_tokens;
|
|
}
|
|
result.push_back(std::move(entry));
|
|
}
|
|
res->ok(result);
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t post_lora_adapters = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
const json body = json::parse(req.body);
|
|
if (!body.is_array()) {
|
|
res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
|
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
|
{
|
|
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
|
task.id = task_id;
|
|
task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
|
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
|
ctx_server.queue_tasks.post(std::move(task));
|
|
}
|
|
|
|
// get the result
|
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
|
|
|
if (result->is_error()) {
|
|
res->error(result->to_json());
|
|
return res;
|
|
}
|
|
|
|
GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
|
|
res->ok(result->to_json());
|
|
return res;
|
|
};
|
|
|
|
private:
|
|
std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
server_task_type type,
|
|
const json & data,
|
|
const std::vector<raw_buffer> & files,
|
|
const std::function<bool()> & should_stop,
|
|
oaicompat_type oaicompat) {
|
|
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
|
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
auto completion_id = gen_chatcmplid();
|
|
auto & rd = res->rd;
|
|
|
|
try {
|
|
std::vector<server_task> tasks;
|
|
|
|
const auto & prompt = data.at("prompt");
|
|
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
|
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
|
|
|
// process prompt
|
|
std::vector<server_tokens> inputs;
|
|
|
|
if (oaicompat && ctx_server.mctx != nullptr) {
|
|
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
|
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
|
} else {
|
|
// Everything else, including multimodal completions.
|
|
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
|
}
|
|
tasks.reserve(inputs.size());
|
|
for (size_t i = 0; i < inputs.size(); i++) {
|
|
server_task task = server_task(type);
|
|
|
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
task.index = i;
|
|
|
|
task.tokens = std::move(inputs[i]);
|
|
task.params = server_task::params_from_json_cmpl(
|
|
ctx_server.ctx,
|
|
ctx_server.params_base,
|
|
data);
|
|
task.id_slot = json_value(data, "id_slot", -1);
|
|
|
|
// OAI-compat
|
|
task.params.oaicompat = oaicompat;
|
|
task.params.oaicompat_cmpl_id = completion_id;
|
|
// oaicompat_model is already populated by params_from_json_cmpl
|
|
|
|
tasks.push_back(std::move(task));
|
|
}
|
|
|
|
rd.post_tasks(std::move(tasks));
|
|
} catch (const std::exception & e) {
|
|
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
|
|
bool stream = json_value(data, "stream", false);
|
|
|
|
if (!stream) {
|
|
// non-stream, wait for the results
|
|
auto all_results = rd.wait_for_all(should_stop);
|
|
if (all_results.is_terminated) {
|
|
return res; // connection is closed
|
|
} else if (all_results.error) {
|
|
res->error(all_results.error->to_json());
|
|
return res;
|
|
} else {
|
|
json arr = json::array();
|
|
for (auto & res : all_results.results) {
|
|
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
|
arr.push_back(res->to_json());
|
|
}
|
|
// if single request, return single object instead of array
|
|
res->ok(arr.size() == 1 ? arr[0] : arr);
|
|
}
|
|
|
|
} else {
|
|
// in streaming mode, the first error must be treated as non-stream response
|
|
// this is to match the OAI API behavior
|
|
// ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
|
|
server_task_result_ptr first_result = rd.next(should_stop);
|
|
if (first_result == nullptr) {
|
|
return res; // connection is closed
|
|
} else if (first_result->is_error()) {
|
|
res->error(first_result->to_json());
|
|
return res;
|
|
} else {
|
|
GGML_ASSERT(
|
|
dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
|
|
|| dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
|
|
);
|
|
}
|
|
|
|
// next responses are streamed
|
|
res->data = format_sse(first_result->to_json()); // to be sent immediately
|
|
res->status = 200;
|
|
res->content_type = "text/event-stream";
|
|
res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool {
|
|
if (should_stop()) {
|
|
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
|
return false; // should_stop condition met
|
|
}
|
|
|
|
if (!res_this->data.empty()) {
|
|
// flush the first chunk
|
|
output = std::move(res_this->data);
|
|
res_this->data.clear();
|
|
return true;
|
|
}
|
|
|
|
server_response_reader & rd = res_this->rd;
|
|
|
|
// check if there is more data
|
|
if (!rd.has_next()) {
|
|
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
|
output = "data: [DONE]\n\n";
|
|
} else {
|
|
output = "";
|
|
}
|
|
SRV_DBG("%s", "all results received, terminating stream\n");
|
|
return false; // no more data, terminate
|
|
}
|
|
|
|
// receive subsequent results
|
|
auto result = rd.next(should_stop);
|
|
if (result == nullptr) {
|
|
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
|
return false; // should_stop condition met
|
|
}
|
|
|
|
// send the results
|
|
json res_json = result->to_json();
|
|
if (result->is_error()) {
|
|
output = format_sse(json {{ "error", res_json }});
|
|
SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
|
return false; // terminate on error
|
|
} else {
|
|
GGML_ASSERT(
|
|
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
|
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
|
);
|
|
output = format_sse(res_json);
|
|
}
|
|
|
|
// has next data, continue
|
|
return true;
|
|
};
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
std::unique_ptr<server_res_generator> handle_slots_save(const server_http_req & req, int id_slot) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
const json request_data = json::parse(req.body);
|
|
std::string filename = request_data.at("filename");
|
|
if (!fs_validate_filename(filename)) {
|
|
res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
std::string filepath = params.slot_save_path + filename;
|
|
|
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
|
{
|
|
server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
|
|
task.id = task_id;
|
|
task.slot_action.slot_id = id_slot;
|
|
task.slot_action.filename = filename;
|
|
task.slot_action.filepath = filepath;
|
|
|
|
// TODO: use server_response_reader
|
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
|
ctx_server.queue_tasks.post(std::move(task));
|
|
}
|
|
|
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
|
|
|
if (result->is_error()) {
|
|
res->error(result->to_json());
|
|
return res;
|
|
}
|
|
|
|
res->ok(result->to_json());
|
|
return res;
|
|
}
|
|
|
|
std::unique_ptr<server_res_generator> handle_slots_restore(const server_http_req & req, int id_slot) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
const json request_data = json::parse(req.body);
|
|
std::string filename = request_data.at("filename");
|
|
if (!fs_validate_filename(filename)) {
|
|
res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
std::string filepath = params.slot_save_path + filename;
|
|
|
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
|
{
|
|
server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
|
|
task.id = task_id;
|
|
task.slot_action.slot_id = id_slot;
|
|
task.slot_action.filename = filename;
|
|
task.slot_action.filepath = filepath;
|
|
|
|
// TODO: use server_response_reader
|
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
|
ctx_server.queue_tasks.post(std::move(task));
|
|
}
|
|
|
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
|
|
|
if (result->is_error()) {
|
|
res->error(result->to_json());
|
|
return res;
|
|
}
|
|
|
|
GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
|
|
res->ok(result->to_json());
|
|
return res;
|
|
}
|
|
|
|
std::unique_ptr<server_res_generator> handle_slots_erase(const server_http_req &, int id_slot) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
|
{
|
|
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
|
|
task.id = task_id;
|
|
task.slot_action.slot_id = id_slot;
|
|
|
|
// TODO: use server_response_reader
|
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
|
ctx_server.queue_tasks.post(std::move(task));
|
|
}
|
|
|
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
|
|
|
if (result->is_error()) {
|
|
res->error(result->to_json());
|
|
return res;
|
|
}
|
|
|
|
GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
|
|
res->ok(result->to_json());
|
|
return res;
|
|
}
|
|
|
|
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) {
|
|
auto res = std::make_unique<server_res_generator>(ctx_server);
|
|
if (!ctx_server.params_base.embedding) {
|
|
res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
return res;
|
|
}
|
|
|
|
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
|
res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
|
|
const json body = json::parse(req.body);
|
|
|
|
// for the shape of input/content, see tokenize_input_prompts()
|
|
json prompt;
|
|
if (body.count("input") != 0) {
|
|
prompt = body.at("input");
|
|
} else if (body.contains("content")) {
|
|
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
|
|
prompt = body.at("content");
|
|
} else {
|
|
res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
|
|
bool use_base64 = false;
|
|
if (body.count("encoding_format") != 0) {
|
|
const std::string& format = body.at("encoding_format");
|
|
if (format == "base64") {
|
|
use_base64 = true;
|
|
} else if (format != "float") {
|
|
res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
}
|
|
|
|
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
|
for (const auto & tokens : tokenized_prompts) {
|
|
// this check is necessary for models that do not add BOS token to the input
|
|
if (tokens.empty()) {
|
|
res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
}
|
|
|
|
int embd_normalize = 2; // default to Euclidean/L2 norm
|
|
if (body.count("embd_normalize") != 0) {
|
|
embd_normalize = body.at("embd_normalize");
|
|
if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
|
SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx));
|
|
}
|
|
}
|
|
|
|
// create and queue the task
|
|
json responses = json::array();
|
|
server_response_reader rd(ctx_server);
|
|
{
|
|
std::vector<server_task> tasks;
|
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
|
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
|
|
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
task.index = i;
|
|
task.tokens = std::move(tokenized_prompts[i]);
|
|
|
|
// OAI-compat
|
|
task.params.oaicompat = oaicompat;
|
|
task.params.embd_normalize = embd_normalize;
|
|
|
|
tasks.push_back(std::move(task));
|
|
}
|
|
rd.post_tasks(std::move(tasks));
|
|
}
|
|
|
|
// wait for the results
|
|
auto all_results = rd.wait_for_all(req.should_stop);
|
|
|
|
// collect results
|
|
if (all_results.is_terminated) {
|
|
return res; // connection is closed
|
|
} else if (all_results.error) {
|
|
res->error(all_results.error->to_json());
|
|
return res;
|
|
} else {
|
|
for (auto & res : all_results.results) {
|
|
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
|
responses.push_back(res->to_json());
|
|
}
|
|
}
|
|
|
|
// write JSON response
|
|
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
|
? format_embeddings_response_oaicompat(body, responses, use_base64)
|
|
: json(responses);
|
|
res->ok(root);
|
|
return res;
|
|
}
|
|
};
|
|
|
|
static std::function<void(int)> shutdown_handler;
|
|
static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
|
|
|
|
static inline void signal_handler(int signal) {
|
|
if (is_terminating.test_and_set()) {
|
|
// in case it hangs, we can force terminate the server by hitting Ctrl+C twice
|
|
// this is for better developer experience, we can remove when the server is stable enough
|
|
fprintf(stderr, "Received second interrupt, terminating immediately.\n");
|
|
exit(1);
|
|
}
|
|
|
|
shutdown_handler(signal);
|
|
}
|
|
|
|
// wrapper function that handles exceptions and logs errors
|
|
// this is to make sure handler_t never throws exceptions; instead, it returns an error response
|
|
static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
|
|
return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr {
|
|
std::string message;
|
|
try {
|
|
return func(req);
|
|
} catch (const std::exception & e) {
|
|
message = e.what();
|
|
} catch (...) {
|
|
message = "unknown error";
|
|
}
|
|
|
|
auto res = std::make_unique<server_http_res>();
|
|
res->status = 500;
|
|
try {
|
|
json error_data = format_error_response(message, ERROR_TYPE_SERVER);
|
|
res->status = json_value(error_data, "code", 500);
|
|
res->data = safe_json_to_str({{ "error", error_data }});
|
|
LOG_WRN("got exception: %s\n", res->data.c_str());
|
|
} catch (const std::exception & e) {
|
|
LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
|
|
res->data = "Internal Server Error";
|
|
}
|
|
return res;
|
|
};
|
|
}
|
|
|
|
int main(int argc, char ** argv) {
|
|
// own arguments required by this example
|
|
common_params params;
|
|
|
|
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
|
|
return 1;
|
|
}
|
|
|
|
// TODO: should we have a separate n_parallel parameter for the server?
|
|
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
|
|
// TODO: this is a common configuration that is suitable for most local use cases
|
|
// however, overriding the parameters is a bit confusing - figure out something more intuitive
|
|
if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) {
|
|
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__);
|
|
|
|
params.n_parallel = 4;
|
|
params.kv_unified = true;
|
|
}
|
|
|
|
common_init();
|
|
|
|
// struct that contains llama context and inference
|
|
server_context ctx_server;
|
|
|
|
// Necessary similarity of prompt for slot selection
|
|
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
|
|
|
|
llama_backend_init();
|
|
llama_numa_init(params.numa);
|
|
|
|
LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
|
|
LOG_INF("\n");
|
|
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
|
LOG_INF("\n");
|
|
|
|
server_http_context ctx_http;
|
|
if (!ctx_http.init(params)) {
|
|
LOG_ERR("%s: failed to initialize HTTP server\n", __func__);
|
|
return 1;
|
|
}
|
|
|
|
//
|
|
// Router
|
|
//
|
|
|
|
// register API routes
|
|
server_routes routes(params, ctx_server, ctx_http);
|
|
|
|
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
|
|
ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
|
|
ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics));
|
|
ctx_http.get ("/props", ex_wrapper(routes.get_props));
|
|
ctx_http.post("/props", ex_wrapper(routes.post_props));
|
|
ctx_http.post("/api/show", ex_wrapper(routes.get_api_show));
|
|
ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
|
|
ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
|
|
ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check)
|
|
ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy
|
|
ctx_http.post("/completions", ex_wrapper(routes.post_completions));
|
|
ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai));
|
|
ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
|
|
ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
|
|
ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
|
|
ctx_http.post("/infill", ex_wrapper(routes.post_infill));
|
|
ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
|
|
ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
|
|
ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai));
|
|
ctx_http.post("/rerank", ex_wrapper(routes.post_rerank));
|
|
ctx_http.post("/reranking", ex_wrapper(routes.post_rerank));
|
|
ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank));
|
|
ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank));
|
|
ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize));
|
|
ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize));
|
|
ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template));
|
|
// LoRA adapters hotswap
|
|
ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters));
|
|
ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters));
|
|
// Save & load slots
|
|
ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
|
|
ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
|
|
|
|
//
|
|
// Start the server
|
|
//
|
|
|
|
// setup clean up function, to be called before exit
|
|
auto clean_up = [&ctx_http, &ctx_server]() {
|
|
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
|
ctx_http.stop();
|
|
ctx_server.queue_results.terminate();
|
|
llama_backend_free();
|
|
};
|
|
|
|
// start the HTTP server before loading the model to be able to serve /health requests
|
|
if (!ctx_http.start()) {
|
|
clean_up();
|
|
LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
|
|
return 1;
|
|
}
|
|
|
|
// load the model
|
|
LOG_INF("%s: loading model\n", __func__);
|
|
|
|
if (!ctx_server.load_model(params)) {
|
|
clean_up();
|
|
if (ctx_http.thread.joinable()) {
|
|
ctx_http.thread.join();
|
|
}
|
|
LOG_ERR("%s: exiting due to model loading error\n", __func__);
|
|
return 1;
|
|
}
|
|
|
|
ctx_server.init();
|
|
ctx_http.is_ready.store(true);
|
|
|
|
LOG_INF("%s: model loaded\n", __func__);
|
|
|
|
ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) {
|
|
ctx_server.process_single_task(std::move(task));
|
|
});
|
|
|
|
ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
|
|
ctx_server.update_slots();
|
|
});
|
|
|
|
shutdown_handler = [&](int) {
|
|
// this will unblock start_loop()
|
|
ctx_server.queue_tasks.terminate();
|
|
};
|
|
|
|
// TODO: refactor in common/console
|
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
|
struct sigaction sigint_action;
|
|
sigint_action.sa_handler = signal_handler;
|
|
sigemptyset (&sigint_action.sa_mask);
|
|
sigint_action.sa_flags = 0;
|
|
sigaction(SIGINT, &sigint_action, NULL);
|
|
sigaction(SIGTERM, &sigint_action, NULL);
|
|
#elif defined (_WIN32)
|
|
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
|
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
|
};
|
|
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
|
#endif
|
|
|
|
LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
|
|
LOG_INF("%s: starting the main loop...\n", __func__);
|
|
// this call blocks the main thread until queue_tasks.terminate() is called
|
|
ctx_server.queue_tasks.start_loop();
|
|
|
|
clean_up();
|
|
if (ctx_http.thread.joinable()) {
|
|
ctx_http.thread.join();
|
|
}
|
|
llama_memory_breakdown_print(ctx_server.ctx);
|
|
|
|
return 0;
|
|
}
|