llama : add llama_sampling API + move grammar in libllama

ggml-ci
This commit is contained in:
Georgi Gerganov
2024-08-05 10:08:25 +03:00
parent b69a480af4
commit f648ca2cee
48 changed files with 2481 additions and 2590 deletions

View File

@@ -1,5 +1,8 @@
#include "llama-sampling.h"
#include "llama-vocab.h"
#include "llama-grammar.h"
#include <algorithm>
#include <cstring>
#include <ctime>
@@ -21,18 +24,104 @@ static void llama_log_softmax(float * array, size_t size) {
}
}
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
llama_sampling::llama_sampling(const struct llama_vocab & vocab) : vocab(vocab) {
}
llama_sampling::~llama_sampling() {
if (grammar) {
llama_grammar_free_impl(grammar);
}
}
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) {
auto * result = new llama_sampling(vocab);
result->params = params;
result->prev = ring_buffer<llama_token>(params.n_prev);
for (int i = 0; i < params.n_samplers; ++i) {
result->samplers.push_back(params.samplers[i]);
}
llama_sampling_set_rng_seed_impl(*result, params.seed);
return result;
}
void llama_sampling_free_impl(struct llama_sampling * sampling) {
delete sampling;
}
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) {
auto * result = new llama_sampling(smpl.vocab);
result->params = smpl.params;
result->grammar_str = smpl.grammar_str;
result->grammar_root = smpl.grammar_root;
result->logit_bias = smpl.logit_bias;
if (smpl.grammar) {
result->grammar = llama_grammar_cp_impl(*smpl.grammar);
}
result->rng = smpl.rng;
result->prev = smpl.prev;
return result;
}
void llama_sampling_reset_impl(struct llama_sampling & smpl) {
if (smpl.grammar) {
llama_grammar_free_impl(smpl.grammar);
smpl.grammar = nullptr;
}
if (!smpl.grammar_str.empty()) {
smpl.grammar = llama_grammar_init_impl(&smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data());
}
smpl.prev.clear();
}
void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
seed = time(NULL);
}
smpl->rng.seed(seed);
smpl.rng.seed(seed);
}
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
GGML_ASSERT(candidates->size > 0);
void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) {
if (smpl.grammar) {
llama_grammar_free_impl(smpl.grammar);
smpl.grammar = nullptr;
}
const int64_t t_start_sample_us = ggml_time_us();
if (grammar_str != nullptr && grammar_str[0] != '\0') {
smpl.grammar_str = grammar_str;
smpl.grammar_root = grammar_root;
smpl.grammar = llama_grammar_init_impl(&smpl.vocab, grammar_str, grammar_root);
} else {
smpl.grammar_str.clear();
smpl.grammar_root.clear();
}
}
void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
smpl.logit_bias.clear();
smpl.logit_bias.reserve(n_logit_bias);
for (int32_t i = 0; i < n_logit_bias; ++i) {
smpl.logit_bias.push_back(logit_bias[i]);
}
}
void llama_sampling_softmax_impl(llama_token_data_array * candidates) {
GGML_ASSERT(candidates->size > 0);
// Sort the logits in descending order
if (!candidates->sorted) {
@@ -44,28 +133,24 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar
float max_l = candidates->data[0].logit;
float cum_sum = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) {
float p = expf(candidates->data[i].logit - max_l);
candidates->data[i].p = p;
cum_sum += p;
}
for (size_t i = 0; i < candidates->size; ++i) {
candidates->data[i].p /= cum_sum;
}
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, size_t min_keep) {
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
// if (k >= (int32_t)candidates->size) {
// return;
// }
const int64_t t_start_sample_us = ggml_time_us();
if (k <= 0) {
k = candidates->size;
}
@@ -101,10 +186,12 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
int ib = nbuckets - 1;
for ( ; ib >= 0; --ib) {
nhave += histo[ib];
if (nhave >= k) break;
if (nhave >= k) {
break;
}
}
std::vector<llama_token_data> tmp_tokens(nhave);
auto ptr = tmp_tokens.data();
auto * ptr = tmp_tokens.data();
std::vector<llama_token_data*> bucket_ptrs;
bucket_ptrs.reserve(nbuckets - ib);
for (int j = nbuckets - 1; j >= ib; --j) {
@@ -133,20 +220,14 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
candidates->sorted = true;
}
candidates->size = k;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
void llama_sampling_top_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) {
if (p >= 1.0f) {
return;
}
llama_sample_softmax_impl(smpl, candidates);
const int64_t t_start_sample_us = ggml_time_us();
llama_sampling_softmax_impl(candidates);
// Compute the cumulative probabilities
float cum_sum = 0.0f;
@@ -165,19 +246,13 @@ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_arra
// Resize the output vector to keep only the top-p tokens
candidates->size = last_idx;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
void llama_sampling_min_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) {
if (p <= 0.0f || !candidates->size) {
return;
}
const int64_t t_start_sample_us = ggml_time_us();
bool min_p_applied = false;
// if the candidates aren't sorted, try the unsorted implementation first
@@ -226,19 +301,14 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
// Resize the output vector to keep only the matching tokens
candidates->size = i;
}
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
void llama_sampling_tail_free_impl(llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) {
return;
}
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();
llama_sampling_softmax_impl(candidates);
// Compute the first and second derivatives
std::vector<float> first_derivatives(candidates->size - 1);
@@ -285,13 +355,9 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
// Resize the output vector to keep only the tokens above the tail location
candidates->size = last_idx;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, size_t min_keep) {
// Reference implementation:
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
if (p >= 1.0f) {
@@ -299,9 +365,7 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
}
// Compute the softmax of logits and calculate entropy
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();
llama_sampling_softmax_impl(candidates);
float entropy = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) {
@@ -349,15 +413,9 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
candidates->size = new_candidates.size();
candidates->sorted = false;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
const int64_t t_start_sample_us = ggml_time_us();
void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
// no need to do anything if there is only one (or zero) candidates
if(candidates->size <= 1) {
return;
@@ -366,7 +424,7 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
// Calculate maximum possible entropy
float max_entropy = -logf(1.0f / candidates->size);
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
llama_sampling_softmax_impl(candidates);
// Calculate entropy of the softmax probabilities
float entropy = 0.0f;
@@ -398,13 +456,15 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
}
// Re-compute softmax probabilities after scaling logits with dynamic temperature
double max_l_double = candidates->data[0].logit;
const double max_l_double = candidates->data[0].logit;
double cum_sum_double = 0.0;
for (size_t i = 0; i < candidates->size; ++i) {
double p = exp(candidates->data[i].logit - max_l_double);
candidates->data[i].p = p; // Store the scaled probability
cum_sum_double += p;
}
for (size_t i = 0; i < candidates->size; ++i) {
candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
}
@@ -416,44 +476,24 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
}
#endif
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
const int64_t t_start_sample_us = ggml_time_us();
void llama_sampling_temp_impl(llama_token_data_array * candidates, float temp) {
for (size_t i = 0; i < candidates->size; ++i) {
candidates->data[i].logit /= temp;
}
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
void llama_sampling_grammar_impl(llama_token_data_array * candidates, const struct llama_grammar & grammar) {
llama_grammar_apply_impl(grammar, candidates);
}
void llama_sampling_penalties_impl(
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present) {
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
return;
}
const int64_t t_start_sample_us = ggml_time_us();
// Create a frequency map to count occurrences of each token in last_tokens
std::unordered_map<llama_token, int> token_count;
for (size_t i = 0; i < penalty_last_n; ++i) {
token_count[last_tokens[i]]++;
}
const llama_token_cnt & token_count,
float penalty_repeat,
float penalty_freq,
float penalty_present) {
// Apply frequency and presence penalties to the candidates
for (size_t i = 0; i < candidates->size; ++i) {
const auto token_iter = token_count.find(candidates->data[i].id);
@@ -475,43 +515,10 @@ void llama_sample_repetition_penalties_impl(
}
candidates->sorted = false;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_apply_guidance_impl(
struct llama_sampling * smpl,
float * logits,
float * logits_guidance,
float scale) {
GGML_ASSERT(smpl);
const auto t_start_sample_us = ggml_time_us();
const auto n_vocab = smpl->n_vocab;
llama_log_softmax(logits, n_vocab);
llama_log_softmax(logits_guidance, n_vocab);
for (int i = 0; i < n_vocab; ++i) {
auto & l = logits[i];
const auto & g = logits_guidance[i];
l = scale * (l - g) + g;
}
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
GGML_ASSERT(smpl);
const int32_t n_vocab = float(smpl->n_vocab);
int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) {
llama_sampling_softmax_impl(candidates);
// Estimate s_hat using the most probable m tokens
float s_hat = 0.0;
@@ -527,13 +534,11 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
// Compute k from the estimated s_hat and target surprise value
float epsilon_hat = s_hat - 1;
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
// Sample the next word X using top-k sampling
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
llama_token X = llama_sample_token_impl(smpl, candidates);
t_start_sample_us = ggml_time_us();
llama_sampling_top_k_impl(candidates, int(k), 1);
llama_token X = llama_sampling_sample_dist_impl(candidates, rng);
// Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
@@ -543,93 +548,88 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
float e = observed_surprise - tau;
// Update mu using the learning rate and error
*mu = *mu - eta * e;
mu = mu - eta * e;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
return X;
}
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
int64_t t_start_sample_us;
t_start_sample_us = ggml_time_us();
llama_sample_softmax_impl(smpl, candidates);
llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) {
llama_sampling_softmax_impl(candidates);
// Truncate the words with surprise values greater than mu
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return -log2f(candidate.p) > *mu;
return -log2f(candidate.p) > mu;
}));
if (candidates->size == 0) {
candidates->size = 1;
}
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
// Normalize the probabilities of the remaining words
llama_sample_softmax_impl(smpl, candidates);
llama_sampling_softmax_impl(candidates);
// Sample the next word X from the remaining words
llama_token X = llama_sample_token_impl(smpl, candidates);
t_start_sample_us = ggml_time_us();
llama_token X = llama_sampling_sample_dist_impl(candidates, rng);
// Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return candidate.id == X;
}));
float observed_surprise = -log2f(candidates->data[X_idx].p);
float e = observed_surprise - tau;
// Update mu using the learning rate and error
*mu = *mu - eta * e;
mu = mu - eta * e;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
return X;
}
llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
const int64_t t_start_sample_us = ggml_time_us();
llama_token llama_sampling_sample_greedy_impl(llama_token_data_array * candidates) {
// Find max element
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit < b.logit;
});
llama_token result = max_iter->id;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;
}
return result;
}
llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
GGML_ASSERT(smpl);
const int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
llama_token llama_sampling_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) {
llama_sampling_softmax_impl(candidates);
std::vector<float> probs;
probs.reserve(candidates->size);
for (size_t i = 0; i < candidates->size; ++i) {
probs.push_back(candidates->data[i].p);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
const int idx = dist(rng);
llama_token result = candidates->data[idx].id;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;
return result;
}
llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar) {
smpl.prev.push_back(token);
if (apply_grammar && smpl.grammar) {
llama_grammar_accept_impl(*smpl.grammar, token);
}
}
llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith) {
if (ith < 0 || ith >= (int) smpl.prev.size()) {
return LLAMA_TOKEN_NULL;
}
return smpl.prev.rat(ith);
}
int llama_sampling_n_prev_impl(const struct llama_sampling & smpl) {
return smpl.prev.size();
}