mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-15 05:24:06 +00:00
* server, webui: accept continue_final_message flag for vLLM API compat Add the continue_final_message body flag from the vLLM and transformers API. When set together with add_generation_prompt false, it triggers the existing prefill_assistant code path, regardless of the server side opt.prefill_assistant option. Mutual exclusion with add_generation_prompt true is enforced, matching vLLM behavior. WebUI sends continue_final_message and add_generation_prompt false on the Continue button, with the matching opt in option on the chat service. Pure API alignment, no change to the prefill logic itself. Paves the way for the upcoming per-template prefill plumbing in common/chat. * test: add coverage for continue_final_message vLLM compat flag Two cases on top of the existing assistant prefill coverage. First, continue_final_message true with add_generation_prompt false produces the same rendered prompt as the prefill_assistant heuristic, proving the new flag is a correct alias of the existing path. Second, both flags set to true is rejected with HTTP 400, matching the vLLM/transformers mutual exclusion contract. * chore: update webui build output
574 lines
23 KiB
Python
574 lines
23 KiB
Python
import pytest
|
|
from openai import OpenAI
|
|
from utils import *
|
|
|
|
server: ServerProcess
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def create_server():
|
|
global server
|
|
server = ServerPreset.tinyllama2()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
|
|
[
|
|
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
|
|
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
|
|
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
|
|
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
|
|
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
|
|
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
|
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
|
|
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
|
|
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
|
|
]
|
|
)
|
|
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
|
|
global server
|
|
server.jinja = jinja
|
|
server.chat_template = chat_template
|
|
server.start()
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"model": model,
|
|
"max_tokens": max_tokens,
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt},
|
|
],
|
|
})
|
|
assert res.status_code == 200
|
|
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
|
|
assert res.body["system_fingerprint"].startswith("b")
|
|
# we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668
|
|
# assert res.body["model"] == model if model is not None else server.model_alias
|
|
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
|
assert res.body["usage"]["completion_tokens"] == n_predicted
|
|
choice = res.body["choices"][0]
|
|
assert "assistant" == choice["message"]["role"]
|
|
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
|
|
assert choice["finish_reason"] == finish_reason
|
|
|
|
|
|
def test_chat_completion_cached_tokens():
|
|
global server
|
|
server.n_slots = 1
|
|
server.start()
|
|
seq = [
|
|
("1 2 3 4 5 6", 77, 0),
|
|
("1 2 3 4 5 6", 77, 76),
|
|
("1 2 3 4 5 9", 77, 51),
|
|
("1 2 3 9 9 9", 77, 47),
|
|
]
|
|
for user_prompt, n_prompt, n_cache in seq:
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"max_tokens": 8,
|
|
"messages": [
|
|
{"role": "system", "content": "Test"},
|
|
{"role": "user", "content": user_prompt},
|
|
],
|
|
})
|
|
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
|
assert res.body["usage"]["prompt_tokens_details"]["cached_tokens"] == n_cache
|
|
|
|
@pytest.mark.parametrize(
|
|
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
|
[
|
|
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
|
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
|
|
]
|
|
)
|
|
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
|
global server
|
|
server.model_alias = "llama-test-model"
|
|
server.start()
|
|
res = server.make_stream_request("POST", "/chat/completions", data={
|
|
"max_tokens": max_tokens,
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt},
|
|
],
|
|
"stream": True,
|
|
})
|
|
content = ""
|
|
last_cmpl_id = None
|
|
for i, data in enumerate(res):
|
|
if data["choices"]:
|
|
choice = data["choices"][0]
|
|
if i == 0:
|
|
# Check first role message for stream=True
|
|
assert choice["delta"]["content"] is None
|
|
assert choice["delta"]["role"] == "assistant"
|
|
else:
|
|
assert "role" not in choice["delta"]
|
|
assert data["system_fingerprint"].startswith("b")
|
|
assert data["model"] == "llama-test-model"
|
|
if last_cmpl_id is None:
|
|
last_cmpl_id = data["id"]
|
|
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
|
|
if choice["finish_reason"] in ["stop", "length"]:
|
|
assert "content" not in choice["delta"]
|
|
assert match_regex(re_content, content)
|
|
assert choice["finish_reason"] == finish_reason
|
|
else:
|
|
assert choice["finish_reason"] is None
|
|
content += choice["delta"]["content"] or ''
|
|
else:
|
|
assert data["usage"]["prompt_tokens"] == n_prompt
|
|
assert data["usage"]["completion_tokens"] == n_predicted
|
|
|
|
|
|
def test_chat_completion_with_openai_library():
|
|
global server
|
|
server.start()
|
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
|
res = client.chat.completions.create(
|
|
model="gpt-3.5-turbo-instruct",
|
|
messages=[
|
|
{"role": "system", "content": "Book"},
|
|
{"role": "user", "content": "What is the best book"},
|
|
],
|
|
max_tokens=8,
|
|
seed=42,
|
|
temperature=0.8,
|
|
)
|
|
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
|
|
assert res.choices[0].finish_reason == "length"
|
|
assert res.choices[0].message.content is not None
|
|
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
|
|
|
|
|
def test_chat_template():
|
|
global server
|
|
server.chat_template = "llama3"
|
|
server.debug = True # to get the "__verbose" object in the response
|
|
server.start()
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"max_tokens": 8,
|
|
"messages": [
|
|
{"role": "system", "content": "Book"},
|
|
{"role": "user", "content": "What is the best book"},
|
|
]
|
|
})
|
|
assert res.status_code == 200
|
|
assert "__verbose" in res.body
|
|
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
|
|
|
|
@pytest.mark.parametrize("prefill,re_prefill", [
|
|
("Whill", "Whill"),
|
|
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
|
|
])
|
|
def test_chat_template_assistant_prefill(prefill, re_prefill):
|
|
global server
|
|
server.chat_template = "llama3"
|
|
server.debug = True # to get the "__verbose" object in the response
|
|
server.start()
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"max_tokens": 8,
|
|
"messages": [
|
|
{"role": "system", "content": "Book"},
|
|
{"role": "user", "content": "What is the best book"},
|
|
{"role": "assistant", "content": prefill},
|
|
]
|
|
})
|
|
assert res.status_code == 200
|
|
assert "__verbose" in res.body
|
|
assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
|
|
|
|
|
|
def test_chat_template_continue_final_message_vllm_compat():
|
|
"""continue_final_message is the vLLM/transformers explicit alias for the prefill_assistant heuristic.
|
|
Both must produce the same prompt."""
|
|
global server
|
|
server.chat_template = "llama3"
|
|
server.debug = True
|
|
server.start()
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"max_tokens": 8,
|
|
"add_generation_prompt": False,
|
|
"continue_final_message": True,
|
|
"messages": [
|
|
{"role": "system", "content": "Book"},
|
|
{"role": "user", "content": "What is the best book"},
|
|
{"role": "assistant", "content": "Whill"},
|
|
]
|
|
})
|
|
assert res.status_code == 200
|
|
assert "__verbose" in res.body
|
|
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill"
|
|
|
|
|
|
def test_chat_template_continue_final_message_mutual_exclusion():
|
|
"""add_generation_prompt and continue_final_message both set to true must be rejected"""
|
|
global server
|
|
server.chat_template = "llama3"
|
|
server.start()
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"max_tokens": 8,
|
|
"add_generation_prompt": True,
|
|
"continue_final_message": True,
|
|
"messages": [
|
|
{"role": "user", "content": "Hi"},
|
|
{"role": "assistant", "content": "Hello"},
|
|
]
|
|
})
|
|
assert res.status_code == 400
|
|
|
|
|
|
def test_apply_chat_template():
|
|
global server
|
|
server.chat_template = "command-r"
|
|
server.start()
|
|
res = server.make_request("POST", "/apply-template", data={
|
|
"messages": [
|
|
{"role": "system", "content": "You are a test."},
|
|
{"role": "user", "content":"Hi there"},
|
|
]
|
|
})
|
|
assert res.status_code == 200
|
|
assert "prompt" in res.body
|
|
assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
|
|
|
|
|
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
|
|
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
|
|
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
|
|
({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
|
|
({"type": "json_object"}, 10, "(\\{|John)+"),
|
|
({"type": "sound"}, 0, None),
|
|
# invalid response format (expected to fail)
|
|
({"type": "json_object", "schema": 123}, 0, None),
|
|
({"type": "json_object", "schema": {"type": 123}}, 0, None),
|
|
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
|
|
])
|
|
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
|
|
global server
|
|
server.start()
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"max_tokens": n_predicted,
|
|
"messages": [
|
|
{"role": "system", "content": "You are a coding assistant."},
|
|
{"role": "user", "content": "Write an example"},
|
|
],
|
|
"response_format": response_format,
|
|
})
|
|
if re_content is not None:
|
|
assert res.status_code == 200
|
|
choice = res.body["choices"][0]
|
|
assert match_regex(re_content, choice["message"]["content"])
|
|
else:
|
|
assert res.status_code == 400
|
|
assert "error" in res.body
|
|
|
|
|
|
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
|
|
(False, {"const": "42"}, 6, "\"42\""),
|
|
(True, {"const": "42"}, 6, "\"42\""),
|
|
])
|
|
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
|
|
global server
|
|
server.jinja = jinja
|
|
server.debug = True
|
|
server.start()
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"max_tokens": n_predicted,
|
|
"messages": [
|
|
{"role": "system", "content": "You are a coding assistant."},
|
|
{"role": "user", "content": "Write an example"},
|
|
],
|
|
"json_schema": json_schema,
|
|
})
|
|
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
|
|
choice = res.body["choices"][0]
|
|
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
|
|
|
|
|
|
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
|
|
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
|
|
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
|
|
])
|
|
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
|
|
global server
|
|
server.jinja = jinja
|
|
server.start()
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"max_tokens": n_predicted,
|
|
"messages": [
|
|
{"role": "user", "content": "Does not matter what I say, does it?"},
|
|
],
|
|
"grammar": grammar,
|
|
})
|
|
assert res.status_code == 200, res.body
|
|
choice = res.body["choices"][0]
|
|
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
|
|
|
|
|
|
@pytest.mark.parametrize("messages", [
|
|
None,
|
|
"string",
|
|
[123],
|
|
[{}],
|
|
[{"role": 123}],
|
|
[{"role": "system", "content": 123}],
|
|
# [{"content": "hello"}], # TODO: should not be a valid case
|
|
[{"role": "system", "content": "test"}, {}],
|
|
[{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
|
|
])
|
|
def test_invalid_chat_completion_req(messages):
|
|
global server
|
|
server.start()
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"messages": messages,
|
|
})
|
|
assert res.status_code == 400 or res.status_code == 500
|
|
assert "error" in res.body
|
|
|
|
|
|
def test_chat_completion_with_timings_per_token():
|
|
global server
|
|
server.start()
|
|
res = server.make_stream_request("POST", "/chat/completions", data={
|
|
"max_tokens": 10,
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"stream": True,
|
|
"stream_options": {"include_usage": True},
|
|
"timings_per_token": True,
|
|
})
|
|
stats_received = False
|
|
for i, data in enumerate(res):
|
|
if i == 0:
|
|
# Check first role message for stream=True
|
|
assert data["choices"][0]["delta"]["content"] is None
|
|
assert data["choices"][0]["delta"]["role"] == "assistant"
|
|
assert "timings" not in data, f'First event should not have timings: {data}'
|
|
else:
|
|
if data["choices"]:
|
|
assert "role" not in data["choices"][0]["delta"]
|
|
else:
|
|
assert "timings" in data
|
|
assert "prompt_per_second" in data["timings"]
|
|
assert "predicted_per_second" in data["timings"]
|
|
assert "predicted_n" in data["timings"]
|
|
assert data["timings"]["predicted_n"] <= 10
|
|
stats_received = True
|
|
assert stats_received
|
|
|
|
|
|
def test_logprobs():
|
|
global server
|
|
server.start()
|
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
|
res = client.chat.completions.create(
|
|
model="gpt-3.5-turbo-instruct",
|
|
temperature=0.0,
|
|
messages=[
|
|
{"role": "system", "content": "Book"},
|
|
{"role": "user", "content": "What is the best book"},
|
|
],
|
|
max_tokens=5,
|
|
logprobs=True,
|
|
top_logprobs=10,
|
|
)
|
|
output_text = res.choices[0].message.content
|
|
aggregated_text = ''
|
|
assert res.choices[0].logprobs is not None
|
|
assert res.choices[0].logprobs.content is not None
|
|
for token in res.choices[0].logprobs.content:
|
|
aggregated_text += token.token
|
|
assert token.logprob <= 0.0
|
|
assert token.bytes is not None
|
|
assert len(token.top_logprobs) > 0
|
|
assert aggregated_text == output_text
|
|
|
|
|
|
def test_logprobs_stream():
|
|
global server
|
|
server.start()
|
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
|
res = client.chat.completions.create(
|
|
model="gpt-3.5-turbo-instruct",
|
|
temperature=0.0,
|
|
messages=[
|
|
{"role": "system", "content": "Book"},
|
|
{"role": "user", "content": "What is the best book"},
|
|
],
|
|
max_tokens=5,
|
|
logprobs=True,
|
|
top_logprobs=10,
|
|
stream=True,
|
|
)
|
|
output_text = ''
|
|
aggregated_text = ''
|
|
for i, data in enumerate(res):
|
|
if data.choices:
|
|
choice = data.choices[0]
|
|
if i == 0:
|
|
# Check first role message for stream=True
|
|
assert choice.delta.content is None
|
|
assert choice.delta.role == "assistant"
|
|
else:
|
|
assert choice.delta.role is None
|
|
if choice.finish_reason is None:
|
|
if choice.delta.content:
|
|
output_text += choice.delta.content
|
|
assert choice.logprobs is not None
|
|
assert choice.logprobs.content is not None
|
|
for token in choice.logprobs.content:
|
|
aggregated_text += token.token
|
|
assert token.logprob <= 0.0
|
|
assert token.bytes is not None
|
|
assert token.top_logprobs is not None
|
|
assert len(token.top_logprobs) > 0
|
|
assert aggregated_text == output_text
|
|
|
|
|
|
def test_logit_bias():
|
|
global server
|
|
server.start()
|
|
|
|
exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
|
|
|
|
res = server.make_request("POST", "/tokenize", data={
|
|
"content": " " + " ".join(exclude) + " ",
|
|
})
|
|
assert res.status_code == 200
|
|
tokens = res.body["tokens"]
|
|
logit_bias = {tok: -100 for tok in tokens}
|
|
|
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
|
res = client.chat.completions.create(
|
|
model="gpt-3.5-turbo-instruct",
|
|
temperature=0.0,
|
|
messages=[
|
|
{"role": "system", "content": "Book"},
|
|
{"role": "user", "content": "What is the best book"},
|
|
],
|
|
max_tokens=64,
|
|
logit_bias=logit_bias
|
|
)
|
|
output_text = res.choices[0].message.content
|
|
assert output_text
|
|
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
|
|
|
|
def test_context_size_exceeded():
|
|
global server
|
|
server.start()
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"messages": [
|
|
{"role": "system", "content": "Book"},
|
|
{"role": "user", "content": "What is the best book"},
|
|
] * 100, # make the prompt too long
|
|
})
|
|
assert res.status_code == 400
|
|
assert "error" in res.body
|
|
assert res.body["error"]["type"] == "exceed_context_size_error"
|
|
assert res.body["error"]["n_prompt_tokens"] > 0
|
|
assert server.n_ctx is not None
|
|
assert server.n_slots is not None
|
|
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
|
|
|
|
|
|
def test_context_size_exceeded_stream():
|
|
global server
|
|
server.start()
|
|
try:
|
|
for _ in server.make_stream_request("POST", "/chat/completions", data={
|
|
"messages": [
|
|
{"role": "system", "content": "Book"},
|
|
{"role": "user", "content": "What is the best book"},
|
|
] * 100, # make the prompt too long
|
|
"stream": True}):
|
|
pass
|
|
assert False, "Should have failed"
|
|
except ServerError as e:
|
|
assert e.code == 400
|
|
assert "error" in e.body
|
|
assert e.body["error"]["type"] == "exceed_context_size_error"
|
|
assert e.body["error"]["n_prompt_tokens"] > 0
|
|
assert server.n_ctx is not None
|
|
assert server.n_slots is not None
|
|
assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"n_batch,batch_count,reuse_cache",
|
|
[
|
|
(64, 4, False),
|
|
(64, 2, True),
|
|
]
|
|
)
|
|
def test_return_progress(n_batch, batch_count, reuse_cache):
|
|
global server
|
|
server.n_batch = n_batch
|
|
server.n_ctx = 256
|
|
server.n_slots = 1
|
|
server.start()
|
|
def make_cmpl_request():
|
|
return server.make_stream_request("POST", "/chat/completions", data={
|
|
"max_tokens": 10,
|
|
"messages": [
|
|
{"role": "user", "content": "This is a test" * 10},
|
|
],
|
|
"stream": True,
|
|
"return_progress": True,
|
|
})
|
|
if reuse_cache:
|
|
# make a first request to populate the cache
|
|
res0 = make_cmpl_request()
|
|
for _ in res0:
|
|
pass # discard the output
|
|
|
|
res = make_cmpl_request()
|
|
last_progress = None
|
|
total_batch_count = 0
|
|
|
|
for data in res:
|
|
cur_progress = data.get("prompt_progress", None)
|
|
if cur_progress is None:
|
|
continue
|
|
if total_batch_count == 0:
|
|
# first progress report must have n_cache == n_processed
|
|
assert cur_progress["total"] > 0
|
|
assert cur_progress["cache"] == cur_progress["processed"]
|
|
if reuse_cache:
|
|
# when reusing cache, we expect some cached tokens
|
|
assert cur_progress["cache"] > 0
|
|
if last_progress is not None:
|
|
assert cur_progress["total"] == last_progress["total"]
|
|
assert cur_progress["cache"] == last_progress["cache"]
|
|
assert cur_progress["processed"] > last_progress["processed"]
|
|
total_batch_count += 1
|
|
last_progress = cur_progress
|
|
|
|
# last progress should indicate completion (all tokens processed)
|
|
assert last_progress is not None
|
|
assert last_progress["total"] > 0
|
|
assert last_progress["processed"] == last_progress["total"]
|
|
assert total_batch_count == batch_count
|
|
|
|
|
|
def test_chat_completions_multiple_choices():
|
|
global server
|
|
server.start()
|
|
# make sure cache can be reused across multiple choices and multiple requests
|
|
# ref: https://github.com/ggml-org/llama.cpp/pull/18663
|
|
for _ in range(2):
|
|
res = server.make_request("POST", "/chat/completions", data={
|
|
"max_tokens": 8,
|
|
"n": 2,
|
|
"messages": [
|
|
{"role": "system", "content": "Book"},
|
|
{"role": "user", "content": "What is the best book"},
|
|
],
|
|
# test forcing the same slot to be used
|
|
# the scheduler should not be locked up in this case
|
|
"id_slot": 0,
|
|
})
|
|
assert res.status_code == 200
|
|
assert len(res.body["choices"]) == 2
|
|
for choice in res.body["choices"]:
|
|
assert "assistant" == choice["message"]["role"]
|
|
assert choice["finish_reason"] == "length"
|