common : simplify sampler chain initialization

This commit is contained in:
Georgi Gerganov
2025-12-01 17:10:32 +02:00
parent 217469f07f
commit 4032ce2378
7 changed files with 171 additions and 178 deletions

View File

@@ -163,84 +163,6 @@ struct common_sampler {
mutable int64_t t_total_us = 0;
};
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
static bool common_sampler_type_has_backend_support(enum common_sampler_type type) {
switch (type) {
case COMMON_SAMPLER_TYPE_TOP_K:
case COMMON_SAMPLER_TYPE_TEMPERATURE:
case COMMON_SAMPLER_TYPE_MIN_P:
case COMMON_SAMPLER_TYPE_TOP_P:
return true;
default:
return false;
}
}
bool common_params_sampling::is_disabled(enum common_sampler_type type) const {
switch (type) {
case COMMON_SAMPLER_TYPE_PENALTIES:
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_DRY:
if (dry_multiplier == 0.0f || dry_base < 1.0f || dry_penalty_last_n == 0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
if (typ_p >= 1.0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
if (top_n_sigma <= 0.0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
if (top_k <= 0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
if (dynatemp_range <= 0.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_MIN_P:
if (min_p <= 0.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_P:
if (top_p >= 1.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_XTC:
if (xtc_probability <= 0.0f || xtc_threshold == 0.50f) {
return true;
}
break;
default:
break;
}
return false;
}
void common_params_sampling::filter_disabled() {
for (auto it = samplers.begin(); it != samplers.end();) {
if (is_disabled(*it)) {
LOG_WRN("%s: removing disabled sampler %s\n", __func__, common_sampler_type_to_str(*it).c_str());
it = samplers.erase(it);
} else {
++it;
}
}
}
std::string common_params_sampling::print() const {
char result[1024];
@@ -257,7 +179,7 @@ std::string common_params_sampling::print() const {
return std::string(result);
}
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
@@ -324,11 +246,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
if (params.backend_sampling) {
params.filter_disabled();
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
@@ -339,54 +256,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
/* .cur_p = */ {},
};
size_t idx_smpl = 0;
bool is_backend = true;
is_backend = is_backend && params.backend_sampling;
is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl]));
std::vector<llama_sampler *> samplers;
if (params.has_logit_bias()) {
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain,
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
}
if (params.mirostat == 0) {
// backend samplers are added first
while (is_backend && idx_smpl < params.samplers.size()) {
const auto & cnstr = params.samplers[idx_smpl++];
if (!common_sampler_type_has_backend_support(cnstr)) {
is_backend = false;
--idx_smpl;
break;
}
switch (cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_k(params.top_k));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_temp(params.temp));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_min_p(params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_p(params.top_p, params.min_keep));
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
}
}
// Add remaining CPU samplers
while (idx_smpl < params.samplers.size()) {
const auto & cnstr = params.samplers[idx_smpl++];
for (const auto & cnstr : params.samplers) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
{
@@ -396,52 +272,63 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
c_breakers.push_back(str.c_str());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
samplers.push_back(llama_sampler_init_top_k (params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
samplers.push_back(llama_sampler_init_infill (vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
}
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, llama_sampler_init_dist(params.seed));
samplers.push_back(llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
} else {
GGML_ASSERT(false && "unknown mirostat version");
}
bool is_backend = params.backend_sampling;
// split in two chains: backend -> CPU
for (auto * smpl : samplers) {
if (!smpl->iface->backend_apply) {
is_backend = false;
}
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, smpl);
}
return result;
}