context : fix output reorder with backend sampling (#19638)

This commit is contained in:
Georgi Gerganov
2026-02-15 14:57:40 +02:00
committed by GitHub
parent 08e6d914b8
commit 341bc7d23c
2 changed files with 30 additions and 31 deletions

View File

@@ -878,6 +878,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
} }
} catch (const std::exception & err) { } catch (const std::exception & err) {
// fallback to full vocab list // fallback to full vocab list
GGML_UNUSED(err);
} }
return sampling.token_ids_full_vocab.data(); return sampling.token_ids_full_vocab.data();
@@ -1809,7 +1810,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
// //
uint32_t llama_context::output_reserve(int32_t n_outputs) { uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const auto & vocab = model.vocab; const auto & vocab = model.vocab;
@@ -1893,11 +1893,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0}; embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
offset += embd.size * sizeof(float); offset += embd.size * sizeof(float);
sampling.logits = {nullptr, 0};
sampling.probs = {nullptr, 0};
sampling.sampled = {nullptr, 0};
sampling.candidates = {nullptr, 0};
if (has_sampling) { if (has_sampling) {
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.logits.size * sizeof(float); offset += sampling.logits.size * sizeof(float);
@@ -1923,6 +1918,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
} else {
sampling.logits = {nullptr, 0};
sampling.probs = {nullptr, 0};
sampling.sampled = {nullptr, 0};
sampling.candidates = {nullptr, 0};
sampling.logits_count.clear();
sampling.probs_count.clear();
sampling.candidates_count.clear();
} }
// set all ids as invalid (negative) // set all ids as invalid (negative)
@@ -1953,37 +1957,30 @@ void llama_context::output_reorder() {
} }
} }
if (sampling.logits.has_data()) { if (!sampling.samplers.empty()) {
assert(sampling.logits.size > 0);
assert(sampling.probs.size > 0);
assert(sampling.candidates.size > 0);
assert(sampling.sampled.size > 0);
assert(sampling.logits_count.size() > 0);
assert(sampling.probs_count.size() > 0);
assert(sampling.candidates_count.size() > 0);
for (uint64_t k = 0; k < n_vocab; ++k) { for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
} }
}
if (sampling.probs.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) { for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
} }
}
if (sampling.candidates.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) { for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
} }
}
if (sampling.sampled.has_data()) { std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
} std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
if (!sampling.logits_count.empty()) {
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
}
if (!sampling.probs_count.empty()) {
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
}
if (!sampling.candidates_count.empty()) {
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
} }
} }

View File

@@ -265,24 +265,26 @@ private:
std::unique_ptr<llama_memory_i> memory; std::unique_ptr<llama_memory_i> memory;
// decode output (2-dimensional array: [n_outputs][n_vocab]) // decode output (2-dimensional array: [n_outputs][n_vocab])
struct buffer_view<float> logits = {nullptr, 0}; buffer_view<float> logits = {nullptr, 0};
// embeddings output (2-dimensional array: [n_outputs][n_embd]) // embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
struct buffer_view<float> embd = {nullptr, 0}; buffer_view<float> embd = {nullptr, 0};
struct sampling_info { struct sampling_info {
// !samplers.empty() to check if any samplers are active
std::map<llama_seq_id, llama_sampler *> samplers; std::map<llama_seq_id, llama_sampler *> samplers;
struct buffer_view<float> logits = {nullptr, 0}; buffer_view<float> logits = {nullptr, 0};
struct buffer_view<llama_token> sampled = {nullptr, 0}; buffer_view<llama_token> sampled = {nullptr, 0};
struct buffer_view<float> probs = {nullptr, 0}; buffer_view<float> probs = {nullptr, 0};
struct buffer_view<llama_token> candidates = {nullptr, 0}; buffer_view<llama_token> candidates = {nullptr, 0};
std::vector<uint32_t> logits_count; std::vector<uint32_t> logits_count;
std::vector<uint32_t> probs_count; std::vector<uint32_t> probs_count;
std::vector<uint32_t> candidates_count; std::vector<uint32_t> candidates_count;
// optimization
std::vector<llama_token> token_ids_full_vocab; std::vector<llama_token> token_ids_full_vocab;
}; };