mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-14 04:54:06 +00:00
tests : fix batch token position tracking in test_backend_sampler.cpp
This commit is contained in:
@@ -97,8 +97,6 @@ struct test_model_context {
|
||||
last_batch_info.clear();
|
||||
llama_batch batch = llama_batch_init(512, 0, prompts.size());
|
||||
|
||||
int n_tokens_per_prompt = 0;
|
||||
|
||||
auto vocab = get_vocab();
|
||||
for (const auto & [seq_id, prompt] : prompts) {
|
||||
std::vector<llama_token> tokens;
|
||||
@@ -108,18 +106,6 @@ struct test_model_context {
|
||||
int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(),
|
||||
prompt_tokens.data(), prompt_tokens.size(),
|
||||
false, false);
|
||||
//TODO: refactor this function to just handle a single prompt at a time
|
||||
// to avoid this check and complexity.
|
||||
if (n_tokens_per_prompt == 0) {
|
||||
n_tokens_per_prompt = n_tokens;
|
||||
} else {
|
||||
if (n_tokens != n_tokens_per_prompt) {
|
||||
fprintf(stderr, "Error: prompts must have the same number of tokens\n");
|
||||
llama_batch_free(batch);
|
||||
return false;
|
||||
}
|
||||
n_tokens_per_prompt = n_tokens;
|
||||
}
|
||||
if (n_tokens < 0) {
|
||||
fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id);
|
||||
llama_batch_free(batch);
|
||||
@@ -130,11 +116,16 @@ struct test_model_context {
|
||||
tokens.push_back(prompt_tokens[i]);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
common_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
|
||||
if (seq_positions.find(seq_id) == seq_positions.end()) {
|
||||
seq_positions[seq_id] = 0;
|
||||
}
|
||||
|
||||
seq_positions[seq_id] = tokens.size();
|
||||
int32_t start_pos = seq_positions[seq_id];
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
common_batch_add(batch, tokens[i], start_pos + i, { seq_id }, i == tokens.size() - 1);
|
||||
}
|
||||
|
||||
seq_positions[seq_id] = start_pos + tokens.size();
|
||||
}
|
||||
|
||||
|
||||
@@ -375,7 +366,7 @@ static void test_backend_temp_sampling(const char * model_path) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) {
|
||||
if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user