From 2e97c5f96f9fe2bb26f794a348e05d7a1c74baa1 Mon Sep 17 00:00:00 2001 From: Tim Neumann Date: Sun, 10 May 2026 19:12:02 +0200 Subject: [PATCH] backend sampling: support returning post-sampling probs (#22622) * server: Never return 0.0 post-sampling probabilities * backend sampling: support returning post-sampling probs --- common/sampling.cpp | 14 +++-- tools/server/server-context.cpp | 12 +++- tools/server/tests/unit/test_completion.py | 67 +++++++++++++++++++--- tools/server/tests/utils.py | 3 + 4 files changed, 80 insertions(+), 16 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index d4a2fdcdac..5665d0a706 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -547,6 +547,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co auto & chain = gsmpl->chain; auto & cur_p = gsmpl->cur_p; // initialized by set_logits + gsmpl->set_logits(ctx, idx); + // Check if a backend sampler has already sampled a token in which case we // return that token id directly. { @@ -558,17 +560,17 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported"); GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported"); - // TODO: simplify - gsmpl->cur.resize(1); - gsmpl->cur[0] = { id, 0.0f, 1.0f }; - cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true }; + for (size_t i = 0; i < cur_p.size; ++i) { + if (cur_p.data[i].id == id) { + cur_p.selected = i; + break; + } + } return id; } } - gsmpl->set_logits(ctx, idx); - // apply reasoning budget first llama_sampler_apply(rbudget, &cur_p); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 3f20c94c55..637f8d2164 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1317,7 +1317,7 @@ private: return false; } - const bool need_logits = task.params.sampling.n_probs > 0; + const bool need_pre_sample_logits = task.params.sampling.n_probs > 0 && !task.params.post_sampling_probs; bool backend_sampling = true; @@ -1326,8 +1326,8 @@ private: // TODO: speculative decoding requires multiple samples per batch - not supported yet backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0); - // TODO: getting post/pre sampling logits is not yet supported with backend sampling - backend_sampling &= !need_logits; + // TODO: getting pre sampling logits is not yet supported with backend sampling + backend_sampling &= !need_pre_sample_logits; // TODO: tmp until backend sampling is fully implemented if (backend_sampling) { @@ -1504,6 +1504,12 @@ private: // set probability for top n_probs tokens result.probs.reserve(n_probs); for (size_t i = 0; i < n_probs; i++) { + // Some samplers do return 0.0 probabilities, others don't. + // Filter 0.0 probailities, to ensure the behavior is consistent. + if (cur_p->data[i].p == 0.0) { + break; + } + result.probs.push_back({ cur_p->data[i].id, common_token_to_piece(ctx, cur_p->data[i].id, special), diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index c1a1978543..1e0891987a 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -491,29 +491,82 @@ def test_n_probs_post_sampling(): global server server.start() res = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", + "prompt": "Today was the day. Today I would finally become a", "n_probs": 10, - "temperature": 0.0, + "temperature": 1.0, "n_predict": 5, "post_sampling_probs": True, }) assert res.status_code == 200 assert "completion_probabilities" in res.body assert len(res.body["completion_probabilities"]) == 5 - for tok in res.body["completion_probabilities"]: + for (i, tok) in enumerate(res.body["completion_probabilities"]): assert "id" in tok and tok["id"] > 0 assert "token" in tok and type(tok["token"]) == str assert "prob" in tok and 0.0 < tok["prob"] <= 1.0 assert "bytes" in tok and type(tok["bytes"]) == list - assert len(tok["top_probs"]) == 10 + assert "top_probs" in tok and type(tok["top_probs"]) == list + for prob in tok["top_probs"]: assert "id" in prob and prob["id"] > 0 assert "token" in prob and type(prob["token"]) == str - assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 + # 0.0 probability tokens should never be returned by the server + assert "prob" in prob and 0.0 < prob["prob"] <= 1.0 assert "bytes" in prob and type(prob["bytes"]) == list - # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs - assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) + if i == 0: + # The prompt is vague enough that we should get at least 10 possibilities + # for the first token. + assert len(tok["top_probs"]) == 10 + + if len(tok["top_probs"]) < 10: + # Getting less than the requested number of probabilities should only happen + # if the ones we did get already sum to 1.0. + assert sum(p["prob"] for p in tok["top_probs"]) == pytest.approx(1.0) + +def test_n_probs_post_backend_sampling(): + """Verify that the same probabilities are returned with and without backend sampling.""" + global server + server.backend_sampling = True + server.start() + + def make_request(backend_sampling): + n_predict = 20 + + res = server.make_request("POST", "/completion", data={ + "prompt": "The countries of Europe, in random order, are:", + "n_probs": 10, + "n_predict": n_predict, + "post_sampling_probs": True, + "seed": 4242, + "backend_sampling": backend_sampling, + }) + assert res.status_code == 200 + + total_probs = 0 + completions = res.body["completion_probabilities"] + assert len(completions) == n_predict + for tok in completions: + # Handling of 0.0 probabilities differs between samplers and backend sampling. Filter them to normalize the + # data. + tok["top_probs"] = [x for x in tok["top_probs"] if x["prob"] > 0.0] + total_probs += len(tok["top_probs"]) + # Verify that we got at least two top probs on average, to ensure the effectiveness of the test. + assert total_probs >= 2 * n_predict + return completions + + def verify_token(a, b): + assert a["id"] == b["id"] + assert a["token"] == b["token"] + assert a["bytes"] == b["bytes"] + assert a["prob"] == pytest.approx(b["prob"], abs=0.01) + + for (a, b) in zip(make_request(True), make_request(False)): + verify_token(a, b) + assert len(a["top_probs"]) == len(b["top_probs"]) + + for (aa, bb) in zip(a["top_probs"], b["top_probs"]): + verify_token(aa, bb) @pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)]) def test_logit_bias(tokenize, openai_style): diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index ce93903872..c5dba1c139 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -108,6 +108,7 @@ class ServerProcess: no_cache_idle_slots: bool = False log_path: str | None = None webui_mcp_proxy: bool = False + backend_sampling: bool = False gcp_compat: bool = False # session variables @@ -252,6 +253,8 @@ class ServerProcess: server_args.append("--no-cache-idle-slots") if self.webui_mcp_proxy: server_args.append("--webui-mcp-proxy") + if self.backend_sampling: + server_args.append("--backend_sampling") if self.gcp_compat: env["AIP_MODE"] = "PREDICTION"