mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-07 17:44:09 +00:00
This commit adds support for performing sampling operations on the backend (e.g. GPU) as part of the model computation graph. The motivation for this feature is to enable sampling to be performed directly on the backend as part of the computation graph being executed, allowing for some or all of the sampling to be done on the backend. For example, the backend sampler chain might select/sample a token directly in which case only the sampled token needs to be transferred from device memory to host memory. It is also possible for the backend samplers to perform filtering of the logits, or compute and filter the probability distribution, in which case only the filtered logits or probabilites need to be transferred back to system memory for further processing by CPU samplers. Currently the backend sampling works in a similar manner to how pooling works, it is a function that is called by build_graph and the sampler operations become part of the models computation graph.
259 lines
8.5 KiB
C++
259 lines
8.5 KiB
C++
#include "sampling.h"
|
|
#include "log.h"
|
|
|
|
#ifdef LLAMA_USE_LLGUIDANCE
|
|
|
|
# include "llguidance.h"
|
|
# include <cmath>
|
|
|
|
struct llama_sampler_llg {
|
|
const llama_vocab * vocab;
|
|
std::string grammar_kind;
|
|
std::string grammar_data;
|
|
LlgTokenizer * tokenizer;
|
|
LlgMatcher * grammar;
|
|
};
|
|
|
|
static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
|
|
const char * grammar_data) {
|
|
LlgConstraintInit cinit;
|
|
llg_constraint_init_set_defaults(&cinit, tokenizer);
|
|
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
|
|
if (log_level && *log_level) {
|
|
cinit.log_stderr_level = atoi(log_level);
|
|
}
|
|
auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
|
|
if (llg_matcher_get_error(c)) {
|
|
LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
|
|
llg_free_matcher(c);
|
|
return nullptr;
|
|
}
|
|
|
|
return c;
|
|
}
|
|
|
|
static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
|
|
return "llguidance";
|
|
}
|
|
|
|
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
|
|
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
if (ctx->grammar) {
|
|
llg_matcher_consume_token(ctx->grammar, token);
|
|
}
|
|
}
|
|
|
|
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
if (ctx->grammar) {
|
|
const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
|
|
if (mask == nullptr) {
|
|
if (llg_matcher_compute_mask(ctx->grammar) == 0) {
|
|
mask = llg_matcher_get_mask(ctx->grammar);
|
|
} else {
|
|
LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
|
|
llg_free_matcher(ctx->grammar);
|
|
ctx->grammar = nullptr;
|
|
return;
|
|
}
|
|
}
|
|
|
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
auto token = cur_p->data[i].id;
|
|
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
|
|
cur_p->data[i].logit = -INFINITY;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void llama_sampler_llg_reset(llama_sampler * smpl) {
|
|
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
if (ctx->grammar) {
|
|
llg_matcher_reset(ctx->grammar);
|
|
}
|
|
}
|
|
|
|
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
|
|
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
|
|
|
|
auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
|
|
|
|
// copy the state
|
|
{
|
|
auto * result_ctx = (llama_sampler_llg *) result->ctx;
|
|
|
|
if (ctx->grammar) {
|
|
result_ctx->grammar_kind = ctx->grammar_kind;
|
|
result_ctx->grammar_data = ctx->grammar_data;
|
|
result_ctx->grammar = llg_clone_matcher(ctx->grammar);
|
|
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
static void llama_sampler_llg_free(llama_sampler * smpl) {
|
|
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
|
|
if (ctx->grammar) {
|
|
llg_free_matcher(ctx->grammar);
|
|
llg_free_tokenizer(ctx->tokenizer);
|
|
}
|
|
|
|
delete ctx;
|
|
}
|
|
|
|
static llama_sampler_i llama_sampler_llg_i = {
|
|
/* .name = */ llama_sampler_llg_name,
|
|
/* .accept = */ llama_sampler_llg_accept_impl,
|
|
/* .apply = */ llama_sampler_llg_apply,
|
|
/* .reset = */ llama_sampler_llg_reset,
|
|
/* .clone = */ llama_sampler_llg_clone,
|
|
/* .free = */ llama_sampler_llg_free,
|
|
/* .apply_ggml = */ NULL,
|
|
/* .accept_ggml = */ NULL,
|
|
/* .set_input_ggml = */ NULL,
|
|
/* .set_backend_context = */ NULL,
|
|
};
|
|
|
|
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
|
uint32_t * output_tokens, size_t output_tokens_len) {
|
|
const llama_vocab * vocab = (const llama_vocab *) user_data;
|
|
int r = 0;
|
|
try {
|
|
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
|
|
true);
|
|
} catch (const std::exception & e) {
|
|
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
|
|
}
|
|
if (r < 0) {
|
|
return -r;
|
|
}
|
|
return r;
|
|
}
|
|
|
|
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
|
|
// TODO store the tokenizer in the vocab somehow
|
|
static const llama_vocab * vocab_cache;
|
|
static LlgTokenizer * tokenizer_cache;
|
|
|
|
if (vocab_cache == vocab) {
|
|
return llg_clone_tokenizer(tokenizer_cache);
|
|
}
|
|
|
|
auto tok_eos = llama_vocab_eot(vocab);
|
|
if (tok_eos == LLAMA_TOKEN_NULL) {
|
|
tok_eos = llama_vocab_eos(vocab);
|
|
}
|
|
|
|
size_t vocab_size = llama_vocab_n_tokens(vocab);
|
|
|
|
auto token_lens = new uint32_t[vocab_size];
|
|
// we typically have ~7 bytes per token; let's go on the safe side here
|
|
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
|
|
auto token_bytes = new uint8_t[token_bytes_size];
|
|
|
|
size_t offset = 0;
|
|
for (size_t i = 0; i < vocab_size; i++) {
|
|
size_t max_token = 1024;
|
|
if (token_bytes_size - offset < max_token) {
|
|
GGML_ABORT("token_bytes buffer too small\n");
|
|
}
|
|
|
|
llama_token token = i;
|
|
auto dp = (char *) token_bytes + offset;
|
|
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
|
|
if (size < 0) {
|
|
GGML_ABORT("llama_detokenize failed\n");
|
|
}
|
|
if (size == 0) {
|
|
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
|
|
if (size < 0) {
|
|
GGML_ABORT("llama_detokenize failed\n");
|
|
}
|
|
if (size != 0) {
|
|
*dp = '\xff'; // special token prefix marker
|
|
size += 1;
|
|
}
|
|
}
|
|
|
|
token_lens[i] = size;
|
|
offset += size;
|
|
}
|
|
|
|
LlgTokenizerInit tinit = {
|
|
/* .vocab_size = */ (uint32_t) vocab_size,
|
|
/* .tok_eos = */ (uint32_t) tok_eos,
|
|
/* .token_lens = */ token_lens,
|
|
/* .token_bytes = */ token_bytes,
|
|
/* .tokenizer_json = */ nullptr,
|
|
/* .tokenize_assumes_string = */ true,
|
|
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
|
|
/* .use_approximate_greedy_tokenize_fn = */ false,
|
|
/* .tokenize_user_data = */ vocab,
|
|
/* .slices = */ nullptr,
|
|
};
|
|
|
|
char error_buffer[1024];
|
|
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
|
|
|
|
delete[] token_bytes;
|
|
delete[] token_lens;
|
|
|
|
if (tokenizer == nullptr) {
|
|
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
|
|
return tokenizer;
|
|
}
|
|
|
|
if (tokenizer_cache) {
|
|
llg_free_tokenizer(tokenizer_cache);
|
|
}
|
|
vocab_cache = vocab;
|
|
tokenizer_cache = tokenizer;
|
|
|
|
return llg_clone_tokenizer(tokenizer_cache);
|
|
}
|
|
|
|
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
|
|
const char * grammar_data) {
|
|
auto * ctx = new llama_sampler_llg;
|
|
|
|
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
|
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
|
|
*ctx = {
|
|
/* .vocab = */ vocab,
|
|
/* .grammar_kind = */ grammar_kind,
|
|
/* .grammar_data = */ grammar_data,
|
|
/* .tokenizer = */ tokenizer,
|
|
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
|
|
};
|
|
if (ctx->grammar) {
|
|
GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
|
|
llg_matcher_get_mask_byte_size(ctx->grammar));
|
|
}
|
|
} else {
|
|
*ctx = {
|
|
/* .vocab = */ vocab,
|
|
/* .grammar_kind = */ {},
|
|
/* .grammar_data = */ {},
|
|
/* .tokenizer = */ nullptr,
|
|
/* .grammar = */ nullptr,
|
|
};
|
|
}
|
|
|
|
return llama_sampler_init(
|
|
/* .iface = */ &llama_sampler_llg_i,
|
|
/* .ctx = */ ctx);
|
|
}
|
|
|
|
#else
|
|
|
|
llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
|
|
LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
|
return nullptr;
|
|
}
|
|
|
|
#endif // LLAMA_USE_LLGUIDANCE
|