mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-07 09:34:07 +00:00
This commit removes the `--backend-dist` option and instead uses the
configured --samplers chain to determine which samplers run on the
backend.
Backend sampling is still enabled using With `--backend_sampling`, and
the sampler chain, either explictly specified using `--samplers` or the
default, is automatically analyzed to determine which samplers can run
on the backend. The system finds the longest contiguous chain of
backend supported samplers from the start of the sampler sequence.
For example:
* If the chain is `top-k -> temperature -> top-p`, and both `top-k` and
`temperature` are backend-supported but `top-p` is not, then `top-k`
and `temperature` will run on the backend, while `top-p` and
subsequent samplers run on the CPU.
* If all configured samplers are supported, the final distribution
sampling will also happen on the backend, transferring only the
sampled token IDs back to the host.
* If the sampler chain starts with an unsupported sampler (e.g.,
`penalties`), all sampling runs on the CPU. Note that this is
currently the case with the default sampler so to use backend sampling
it is required to specify a sampler chain. See below for an example.
The following shows how llama-cli can be run with backend sampling:
```console
$ llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
--prompt 'What is the capital of Sweden?' \
-n 20 \
-no-cnv \
--verbose-prompt \
-ngl 40 \
--backend-sampling \
--samplers 'top_k;temperature'
```
In this case the all sampling will happen on the backend since both
`top_k` and `temperature` are supported backend samplers.
To enable a partial backend sampling (hybrid sampling), for example
running `top_k` and `temperature` on the backend and `typ_p` on the CPU
the following sampler chain could be specified:
```console
$ llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
--prompt 'What is the capital of Sweden?' \
-n 20 \
-no-cnv \
--verbose-prompt \
-ngl 40 \
--backend-sampling \
--samplers 'top_k;temperature;top_p'
```
If this looks good then I'll follow up with updates the llama-cli and
llama-server documentation to reflect these changes.
1198 lines
46 KiB
C++
1198 lines
46 KiB
C++
#include "server-common.h"
|
|
#include "server-task.h"
|
|
|
|
#include "common.h"
|
|
#include "llama.h"
|
|
#include "chat.h"
|
|
#include "sampling.h"
|
|
#include "json-schema-to-grammar.h"
|
|
|
|
using json = nlohmann::ordered_json;
|
|
|
|
//
|
|
// task_params
|
|
//
|
|
|
|
json task_params::format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const {
|
|
json data = json::array();
|
|
for (const auto & lb : logit_bias) {
|
|
data.push_back(json{
|
|
{"bias", lb.bias},
|
|
{"token", lb.token},
|
|
});
|
|
}
|
|
return data;
|
|
}
|
|
|
|
json task_params::to_json(bool only_metrics) const {
|
|
std::vector<std::string> samplers;
|
|
samplers.reserve(sampling.samplers.size());
|
|
for (const auto & sampler : sampling.samplers) {
|
|
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
|
}
|
|
|
|
json lora = json::array();
|
|
for (size_t i = 0; i < this->lora.size(); ++i) {
|
|
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
|
|
}
|
|
|
|
if (only_metrics) {
|
|
return json {
|
|
{"seed", sampling.seed},
|
|
{"temperature", sampling.temp},
|
|
{"dynatemp_range", sampling.dynatemp_range},
|
|
{"dynatemp_exponent", sampling.dynatemp_exponent},
|
|
{"top_k", sampling.top_k},
|
|
{"top_p", sampling.top_p},
|
|
{"min_p", sampling.min_p},
|
|
{"top_n_sigma", sampling.top_n_sigma},
|
|
{"xtc_probability", sampling.xtc_probability},
|
|
{"xtc_threshold", sampling.xtc_threshold},
|
|
{"typical_p", sampling.typ_p},
|
|
{"repeat_last_n", sampling.penalty_last_n},
|
|
{"repeat_penalty", sampling.penalty_repeat},
|
|
{"presence_penalty", sampling.penalty_present},
|
|
{"frequency_penalty", sampling.penalty_freq},
|
|
{"dry_multiplier", sampling.dry_multiplier},
|
|
{"dry_base", sampling.dry_base},
|
|
{"dry_allowed_length", sampling.dry_allowed_length},
|
|
{"dry_penalty_last_n", sampling.dry_penalty_last_n},
|
|
{"mirostat", sampling.mirostat},
|
|
{"mirostat_tau", sampling.mirostat_tau},
|
|
{"mirostat_eta", sampling.mirostat_eta},
|
|
{"max_tokens", n_predict},
|
|
{"n_predict", n_predict}, // TODO: deduplicate?
|
|
{"n_keep", n_keep},
|
|
{"n_discard", n_discard},
|
|
{"ignore_eos", sampling.ignore_eos},
|
|
{"stream", stream},
|
|
{"n_probs", sampling.n_probs},
|
|
{"min_keep", sampling.min_keep},
|
|
{"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
|
|
{"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
|
|
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
|
|
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
|
|
{"samplers", samplers},
|
|
{"speculative.n_max", speculative.n_max},
|
|
{"speculative.n_min", speculative.n_min},
|
|
{"speculative.p_min", speculative.p_min},
|
|
{"timings_per_token", timings_per_token},
|
|
{"post_sampling_probs", post_sampling_probs},
|
|
{"backend_sampling", sampling.backend_sampling},
|
|
{"lora", lora},
|
|
};
|
|
}
|
|
|
|
auto grammar_triggers = json::array();
|
|
for (const auto & trigger : sampling.grammar_triggers) {
|
|
server_grammar_trigger ct(trigger);
|
|
grammar_triggers.push_back(ct.to_json());
|
|
}
|
|
|
|
return json {
|
|
{"seed", sampling.seed},
|
|
{"temperature", sampling.temp},
|
|
{"dynatemp_range", sampling.dynatemp_range},
|
|
{"dynatemp_exponent", sampling.dynatemp_exponent},
|
|
{"top_k", sampling.top_k},
|
|
{"top_p", sampling.top_p},
|
|
{"min_p", sampling.min_p},
|
|
{"top_n_sigma", sampling.top_n_sigma},
|
|
{"xtc_probability", sampling.xtc_probability},
|
|
{"xtc_threshold", sampling.xtc_threshold},
|
|
{"typical_p", sampling.typ_p},
|
|
{"repeat_last_n", sampling.penalty_last_n},
|
|
{"repeat_penalty", sampling.penalty_repeat},
|
|
{"presence_penalty", sampling.penalty_present},
|
|
{"frequency_penalty", sampling.penalty_freq},
|
|
{"dry_multiplier", sampling.dry_multiplier},
|
|
{"dry_base", sampling.dry_base},
|
|
{"dry_allowed_length", sampling.dry_allowed_length},
|
|
{"dry_penalty_last_n", sampling.dry_penalty_last_n},
|
|
{"dry_sequence_breakers", sampling.dry_sequence_breakers},
|
|
{"mirostat", sampling.mirostat},
|
|
{"mirostat_tau", sampling.mirostat_tau},
|
|
{"mirostat_eta", sampling.mirostat_eta},
|
|
{"stop", antiprompt},
|
|
{"max_tokens", n_predict},
|
|
{"n_predict", n_predict}, // TODO: deduplicate?
|
|
{"n_keep", n_keep},
|
|
{"n_discard", n_discard},
|
|
{"ignore_eos", sampling.ignore_eos},
|
|
{"stream", stream},
|
|
{"logit_bias", format_logit_bias(sampling.logit_bias)},
|
|
{"n_probs", sampling.n_probs},
|
|
{"min_keep", sampling.min_keep},
|
|
{"grammar", sampling.grammar},
|
|
{"grammar_lazy", sampling.grammar_lazy},
|
|
{"grammar_triggers", grammar_triggers},
|
|
{"preserved_tokens", sampling.preserved_tokens},
|
|
{"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
|
|
{"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
|
|
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
|
|
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
|
|
{"samplers", samplers},
|
|
{"speculative.n_max", speculative.n_max},
|
|
{"speculative.n_min", speculative.n_min},
|
|
{"speculative.p_min", speculative.p_min},
|
|
{"timings_per_token", timings_per_token},
|
|
{"post_sampling_probs", post_sampling_probs},
|
|
{"backend_sampling", sampling.backend_sampling},
|
|
{"lora", lora},
|
|
};
|
|
}
|
|
|
|
//
|
|
// server_task
|
|
//
|
|
|
|
task_params server_task::params_from_json_cmpl(
|
|
const llama_context * ctx,
|
|
const common_params & params_base,
|
|
const json & data) {
|
|
const llama_model * model = llama_get_model(ctx);
|
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
|
|
task_params params;
|
|
|
|
// Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
|
|
task_params defaults;
|
|
defaults.sampling = params_base.sampling;
|
|
defaults.speculative = params_base.speculative;
|
|
defaults.n_keep = params_base.n_keep;
|
|
defaults.n_predict = params_base.n_predict;
|
|
defaults.antiprompt = params_base.antiprompt;
|
|
|
|
// enabling this will output extra debug information in the HTTP responses from the server
|
|
params.verbose = params_base.verbosity > 9;
|
|
params.timings_per_token = json_value(data, "timings_per_token", false);
|
|
|
|
params.stream = json_value(data, "stream", false);
|
|
auto stream_opt = json_value(data, "stream_options", json::object());
|
|
params.include_usage = json_value(stream_opt, "include_usage", false);
|
|
params.cache_prompt = json_value(data, "cache_prompt", true);
|
|
params.return_tokens = json_value(data, "return_tokens", false);
|
|
params.return_progress = json_value(data, "return_progress", false);
|
|
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
|
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
|
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
|
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
|
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
|
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
|
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
|
|
|
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
|
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
|
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
|
|
params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
|
|
params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
|
|
params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
|
|
params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
|
|
params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
|
|
params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
|
|
params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
|
|
params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
|
|
params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
|
|
params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
|
|
params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
|
|
params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
|
|
params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
|
|
params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
|
|
params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
|
|
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
|
|
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
|
|
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
|
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
|
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
|
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
|
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
|
|
|
const bool request_backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling);
|
|
params.sampling.backend_sampling = defaults.sampling.backend_sampling && request_backend_sampling;
|
|
|
|
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
|
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
|
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
|
|
|
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
|
|
params.speculative.n_min = std::max(params.speculative.n_min, 0);
|
|
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
|
|
|
// Use OpenAI API logprobs only if n_probs wasn't provided
|
|
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
|
|
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
|
|
}
|
|
|
|
if (data.contains("lora")) {
|
|
if (data.at("lora").is_array()) {
|
|
params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
|
|
} else {
|
|
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
|
|
}
|
|
} else {
|
|
params.lora = params_base.lora_adapters;
|
|
}
|
|
|
|
// TODO: add more sanity checks for the input parameters
|
|
|
|
if (params.sampling.penalty_last_n < -1) {
|
|
throw std::runtime_error("Error: repeat_last_n must be >= -1");
|
|
}
|
|
|
|
if (params.sampling.dry_penalty_last_n < -1) {
|
|
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
|
|
}
|
|
|
|
if (params.sampling.penalty_last_n == -1) {
|
|
// note: should be the slot's context and not the full context, but it's ok
|
|
params.sampling.penalty_last_n = llama_n_ctx(ctx);
|
|
}
|
|
|
|
if (params.sampling.dry_penalty_last_n == -1) {
|
|
params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
|
|
}
|
|
|
|
if (params.sampling.dry_base < 1.0f) {
|
|
params.sampling.dry_base = defaults.sampling.dry_base;
|
|
}
|
|
|
|
// sequence breakers for DRY
|
|
{
|
|
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
|
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
|
|
|
if (data.contains("dry_sequence_breakers")) {
|
|
params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
|
if (params.sampling.dry_sequence_breakers.empty()) {
|
|
throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
|
|
}
|
|
}
|
|
}
|
|
|
|
// process "json_schema" and "grammar"
|
|
if (data.contains("json_schema") && !data.contains("grammar")) {
|
|
try {
|
|
auto schema = json_value(data, "json_schema", json::object());
|
|
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
|
|
params.sampling.grammar = json_schema_to_grammar(schema);
|
|
SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
|
|
} catch (const std::exception & e) {
|
|
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
|
}
|
|
} else {
|
|
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
|
SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
|
|
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
|
|
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
|
|
}
|
|
|
|
{
|
|
auto it = data.find("chat_format");
|
|
if (it != data.end()) {
|
|
params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
|
|
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format));
|
|
} else {
|
|
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
|
|
}
|
|
common_reasoning_format reasoning_format = params_base.reasoning_format;
|
|
if (data.contains("reasoning_format")) {
|
|
reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
|
|
}
|
|
params.oaicompat_chat_syntax.reasoning_format = reasoning_format;
|
|
params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
|
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
|
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
|
}
|
|
|
|
{
|
|
const auto preserved_tokens = data.find("preserved_tokens");
|
|
if (preserved_tokens != data.end()) {
|
|
for (const auto & t : *preserved_tokens) {
|
|
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
|
|
if (ids.size() == 1) {
|
|
SRV_DBG("Preserved token: %d\n", ids[0]);
|
|
params.sampling.preserved_tokens.insert(ids[0]);
|
|
} else {
|
|
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
|
|
SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
|
|
}
|
|
}
|
|
}
|
|
const auto grammar_triggers = data.find("grammar_triggers");
|
|
if (grammar_triggers != data.end()) {
|
|
for (const auto & t : *grammar_triggers) {
|
|
server_grammar_trigger ct(t);
|
|
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
|
const auto & word = ct.value.value;
|
|
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
|
|
if (ids.size() == 1) {
|
|
auto token = ids[0];
|
|
if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
|
|
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
|
|
}
|
|
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
|
|
common_grammar_trigger trigger;
|
|
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
|
trigger.value = word;
|
|
trigger.token = token;
|
|
params.sampling.grammar_triggers.push_back(std::move(trigger));
|
|
} else {
|
|
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
|
|
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
|
}
|
|
} else {
|
|
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
|
|
SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
|
|
} else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
|
|
SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
|
|
} else {
|
|
throw std::runtime_error("Unknown grammar trigger type");
|
|
}
|
|
params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
|
|
}
|
|
}
|
|
}
|
|
if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
|
|
throw std::runtime_error("Error: no triggers set for lazy grammar!");
|
|
}
|
|
}
|
|
|
|
{
|
|
params.sampling.logit_bias.clear();
|
|
|
|
const auto & logit_bias = data.find("logit_bias");
|
|
if (logit_bias != data.end() && logit_bias->is_array()) {
|
|
const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
for (const auto & el : *logit_bias) {
|
|
// TODO: we may want to throw errors here, in case "el" is incorrect
|
|
if (el.is_array() && el.size() == 2) {
|
|
float bias;
|
|
if (el[1].is_number()) {
|
|
bias = el[1].get<float>();
|
|
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
|
bias = -INFINITY;
|
|
} else {
|
|
continue;
|
|
}
|
|
|
|
if (el[0].is_number_integer()) {
|
|
llama_token tok = el[0].get<llama_token>();
|
|
if (tok >= 0 && tok < n_vocab) {
|
|
params.sampling.logit_bias.push_back({tok, bias});
|
|
}
|
|
} else if (el[0].is_string()) {
|
|
auto toks = common_tokenize(vocab, el[0].get<std::string>(), false);
|
|
for (auto tok : toks) {
|
|
params.sampling.logit_bias.push_back({tok, bias});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else if (logit_bias != data.end() && logit_bias->is_object()) {
|
|
const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
for (const auto & el : logit_bias->items()) {
|
|
float bias;
|
|
const auto & key = el.key();
|
|
const auto & value = el.value();
|
|
if (value.is_number()) {
|
|
bias = value.get<float>();
|
|
} else if (value.is_boolean() && !value.get<bool>()) {
|
|
bias = -INFINITY;
|
|
} else {
|
|
continue;
|
|
}
|
|
|
|
char *end;
|
|
llama_token tok = strtol(key.c_str(), &end, 10);
|
|
if (*end == 0) {
|
|
if (tok >= 0 && tok < n_vocab) {
|
|
params.sampling.logit_bias.push_back({tok, bias});
|
|
}
|
|
} else {
|
|
auto toks = common_tokenize(vocab, key, false);
|
|
for (auto tok : toks) {
|
|
params.sampling.logit_bias.push_back({tok, bias});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
|
|
if (params.sampling.ignore_eos) {
|
|
params.sampling.logit_bias.insert(
|
|
params.sampling.logit_bias.end(),
|
|
defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
|
|
}
|
|
}
|
|
|
|
{
|
|
params.antiprompt.clear();
|
|
|
|
const auto & stop = data.find("stop");
|
|
if (stop != data.end() && stop->is_array()) {
|
|
for (const auto & word : *stop) {
|
|
if (!word.empty()) {
|
|
params.antiprompt.push_back(word);
|
|
}
|
|
}
|
|
}
|
|
// set reverse prompt from cli args if not set in the request
|
|
if (params.antiprompt.empty()) {
|
|
params.antiprompt = defaults.antiprompt;
|
|
}
|
|
}
|
|
|
|
{
|
|
const auto samplers = data.find("samplers");
|
|
if (samplers != data.end()) {
|
|
if (samplers->is_array()) {
|
|
params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
|
|
} else if (samplers->is_string()){
|
|
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
|
|
}
|
|
} else {
|
|
params.sampling.samplers = defaults.sampling.samplers;
|
|
}
|
|
}
|
|
|
|
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
|
|
params.oaicompat_model = json_value(data, "model", model_name);
|
|
|
|
return params;
|
|
}
|
|
|
|
//
|
|
// result_timings
|
|
//
|
|
|
|
json result_timings::to_json() const {
|
|
json base = {
|
|
{"cache_n", cache_n},
|
|
|
|
{"prompt_n", prompt_n},
|
|
{"prompt_ms", prompt_ms},
|
|
{"prompt_per_token_ms", prompt_per_token_ms},
|
|
{"prompt_per_second", prompt_per_second},
|
|
|
|
{"predicted_n", predicted_n},
|
|
{"predicted_ms", predicted_ms},
|
|
{"predicted_per_token_ms", predicted_per_token_ms},
|
|
{"predicted_per_second", predicted_per_second},
|
|
};
|
|
|
|
if (draft_n > 0) {
|
|
base["draft_n"] = draft_n;
|
|
base["draft_n_accepted"] = draft_n_accepted;
|
|
}
|
|
|
|
return base;
|
|
}
|
|
|
|
//
|
|
// result_prompt_progress
|
|
//
|
|
json result_prompt_progress::to_json() const {
|
|
return json {
|
|
{"total", total},
|
|
{"cache", cache},
|
|
{"processed", processed},
|
|
{"time_ms", time_ms},
|
|
};
|
|
}
|
|
|
|
static inline std::string stop_type_to_str(stop_type type) {
|
|
switch (type) {
|
|
case STOP_TYPE_EOS: return "eos";
|
|
case STOP_TYPE_WORD: return "word";
|
|
case STOP_TYPE_LIMIT: return "limit";
|
|
default: return "none";
|
|
}
|
|
}
|
|
|
|
//
|
|
// completion_token_output
|
|
//
|
|
|
|
json completion_token_output::to_json(bool post_sampling_probs) const {
|
|
json probs_for_token = json::array();
|
|
for (const auto & p : probs) {
|
|
std::string txt(p.txt);
|
|
txt.resize(validate_utf8(txt));
|
|
probs_for_token.push_back(json {
|
|
{"id", p.tok},
|
|
{"token", txt},
|
|
{"bytes", str_to_bytes(p.txt)},
|
|
{
|
|
post_sampling_probs ? "prob" : "logprob",
|
|
post_sampling_probs ? p.prob : logarithm(p.prob)
|
|
},
|
|
});
|
|
}
|
|
return probs_for_token;
|
|
}
|
|
|
|
json completion_token_output::probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
|
|
json out = json::array();
|
|
for (const auto & p : probs) {
|
|
std::string txt(p.text_to_send);
|
|
txt.resize(validate_utf8(txt));
|
|
out.push_back(json {
|
|
{"id", p.tok},
|
|
{"token", txt},
|
|
{"bytes", str_to_bytes(p.text_to_send)},
|
|
{
|
|
post_sampling_probs ? "prob" : "logprob",
|
|
post_sampling_probs ? p.prob : logarithm(p.prob)
|
|
},
|
|
{
|
|
post_sampling_probs ? "top_probs" : "top_logprobs",
|
|
p.to_json(post_sampling_probs)
|
|
},
|
|
});
|
|
}
|
|
return out;
|
|
}
|
|
|
|
float completion_token_output::logarithm(float x) {
|
|
// nlohmann::json converts -inf to null, so we need to prevent that
|
|
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
|
|
}
|
|
|
|
std::vector<unsigned char> completion_token_output::str_to_bytes(const std::string & str) {
|
|
std::vector<unsigned char> bytes;
|
|
for (unsigned char c : str) {
|
|
bytes.push_back(c);
|
|
}
|
|
return bytes;
|
|
}
|
|
|
|
//
|
|
// server_task_result_cmpl_final
|
|
//
|
|
json server_task_result_cmpl_final::to_json() {
|
|
switch (oaicompat) {
|
|
case OAICOMPAT_TYPE_NONE:
|
|
return to_json_non_oaicompat();
|
|
case OAICOMPAT_TYPE_COMPLETION:
|
|
return to_json_oaicompat();
|
|
case OAICOMPAT_TYPE_CHAT:
|
|
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
|
|
default:
|
|
GGML_ASSERT(false && "Invalid oaicompat_type");
|
|
}
|
|
}
|
|
|
|
json server_task_result_cmpl_final::to_json_non_oaicompat() {
|
|
json res = json {
|
|
{"index", index},
|
|
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
|
{"tokens", stream ? llama_tokens {} : tokens},
|
|
{"id_slot", id_slot},
|
|
{"stop", true},
|
|
{"model", oaicompat_model},
|
|
{"tokens_predicted", n_decoded},
|
|
{"tokens_evaluated", n_prompt_tokens},
|
|
{"generation_settings", generation_params.to_json()},
|
|
{"prompt", prompt},
|
|
{"has_new_line", has_new_line},
|
|
{"truncated", truncated},
|
|
{"stop_type", stop_type_to_str(stop)},
|
|
{"stopping_word", stopping_word},
|
|
{"tokens_cached", n_tokens_cached},
|
|
{"timings", timings.to_json()},
|
|
};
|
|
if (!stream && !probs_output.empty()) {
|
|
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
|
}
|
|
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
|
}
|
|
|
|
json server_task_result_cmpl_final::to_json_oaicompat() {
|
|
std::time_t t = std::time(0);
|
|
json logprobs = json(nullptr); // OAI default to null
|
|
if (!stream && probs_output.size() > 0) {
|
|
logprobs = json{
|
|
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
|
};
|
|
}
|
|
json finish_reason = "length";
|
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
|
finish_reason = "stop";
|
|
}
|
|
json res = json {
|
|
{"choices", json::array({
|
|
json{
|
|
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
|
{"index", index},
|
|
{"logprobs", logprobs},
|
|
{"finish_reason", finish_reason},
|
|
}
|
|
})},
|
|
{"created", t},
|
|
{"model", oaicompat_model},
|
|
{"system_fingerprint", build_info},
|
|
{"object", "text_completion"},
|
|
{"usage", json {
|
|
{"completion_tokens", n_decoded},
|
|
{"prompt_tokens", n_prompt_tokens},
|
|
{"total_tokens", n_decoded + n_prompt_tokens}
|
|
}},
|
|
{"id", oaicompat_cmpl_id}
|
|
};
|
|
|
|
// extra fields for debugging purposes
|
|
if (verbose) {
|
|
res["__verbose"] = to_json_non_oaicompat();
|
|
}
|
|
if (timings.prompt_n >= 0) {
|
|
res.push_back({"timings", timings.to_json()});
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
json server_task_result_cmpl_final::to_json_oaicompat_chat() {
|
|
std::string finish_reason = "length";
|
|
common_chat_msg msg;
|
|
if (!oaicompat_msg.empty()) {
|
|
msg = oaicompat_msg;
|
|
} else {
|
|
msg.role = "assistant";
|
|
msg.content = content;
|
|
}
|
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
|
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
|
}
|
|
|
|
json choice {
|
|
{"finish_reason", finish_reason},
|
|
{"index", 0},
|
|
{"message", msg.to_json_oaicompat<json>()},
|
|
};
|
|
|
|
if (!stream && probs_output.size() > 0) {
|
|
choice["logprobs"] = json{
|
|
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
|
};
|
|
}
|
|
|
|
std::time_t t = std::time(0);
|
|
|
|
json res = json {
|
|
{"choices", json::array({choice})},
|
|
{"created", t},
|
|
{"model", oaicompat_model},
|
|
{"system_fingerprint", build_info},
|
|
{"object", "chat.completion"},
|
|
{"usage", json {
|
|
{"completion_tokens", n_decoded},
|
|
{"prompt_tokens", n_prompt_tokens},
|
|
{"total_tokens", n_decoded + n_prompt_tokens}
|
|
}},
|
|
{"id", oaicompat_cmpl_id}
|
|
};
|
|
|
|
// extra fields for debugging purposes
|
|
if (verbose) {
|
|
res["__verbose"] = to_json_non_oaicompat();
|
|
}
|
|
if (timings.prompt_n >= 0) {
|
|
res.push_back({"timings", timings.to_json()});
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
|
std::time_t t = std::time(0);
|
|
std::string finish_reason = "length";
|
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
|
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
|
|
}
|
|
|
|
json deltas = json::array();
|
|
for (const auto & diff : oaicompat_msg_diffs) {
|
|
deltas.push_back({
|
|
{"choices", json::array({
|
|
json {
|
|
{"finish_reason", nullptr},
|
|
{"index", 0},
|
|
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
|
|
},
|
|
})},
|
|
{"created", t},
|
|
{"id", oaicompat_cmpl_id},
|
|
{"model", oaicompat_model},
|
|
{"system_fingerprint", build_info},
|
|
{"object", "chat.completion.chunk"},
|
|
});
|
|
}
|
|
|
|
deltas.push_back({
|
|
{"choices", json::array({
|
|
json {
|
|
{"finish_reason", finish_reason},
|
|
{"index", 0},
|
|
{"delta", json::object()},
|
|
},
|
|
})},
|
|
{"created", t},
|
|
{"id", oaicompat_cmpl_id},
|
|
{"model", oaicompat_model},
|
|
{"system_fingerprint", build_info},
|
|
{"object", "chat.completion.chunk"},
|
|
});
|
|
|
|
if (include_usage) {
|
|
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
|
|
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
|
|
deltas.push_back({
|
|
{"choices", json::array()},
|
|
{"created", t},
|
|
{"id", oaicompat_cmpl_id},
|
|
{"model", oaicompat_model},
|
|
{"system_fingerprint", build_info},
|
|
{"object", "chat.completion.chunk"},
|
|
{"usage", json {
|
|
{"completion_tokens", n_decoded},
|
|
{"prompt_tokens", n_prompt_tokens},
|
|
{"total_tokens", n_decoded + n_prompt_tokens},
|
|
}},
|
|
});
|
|
}
|
|
|
|
if (timings.prompt_n >= 0) {
|
|
deltas.back().push_back({"timings", timings.to_json()});
|
|
}
|
|
|
|
// extra fields for debugging purposes
|
|
if (verbose && !deltas.empty()) {
|
|
deltas.front()["__verbose"] = to_json_non_oaicompat();
|
|
}
|
|
|
|
return deltas;
|
|
}
|
|
|
|
//
|
|
// server_task_result_cmpl_partial
|
|
//
|
|
json server_task_result_cmpl_partial::to_json() {
|
|
switch (oaicompat) {
|
|
case OAICOMPAT_TYPE_NONE:
|
|
return to_json_non_oaicompat();
|
|
case OAICOMPAT_TYPE_COMPLETION:
|
|
return to_json_oaicompat();
|
|
case OAICOMPAT_TYPE_CHAT:
|
|
return to_json_oaicompat_chat();
|
|
default:
|
|
GGML_ASSERT(false && "Invalid oaicompat_type");
|
|
}
|
|
}
|
|
|
|
json server_task_result_cmpl_partial::to_json_non_oaicompat() {
|
|
// non-OAI-compat JSON
|
|
json res = json {
|
|
{"index", index},
|
|
{"content", content},
|
|
{"tokens", tokens},
|
|
{"stop", false},
|
|
{"id_slot", id_slot},
|
|
{"tokens_predicted", n_decoded},
|
|
{"tokens_evaluated", n_prompt_tokens},
|
|
};
|
|
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
|
|
if (timings.prompt_n > 0) {
|
|
res.push_back({"timings", timings.to_json()});
|
|
}
|
|
if (is_progress) {
|
|
res.push_back({"prompt_progress", progress.to_json()});
|
|
}
|
|
if (!prob_output.probs.empty()) {
|
|
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
json server_task_result_cmpl_partial::to_json_oaicompat() {
|
|
std::time_t t = std::time(0);
|
|
json logprobs = json(nullptr); // OAI default to null
|
|
if (prob_output.probs.size() > 0) {
|
|
logprobs = json{
|
|
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
|
};
|
|
}
|
|
json res = json {
|
|
{"choices", json::array({
|
|
json{
|
|
{"text", content},
|
|
{"index", index},
|
|
{"logprobs", logprobs},
|
|
{"finish_reason", nullptr},
|
|
}
|
|
})},
|
|
{"created", t},
|
|
{"model", oaicompat_model},
|
|
{"system_fingerprint", build_info},
|
|
{"object", "text_completion"},
|
|
{"id", oaicompat_cmpl_id}
|
|
};
|
|
|
|
// extra fields for debugging purposes
|
|
if (verbose) {
|
|
res["__verbose"] = to_json_non_oaicompat();
|
|
}
|
|
if (timings.prompt_n >= 0) {
|
|
res.push_back({"timings", timings.to_json()});
|
|
}
|
|
if (is_progress) {
|
|
res.push_back({"prompt_progress", progress.to_json()});
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
|
|
bool first = n_decoded == 1;
|
|
std::time_t t = std::time(0);
|
|
json choices;
|
|
|
|
std::vector<json> deltas;
|
|
auto add_delta = [&](const json & delta) {
|
|
deltas.push_back({
|
|
{"choices", json::array({
|
|
json {
|
|
{"finish_reason", nullptr},
|
|
{"index", 0},
|
|
{"delta", delta},
|
|
},
|
|
})},
|
|
{"created", t},
|
|
{"id", oaicompat_cmpl_id},
|
|
{"model", oaicompat_model},
|
|
{"system_fingerprint", build_info},
|
|
{"object", "chat.completion.chunk"},
|
|
});
|
|
};
|
|
// We have to send an initial update to conform to openai behavior
|
|
if (first || is_progress) {
|
|
add_delta({
|
|
{"role", "assistant"},
|
|
{"content", nullptr},
|
|
});
|
|
}
|
|
|
|
for (const auto & diff : oaicompat_msg_diffs) {
|
|
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
|
|
}
|
|
|
|
if (!deltas.empty()) {
|
|
auto & last_json = deltas[deltas.size() - 1];
|
|
GGML_ASSERT(last_json.at("choices").size() >= 1);
|
|
|
|
if (prob_output.probs.size() > 0) {
|
|
last_json.at("choices").at(0)["logprobs"] = json {
|
|
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
|
};
|
|
}
|
|
|
|
if (timings.prompt_n >= 0) {
|
|
last_json.push_back({"timings", timings.to_json()});
|
|
}
|
|
if (is_progress) {
|
|
last_json.push_back({"prompt_progress", progress.to_json()});
|
|
}
|
|
}
|
|
|
|
return deltas;
|
|
}
|
|
|
|
//
|
|
// server_task_result_embd
|
|
//
|
|
json server_task_result_embd::to_json() {
|
|
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
|
? to_json_oaicompat()
|
|
: to_json_non_oaicompat();
|
|
}
|
|
|
|
json server_task_result_embd::to_json_non_oaicompat() {
|
|
return json {
|
|
{"index", index},
|
|
{"embedding", embedding},
|
|
};
|
|
}
|
|
|
|
json server_task_result_embd::to_json_oaicompat() {
|
|
return json {
|
|
{"index", index},
|
|
{"embedding", embedding[0]},
|
|
{"tokens_evaluated", n_tokens},
|
|
};
|
|
}
|
|
|
|
//
|
|
// server_task_result_rerank
|
|
//
|
|
json server_task_result_rerank::to_json() {
|
|
return json {
|
|
{"index", index},
|
|
{"score", score},
|
|
{"tokens_evaluated", n_tokens},
|
|
};
|
|
}
|
|
|
|
//
|
|
// server_task_result_error
|
|
//
|
|
json server_task_result_error::to_json() {
|
|
json res = format_error_response(err_msg, err_type);
|
|
if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
|
|
res["n_prompt_tokens"] = n_prompt_tokens;
|
|
res["n_ctx"] = n_ctx;
|
|
}
|
|
return res;
|
|
}
|
|
|
|
//
|
|
// server_task_result_metrics
|
|
//
|
|
json server_task_result_metrics::to_json() {
|
|
return json {
|
|
{ "idle", n_idle_slots },
|
|
{ "processing", n_processing_slots },
|
|
{ "deferred", n_tasks_deferred },
|
|
{ "t_start", t_start },
|
|
|
|
{ "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total },
|
|
{ "t_tokens_generation_total", t_tokens_generation_total },
|
|
{ "n_tokens_predicted_total", n_tokens_predicted_total },
|
|
{ "t_prompt_processing_total", t_prompt_processing_total },
|
|
|
|
{ "n_tokens_max", n_tokens_max },
|
|
|
|
{ "n_prompt_tokens_processed", n_prompt_tokens_processed },
|
|
{ "t_prompt_processing", t_prompt_processing },
|
|
{ "n_tokens_predicted", n_tokens_predicted },
|
|
{ "t_tokens_generation", t_tokens_generation },
|
|
|
|
{ "n_decode_total", n_decode_total },
|
|
{ "n_busy_slots_total", n_busy_slots_total },
|
|
|
|
{ "slots", slots_data },
|
|
};
|
|
}
|
|
|
|
//
|
|
// server_task_result_slot_save_load
|
|
//
|
|
json server_task_result_slot_save_load::to_json() {
|
|
if (is_save) {
|
|
return json {
|
|
{ "id_slot", id_slot },
|
|
{ "filename", filename },
|
|
{ "n_saved", n_tokens },
|
|
{ "n_written", n_bytes },
|
|
{ "timings", {
|
|
{ "save_ms", t_ms }
|
|
}},
|
|
};
|
|
}
|
|
|
|
return json {
|
|
{ "id_slot", id_slot },
|
|
{ "filename", filename },
|
|
{ "n_restored", n_tokens },
|
|
{ "n_read", n_bytes },
|
|
{ "timings", {
|
|
{ "restore_ms", t_ms }
|
|
}},
|
|
};
|
|
}
|
|
|
|
//
|
|
// server_task_result_slot_erase
|
|
//
|
|
json server_task_result_slot_erase::to_json() {
|
|
return json {
|
|
{ "id_slot", id_slot },
|
|
{ "n_erased", n_erased },
|
|
};
|
|
}
|
|
|
|
//
|
|
// server_task_result_apply_lora
|
|
//
|
|
|
|
json server_task_result_apply_lora::to_json() {
|
|
return json {{ "success", true }};
|
|
}
|
|
|
|
//
|
|
// server_prompt_cache
|
|
//
|
|
size_t server_prompt_cache::size() const {
|
|
size_t res = 0;
|
|
|
|
for (const auto & state : states) {
|
|
res += state.size();
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
size_t server_prompt_cache::n_tokens() const {
|
|
size_t res = 0;
|
|
|
|
for (const auto & state : states) {
|
|
res += state.n_tokens();
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
|
|
// first check if the current state is contained fully in the cache
|
|
for (auto it = states.begin(); it != states.end(); ++it) {
|
|
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
|
|
|
|
if (cur_lcp_len == (int) prompt.tokens.size()) {
|
|
SRV_WRN("%s", " - prompt is already in the cache, skipping\n");
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
// next, remove any cached prompts that are fully contained in the current prompt
|
|
for (auto it = states.begin(); it != states.end();) {
|
|
const int len = it->tokens.get_common_prefix(prompt.tokens);
|
|
|
|
if (len == (int) it->tokens.size()) {
|
|
SRV_WRN(" - removing obsolete cached prompt with length %d\n", len);
|
|
|
|
it = states.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
|
|
std::vector<uint8_t> state_data;
|
|
|
|
// check if we can allocate enough memory for the new state
|
|
try {
|
|
state_data.resize(state_size);
|
|
} catch (const std::bad_alloc & e) {
|
|
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
|
|
|
|
limit_size = std::max<size_t>(1, 0.4*size());
|
|
|
|
SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0));
|
|
|
|
update();
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
|
|
auto & cur = states.emplace_back();
|
|
cur = {
|
|
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
|
|
/*.data =*/ std::move(state_data),
|
|
/*.checkpoints =*/ prompt.checkpoints,
|
|
};
|
|
|
|
return &cur;
|
|
}
|
|
|
|
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
|
|
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
|
|
|
|
float f_keep_best = float(lcp_best) / prompt.tokens.size();
|
|
float sim_best = float(lcp_best) / tokens_new.size();
|
|
|
|
SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
|
|
|
auto it_best = states.end();
|
|
|
|
// find the most similar cached prompt, that would also preserve the most context
|
|
for (auto it = states.begin(); it != states.end(); ++it) {
|
|
const int lcp_cur = it->tokens.get_common_prefix(tokens_new);
|
|
|
|
const float f_keep_cur = float(lcp_cur) / it->tokens.size();
|
|
const float sim_cur = float(lcp_cur) / tokens_new.size();
|
|
|
|
// don't trash large prompts
|
|
if (f_keep_cur < 0.25f) {
|
|
continue;
|
|
}
|
|
|
|
if (f_keep_best < f_keep_cur && sim_best < sim_cur) {
|
|
f_keep_best = f_keep_cur;
|
|
sim_best = sim_cur;
|
|
|
|
it_best = it;
|
|
}
|
|
}
|
|
|
|
if (it_best != states.end()) {
|
|
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
|
|
|
const size_t size = it_best->data.size();
|
|
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
|
|
if (n != size) {
|
|
SRV_WRN("failed to restore state with size %zu\n", size);
|
|
|
|
return false;
|
|
}
|
|
|
|
it_best->data.clear();
|
|
it_best->data.shrink_to_fit();
|
|
|
|
prompt = std::move(*it_best);
|
|
|
|
states.erase(it_best);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void server_prompt_cache::update() {
|
|
if (limit_size > 0) {
|
|
// always keep at least one state, regardless of the limits
|
|
while (states.size() > 1 && size() > limit_size) {
|
|
if (states.empty()) {
|
|
break;
|
|
}
|
|
|
|
SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
|
|
|
|
states.pop_front();
|
|
}
|
|
}
|
|
|
|
// average size per token
|
|
const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
|
|
|
|
// dynamically increase the token limit if it can fit in the memory limit
|
|
const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size/size_per_token) : limit_tokens;
|
|
|
|
if (limit_tokens > 0) {
|
|
while (states.size() > 1 && n_tokens() > limit_tokens_cur) {
|
|
if (states.empty()) {
|
|
break;
|
|
}
|
|
|
|
SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n",
|
|
limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0));
|
|
|
|
states.pop_front();
|
|
}
|
|
}
|
|
|
|
SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
|
|
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
|
|
|
|
for (const auto & state : states) {
|
|
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
|
|
(const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
|
|
}
|
|
}
|