Files
llama.cpp/tools/server/tests/unit/test_chat_completion.py
Pascal 95d469a915 server, webui: accept continue_final_message flag for vLLM API compat (#23012)
* server, webui: accept continue_final_message flag for vLLM API compat

Add the continue_final_message body flag from the vLLM and transformers
API. When set together with add_generation_prompt false, it triggers the
existing prefill_assistant code path, regardless of the server side
opt.prefill_assistant option. Mutual exclusion with add_generation_prompt
true is enforced, matching vLLM behavior.

WebUI sends continue_final_message and add_generation_prompt false on
the Continue button, with the matching opt in option on the chat service.

Pure API alignment, no change to the prefill logic itself. Paves the way
for the upcoming per-template prefill plumbing in common/chat.

* test: add coverage for continue_final_message vLLM compat flag

Two cases on top of the existing assistant prefill coverage. First,
continue_final_message true with add_generation_prompt false produces
the same rendered prompt as the prefill_assistant heuristic, proving
the new flag is a correct alias of the existing path. Second, both
flags set to true is rejected with HTTP 400, matching the
vLLM/transformers mutual exclusion contract.

* chore: update webui build output
2026-05-13 20:47:58 +02:00

574 lines
23 KiB
Python

import pytest
from openai import OpenAI
from utils import *
server: ServerProcess
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
[
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
global server
server.jinja = jinja
server.chat_template = chat_template
server.start()
res = server.make_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
})
assert res.status_code == 200
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
assert res.body["system_fingerprint"].startswith("b")
# we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668
# assert res.body["model"] == model if model is not None else server.model_alias
assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
assert choice["finish_reason"] == finish_reason
def test_chat_completion_cached_tokens():
global server
server.n_slots = 1
server.start()
seq = [
("1 2 3 4 5 6", 77, 0),
("1 2 3 4 5 6", 77, 76),
("1 2 3 4 5 9", 77, 51),
("1 2 3 9 9 9", 77, 47),
]
for user_prompt, n_prompt, n_cache in seq:
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Test"},
{"role": "user", "content": user_prompt},
],
})
assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["prompt_tokens_details"]["cached_tokens"] == n_cache
@pytest.mark.parametrize(
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
[
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
]
)
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
global server
server.model_alias = "llama-test-model"
server.start()
res = server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"stream": True,
})
content = ""
last_cmpl_id = None
for i, data in enumerate(res):
if data["choices"]:
choice = data["choices"][0]
if i == 0:
# Check first role message for stream=True
assert choice["delta"]["content"] is None
assert choice["delta"]["role"] == "assistant"
else:
assert "role" not in choice["delta"]
assert data["system_fingerprint"].startswith("b")
assert data["model"] == "llama-test-model"
if last_cmpl_id is None:
last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
if choice["finish_reason"] in ["stop", "length"]:
assert "content" not in choice["delta"]
assert match_regex(re_content, content)
assert choice["finish_reason"] == finish_reason
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"] or ''
else:
assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted
def test_chat_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=8,
seed=42,
temperature=0.8,
)
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
assert res.choices[0].finish_reason == "length"
assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content)
def test_chat_template():
global server
server.chat_template = "llama3"
server.debug = True # to get the "__verbose" object in the response
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
]
})
assert res.status_code == 200
assert "__verbose" in res.body
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
@pytest.mark.parametrize("prefill,re_prefill", [
("Whill", "Whill"),
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
])
def test_chat_template_assistant_prefill(prefill, re_prefill):
global server
server.chat_template = "llama3"
server.debug = True # to get the "__verbose" object in the response
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
{"role": "assistant", "content": prefill},
]
})
assert res.status_code == 200
assert "__verbose" in res.body
assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
def test_chat_template_continue_final_message_vllm_compat():
"""continue_final_message is the vLLM/transformers explicit alias for the prefill_assistant heuristic.
Both must produce the same prompt."""
global server
server.chat_template = "llama3"
server.debug = True
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"add_generation_prompt": False,
"continue_final_message": True,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
{"role": "assistant", "content": "Whill"},
]
})
assert res.status_code == 200
assert "__verbose" in res.body
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill"
def test_chat_template_continue_final_message_mutual_exclusion():
"""add_generation_prompt and continue_final_message both set to true must be rejected"""
global server
server.chat_template = "llama3"
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"add_generation_prompt": True,
"continue_final_message": True,
"messages": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
]
})
assert res.status_code == 400
def test_apply_chat_template():
global server
server.chat_template = "command-r"
server.start()
res = server.make_request("POST", "/apply-template", data={
"messages": [
{"role": "system", "content": "You are a test."},
{"role": "user", "content":"Hi there"},
]
})
assert res.status_code == 200
assert "prompt" in res.body
assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
({"type": "json_object"}, 10, "(\\{|John)+"),
({"type": "sound"}, 0, None),
# invalid response format (expected to fail)
({"type": "json_object", "schema": 123}, 0, None),
({"type": "json_object", "schema": {"type": 123}}, 0, None),
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
])
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"response_format": response_format,
})
if re_content is not None:
assert res.status_code == 200
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"])
else:
assert res.status_code == 400
assert "error" in res.body
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
(False, {"const": "42"}, 6, "\"42\""),
(True, {"const": "42"}, 6, "\"42\""),
])
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.debug = True
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"json_schema": json_schema,
})
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
])
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "user", "content": "Does not matter what I say, does it?"},
],
"grammar": grammar,
})
assert res.status_code == 200, res.body
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
@pytest.mark.parametrize("messages", [
None,
"string",
[123],
[{}],
[{"role": 123}],
[{"role": "system", "content": 123}],
# [{"content": "hello"}], # TODO: should not be a valid case
[{"role": "system", "content": "test"}, {}],
[{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
])
def test_invalid_chat_completion_req(messages):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"messages": messages,
})
assert res.status_code == 400 or res.status_code == 500
assert "error" in res.body
def test_chat_completion_with_timings_per_token():
global server
server.start()
res = server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": 10,
"messages": [{"role": "user", "content": "test"}],
"stream": True,
"stream_options": {"include_usage": True},
"timings_per_token": True,
})
stats_received = False
for i, data in enumerate(res):
if i == 0:
# Check first role message for stream=True
assert data["choices"][0]["delta"]["content"] is None
assert data["choices"][0]["delta"]["role"] == "assistant"
assert "timings" not in data, f'First event should not have timings: {data}'
else:
if data["choices"]:
assert "role" not in data["choices"][0]["delta"]
else:
assert "timings" in data
assert "prompt_per_second" in data["timings"]
assert "predicted_per_second" in data["timings"]
assert "predicted_n" in data["timings"]
assert data["timings"]["predicted_n"] <= 10
stats_received = True
assert stats_received
def test_logprobs():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=5,
logprobs=True,
top_logprobs=10,
)
output_text = res.choices[0].message.content
aggregated_text = ''
assert res.choices[0].logprobs is not None
assert res.choices[0].logprobs.content is not None
for token in res.choices[0].logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text
def test_logprobs_stream():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=5,
logprobs=True,
top_logprobs=10,
stream=True,
)
output_text = ''
aggregated_text = ''
for i, data in enumerate(res):
if data.choices:
choice = data.choices[0]
if i == 0:
# Check first role message for stream=True
assert choice.delta.content is None
assert choice.delta.role == "assistant"
else:
assert choice.delta.role is None
if choice.finish_reason is None:
if choice.delta.content:
output_text += choice.delta.content
assert choice.logprobs is not None
assert choice.logprobs.content is not None
for token in choice.logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text
def test_logit_bias():
global server
server.start()
exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
res = server.make_request("POST", "/tokenize", data={
"content": " " + " ".join(exclude) + " ",
})
assert res.status_code == 200
tokens = res.body["tokens"]
logit_bias = {tok: -100 for tok in tokens}
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=64,
logit_bias=logit_bias
)
output_text = res.choices[0].message.content
assert output_text
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
def test_context_size_exceeded():
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
] * 100, # make the prompt too long
})
assert res.status_code == 400
assert "error" in res.body
assert res.body["error"]["type"] == "exceed_context_size_error"
assert res.body["error"]["n_prompt_tokens"] > 0
assert server.n_ctx is not None
assert server.n_slots is not None
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
def test_context_size_exceeded_stream():
global server
server.start()
try:
for _ in server.make_stream_request("POST", "/chat/completions", data={
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
] * 100, # make the prompt too long
"stream": True}):
pass
assert False, "Should have failed"
except ServerError as e:
assert e.code == 400
assert "error" in e.body
assert e.body["error"]["type"] == "exceed_context_size_error"
assert e.body["error"]["n_prompt_tokens"] > 0
assert server.n_ctx is not None
assert server.n_slots is not None
assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
@pytest.mark.parametrize(
"n_batch,batch_count,reuse_cache",
[
(64, 4, False),
(64, 2, True),
]
)
def test_return_progress(n_batch, batch_count, reuse_cache):
global server
server.n_batch = n_batch
server.n_ctx = 256
server.n_slots = 1
server.start()
def make_cmpl_request():
return server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": 10,
"messages": [
{"role": "user", "content": "This is a test" * 10},
],
"stream": True,
"return_progress": True,
})
if reuse_cache:
# make a first request to populate the cache
res0 = make_cmpl_request()
for _ in res0:
pass # discard the output
res = make_cmpl_request()
last_progress = None
total_batch_count = 0
for data in res:
cur_progress = data.get("prompt_progress", None)
if cur_progress is None:
continue
if total_batch_count == 0:
# first progress report must have n_cache == n_processed
assert cur_progress["total"] > 0
assert cur_progress["cache"] == cur_progress["processed"]
if reuse_cache:
# when reusing cache, we expect some cached tokens
assert cur_progress["cache"] > 0
if last_progress is not None:
assert cur_progress["total"] == last_progress["total"]
assert cur_progress["cache"] == last_progress["cache"]
assert cur_progress["processed"] > last_progress["processed"]
total_batch_count += 1
last_progress = cur_progress
# last progress should indicate completion (all tokens processed)
assert last_progress is not None
assert last_progress["total"] > 0
assert last_progress["processed"] == last_progress["total"]
assert total_batch_count == batch_count
def test_chat_completions_multiple_choices():
global server
server.start()
# make sure cache can be reused across multiple choices and multiple requests
# ref: https://github.com/ggml-org/llama.cpp/pull/18663
for _ in range(2):
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"n": 2,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
# test forcing the same slot to be used
# the scheduler should not be locked up in this case
"id_slot": 0,
})
assert res.status_code == 200
assert len(res.body["choices"]) == 2
for choice in res.body["choices"]:
assert "assistant" == choice["message"]["role"]
assert choice["finish_reason"] == "length"