mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-14 04:54:06 +00:00
common : simplify sampler chain initialization
This commit is contained in:
@@ -163,84 +163,6 @@ struct common_sampler {
|
||||
mutable int64_t t_total_us = 0;
|
||||
};
|
||||
|
||||
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
|
||||
static bool common_sampler_type_has_backend_support(enum common_sampler_type type) {
|
||||
switch (type) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool common_params_sampling::is_disabled(enum common_sampler_type type) const {
|
||||
switch (type) {
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
if (dry_multiplier == 0.0f || dry_base < 1.0f || dry_penalty_last_n == 0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
if (typ_p >= 1.0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
if (top_n_sigma <= 0.0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
if (top_k <= 0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
if (dynatemp_range <= 0.0f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
if (min_p <= 0.0f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
if (top_p >= 1.0f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
if (xtc_probability <= 0.0f || xtc_threshold == 0.50f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void common_params_sampling::filter_disabled() {
|
||||
for (auto it = samplers.begin(); it != samplers.end();) {
|
||||
if (is_disabled(*it)) {
|
||||
LOG_WRN("%s: removing disabled sampler %s\n", __func__, common_sampler_type_to_str(*it).c_str());
|
||||
it = samplers.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string common_params_sampling::print() const {
|
||||
char result[1024];
|
||||
|
||||
@@ -257,7 +179,7 @@ std::string common_params_sampling::print() const {
|
||||
return std::string(result);
|
||||
}
|
||||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||
@@ -324,11 +246,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
|
||||
if (params.backend_sampling) {
|
||||
params.filter_disabled();
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
@@ -339,54 +256,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
/* .cur_p = */ {},
|
||||
};
|
||||
|
||||
size_t idx_smpl = 0;
|
||||
|
||||
bool is_backend = true;
|
||||
|
||||
is_backend = is_backend && params.backend_sampling;
|
||||
is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl]));
|
||||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
if (params.has_logit_bias()) {
|
||||
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain,
|
||||
llama_sampler_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
||||
}
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
// backend samplers are added first
|
||||
while (is_backend && idx_smpl < params.samplers.size()) {
|
||||
const auto & cnstr = params.samplers[idx_smpl++];
|
||||
|
||||
if (!common_sampler_type_has_backend_support(cnstr)) {
|
||||
is_backend = false;
|
||||
--idx_smpl;
|
||||
break;
|
||||
}
|
||||
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_k(params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_temp(params.temp));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_min_p(params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_p(params.top_p, params.min_keep));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported backend sampler");
|
||||
}
|
||||
}
|
||||
|
||||
// Add remaining CPU samplers
|
||||
while (idx_smpl < params.samplers.size()) {
|
||||
const auto & cnstr = params.samplers[idx_smpl++];
|
||||
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
@@ -396,52 +272,63 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
||||
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
samplers.push_back(llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, llama_sampler_init_dist(params.seed));
|
||||
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||
} else if (params.mirostat == 1) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
} else if (params.mirostat == 2) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||
} else {
|
||||
GGML_ASSERT(false && "unknown mirostat version");
|
||||
}
|
||||
|
||||
bool is_backend = params.backend_sampling;
|
||||
|
||||
// split in two chains: backend -> CPU
|
||||
for (auto * smpl : samplers) {
|
||||
if (!smpl->iface->backend_apply) {
|
||||
is_backend = false;
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, smpl);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user