diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp index 74fce53677..c6e1f86c91 100644 --- a/common/reasoning-budget.cpp +++ b/common/reasoning-budget.cpp @@ -232,34 +232,6 @@ static struct llama_sampler * common_reasoning_budget_init_state( ); } -struct llama_sampler * common_reasoning_budget_init( - const struct llama_vocab * vocab, - const std::vector & start_tokens, - const std::vector & end_tokens, - const std::vector & forced_tokens, - int32_t budget, - const std::vector & prefill_tokens) { - // Determine initial state from prefill: COUNTING if the prefill begins with - // the start sequence but does not also contain the end sequence after it. - common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE; - if (!prefill_tokens.empty() && !start_tokens.empty() && - prefill_tokens.size() >= start_tokens.size() && - std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) { - initial_state = REASONING_BUDGET_COUNTING; - // If the end sequence also follows the start in the prefill, reasoning - // was opened and immediately closed — stay IDLE. - if (!end_tokens.empty() && - prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) { - auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size(); - if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() && - std::equal(end_tokens.begin(), end_tokens.end(), end_start)) { - initial_state = REASONING_BUDGET_IDLE; - } - } - } - return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); -} - struct llama_sampler * common_reasoning_budget_init( const struct llama_vocab * vocab, const std::vector & start_tokens, diff --git a/common/reasoning-budget.h b/common/reasoning-budget.h index ee1a30ed3c..ef37f46ee4 100644 --- a/common/reasoning-budget.h +++ b/common/reasoning-budget.h @@ -29,10 +29,7 @@ enum common_reasoning_budget_state { // end_tokens - token sequence for natural deactivation // forced_tokens - token sequence forced when budget expires // budget - max tokens allowed in the reasoning block -// prefill_tokens - tokens already present in the prompt (generation prompt); -// used to determine the initial state: COUNTING if they begin -// with start_tokens (but don't also end with end_tokens), -// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING. +// initial_state - initial state // struct llama_sampler * common_reasoning_budget_init( const struct llama_vocab * vocab, @@ -40,16 +37,6 @@ struct llama_sampler * common_reasoning_budget_init( const std::vector & end_tokens, const std::vector & forced_tokens, int32_t budget, - const std::vector & prefill_tokens = {}); - -// Variant that takes an explicit initial state (used by tests and clone). -// COUNTING with budget <= 0 is promoted to FORCING. -struct llama_sampler * common_reasoning_budget_init( - const struct llama_vocab * vocab, - const std::vector & start_tokens, - const std::vector & end_tokens, - const std::vector & forced_tokens, - int32_t budget, - common_reasoning_budget_state initial_state); + common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE); common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl); diff --git a/common/sampling.cpp b/common/sampling.cpp index b2e6d8e8d8..d4a2fdcdac 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -260,32 +260,35 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st } } + // Compute prefill tokens from the generation prompt + std::vector prefill_tokens; + if (!params.generation_prompt.empty()) { + GGML_ASSERT(vocab != nullptr); + auto tokens = common_tokenize(vocab, params.generation_prompt, false, true); + for (size_t i = 0; i < tokens.size(); i++) { + std::string piece = common_token_to_piece(vocab, tokens[i], true); + if (i == 0 && std::isspace(piece[0]) && !std::isspace(params.generation_prompt[0])) { + // Some tokenizers will add a space before the first special token, need to exclude + continue; + } + LOG_DBG("%s: prefill token: %d = %s\n", __func__, tokens[i], piece.c_str()); + prefill_tokens.push_back(tokens[i]); + } + } + // Feed generation prompt tokens to the grammar sampler so it advances past // tokens the template already placed in the prompt. // Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled. - std::vector prefill_tokens; - if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) { - GGML_ASSERT(vocab != nullptr); - prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true); - if (!prefill_tokens.empty()) { - std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true); - if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) { - // Some tokenizers will add a space before the first special token, need to remove - prefill_tokens = std::vector(prefill_tokens.begin() + 1, prefill_tokens.end()); - } - } - - if (grmr && !params.grammar_lazy) { - try { - for (const auto & token : prefill_tokens) { - llama_sampler_accept(grmr, token); - LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token); - } - } catch (std::exception &e) { - LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__, - common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str()); - throw e; + if (grmr && !params.grammar_lazy && common_grammar_needs_prefill(params.grammar)) { + try { + for (const auto & token : prefill_tokens) { + llama_sampler_accept(grmr, token); + LOG_DBG("%s: grammar accepted prefill token (%d)\n", __func__, token); } + } catch (std::exception &e) { + LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__, + common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str()); + throw e; } } @@ -296,8 +299,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st params.reasoning_budget_start, params.reasoning_budget_end, params.reasoning_budget_forced, - params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens, - prefill_tokens); + params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens); + + for (const auto & token : prefill_tokens) { + llama_sampler_accept(rbudget, token); + LOG_DBG("%s: reasoning-budget accepted prefill token (%d)\n", __func__, token); + } } if (params.has_logit_bias()) { @@ -431,7 +438,7 @@ static bool grammar_should_apply(struct common_sampler * gsmpl) { return true; } -void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { +void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated) { if (!gsmpl) { return; } @@ -439,9 +446,11 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo const auto tm = gsmpl->tm(); // grammar_should_apply() checks the reasoning budget state, so calculate this before we accept - accept_grammar = accept_grammar && grammar_should_apply(gsmpl); + const auto accept_grammar = is_generated && grammar_should_apply(gsmpl); - llama_sampler_accept(gsmpl->rbudget, token); + if (gsmpl->rbudget && is_generated) { + llama_sampler_accept(gsmpl->rbudget, token); + } if (gsmpl->grmr && accept_grammar) { llama_sampler_accept(gsmpl->grmr, token); diff --git a/common/sampling.h b/common/sampling.h index 5b57ad6581..49506a00cd 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -41,8 +41,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st void common_sampler_free(struct common_sampler * gsmpl); -// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar -void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar); +// if is_generated is true, the token is accepted by the sampling chain, the reasoning budget sampler, and the grammar sampler +void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated); void common_sampler_reset (struct common_sampler * gsmpl); struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);