mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-12 03:54:06 +00:00
llama : add llama_sampling API + move grammar in libllama
ggml-ci
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
#include "llama-sampling.h"
|
||||
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
@@ -21,18 +24,104 @@ static void llama_log_softmax(float * array, size_t size) {
|
||||
}
|
||||
}
|
||||
|
||||
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
|
||||
llama_sampling::llama_sampling(const struct llama_vocab & vocab) : vocab(vocab) {
|
||||
}
|
||||
|
||||
llama_sampling::~llama_sampling() {
|
||||
if (grammar) {
|
||||
llama_grammar_free_impl(grammar);
|
||||
}
|
||||
}
|
||||
|
||||
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) {
|
||||
auto * result = new llama_sampling(vocab);
|
||||
|
||||
result->params = params;
|
||||
|
||||
result->prev = ring_buffer<llama_token>(params.n_prev);
|
||||
|
||||
for (int i = 0; i < params.n_samplers; ++i) {
|
||||
result->samplers.push_back(params.samplers[i]);
|
||||
}
|
||||
|
||||
llama_sampling_set_rng_seed_impl(*result, params.seed);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void llama_sampling_free_impl(struct llama_sampling * sampling) {
|
||||
delete sampling;
|
||||
}
|
||||
|
||||
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) {
|
||||
auto * result = new llama_sampling(smpl.vocab);
|
||||
|
||||
result->params = smpl.params;
|
||||
|
||||
result->grammar_str = smpl.grammar_str;
|
||||
result->grammar_root = smpl.grammar_root;
|
||||
|
||||
result->logit_bias = smpl.logit_bias;
|
||||
|
||||
if (smpl.grammar) {
|
||||
result->grammar = llama_grammar_cp_impl(*smpl.grammar);
|
||||
}
|
||||
|
||||
result->rng = smpl.rng;
|
||||
result->prev = smpl.prev;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void llama_sampling_reset_impl(struct llama_sampling & smpl) {
|
||||
if (smpl.grammar) {
|
||||
llama_grammar_free_impl(smpl.grammar);
|
||||
smpl.grammar = nullptr;
|
||||
}
|
||||
|
||||
if (!smpl.grammar_str.empty()) {
|
||||
smpl.grammar = llama_grammar_init_impl(&smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data());
|
||||
}
|
||||
|
||||
smpl.prev.clear();
|
||||
}
|
||||
|
||||
void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
seed = time(NULL);
|
||||
}
|
||||
|
||||
smpl->rng.seed(seed);
|
||||
smpl.rng.seed(seed);
|
||||
}
|
||||
|
||||
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
GGML_ASSERT(candidates->size > 0);
|
||||
void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) {
|
||||
if (smpl.grammar) {
|
||||
llama_grammar_free_impl(smpl.grammar);
|
||||
smpl.grammar = nullptr;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
smpl.grammar_str = grammar_str;
|
||||
smpl.grammar_root = grammar_root;
|
||||
|
||||
smpl.grammar = llama_grammar_init_impl(&smpl.vocab, grammar_str, grammar_root);
|
||||
} else {
|
||||
smpl.grammar_str.clear();
|
||||
smpl.grammar_root.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
|
||||
smpl.logit_bias.clear();
|
||||
smpl.logit_bias.reserve(n_logit_bias);
|
||||
|
||||
for (int32_t i = 0; i < n_logit_bias; ++i) {
|
||||
smpl.logit_bias.push_back(logit_bias[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampling_softmax_impl(llama_token_data_array * candidates) {
|
||||
GGML_ASSERT(candidates->size > 0);
|
||||
|
||||
// Sort the logits in descending order
|
||||
if (!candidates->sorted) {
|
||||
@@ -44,28 +133,24 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar
|
||||
|
||||
float max_l = candidates->data[0].logit;
|
||||
float cum_sum = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
float p = expf(candidates->data[i].logit - max_l);
|
||||
candidates->data[i].p = p;
|
||||
cum_sum += p;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
candidates->data[i].p /= cum_sum;
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
||||
// if (k >= (int32_t)candidates->size) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
if (k <= 0) {
|
||||
k = candidates->size;
|
||||
}
|
||||
@@ -101,10 +186,12 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
||||
int ib = nbuckets - 1;
|
||||
for ( ; ib >= 0; --ib) {
|
||||
nhave += histo[ib];
|
||||
if (nhave >= k) break;
|
||||
if (nhave >= k) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::vector<llama_token_data> tmp_tokens(nhave);
|
||||
auto ptr = tmp_tokens.data();
|
||||
auto * ptr = tmp_tokens.data();
|
||||
std::vector<llama_token_data*> bucket_ptrs;
|
||||
bucket_ptrs.reserve(nbuckets - ib);
|
||||
for (int j = nbuckets - 1; j >= ib; --j) {
|
||||
@@ -133,20 +220,14 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
||||
candidates->sorted = true;
|
||||
}
|
||||
candidates->size = k;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
void llama_sampling_top_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
if (p >= 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sampling_softmax_impl(candidates);
|
||||
|
||||
// Compute the cumulative probabilities
|
||||
float cum_sum = 0.0f;
|
||||
@@ -165,19 +246,13 @@ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_arra
|
||||
|
||||
// Resize the output vector to keep only the top-p tokens
|
||||
candidates->size = last_idx;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
void llama_sampling_min_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
if (p <= 0.0f || !candidates->size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
bool min_p_applied = false;
|
||||
|
||||
// if the candidates aren't sorted, try the unsorted implementation first
|
||||
@@ -226,19 +301,14 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
|
||||
// Resize the output vector to keep only the matching tokens
|
||||
candidates->size = i;
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
void llama_sampling_tail_free_impl(llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
if (z >= 1.0f || candidates->size <= 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sampling_softmax_impl(candidates);
|
||||
|
||||
// Compute the first and second derivatives
|
||||
std::vector<float> first_derivatives(candidates->size - 1);
|
||||
@@ -285,13 +355,9 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
|
||||
|
||||
// Resize the output vector to keep only the tokens above the tail location
|
||||
candidates->size = last_idx;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
// Reference implementation:
|
||||
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
||||
if (p >= 1.0f) {
|
||||
@@ -299,9 +365,7 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
||||
}
|
||||
|
||||
// Compute the softmax of logits and calculate entropy
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sampling_softmax_impl(candidates);
|
||||
|
||||
float entropy = 0.0f;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
@@ -349,15 +413,9 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
||||
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
|
||||
candidates->size = new_candidates.size();
|
||||
candidates->sorted = false;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
|
||||
// no need to do anything if there is only one (or zero) candidates
|
||||
if(candidates->size <= 1) {
|
||||
return;
|
||||
@@ -366,7 +424,7 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
|
||||
// Calculate maximum possible entropy
|
||||
float max_entropy = -logf(1.0f / candidates->size);
|
||||
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
llama_sampling_softmax_impl(candidates);
|
||||
|
||||
// Calculate entropy of the softmax probabilities
|
||||
float entropy = 0.0f;
|
||||
@@ -398,13 +456,15 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
|
||||
}
|
||||
|
||||
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
||||
double max_l_double = candidates->data[0].logit;
|
||||
const double max_l_double = candidates->data[0].logit;
|
||||
|
||||
double cum_sum_double = 0.0;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
double p = exp(candidates->data[i].logit - max_l_double);
|
||||
candidates->data[i].p = p; // Store the scaled probability
|
||||
cum_sum_double += p;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
|
||||
}
|
||||
@@ -416,44 +476,24 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
|
||||
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
void llama_sampling_temp_impl(llama_token_data_array * candidates, float temp) {
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
candidates->data[i].logit /= temp;
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalties_impl(
|
||||
struct llama_sampling * smpl,
|
||||
void llama_sampling_grammar_impl(llama_token_data_array * candidates, const struct llama_grammar & grammar) {
|
||||
llama_grammar_apply_impl(grammar, candidates);
|
||||
}
|
||||
|
||||
void llama_sampling_penalties_impl(
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t penalty_last_n,
|
||||
float penalty_repeat,
|
||||
float penalty_freq,
|
||||
float penalty_present) {
|
||||
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Create a frequency map to count occurrences of each token in last_tokens
|
||||
std::unordered_map<llama_token, int> token_count;
|
||||
for (size_t i = 0; i < penalty_last_n; ++i) {
|
||||
token_count[last_tokens[i]]++;
|
||||
}
|
||||
|
||||
const llama_token_cnt & token_count,
|
||||
float penalty_repeat,
|
||||
float penalty_freq,
|
||||
float penalty_present) {
|
||||
// Apply frequency and presence penalties to the candidates
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
const auto token_iter = token_count.find(candidates->data[i].id);
|
||||
@@ -475,43 +515,10 @@ void llama_sample_repetition_penalties_impl(
|
||||
}
|
||||
|
||||
candidates->sorted = false;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_apply_guidance_impl(
|
||||
struct llama_sampling * smpl,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
float scale) {
|
||||
GGML_ASSERT(smpl);
|
||||
|
||||
const auto t_start_sample_us = ggml_time_us();
|
||||
const auto n_vocab = smpl->n_vocab;
|
||||
|
||||
llama_log_softmax(logits, n_vocab);
|
||||
llama_log_softmax(logits_guidance, n_vocab);
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
auto & l = logits[i];
|
||||
const auto & g = logits_guidance[i];
|
||||
|
||||
l = scale * (l - g) + g;
|
||||
}
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||
GGML_ASSERT(smpl);
|
||||
|
||||
const int32_t n_vocab = float(smpl->n_vocab);
|
||||
|
||||
int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) {
|
||||
llama_sampling_softmax_impl(candidates);
|
||||
|
||||
// Estimate s_hat using the most probable m tokens
|
||||
float s_hat = 0.0;
|
||||
@@ -527,13 +534,11 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
|
||||
|
||||
// Compute k from the estimated s_hat and target surprise value
|
||||
float epsilon_hat = s_hat - 1;
|
||||
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||
float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||
|
||||
// Sample the next word X using top-k sampling
|
||||
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
llama_token X = llama_sample_token_impl(smpl, candidates);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
llama_sampling_top_k_impl(candidates, int(k), 1);
|
||||
llama_token X = llama_sampling_sample_dist_impl(candidates, rng);
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
@@ -543,93 +548,88 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
mu = mu - eta * e;
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
||||
int64_t t_start_sample_us;
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) {
|
||||
llama_sampling_softmax_impl(candidates);
|
||||
|
||||
// Truncate the words with surprise values greater than mu
|
||||
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return -log2f(candidate.p) > *mu;
|
||||
return -log2f(candidate.p) > mu;
|
||||
}));
|
||||
|
||||
if (candidates->size == 0) {
|
||||
candidates->size = 1;
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
llama_sampling_softmax_impl(candidates);
|
||||
|
||||
// Sample the next word X from the remaining words
|
||||
llama_token X = llama_sample_token_impl(smpl, candidates);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
llama_token X = llama_sampling_sample_dist_impl(candidates, rng);
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
|
||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
mu = mu - eta * e;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_token llama_sampling_sample_greedy_impl(llama_token_data_array * candidates) {
|
||||
// Find max element
|
||||
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit < b.logit;
|
||||
});
|
||||
|
||||
llama_token result = max_iter->id;
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
smpl->n_sample++;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
GGML_ASSERT(smpl);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
llama_token llama_sampling_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
llama_sampling_softmax_impl(candidates);
|
||||
|
||||
std::vector<float> probs;
|
||||
probs.reserve(candidates->size);
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
probs.push_back(candidates->data[i].p);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
const int idx = dist(rng);
|
||||
llama_token result = candidates->data[idx].id;
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
smpl->n_sample++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
|
||||
void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar) {
|
||||
smpl.prev.push_back(token);
|
||||
|
||||
if (apply_grammar && smpl.grammar) {
|
||||
llama_grammar_accept_impl(*smpl.grammar, token);
|
||||
}
|
||||
}
|
||||
|
||||
llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith) {
|
||||
if (ith < 0 || ith >= (int) smpl.prev.size()) {
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
return smpl.prev.rat(ith);
|
||||
}
|
||||
|
||||
int llama_sampling_n_prev_impl(const struct llama_sampling & smpl) {
|
||||
return smpl.prev.size();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user