mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-14 21:14:10 +00:00
sampling : remove backend-dist option (wip)
This commit removes the `--backend-dist` option and instead uses the
configured --samplers chain to determine which samplers run on the
backend.
Backend sampling is still enabled using With `--backend_sampling`, and
the sampler chain, either explictly specified using `--samplers` or the
default, is automatically analyzed to determine which samplers can run
on the backend. The system finds the longest contiguous chain of
backend supported samplers from the start of the sampler sequence.
For example:
* If the chain is `top-k -> temperature -> top-p`, and both `top-k` and
`temperature` are backend-supported but `top-p` is not, then `top-k`
and `temperature` will run on the backend, while `top-p` and
subsequent samplers run on the CPU.
* If all configured samplers are supported, the final distribution
sampling will also happen on the backend, transferring only the
sampled token IDs back to the host.
* If the sampler chain starts with an unsupported sampler (e.g.,
`penalties`), all sampling runs on the CPU. Note that this is
currently the case with the default sampler so to use backend sampling
it is required to specify a sampler chain. See below for an example.
The following shows how llama-cli can be run with backend sampling:
```console
$ llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
--prompt 'What is the capital of Sweden?' \
-n 20 \
-no-cnv \
--verbose-prompt \
-ngl 40 \
--backend-sampling \
--samplers 'top_k;temperature'
```
In this case the all sampling will happen on the backend since both
`top_k` and `temperature` are supported backend samplers.
To enable a partial backend sampling (hybrid sampling), for example
running `top_k` and `temperature` on the backend and `typ_p` on the CPU
the following sampler chain could be specified:
```console
$ llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
--prompt 'What is the capital of Sweden?' \
-n 20 \
-no-cnv \
--verbose-prompt \
-ngl 40 \
--backend-sampling \
--samplers 'top_k;temperature;top_p'
```
If this looks good then I'll follow up with updates the llama-cli and
llama-server documentation to reflect these changes.
This commit is contained in:
@@ -105,7 +105,8 @@ struct common_sampler {
|
||||
common_params_sampling params;
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
struct llama_sampler * chain;
|
||||
struct llama_sampler * chain; // CPU sampling chain
|
||||
struct llama_sampler * backend_chain; // Backend sampling chain
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
|
||||
@@ -118,6 +119,9 @@ struct common_sampler {
|
||||
|
||||
llama_sampler_reset(grmr);
|
||||
llama_sampler_reset(chain);
|
||||
if (backend_chain) {
|
||||
llama_sampler_reset(backend_chain);
|
||||
}
|
||||
}
|
||||
|
||||
void set_logits(struct llama_context * ctx, int idx) {
|
||||
@@ -165,6 +169,20 @@ static bool sampler_enabled(const struct common_params_sampling & params, enum c
|
||||
return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end();
|
||||
}
|
||||
|
||||
static bool sampler_backend_supported(enum common_sampler_type type) {
|
||||
switch (type) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool has_logit_bias(const struct common_params_sampling & params) {
|
||||
return !params.logit_bias.empty();
|
||||
}
|
||||
|
||||
std::string common_params_sampling::print() const {
|
||||
char result[1024];
|
||||
|
||||
@@ -249,22 +267,86 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||
/* .backend_chain = */ nullptr,
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
};
|
||||
|
||||
llama_sampler_chain_add(result->chain,
|
||||
llama_sampler_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
size_t backend_sampler_count = 0;
|
||||
if (params.backend_sampling && params.mirostat == 0) {
|
||||
if (has_logit_bias(params)) {
|
||||
backend_sampler_count++;
|
||||
}
|
||||
|
||||
// Find the longest contiguous chain of backend-supported samplers from the start
|
||||
for (const auto & sampler_type : params.samplers) {
|
||||
if (sampler_backend_supported(sampler_type)) {
|
||||
backend_sampler_count++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the samplers combination is supported then we can build the backend chain.
|
||||
if (backend_sampler_count > 0 || (params.backend_sampling && has_logit_bias(params))) {
|
||||
llama_sampler_chain_params backend_params = llama_sampler_chain_default_params();
|
||||
backend_params.no_perf = params.no_perf;
|
||||
result->backend_chain = llama_sampler_chain_init(backend_params);
|
||||
|
||||
if (has_logit_bias(params)) {
|
||||
llama_sampler_chain_add(result->backend_chain,
|
||||
llama_sampler_backend_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
}
|
||||
|
||||
size_t backend_idx = 0;
|
||||
for (const auto & sampler_type : params.samplers) {
|
||||
if (backend_idx >= backend_sampler_count - has_logit_bias(params)) {
|
||||
break;
|
||||
}
|
||||
|
||||
switch (sampler_type) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_top_k(params.top_k));
|
||||
}
|
||||
backend_idx++;
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
if (params.temp > 0.0f) {
|
||||
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_temp(params.temp));
|
||||
}
|
||||
backend_idx++;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported backend sampler");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params);
|
||||
bool cpu_has_samplers = cpu_start_idx < params.samplers.size();
|
||||
|
||||
// Build CPU chain
|
||||
if (!params.backend_sampling || !has_logit_bias(params)) {
|
||||
llama_sampler_chain_add(result->chain,
|
||||
llama_sampler_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
}
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
// Add remaining CPU samplers
|
||||
for (size_t i = cpu_start_idx; i < params.samplers.size(); i++) {
|
||||
const auto & cnstr = params.samplers[i];
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
@@ -308,7 +390,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
|
||||
// If all samplers are on backend, add dist to backend; otherwise add to CPU
|
||||
if (result->backend_chain && !cpu_has_samplers) {
|
||||
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_dist(params.seed));
|
||||
} else {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
}
|
||||
} else if (params.mirostat == 1) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
@@ -323,36 +411,74 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
}
|
||||
|
||||
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||
if (!params.backend_sampling) {
|
||||
if (!params.backend_sampling || params.mirostat != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// Determine the split point for backend sampling using the same logic as common_sampler_init
|
||||
size_t backend_sampler_count = 0;
|
||||
if (has_logit_bias(params)) {
|
||||
backend_sampler_count++;
|
||||
}
|
||||
|
||||
// Find the longest contiguous chain of backend-supported samplers from the start
|
||||
for (const auto & sampler_type : params.samplers) {
|
||||
if (sampler_backend_supported(sampler_type)) {
|
||||
backend_sampler_count++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (backend_sampler_count == 0 && !has_logit_bias(params)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
chain_params.no_perf = params.no_perf;
|
||||
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
|
||||
const bool enable_temp = params.temp > 0.0f && sampler_enabled(params, COMMON_SAMPLER_TYPE_TEMPERATURE);
|
||||
const bool enable_top_k = params.top_k > 0 && sampler_enabled(params, COMMON_SAMPLER_TYPE_TOP_K);
|
||||
const bool enable_dist = params.backend_dist;
|
||||
|
||||
if (!params.logit_bias.empty()) {
|
||||
// Add logit_bias to backend chain if present
|
||||
if (has_logit_bias(params)) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
}
|
||||
|
||||
if (enable_temp) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
|
||||
size_t backend_idx = 0;
|
||||
for (const auto & sampler_type : params.samplers) {
|
||||
if (backend_idx >= backend_sampler_count - has_logit_bias(params)) {
|
||||
break;
|
||||
}
|
||||
|
||||
switch (sampler_type) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
|
||||
}
|
||||
backend_idx++;
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
if (params.temp > 0.0f) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
|
||||
}
|
||||
backend_idx++;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported backend sampler");
|
||||
}
|
||||
}
|
||||
|
||||
if (enable_top_k) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
|
||||
}
|
||||
// Determine if we should add dist sampler to backend chain
|
||||
// Only add it if all samplers from params.samplers are on the backend
|
||||
size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params);
|
||||
bool cpu_has_samplers = cpu_start_idx < params.samplers.size();
|
||||
|
||||
if (enable_dist) {
|
||||
if (!cpu_has_samplers) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed));
|
||||
}
|
||||
|
||||
@@ -362,9 +488,12 @@ struct llama_sampler * common_sampler_backend_init(const struct llama_model * mo
|
||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
if (gsmpl) {
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
|
||||
if (gsmpl->backend_chain) {
|
||||
llama_sampler_free(gsmpl->backend_chain);
|
||||
}
|
||||
|
||||
delete gsmpl;
|
||||
}
|
||||
}
|
||||
@@ -387,12 +516,13 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||
|
||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
return new common_sampler {
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
/* .cur_p = */ gsmpl->cur_p,
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .backend_chain = */ gsmpl->backend_chain ? llama_sampler_clone(gsmpl->backend_chain) : nullptr,
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
/* .cur_p = */ gsmpl->cur_p,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user