mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-03-17 16:44:07 +00:00
context : fix output reorder with backend sampling (#19638)
This commit is contained in:
@@ -878,6 +878,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
|||||||
}
|
}
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
// fallback to full vocab list
|
// fallback to full vocab list
|
||||||
|
GGML_UNUSED(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
return sampling.token_ids_full_vocab.data();
|
return sampling.token_ids_full_vocab.data();
|
||||||
@@ -1809,7 +1810,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
//
|
//
|
||||||
|
|
||||||
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||||
|
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
const auto & vocab = model.vocab;
|
const auto & vocab = model.vocab;
|
||||||
|
|
||||||
@@ -1893,11 +1893,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||||||
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
|
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
|
||||||
offset += embd.size * sizeof(float);
|
offset += embd.size * sizeof(float);
|
||||||
|
|
||||||
sampling.logits = {nullptr, 0};
|
|
||||||
sampling.probs = {nullptr, 0};
|
|
||||||
sampling.sampled = {nullptr, 0};
|
|
||||||
sampling.candidates = {nullptr, 0};
|
|
||||||
|
|
||||||
if (has_sampling) {
|
if (has_sampling) {
|
||||||
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
||||||
offset += sampling.logits.size * sizeof(float);
|
offset += sampling.logits.size * sizeof(float);
|
||||||
@@ -1923,6 +1918,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||||||
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
||||||
|
|
||||||
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
|
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
|
||||||
|
} else {
|
||||||
|
sampling.logits = {nullptr, 0};
|
||||||
|
sampling.probs = {nullptr, 0};
|
||||||
|
sampling.sampled = {nullptr, 0};
|
||||||
|
sampling.candidates = {nullptr, 0};
|
||||||
|
|
||||||
|
sampling.logits_count.clear();
|
||||||
|
sampling.probs_count.clear();
|
||||||
|
sampling.candidates_count.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
// set all ids as invalid (negative)
|
// set all ids as invalid (negative)
|
||||||
@@ -1953,37 +1957,30 @@ void llama_context::output_reorder() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sampling.logits.has_data()) {
|
if (!sampling.samplers.empty()) {
|
||||||
|
assert(sampling.logits.size > 0);
|
||||||
|
assert(sampling.probs.size > 0);
|
||||||
|
assert(sampling.candidates.size > 0);
|
||||||
|
assert(sampling.sampled.size > 0);
|
||||||
|
assert(sampling.logits_count.size() > 0);
|
||||||
|
assert(sampling.probs_count.size() > 0);
|
||||||
|
assert(sampling.candidates_count.size() > 0);
|
||||||
|
|
||||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||||
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
|
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (sampling.probs.has_data()) {
|
|
||||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||||
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
|
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (sampling.candidates.has_data()) {
|
|
||||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||||
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
|
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (sampling.sampled.has_data()) {
|
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
|
||||||
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
|
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
||||||
}
|
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
||||||
|
|
||||||
if (!sampling.logits_count.empty()) {
|
|
||||||
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!sampling.probs_count.empty()) {
|
|
||||||
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!sampling.candidates_count.empty()) {
|
|
||||||
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -265,24 +265,26 @@ private:
|
|||||||
std::unique_ptr<llama_memory_i> memory;
|
std::unique_ptr<llama_memory_i> memory;
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||||
struct buffer_view<float> logits = {nullptr, 0};
|
buffer_view<float> logits = {nullptr, 0};
|
||||||
|
|
||||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||||
struct buffer_view<float> embd = {nullptr, 0};
|
buffer_view<float> embd = {nullptr, 0};
|
||||||
|
|
||||||
struct sampling_info {
|
struct sampling_info {
|
||||||
|
// !samplers.empty() to check if any samplers are active
|
||||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||||
|
|
||||||
struct buffer_view<float> logits = {nullptr, 0};
|
buffer_view<float> logits = {nullptr, 0};
|
||||||
struct buffer_view<llama_token> sampled = {nullptr, 0};
|
buffer_view<llama_token> sampled = {nullptr, 0};
|
||||||
struct buffer_view<float> probs = {nullptr, 0};
|
buffer_view<float> probs = {nullptr, 0};
|
||||||
struct buffer_view<llama_token> candidates = {nullptr, 0};
|
buffer_view<llama_token> candidates = {nullptr, 0};
|
||||||
|
|
||||||
std::vector<uint32_t> logits_count;
|
std::vector<uint32_t> logits_count;
|
||||||
std::vector<uint32_t> probs_count;
|
std::vector<uint32_t> probs_count;
|
||||||
std::vector<uint32_t> candidates_count;
|
std::vector<uint32_t> candidates_count;
|
||||||
|
|
||||||
|
// optimization
|
||||||
std::vector<llama_token> token_ids_full_vocab;
|
std::vector<llama_token> token_ids_full_vocab;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user