Compare commits

...

3 Commits
b8889 ... b8892

Author SHA1 Message Date
Nikhil Jain
0d0764dfd2 [WebGPU] Implement async tensor api and event api (#22099)
* Only run webgpu CI on my fork

* Implement set_tensor_async

* Implement synchronize api

* Implement event creation and deletion API

* Cleanup

* Cleanup

* Comment out jobs for local CI run

* Add webgpu only workflow

* Delete .github/workflows/build-webgpu.yml

* Cleanup

* Cleanup

* Update API with function handlers

* Run clang-format

* Replace one-shot buffer with a direct queue.WriteBuffer using the buffer context
2026-04-22 10:52:01 -07:00
Masashi Yoshimura
6da7168312 ggml-webgpu: Add fused RMS_NORM + MUL (#21983)
* fused rms_norm_mul + mul

* Add GGML_WEBGPU_DISABLE_FUSION for being able to disable kernel fusion.

* Decouple num_fused_ops from webgpu_context; misc cleanup

* Fix eps handling and remove disable_fusion.

* Fix not to use c++20 initializers.
2026-04-22 10:51:40 -07:00
Piotr Wilkin (ilintar)
8bccdbbff9 chat: fix parallel_tool_calls default setting based on model capabilities, add tests for parallel tool calls and structured outputs (#22217)
* chat: fix parallel_tool_calls default setting based on model capabilities, add tests for parallel tool calls and structured outputs

* Fix ty errors.

* Fix flake8 err
2026-04-22 18:10:56 +02:00
7 changed files with 2418 additions and 27 deletions

View File

@@ -194,6 +194,26 @@ struct ggml_webgpu_row_norm_pipeline_key_hash {
}
};
/** RMS_NORM + MUL **/
struct ggml_webgpu_rms_norm_mul_pipeline_key {
bool inplace;
bool src_overlap;
bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const {
return inplace == other.inplace && src_overlap == other.src_overlap;
}
};
struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.src_overlap);
return seed;
}
};
/** Pad **/
struct ggml_webgpu_pad_pipeline_key {
bool circular;
@@ -517,7 +537,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
const size_t q_tile = context.sg_mat_m;
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
size_t bytes_per_kv = 0;
if (!key.kv_direct) {
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
}
@@ -755,16 +775,17 @@ class ggml_webgpu_shader_lib {
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
row_norm_pipelines; // op/inplace
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
get_rows_pipelines; // src_type, vectorized
get_rows_pipelines; // src_type, vectorized
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
unary_pipelines; // type/op/inplace
unary_pipelines; // type/op/inplace
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
scale_pipelines; // inplace
scale_pipelines; // inplace
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
solve_tri_pipelines; // type
solve_tri_pipelines; // type
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
ssm_conv_pipelines; // type/vectorized
ssm_conv_pipelines; // type/vectorized
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
webgpu_pipeline,
ggml_webgpu_gated_delta_net_pipeline_key_hash>
@@ -813,6 +834,11 @@ class ggml_webgpu_shader_lib {
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
conv2d_pipelines;
std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
webgpu_pipeline,
ggml_webgpu_rms_norm_mul_pipeline_key_hash>
rms_norm_mul_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@@ -1828,6 +1854,39 @@ class ggml_webgpu_shader_lib {
return unary_pipelines[key];
}
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
key.inplace = context.inplace;
key.src_overlap = context.src_overlap;
auto it = rms_norm_mul_pipelines.find(key);
if (it != rms_norm_mul_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string op_name = "RMS_NORM_MUL";
std::string variant = op_name;
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
} else if (key.src_overlap) {
defines.push_back("SRC_OVERLAP");
variant += "_src_overlap";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
rms_norm_mul_pipelines[key] = pipeline;
return rms_norm_mul_pipelines[key];
}
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_binary_pipeline_key key = {};
key.type = context.dst->type;

View File

@@ -1972,6 +1972,94 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor *
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context & ctx,
ggml_tensor * rn_src,
ggml_tensor * rn_dst,
ggml_tensor * mul_src0,
ggml_tensor * mul_src1,
ggml_tensor * dst) {
ggml_tensor * mul_src;
if (ggml_webgpu_tensor_equal(rn_dst, mul_src0)) {
mul_src = mul_src1;
} else if (ggml_webgpu_tensor_equal(rn_dst, mul_src1)) {
mul_src = mul_src0;
} else {
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
}
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
uint32_t offset_merged_rn_src = 0;
uint32_t offset_merged_mul_src = 0;
size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src);
size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src);
if (src_overlap) {
size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
offset_merged_rn_src =
(uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type));
offset_merged_mul_src =
(uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type));
}
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)),
offset_merged_rn_src,
offset_merged_mul_src,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)),
(uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)),
(uint32_t) (rn_src->nb[3] / ggml_type_size(rn_src->type)),
(uint32_t) (mul_src->nb[1] / ggml_type_size(mul_src->type)),
(uint32_t) (mul_src->nb[2] / ggml_type_size(mul_src->type)),
(uint32_t) (mul_src->nb[3] / ggml_type_size(mul_src->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) mul_src->ne[0],
(uint32_t) mul_src->ne[1],
(uint32_t) mul_src->ne[2],
(uint32_t) mul_src->ne[3],
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(rn_dst, 0)) // epsilon, treated as f32 in the shader
};
std::vector<wgpu::BindGroupEntry> entries;
if (inplace) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
} else if (src_overlap) {
size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
size_t merged_end =
std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src),
mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src));
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset,
merged_end - merged_offset));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
} else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
}
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
shader_lib_ctx.src_overlap = src_overlap;
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst));
}
static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool inplace = ggml_webgpu_tensor_equal(src, dst);
@@ -2468,15 +2556,48 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, int node_idx) {
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
return false;
}
// additional constraints specific to this fusion
const ggml_tensor * rms_norm = cgraph->nodes[node_idx];
const ggml_tensor * mul = cgraph->nodes[node_idx + 1];
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
// rms_norm only supports f32
if (mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 || mul->type != GGML_TYPE_F32) {
return false;
}
// if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
return false;
}
// rms_norm shader assumes contiguous rows
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
return true;
}
// Returns the encoded command, or std::nullopt if the operation is a no-op
static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
ggml_cgraph * cgraph,
int node_idx,
int & num_encoded_ops) {
ggml_tensor ** nodes = cgraph->nodes;
ggml_tensor * node = nodes[node_idx];
if (ggml_is_empty(node)) {
return std::nullopt;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
return std::nullopt;
}
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
WEBGPU_LOG_DEBUG("ggml_webgpu_encode(" << node << ", " << ggml_op_name(node->op) << ")");
ggml_tensor * src0 = node->src[0];
ggml_tensor * src1 = node->src[1];
@@ -2519,6 +2640,13 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context c
case GGML_OP_REPEAT:
return ggml_webgpu_repeat(ctx, src0, node);
case GGML_OP_RMS_NORM:
if (ggml_webgpu_can_fuse_rms_norm_mul(cgraph, node_idx)) {
num_encoded_ops = 2;
ggml_tensor * mul_node = nodes[node_idx + 1];
return ggml_webgpu_rms_norm_mul(ctx, src0, node, mul_node->src[0], mul_node->src[1], mul_node);
} else {
return ggml_webgpu_row_norm(ctx, src0, node);
}
case GGML_OP_L2_NORM:
return ggml_webgpu_row_norm(ctx, src0, node);
case GGML_OP_ROPE:
@@ -2629,6 +2757,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
uint32_t num_inflight_batches = 0;
bool contains_set_rows = false;
bool batch_compute_passes = true;
int num_encoded_ops = 1;
int node_idx = 0;
#ifdef GGML_WEBGPU_GPU_PROFILE
ctx->profile_timestamp_query_count = 0;
@@ -2641,11 +2771,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass();
}
for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
while (node_idx < cgraph->n_nodes) {
if (cgraph->nodes[node_idx]->op == GGML_OP_SET_ROWS) {
contains_set_rows = true;
}
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
if (auto cmd = ggml_webgpu_encode(ctx, cgraph, node_idx, num_encoded_ops)) {
commands.push_back(*cmd);
num_batched_kernels += cmd.value().num_kernels;
#ifdef GGML_WEBGPU_GPU_PROFILE
@@ -2670,6 +2800,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
ctx->param_arena.reset();
commands.clear();
}
node_idx += num_encoded_ops;
num_encoded_ops = 1;
}
if (ctx->active_compute_pass) {
@@ -2699,22 +2832,107 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
return GGML_STATUS_SUCCESS;
}
struct ggml_backend_webgpu_event_context {
webgpu_global_context global_ctx;
wgpu::Future future;
bool recorded = false;
};
static ggml_backend_event_t ggml_backend_webgpu_device_event_new(ggml_backend_dev_t device) {
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) device->context;
auto * event_ctx = new ggml_backend_webgpu_event_context();
event_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
auto * event = new ggml_backend_event;
event->device = device;
event->context = event_ctx;
return event;
}
static void ggml_backend_webgpu_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
GGML_UNUSED(dev);
delete static_cast<ggml_backend_webgpu_event_context *>(event->context);
delete event;
}
static void ggml_backend_webgpu_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
GGML_UNUSED(dev);
ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context;
if (!event_ctx->recorded) {
return;
}
wgpu::WaitStatus status =
event_ctx->global_ctx->instance.WaitAny(event_ctx->future, WEBGPU_RUNTIME_WAIT_TIMEOUT_NS);
if (status == wgpu::WaitStatus::TimedOut) {
GGML_ABORT("ggml_webgpu: event_synchronize timed out after %u ms\n", WEBGPU_RUNTIME_WAIT_TIMEOUT_MS);
}
event_ctx->recorded = false;
}
static void ggml_backend_webgpu_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context;
event_ctx->future = backend_ctx->webgpu_ctx->global_ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus, wgpu::StringView) {});
event_ctx->recorded = true;
}
static void ggml_backend_webgpu_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
GGML_UNUSED(backend);
ggml_backend_webgpu_device_event_synchronize(nullptr, event);
}
static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend,
ggml_tensor * tensor,
const void * data,
size_t offset,
size_t size) {
GGML_UNUSED(backend);
auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
// Write aligned portion
buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
if (size % 4 != 0) {
// If size is not a multiple of 4, we need to memset the remaining bytes
size_t remaining_size = size % 4;
// pack the remaining bytes into a uint32_t
uint32_t val32 = 0;
for (size_t i = 0; i < remaining_size; i++) {
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
}
// memset the remaining bytes
ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
total_offset + (size - remaining_size), remaining_size);
}
}
static void ggml_backend_webgpu_synchronize(ggml_backend_t backend) {
ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
ggml_backend_webgpu_wait_queue(backend_ctx->webgpu_ctx->global_ctx);
}
static ggml_backend_i ggml_backend_webgpu_i = {
/* .get_name = */ ggml_backend_webgpu_name,
/* .free = */ ggml_backend_webgpu_free,
/* .set_tensor_async = */ NULL,
/* .set_tensor_async = */ ggml_backend_webgpu_set_tensor_async,
/* .get_tensor_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .synchronize = */ ggml_backend_webgpu_synchronize,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_webgpu_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .event_record = */ ggml_backend_webgpu_event_record,
/* .event_wait = */ ggml_backend_webgpu_event_wait,
/* .graph_optimize = */ NULL,
};
@@ -3237,7 +3455,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->param_arena.init(
webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES,
webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
@@ -3487,12 +3705,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
}
// Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
if (min_bytes > limit_bytes) {
@@ -3677,9 +3895,9 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
/* .supports_op = */ ggml_backend_webgpu_device_supports_op,
/* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
/* .offload_op = */ NULL,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
/* .event_new = */ ggml_backend_webgpu_device_event_new,
/* .event_free = */ ggml_backend_webgpu_device_event_free,
/* .event_synchronize = */ ggml_backend_webgpu_device_event_synchronize,
};
/* End GGML Backend Device Interface */

View File

@@ -0,0 +1,139 @@
#ifdef INPLACE
@group(0) @binding(0)
var<storage, read_write> rn_src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> mul_src: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) {
mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset];
}
#elif SRC_OVERLAP
@group(0) @binding(0)
var<storage, read_write> merged_src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) {
dst[dst_offset] = scale * merged_src[rn_src_offset] * merged_src[mul_src_offset];
}
#else
@group(0) @binding(0)
var<storage, read_write> rn_src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> mul_src: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) {
dst[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset];
}
#endif
struct Params {
offset_rn_src: u32,
offset_mul_src: u32,
offset_merged_rn_src: u32,
offset_merged_mul_src: u32,
offset_dst: u32,
stride_rn_src1: u32,
stride_rn_src2: u32,
stride_rn_src3: u32,
stride_mul_src1: u32,
stride_mul_src2: u32,
stride_mul_src3: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
mul_src_ne0: u32,
mul_src_ne1: u32,
mul_src_ne2: u32,
mul_src_ne3: u32,
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
eps: f32
};
var<workgroup> scratch: array<f32, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
// one thread per row
var i = wid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1;
let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
var sum = 0.0f;
var col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
#ifdef SRC_OVERLAP
sum += pow(merged_src[i_rn_src_row + col], 2.0);
#else
sum += pow(rn_src[i_rn_src_row + col], 2.0);
#endif
col += WG_SIZE;
}
scratch[lid.x] = sum;
workgroupBarrier();
var offset: u32 = WG_SIZE / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
}
offset = offset / 2;
workgroupBarrier();
}
sum = scratch[0];
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
update(i_rn_src_row + col, i_dst_row + col, scale, i_mul_src_row + col % params.mul_src_ne0);
col += WG_SIZE;
}
}

View File

@@ -0,0 +1,991 @@
#!/usr/bin/env python3
"""
Test parallel tool-calling capability via chat completions endpoint.
Only run this against models that actually support parallel tool calls — this
script does not attempt to toggle that setting on the server. Each scenario is
explicitly worded so that a capable model SHOULD emit multiple tool calls in a
single assistant turn (either the same tool N times, or several different
tools at once).
Each test case contains:
- tools: list of tool definitions (OpenAI-compatible)
- messages: initial conversation messages
- mock_tool_responses: dict mapping tool_name -> callable(arguments) -> str (JSON)
- expected_parallel: dict describing what constitutes a successful parallel turn
{"min_parallel": int, # minimum tool_calls in one turn
"require_same_tool": Optional[str], # all parallel calls must be this tool
"require_distinct_tools": Optional[int], # >= N distinct tool names in one turn
"min_distinct_args_key": Optional[str]} # parallel calls must span this
# many distinct values of this arg key
- validate: callable(turns, all_tool_calls, final_content) -> (passed, reason)
"""
import argparse
import json
import requests
import sys
# ---------------------------------------------------------------------------
# Color / formatting helpers
# ---------------------------------------------------------------------------
RESET = "\x1b[0m"
BOLD = "\x1b[1m"
DIM = "\x1b[2m"
CYAN = "\x1b[36m"
YELLOW = "\x1b[33m"
GREEN = "\x1b[32m"
RED = "\x1b[31m"
BLUE = "\x1b[34m"
WHITE = "\x1b[97m"
MAGENTA = "\x1b[35m"
def _print(text="", end="\n"):
sys.stdout.write(text + end)
sys.stdout.flush()
def print_header(title):
bar = "" * 60
_print(f"\n{BOLD}{CYAN}{bar}{RESET}")
_print(
f"{BOLD}{CYAN}{WHITE}{title}{CYAN}{' ' * max(0, 58 - len(title))}{RESET}"
)
_print(f"{BOLD}{CYAN}{bar}{RESET}")
def print_turn_banner(turn_idx, n_calls):
color = MAGENTA if n_calls >= 2 else DIM
_print(f"\n {BOLD}{color}▶ turn {turn_idx}{n_calls} tool call(s){RESET}")
def print_tool_call(name, args):
args_str = json.dumps(args)
_print(
f" {BOLD}{YELLOW}{name}{RESET}{DIM}({args_str}){RESET}"
)
def print_tool_result(result):
preview = result[:140] + ("" if len(result) > 140 else "")
_print(f" {DIM}{BLUE}{preview}{RESET}")
def print_model_output(text):
sys.stdout.write(text)
sys.stdout.flush()
def print_pass(reason):
_print(f"\n{BOLD}{GREEN}✔ PASS{RESET} {reason}")
def print_fail(reason):
_print(f"\n{BOLD}{RED}✘ FAIL{RESET} {reason}")
def print_info(msg):
_print(f"{DIM}{msg}{RESET}")
def print_warn(msg):
_print(f"{BOLD}{YELLOW}{msg}{RESET}")
# ---------------------------------------------------------------------------
# HTTP helpers
# ---------------------------------------------------------------------------
def chat_completion(url, messages, tools=None, stream=False):
payload = {
"messages": messages,
"stream": stream,
"max_tokens": 4096,
}
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
try:
response = requests.post(url, json=payload, stream=stream)
response.raise_for_status()
except requests.exceptions.RequestException as e:
body = e.response.content if (e.response is not None) else b""
print_fail(f"Request error: {e} | body: {body}")
return None
full_content = ""
reasoning_content = ""
tool_calls: list[dict] = []
if stream:
for line in response.iter_lines():
if not line:
continue
decoded = line.decode("utf-8")
if not decoded.startswith("data: "):
continue
data_str = decoded[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
continue
choices = data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
if delta.get("reasoning_content"):
reasoning_content += delta["reasoning_content"]
if delta.get("content"):
full_content += delta["content"]
print_model_output(delta["content"])
for tc in delta.get("tool_calls", []):
idx = tc.get("index", 0)
while len(tool_calls) <= idx:
tool_calls.append(
{
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
if "id" in tc:
tool_calls[idx]["id"] += tc["id"]
if "function" in tc:
if "name" in tc["function"]:
tool_calls[idx]["function"]["name"] += tc["function"]["name"]
if "arguments" in tc["function"]:
tool_calls[idx]["function"]["arguments"] += tc["function"][
"arguments"
]
else:
data = response.json()
choices = data.get("choices", [])
if choices:
msg = choices[0].get("message", {})
full_content = msg.get("content") or ""
reasoning_content = msg.get("reasoning_content") or ""
tool_calls = msg.get("tool_calls") or []
if full_content:
print_model_output(full_content)
result = {"content": full_content, "tool_calls": tool_calls}
if reasoning_content:
result["reasoning_content"] = reasoning_content
return result
def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turns=6):
"""
Drive the multi-turn tool-call loop, but record each turn's tool calls
separately so parallelism can be validated.
Returns (turns, all_tool_calls, final_content) where `turns` is a list
of dicts: {"index": int, "tool_calls": [...], "content": str}.
"""
msgs = list(messages)
turns: list[dict] = []
all_tool_calls: list[dict] = []
for turn_idx in range(max_turns):
result = chat_completion(url, msgs, tools=tools, stream=stream)
if result is None:
return turns, all_tool_calls, None
tcs = result.get("tool_calls") or []
content = result.get("content") or ""
turns.append(
{"index": turn_idx, "tool_calls": list(tcs), "content": content}
)
if not tcs:
if content:
_print(f"\n{DIM}{'·' * 60}{RESET}")
_print(f"{DIM} model response:{RESET}\n")
return turns, all_tool_calls, content
print_turn_banner(turn_idx, len(tcs))
all_tool_calls.extend(tcs)
assistant_msg: dict = {
"role": "assistant",
"content": content,
"tool_calls": tcs,
}
reasoning = result.get("reasoning_content")
if reasoning:
assistant_msg["reasoning_content"] = reasoning
msgs.append(assistant_msg)
for tc in tcs:
tool_name = tc["function"]["name"]
try:
args = json.loads(tc["function"]["arguments"])
except json.JSONDecodeError:
args = {}
print_tool_call(tool_name, args)
mock_fn = mock_tool_responses.get(tool_name)
if mock_fn:
tool_result = mock_fn(args)
else:
tool_result = json.dumps({"error": f"Unknown tool: {tool_name}"})
print_tool_result(tool_result)
msgs.append(
{
"role": "tool",
"tool_call_id": tc.get("id", ""),
"content": tool_result,
}
)
return turns, all_tool_calls, None
# ---------------------------------------------------------------------------
# Parallelism helpers
# ---------------------------------------------------------------------------
def _best_parallel_turn(turns):
"""Return the turn (dict) with the most tool calls, or None if no tools."""
tool_turns = [t for t in turns if t["tool_calls"]]
if not tool_turns:
return None
return max(tool_turns, key=lambda t: len(t["tool_calls"]))
def _distinct_tool_names(turn):
return {tc["function"]["name"] for tc in turn["tool_calls"]}
def _distinct_arg_values(turn, key):
values = set()
for tc in turn["tool_calls"]:
try:
args = json.loads(tc["function"]["arguments"])
except json.JSONDecodeError:
continue
v = args.get(key)
if v is not None:
if isinstance(v, str):
values.add(v.strip().lower())
else:
values.add(v)
return values
def _check_parallel(turns, expected):
"""
Check that at least one turn satisfies the parallel-call expectations.
Returns (ok, reason).
"""
best = _best_parallel_turn(turns)
if best is None:
return False, "No tool calls were made at all"
min_parallel = expected.get("min_parallel", 2)
if len(best["tool_calls"]) < min_parallel:
by_turn = [len(t["tool_calls"]) for t in turns]
return False, (
f"No turn had >= {min_parallel} parallel tool calls "
f"(per-turn counts: {by_turn})"
)
require_same = expected.get("require_same_tool")
if require_same is not None:
names = [tc["function"]["name"] for tc in best["tool_calls"]]
if any(n != require_same for n in names):
return False, (
f"Parallel turn mixed tools; expected all {require_same!r}, got {names}"
)
require_distinct = expected.get("require_distinct_tools")
if require_distinct is not None:
distinct = _distinct_tool_names(best)
if len(distinct) < require_distinct:
return False, (
f"Parallel turn had only {len(distinct)} distinct tool names "
f"({distinct}); need >= {require_distinct}"
)
distinct_key = expected.get("min_distinct_args_key")
distinct_count = expected.get("min_distinct_args_count", min_parallel)
if distinct_key is not None:
values = _distinct_arg_values(best, distinct_key)
if len(values) < distinct_count:
return False, (
f"Parallel turn had only {len(values)} distinct {distinct_key!r} "
f"values ({values}); need >= {distinct_count}"
)
return True, (
f"Parallel turn had {len(best['tool_calls'])} calls across "
f"{len(_distinct_tool_names(best))} distinct tool(s)"
)
# ---------------------------------------------------------------------------
# Test case runner
# ---------------------------------------------------------------------------
def run_test(url, test_case, stream):
name = test_case["name"]
mode = f"{'stream' if stream else 'non-stream'}"
print_header(f"{name} [{mode}]")
turns, all_tool_calls, final_content = run_agentic_loop(
url,
messages=test_case["messages"],
tools=test_case["tools"],
mock_tool_responses=test_case["mock_tool_responses"],
stream=stream,
)
if not turns:
print_fail("No response from server.")
return False
parallel_ok, parallel_reason = _check_parallel(turns, test_case["expected_parallel"])
if not parallel_ok:
print_fail(parallel_reason)
return False
passed, reason = test_case["validate"](turns, all_tool_calls, final_content)
if passed:
print_pass(f"{parallel_reason}; {reason}")
else:
print_fail(reason)
return passed
# ---------------------------------------------------------------------------
# Test case definitions
# ---------------------------------------------------------------------------
# ---- Test 1: Multi-file read (same tool, multiple distinct paths) ----
_FILE_TOOLS = [
{
"type": "function",
"function": {
"name": "read_file",
"description": (
"Read the full contents of a file from the local filesystem. "
"Call this tool in parallel when asked to read several files — "
"each path needs its own call."
),
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute or repo-relative path to a file",
},
},
"required": ["path"],
},
},
},
]
_FILE_CONTENTS = {
"config/database.yml": "host: db.internal\nport: 5432\nuser: svc_app\n",
"config/redis.yml": "host: cache.internal\nport: 6379\ndb: 0\n",
"config/queue.yml": "broker: rabbitmq.internal\nport: 5672\nvhost: prod\n",
"config/auth.yml": "provider: oidc\nissuer: https://auth.internal\n",
}
def _read_file_mock(args):
path = args.get("path", "")
norm = path.lstrip("./").lstrip("/")
content = _FILE_CONTENTS.get(norm)
if content is None:
for k, v in _FILE_CONTENTS.items():
if path.endswith(k):
content = v
break
if content is None:
return json.dumps({"path": path, "error": "not found"})
return json.dumps({"path": path, "content": content})
MULTIFILE_READ_TEST = {
"name": "Parallel multi-file read (same tool, 4 distinct paths)",
"tools": _FILE_TOOLS,
"messages": [
{
"role": "user",
"content": (
"Please read all four of these config files so I can review them "
"together: config/database.yml, config/redis.yml, config/queue.yml, "
"and config/auth.yml. Call read_file for every path in parallel in "
"a single batch — do NOT read them one by one sequentially across "
"turns. After you have all four, give me a one-line summary of each."
),
}
],
"mock_tool_responses": {"read_file": _read_file_mock},
"expected_parallel": {
"min_parallel": 4,
"require_same_tool": "read_file",
"min_distinct_args_key": "path",
"min_distinct_args_count": 4,
},
"validate": lambda turns, tcs, content: _validate_multifile(turns, tcs, content),
}
def _validate_multifile(turns, tcs, content):
del turns
if not content:
return False, "No final summary produced"
return True, f"{len(tcs)} total read_file calls; content length={len(content)}"
# ---- Test 2: Batch TODO marking (same tool, N calls in one turn) ----
_TODO_TOOLS = [
{
"type": "function",
"function": {
"name": "mark_todo_complete",
"description": (
"Mark a single TODO item as complete by ID. When the user wants "
"several items marked at once, call this tool in parallel — "
"one call per item — rather than sequentially across turns."
),
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "string",
"description": "Identifier of the TODO item",
},
"note": {
"type": "string",
"description": "Optional completion note",
},
},
"required": ["todo_id"],
},
},
},
]
_TODO_DB = {
"T-101": "Draft onboarding doc",
"T-102": "Update dependency lockfile",
"T-103": "Fix flaky login test",
"T-104": "Rotate service credentials",
"T-105": "Archive Q4 reports",
}
def _mark_todo_mock(args):
tid = args.get("todo_id", "")
if tid in _TODO_DB:
return json.dumps({"todo_id": tid, "title": _TODO_DB[tid], "status": "done"})
return json.dumps({"todo_id": tid, "error": "unknown id"})
TODO_BATCH_TEST = {
"name": "Batch TODO completion (same tool, 5 IDs in one turn)",
"tools": _TODO_TOOLS,
"messages": [
{
"role": "user",
"content": (
"I finished every item on today's list. Please mark all of the "
"following TODOs as complete, in one parallel batch: T-101, T-102, "
"T-103, T-104, T-105. Don't mark them one at a time across separate "
"turns — issue all five mark_todo_complete calls at once. Afterwards "
"confirm which ones succeeded."
),
}
],
"mock_tool_responses": {"mark_todo_complete": _mark_todo_mock},
"expected_parallel": {
"min_parallel": 5,
"require_same_tool": "mark_todo_complete",
"min_distinct_args_key": "todo_id",
"min_distinct_args_count": 5,
},
"validate": lambda turns, tcs, content: _validate_todo(turns, tcs, content),
}
def _validate_todo(turns, tcs, content):
del turns
if not content:
return False, "No confirmation summary produced"
return True, f"{len(tcs)} total mark_todo_complete calls"
# ---- Test 3: Multi-city weather (same tool, N parallel locations) ----
_WEATHER_TOOLS = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": (
"Fetch current weather for ONE city. When the user asks about "
"several cities, call this tool in parallel — one call per city — "
"instead of sequentially."
),
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "City name"},
"units": {
"type": "string",
"enum": ["metric", "imperial"],
"default": "metric",
},
},
"required": ["city"],
},
},
},
]
_WEATHER_DB = {
"tokyo": {"city": "Tokyo", "temp_c": 18.4, "condition": "partly cloudy", "humidity": 64},
"london": {"city": "London", "temp_c": 9.1, "condition": "overcast", "humidity": 81},
"new york": {"city": "New York", "temp_c": 12.7, "condition": "clear", "humidity": 55},
"paris": {"city": "Paris", "temp_c": 11.3, "condition": "light rain", "humidity": 78},
}
def _weather_mock(args):
city = args.get("city", "").strip().lower()
if city.startswith("new york"):
city = "new york"
if city in _WEATHER_DB:
return json.dumps(_WEATHER_DB[city])
return json.dumps({"city": args.get("city", ""), "error": "unknown city"})
MULTI_WEATHER_TEST = {
"name": "Parallel multi-city weather (same tool, 4 cities)",
"tools": _WEATHER_TOOLS,
"messages": [
{
"role": "user",
"content": (
"I'm comparing today's weather across four cities for a travel "
"decision: Tokyo, London, New York, and Paris. Please call "
"get_weather for all four in parallel in a single turn — don't "
"fetch them one at a time. Then rank them from warmest to coolest."
),
}
],
"mock_tool_responses": {"get_weather": _weather_mock},
"expected_parallel": {
"min_parallel": 4,
"require_same_tool": "get_weather",
"min_distinct_args_key": "city",
"min_distinct_args_count": 4,
},
"validate": lambda turns, tcs, content: _validate_weather(turns, tcs, content),
}
def _validate_weather(turns, tcs, content):
del turns
if not content or not any(
kw in content.lower() for kw in ("warmest", "rank", "hot", "cool")
):
return False, f"Final content missing a ranking: {content!r}"
return True, f"{len(tcs)} total get_weather calls; ranking produced"
# ---- Test 4: Trip planning (different tools, parallel in one turn) ----
_TRIP_TOOLS = [
{
"type": "function",
"function": {
"name": "search_flights",
"description": "Search one-way flights between two airports on a given date.",
"parameters": {
"type": "object",
"properties": {
"from_airport": {"type": "string", "description": "IATA code, e.g. SFO"},
"to_airport": {"type": "string", "description": "IATA code, e.g. JFK"},
"date": {"type": "string", "description": "YYYY-MM-DD"},
},
"required": ["from_airport", "to_airport", "date"],
},
},
},
{
"type": "function",
"function": {
"name": "search_hotels",
"description": "Search hotels in a city for a date range.",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string"},
"check_in": {"type": "string", "description": "YYYY-MM-DD"},
"check_out": {"type": "string", "description": "YYYY-MM-DD"},
"max_price": {"type": "integer"},
},
"required": ["city", "check_in", "check_out"],
},
},
},
{
"type": "function",
"function": {
"name": "search_restaurants",
"description": "Search restaurants in a city by cuisine.",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string"},
"cuisine": {"type": "string"},
},
"required": ["city"],
},
},
},
]
_FLIGHTS_RESULT = {
"results": [
{"flight": "UA 1552", "depart": "08:15", "arrive": "16:45", "price": 389},
{"flight": "AA 20", "depart": "10:00", "arrive": "18:35", "price": 412},
]
}
_HOTELS_RESULT = {
"results": [
{"name": "Midtown Grand", "nightly_rate": 245, "rating": 4.3},
{"name": "Harbour Boutique", "nightly_rate": 312, "rating": 4.6},
]
}
_RESTAURANTS_RESULT = {
"results": [
{"name": "Trattoria Nona", "cuisine": "italian", "rating": 4.5},
{"name": "Osteria Blu", "cuisine": "italian", "rating": 4.4},
]
}
TRIP_PLAN_TEST = {
"name": "Trip planning (3 different tools in parallel)",
"tools": _TRIP_TOOLS,
"messages": [
{
"role": "user",
"content": (
"I'm flying from SFO to JFK on 2026-06-12 and staying four nights "
"(check out 2026-06-16). I'd also like some Italian restaurant "
"suggestions in New York. Please call search_flights, search_hotels, "
"and search_restaurants in parallel — all three in a single turn, "
"since they don't depend on each other. Then give me a concise "
"travel summary."
),
}
],
"mock_tool_responses": {
"search_flights": lambda _: json.dumps(_FLIGHTS_RESULT),
"search_hotels": lambda _: json.dumps(_HOTELS_RESULT),
"search_restaurants": lambda _: json.dumps(_RESTAURANTS_RESULT),
},
"expected_parallel": {
"min_parallel": 3,
"require_distinct_tools": 3,
},
"validate": lambda turns, tcs, content: _validate_trip(turns, tcs, content),
}
def _validate_trip(turns, tcs, content):
del turns
names = {tc["function"]["name"] for tc in tcs}
required = {"search_flights", "search_hotels", "search_restaurants"}
missing = required - names
if missing:
return False, f"Missing tool calls: {missing}"
if not content:
return False, "No travel summary produced"
return True, f"All three tools called; summary length={len(content)}"
# ---- Test 5: Portfolio check (same tool, parallel tickers) ----
_STOCK_TOOLS = [
{
"type": "function",
"function": {
"name": "get_stock_quote",
"description": (
"Get the latest quote for ONE ticker. When the user asks about "
"multiple tickers, call this tool in parallel — one per symbol — "
"rather than sequentially."
),
"parameters": {
"type": "object",
"properties": {
"symbol": {"type": "string", "description": "Ticker symbol"},
},
"required": ["symbol"],
},
},
},
]
_STOCK_DB = {
"AAPL": {"symbol": "AAPL", "price": 218.45, "change_pct": "+0.8%"},
"MSFT": {"symbol": "MSFT", "price": 421.10, "change_pct": "+1.2%"},
"GOOGL":{"symbol": "GOOGL","price": 175.22, "change_pct": "-0.3%"},
"AMZN": {"symbol": "AMZN", "price": 189.76, "change_pct": "+0.5%"},
"NVDA": {"symbol": "NVDA", "price": 140.88, "change_pct": "+2.4%"},
}
def _stock_mock(args):
sym = args.get("symbol", "").strip().upper()
if sym in _STOCK_DB:
return json.dumps(_STOCK_DB[sym])
return json.dumps({"symbol": sym, "error": "unknown ticker"})
PORTFOLIO_TEST = {
"name": "Portfolio check (same tool, 5 tickers in parallel)",
"tools": _STOCK_TOOLS,
"messages": [
{
"role": "user",
"content": (
"Pull the latest quote for every ticker in my portfolio — AAPL, "
"MSFT, GOOGL, AMZN, and NVDA — in a single parallel batch. These "
"lookups are independent, so please don't chain them across turns. "
"Once you have all five, tell me which ticker had the biggest "
"percentage change today."
),
}
],
"mock_tool_responses": {"get_stock_quote": _stock_mock},
"expected_parallel": {
"min_parallel": 5,
"require_same_tool": "get_stock_quote",
"min_distinct_args_key": "symbol",
"min_distinct_args_count": 5,
},
"validate": lambda turns, tcs, content: _validate_portfolio(turns, tcs, content),
}
def _validate_portfolio(turns, tcs, content):
del turns
if not content or ("nvda" not in content.lower() and "NVDA" not in content):
return False, f"Expected NVDA to be identified as the biggest mover: {content!r}"
return True, f"{len(tcs)} total quotes pulled"
# ---- Test 6: Mixed — translate + dictionary in parallel for the same word ----
_LANG_TOOLS = [
{
"type": "function",
"function": {
"name": "translate_text",
"description": "Translate a short text into a target language.",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string"},
"target_language": {"type": "string",
"description": "ISO 639-1 language code, e.g. 'es'"},
},
"required": ["text", "target_language"],
},
},
},
{
"type": "function",
"function": {
"name": "get_definition",
"description": "Get the English dictionary definition of a word.",
"parameters": {
"type": "object",
"properties": {
"word": {"type": "string"},
},
"required": ["word"],
},
},
},
{
"type": "function",
"function": {
"name": "get_synonyms",
"description": "Get English synonyms for a word.",
"parameters": {
"type": "object",
"properties": {
"word": {"type": "string"},
},
"required": ["word"],
},
},
},
]
def _translate_mock(args):
t = args.get("text", "")
lang = args.get("target_language", "")
return json.dumps({"source": t, "target_language": lang, "translation": f"[{lang}] {t}"})
def _definition_mock(args):
w = args.get("word", "")
return json.dumps({
"word": w,
"definition": f"A standard dictionary definition of {w!r}.",
})
def _synonyms_mock(args):
w = args.get("word", "")
return json.dumps({
"word": w,
"synonyms": ["synonym_a", "synonym_b", "synonym_c"],
})
LANG_TOOLKIT_TEST = {
"name": "Language toolkit (translate + definition + synonyms in parallel)",
"tools": _LANG_TOOLS,
"messages": [
{
"role": "user",
"content": (
"For the English word 'resilient', I need three independent "
"look-ups at once: (a) translate it into Spanish, (b) fetch its "
"dictionary definition, and (c) list its synonyms. These three "
"calls don't depend on each other — please issue them in parallel "
"in a single turn. Then present the combined results as a short "
"language note."
),
}
],
"mock_tool_responses": {
"translate_text": _translate_mock,
"get_definition": _definition_mock,
"get_synonyms": _synonyms_mock,
},
"expected_parallel": {
"min_parallel": 3,
"require_distinct_tools": 3,
},
"validate": lambda turns, tcs, content: _validate_lang(turns, tcs, content),
}
def _validate_lang(turns, tcs, content):
del turns
names = {tc["function"]["name"] for tc in tcs}
required = {"translate_text", "get_definition", "get_synonyms"}
missing = required - names
if missing:
return False, f"Missing tool calls: {missing}"
if not content:
return False, "No language note produced"
return True, f"All three lookup tools called; note length={len(content)}"
# ---------------------------------------------------------------------------
# All test cases
# ---------------------------------------------------------------------------
ALL_TEST_CASES = [
MULTIFILE_READ_TEST,
TODO_BATCH_TEST,
MULTI_WEATHER_TEST,
TRIP_PLAN_TEST,
PORTFOLIO_TEST,
LANG_TOOLKIT_TEST,
]
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description=(
"Test llama-server parallel tool-calling capability. Run this only "
"against models configured for parallel tool calls — this script "
"does not configure that itself."
)
)
parser.add_argument("--host", default="localhost")
parser.add_argument("--port", default=8080, type=int)
parser.add_argument(
"--no-stream", action="store_true", help="Disable streaming mode tests"
)
parser.add_argument(
"--stream-only", action="store_true", help="Only run streaming mode tests"
)
parser.add_argument(
"--test",
help="Run only the test whose name contains this substring (case-insensitive)",
)
args = parser.parse_args()
url = f"http://{args.host}:{args.port}/v1/chat/completions"
print_info(f"Testing server at {url}")
print_warn(
"This script expects the target model to emit multiple tool calls in a "
"single assistant turn. Run it only against parallel-tool-capable models."
)
modes: list[bool] = []
if not args.stream_only:
modes.append(False)
if not args.no_stream:
modes.append(True)
cases: list[dict] = ALL_TEST_CASES
if args.test:
name_filter = args.test.lower()
cases = [c for c in cases if name_filter in str(c["name"]).lower()]
if not cases:
print_fail(f"No test cases matched '{args.test}'")
sys.exit(1)
total = 0
passed = 0
for stream in modes:
for case in cases:
total += 1
if run_test(url, case, stream=stream):
passed += 1
color = GREEN if passed == total else RED
_print(f"\n{BOLD}{color}{'' * 60}{RESET}")
_print(f"{BOLD}{color} Results: {passed}/{total} passed{RESET}")
_print(f"{BOLD}{color}{'' * 60}{RESET}\n")
sys.exit(0 if passed == total else 1)
if __name__ == "__main__":
main()

980
scripts/server-test-structured.py Executable file
View File

@@ -0,0 +1,980 @@
#!/usr/bin/env python3
"""
Test structured output capability via chat completions endpoint.
Each test case contains:
- response_format: OpenAI-compatible response_format specification
(json_schema only — llama.cpp does not support json_object)
- messages: initial conversation messages
- tools (optional): tool definitions (for mixed tool + structured tests)
- mock_tool_responses (optional): dict mapping tool_name -> callable(arguments) -> str (JSON)
- apply_stage: "always" to apply response_format to every request,
"after_tools" to run the tool loop plain, then request a
structured summary in a follow-up user turn.
- followup (optional, for after_tools): user message appended before the
final structured call.
- validate: callable(parsed_json, tool_calls_history, raw_content) -> (passed: bool, reason: str)
"""
import argparse
import json
import requests
import sys
from typing import Any, cast
# ---------------------------------------------------------------------------
# Color / formatting helpers
# ---------------------------------------------------------------------------
RESET = "\x1b[0m"
BOLD = "\x1b[1m"
DIM = "\x1b[2m"
CYAN = "\x1b[36m"
YELLOW = "\x1b[33m"
GREEN = "\x1b[32m"
RED = "\x1b[31m"
BLUE = "\x1b[34m"
WHITE = "\x1b[97m"
MAGENTA = "\x1b[35m"
def _print(text="", end="\n"):
sys.stdout.write(text + end)
sys.stdout.flush()
def print_header(title):
bar = "" * 60
_print(f"\n{BOLD}{CYAN}{bar}{RESET}")
_print(
f"{BOLD}{CYAN}{WHITE}{title}{CYAN}{' ' * max(0, 58 - len(title))}{RESET}"
)
_print(f"{BOLD}{CYAN}{bar}{RESET}")
def print_tool_call(name, args):
args_str = json.dumps(args)
_print(
f"\n {BOLD}{YELLOW}⚙ tool call{RESET} {CYAN}{name}{RESET}{DIM}({args_str}){RESET}"
)
def print_tool_result(result):
preview = result[:160] + ("" if len(result) > 160 else "")
_print(f" {DIM}{BLUE}↳ result{RESET} {DIM}{preview}{RESET}")
def print_model_output(text):
sys.stdout.write(text)
sys.stdout.flush()
def print_pass(reason):
_print(f"\n{BOLD}{GREEN}✔ PASS{RESET} {reason}")
def print_fail(reason):
_print(f"\n{BOLD}{RED}✘ FAIL{RESET} {reason}")
def print_info(msg):
_print(f"{DIM}{msg}{RESET}")
def print_schema_note(label, rf):
kind = rf.get("type", "?")
name = ""
if kind == "json_schema":
name = rf.get("json_schema", {}).get("name", "")
_print(f"{DIM}{MAGENTA} ⟐ response_format [{label}]: {kind}"
f"{(' / ' + name) if name else ''}{RESET}")
# ---------------------------------------------------------------------------
# HTTP helpers
# ---------------------------------------------------------------------------
def chat_completion(url, messages, tools=None, response_format=None, stream=False):
payload = {
"messages": messages,
"stream": stream,
"max_tokens": 4096,
}
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
if response_format is not None:
payload["response_format"] = response_format
try:
response = requests.post(url, json=payload, stream=stream)
response.raise_for_status()
except requests.exceptions.RequestException as e:
body = e.response.content if (e.response is not None) else b""
print_fail(f"Request error: {e} | body: {body}")
return None
full_content = ""
reasoning_content = ""
tool_calls: list[dict] = []
if stream:
for line in response.iter_lines():
if not line:
continue
decoded = line.decode("utf-8")
if not decoded.startswith("data: "):
continue
data_str = decoded[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
continue
choices = data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
if delta.get("reasoning_content"):
reasoning_content += delta["reasoning_content"]
if delta.get("content"):
full_content += delta["content"]
print_model_output(delta["content"])
for tc in delta.get("tool_calls", []):
idx = tc.get("index", 0)
while len(tool_calls) <= idx:
tool_calls.append(
{
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
if "id" in tc:
tool_calls[idx]["id"] += tc["id"]
if "function" in tc:
if "name" in tc["function"]:
tool_calls[idx]["function"]["name"] += tc["function"]["name"]
if "arguments" in tc["function"]:
tool_calls[idx]["function"]["arguments"] += tc["function"][
"arguments"
]
else:
data = response.json()
choices = data.get("choices", [])
if choices:
msg = choices[0].get("message", {})
full_content = msg.get("content") or ""
reasoning_content = msg.get("reasoning_content") or ""
tool_calls = msg.get("tool_calls") or []
if full_content:
print_model_output(full_content)
result = {"content": full_content, "tool_calls": tool_calls}
if reasoning_content:
result["reasoning_content"] = reasoning_content
return result
def run_tool_loop(
url, messages, tools, mock_tool_responses, stream, response_format=None,
max_turns=6,
):
"""
Drive the tool-call loop. If response_format is provided it is applied to
every request. Returns (all_tool_calls, final_messages, final_content).
"""
msgs = list(messages)
all_tool_calls: list[dict] = []
for _ in range(max_turns):
result = chat_completion(
url, msgs, tools=tools, response_format=response_format, stream=stream
)
if result is None:
return all_tool_calls, msgs, None
tcs = result.get("tool_calls") or []
content = result.get("content") or ""
if not tcs:
if content:
_print(f"\n{DIM}{'·' * 60}{RESET}")
return all_tool_calls, msgs, content
all_tool_calls.extend(tcs)
assistant_msg: dict = {
"role": "assistant",
"content": content,
"tool_calls": tcs,
}
reasoning = result.get("reasoning_content")
if reasoning:
assistant_msg["reasoning_content"] = reasoning
msgs.append(assistant_msg)
for tc in tcs:
tool_name = tc["function"]["name"]
try:
args = json.loads(tc["function"]["arguments"])
except json.JSONDecodeError:
args = {}
print_tool_call(tool_name, args)
mock_fn = mock_tool_responses.get(tool_name) if mock_tool_responses else None
if mock_fn:
tool_result = mock_fn(args)
else:
tool_result = json.dumps({"error": f"Unknown tool: {tool_name}"})
print_tool_result(tool_result)
msgs.append(
{
"role": "tool",
"tool_call_id": tc.get("id", ""),
"content": tool_result,
}
)
return all_tool_calls, msgs, None
# ---------------------------------------------------------------------------
# Test case runner
# ---------------------------------------------------------------------------
def _try_parse_json(text):
"""Attempt to parse text as JSON, trimming common markdown fences."""
if text is None:
return None
stripped = text.strip()
if stripped.startswith("```"):
lines = stripped.splitlines()
if lines and lines[0].startswith("```"):
lines = lines[1:]
if lines and lines[-1].strip().startswith("```"):
lines = lines[:-1]
stripped = "\n".join(lines).strip()
try:
return json.loads(stripped)
except json.JSONDecodeError:
return None
def run_test(url, test_case, stream):
name = test_case["name"]
mode = f"{'stream' if stream else 'non-stream'}"
apply_stage = test_case.get("apply_stage", "always")
print_header(f"{name} [{mode}] ({apply_stage})")
response_format = test_case["response_format"]
print_schema_note(apply_stage, response_format)
tools = test_case.get("tools")
mocks = test_case.get("mock_tool_responses") or {}
all_tcs: list[dict] = []
final_content = None
if apply_stage == "always":
all_tcs, _msgs, final_content = run_tool_loop(
url,
messages=list(test_case["messages"]),
tools=tools,
mock_tool_responses=mocks,
stream=stream,
response_format=response_format,
)
elif apply_stage == "after_tools":
# Phase 1: plain tool loop, no response_format applied yet.
all_tcs, msgs, interim_content = run_tool_loop(
url,
messages=list(test_case["messages"]),
tools=tools,
mock_tool_responses=mocks,
stream=stream,
response_format=None,
)
if interim_content:
msgs.append({"role": "assistant", "content": interim_content})
followup = test_case.get(
"followup",
"Now output the answer strictly as JSON matching the provided schema. "
"Do not include commentary.",
)
msgs.append({"role": "user", "content": followup})
# Phase 2: request final structured output. Tools are not passed so the
# model focuses on producing the schema-constrained answer.
_print(f"\n{DIM}{MAGENTA} ⟐ follow-up turn with response_format applied{RESET}")
result = chat_completion(
url, msgs, tools=None, response_format=response_format, stream=stream
)
final_content = result["content"] if result else None
else:
print_fail(f"Unknown apply_stage: {apply_stage}")
return False
if final_content is None:
print_fail("No final content from server.")
return False
parsed = _try_parse_json(final_content)
if parsed is None:
print_fail(f"Final content is not valid JSON: {final_content[:200]!r}")
return False
passed, reason = test_case["validate"](parsed, all_tcs, final_content)
if passed:
print_pass(reason)
else:
print_fail(reason)
return passed
# ---------------------------------------------------------------------------
# Test case definitions
# ---------------------------------------------------------------------------
# ---- Test 1: Book metadata extraction (always / json_schema) ----
_BOOK_SCHEMA = {
"type": "json_schema",
"json_schema": {
"name": "book_metadata",
"strict": True,
"schema": {
"type": "object",
"additionalProperties": False,
"properties": {
"title": {"type": "string"},
"author": {"type": "string"},
"year": {"type": "integer"},
"genre": {
"type": "string",
"enum": [
"fiction",
"non-fiction",
"fantasy",
"sci-fi",
"mystery",
"biography",
"history",
"other",
],
},
"page_count": {"type": "integer"},
},
"required": ["title", "author", "year", "genre", "page_count"],
},
},
}
BOOK_TEST_CASE = {
"name": "Book metadata extraction (json_schema, always)",
"response_format": _BOOK_SCHEMA,
"apply_stage": "always",
"messages": [
{
"role": "user",
"content": (
"Extract book metadata from this description: "
"'Dune is a 1965 science fiction epic by Frank Herbert, spanning roughly "
"688 pages in its first edition, set on the desert planet Arrakis.' "
"Return the data as JSON."
),
}
],
"validate": lambda parsed, tcs, raw: _validate_book(parsed),
}
def _validate_book(parsed):
required = {"title", "author", "year", "genre", "page_count"}
missing = required - parsed.keys()
if missing:
return False, f"Missing fields: {missing}"
if not isinstance(parsed["title"], str) or not parsed["title"]:
return False, "title must be a non-empty string"
if not isinstance(parsed["author"], str) or "herbert" not in parsed["author"].lower():
return False, f"author unexpected: {parsed['author']!r}"
if not isinstance(parsed["year"], int) or parsed["year"] != 1965:
return False, f"year should be 1965, got {parsed['year']!r}"
if parsed["genre"] not in {
"fiction", "non-fiction", "fantasy", "sci-fi", "mystery",
"biography", "history", "other",
}:
return False, f"genre not in enum: {parsed['genre']!r}"
if not isinstance(parsed["page_count"], int) or parsed["page_count"] <= 0:
return False, f"page_count should be positive int: {parsed['page_count']!r}"
return True, f"Book: {parsed['title']} ({parsed['year']}) / {parsed['genre']}"
# ---- Test 2: Sentiment classification (always / enum-constrained) ----
_SENTIMENT_SCHEMA = {
"type": "json_schema",
"json_schema": {
"name": "sentiment_analysis",
"strict": True,
"schema": {
"type": "object",
"additionalProperties": False,
"properties": {
"sentiment": {
"type": "string",
"enum": ["positive", "negative", "neutral"],
},
"confidence": {"type": "number"},
"keywords": {
"type": "array",
"items": {"type": "string"},
"minItems": 1,
"maxItems": 5,
},
},
"required": ["sentiment", "confidence", "keywords"],
},
},
}
SENTIMENT_TEST_CASE = {
"name": "Sentiment analysis with enum and array",
"response_format": _SENTIMENT_SCHEMA,
"apply_stage": "always",
"messages": [
{
"role": "user",
"content": (
"Analyse the sentiment of this review and return JSON with the "
"detected sentiment label, a confidence score between 0 and 1, "
"and up to five keyword strings that drove the classification:\n\n"
"'This product completely exceeded my expectations. The build "
"quality is phenomenal, it arrived a day early, and customer "
"support was delightful when I had a setup question.'"
),
}
],
"validate": lambda parsed, tcs, raw: _validate_sentiment(parsed),
}
def _validate_sentiment(parsed):
if parsed.get("sentiment") not in {"positive", "negative", "neutral"}:
return False, f"sentiment not in enum: {parsed.get('sentiment')!r}"
if parsed["sentiment"] != "positive":
return False, f"expected positive sentiment, got {parsed['sentiment']}"
conf = parsed.get("confidence")
if not isinstance(conf, (int, float)) or not (0.0 <= conf <= 1.0):
return False, f"confidence not in [0,1]: {conf!r}"
kws = parsed.get("keywords")
if not isinstance(kws, list) or not (1 <= len(kws) <= 5):
return False, f"keywords length out of range: {kws!r}"
if not all(isinstance(k, str) and k for k in kws):
return False, f"keywords must be non-empty strings: {kws!r}"
return True, f"sentiment={parsed['sentiment']} conf={conf} kws={kws}"
# ---- Test 3: Nested recipe schema (always) ----
_RECIPE_SCHEMA = {
"type": "json_schema",
"json_schema": {
"name": "recipe",
"strict": True,
"schema": {
"type": "object",
"additionalProperties": False,
"properties": {
"name": {"type": "string"},
"servings": {"type": "integer"},
"ingredients": {
"type": "array",
"minItems": 2,
"items": {
"type": "object",
"additionalProperties": False,
"properties": {
"item": {"type": "string"},
"quantity": {"type": "string"},
},
"required": ["item", "quantity"],
},
},
"steps": {
"type": "array",
"minItems": 2,
"items": {"type": "string"},
},
"prep_time_minutes": {"type": "integer"},
},
"required": ["name", "servings", "ingredients", "steps", "prep_time_minutes"],
},
},
}
RECIPE_TEST_CASE = {
"name": "Nested recipe with arrays of objects",
"response_format": _RECIPE_SCHEMA,
"apply_stage": "always",
"messages": [
{
"role": "user",
"content": (
"Give me a simple 4-serving scrambled eggs recipe as structured JSON. "
"Include the recipe name, servings, ingredients (each with item and "
"quantity), preparation steps, and total prep time in minutes."
),
}
],
"validate": lambda parsed, tcs, raw: _validate_recipe(parsed),
}
def _validate_recipe(parsed):
required = {"name", "servings", "ingredients", "steps", "prep_time_minutes"}
missing = required - parsed.keys()
if missing:
return False, f"Missing fields: {missing}"
if not isinstance(parsed["name"], str) or not parsed["name"]:
return False, "name must be a non-empty string"
if not isinstance(parsed["servings"], int) or parsed["servings"] <= 0:
return False, f"servings must be positive int: {parsed['servings']!r}"
ings = parsed["ingredients"]
if not isinstance(ings, list) or len(ings) < 2:
return False, f"ingredients must be array of >=2: got {ings!r}"
for i, ing in enumerate(ings):
if not isinstance(ing, dict):
return False, f"ingredient[{i}] is not an object: {ing!r}"
ing_d = cast(dict[str, Any], ing)
item_val = ing_d.get("item")
qty_val = ing_d.get("quantity")
if item_val is None or qty_val is None:
return False, f"ingredient[{i}] missing item/quantity: {ing!r}"
if not isinstance(item_val, str) or not isinstance(qty_val, str):
return False, f"ingredient[{i}] fields must be strings: {ing!r}"
steps = parsed["steps"]
if not isinstance(steps, list) or len(steps) < 2:
return False, f"steps must be array of >=2 strings: got {steps!r}"
if not all(isinstance(s, str) and s for s in steps):
return False, "all steps must be non-empty strings"
pt = parsed["prep_time_minutes"]
if not isinstance(pt, int) or pt <= 0:
return False, f"prep_time_minutes must be positive int: {pt!r}"
return True, f"recipe '{parsed['name']}' with {len(ings)} ingredients, {len(steps)} steps"
# ---- Test 4: Tool call -> structured product comparison (after_tools) ----
_SHOP_TOOLS = [
{
"type": "function",
"function": {
"name": "search_products",
"description": "Search a product catalogue by keyword.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"},
},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "get_product_details",
"description": "Get detailed specs for a product by ID.",
"parameters": {
"type": "object",
"properties": {
"product_id": {"type": "string"},
},
"required": ["product_id"],
},
},
},
]
_SHOP_SEARCH_RESULT = {
"results": [
{"product_id": "LAP-001", "title": "AeroBook 13 Pro", "price": 1399.0, "rating": 4.7},
{"product_id": "LAP-002", "title": "QuantumSlim 14", "price": 1199.0, "rating": 4.4},
{"product_id": "LAP-003", "title": "NimbusWork Ultra 15", "price": 999.0, "rating": 4.2},
],
}
_SHOP_PRODUCT_DETAILS = {
"LAP-001": {
"product_id": "LAP-001",
"title": "AeroBook 13 Pro",
"cpu": "M-series 10-core",
"ram_gb": 16,
"storage_gb": 512,
"battery_hours": 18,
"weight_kg": 1.24,
"price": 1399.0,
},
"LAP-002": {
"product_id": "LAP-002",
"title": "QuantumSlim 14",
"cpu": "Core i7 12-core",
"ram_gb": 16,
"storage_gb": 512,
"battery_hours": 12,
"weight_kg": 1.35,
"price": 1199.0,
},
"LAP-003": {
"product_id": "LAP-003",
"title": "NimbusWork Ultra 15",
"cpu": "Ryzen 7 8-core",
"ram_gb": 16,
"storage_gb": 1024,
"battery_hours": 10,
"weight_kg": 1.70,
"price": 999.0,
},
}
def _shop_details_mock(args):
pid = args.get("product_id", "")
if pid in _SHOP_PRODUCT_DETAILS:
return json.dumps(_SHOP_PRODUCT_DETAILS[pid])
return json.dumps({"error": f"unknown product_id: {pid}"})
_SHOP_COMPARISON_SCHEMA = {
"type": "json_schema",
"json_schema": {
"name": "laptop_comparison",
"strict": True,
"schema": {
"type": "object",
"additionalProperties": False,
"properties": {
"recommendation": {"type": "string"},
"ranked_candidates": {
"type": "array",
"minItems": 2,
"items": {
"type": "object",
"additionalProperties": False,
"properties": {
"product_id": {"type": "string"},
"title": {"type": "string"},
"score": {"type": "number"},
"reason": {"type": "string"},
},
"required": ["product_id", "title", "score", "reason"],
},
},
},
"required": ["recommendation", "ranked_candidates"],
},
},
}
SHOP_COMPARISON_TEST_CASE = {
"name": "Tool calls then structured laptop comparison (after_tools)",
"response_format": _SHOP_COMPARISON_SCHEMA,
"apply_stage": "after_tools",
"tools": _SHOP_TOOLS,
"mock_tool_responses": {
"search_products": lambda _: json.dumps(_SHOP_SEARCH_RESULT),
"get_product_details": _shop_details_mock,
},
"messages": [
{
"role": "user",
"content": (
"I need a lightweight laptop for travel. Please search the catalogue "
"for 'ultraportable laptop', then fetch detailed specs for at least two "
"of the top candidates. Once you've gathered the data I'll ask you to "
"produce a structured comparison."
),
}
],
"followup": (
"Thanks. Now produce the final comparison strictly as JSON matching the "
"laptop_comparison schema: your single best recommendation (the product_id), "
"and a ranked_candidates array of at least two laptops, each with "
"product_id, title, a numeric score, and a short reason."
),
"validate": lambda parsed, tcs, raw: _validate_shop_comparison(parsed, tcs),
}
def _validate_shop_comparison(parsed, tcs):
names = [tc["function"]["name"] for tc in tcs]
if "search_products" not in names:
return False, f"expected search_products tool call, got {names}"
if "get_product_details" not in names:
return False, f"expected get_product_details tool call, got {names}"
if "recommendation" not in parsed or not isinstance(parsed["recommendation"], str):
return False, f"recommendation missing or not a string: {parsed!r}"
cands = parsed.get("ranked_candidates")
if not isinstance(cands, list) or len(cands) < 2:
return False, f"ranked_candidates must be >=2: {cands!r}"
valid_ids = set(_SHOP_PRODUCT_DETAILS.keys())
candidate_pids: list = []
for i, c in enumerate(cands):
if not isinstance(c, dict):
return False, f"candidate[{i}] not an object: {c!r}"
c_d = cast(dict[str, Any], c)
pid = c_d.get("product_id")
title = c_d.get("title")
score = c_d.get("score")
reason = c_d.get("reason")
for k, v in (("product_id", pid), ("title", title),
("score", score), ("reason", reason)):
if v is None:
return False, f"candidate[{i}] missing {k}: {c!r}"
if pid not in valid_ids:
return False, f"candidate[{i}].product_id not in catalogue: {pid!r}"
if not isinstance(score, (int, float)):
return False, f"candidate[{i}].score not numeric: {score!r}"
candidate_pids.append(pid)
recommendation = parsed["recommendation"]
if recommendation not in valid_ids and recommendation not in candidate_pids:
return False, f"recommendation {recommendation!r} not in candidates"
return True, (
f"tools={names}; recommended={parsed['recommendation']}; "
f"{len(cands)} ranked candidates"
)
# ---- Test 5: Multi-step research then structured report (after_tools) ----
_RESEARCH_TOOLS = [
{
"type": "function",
"function": {
"name": "get_country_stats",
"description": "Fetch basic statistics for a country (population, GDP, capital).",
"parameters": {
"type": "object",
"properties": {
"country": {"type": "string"},
},
"required": ["country"],
},
},
},
{
"type": "function",
"function": {
"name": "get_climate_info",
"description": "Fetch climate information for a country.",
"parameters": {
"type": "object",
"properties": {
"country": {"type": "string"},
},
"required": ["country"],
},
},
},
]
_COUNTRY_STATS = {
"norway": {
"country": "Norway",
"capital": "Oslo",
"population": 5_480_000,
"gdp_usd_trillion": 0.48,
"currency": "NOK",
}
}
_CLIMATE_INFO = {
"norway": {
"country": "Norway",
"climate_zone": "subarctic / temperate coastal",
"avg_winter_temp_c": -4.5,
"avg_summer_temp_c": 16.0,
"annual_precipitation_mm": 1400,
}
}
def _country_stats_mock(args):
c = args.get("country", "").strip().lower()
if c in _COUNTRY_STATS:
return json.dumps(_COUNTRY_STATS[c])
return json.dumps({"error": f"unknown country: {c}"})
def _climate_info_mock(args):
c = args.get("country", "").strip().lower()
if c in _CLIMATE_INFO:
return json.dumps(_CLIMATE_INFO[c])
return json.dumps({"error": f"unknown country: {c}"})
_RESEARCH_REPORT_SCHEMA = {
"type": "json_schema",
"json_schema": {
"name": "country_report",
"strict": True,
"schema": {
"type": "object",
"additionalProperties": False,
"properties": {
"country": {"type": "string"},
"capital": {"type": "string"},
"population": {"type": "integer"},
"climate_summary": {"type": "string"},
"highlights": {
"type": "array",
"minItems": 2,
"maxItems": 5,
"items": {"type": "string"},
},
"suitable_for_tourism": {"type": "boolean"},
},
"required": [
"country", "capital", "population",
"climate_summary", "highlights", "suitable_for_tourism",
],
},
},
}
COUNTRY_REPORT_TEST_CASE = {
"name": "Research pipeline then structured country report (after_tools)",
"response_format": _RESEARCH_REPORT_SCHEMA,
"apply_stage": "after_tools",
"tools": _RESEARCH_TOOLS,
"mock_tool_responses": {
"get_country_stats": _country_stats_mock,
"get_climate_info": _climate_info_mock,
},
"messages": [
{
"role": "user",
"content": (
"I'm preparing a short briefing on Norway. Please call the "
"get_country_stats and get_climate_info tools to gather data "
"first. Afterwards I'll ask for a structured summary."
),
}
],
"followup": (
"Based on the tool results, produce the briefing as JSON matching the "
"country_report schema. Populate every required field and provide between "
"two and five highlights."
),
"validate": lambda parsed, tcs, raw: _validate_country_report(parsed, tcs),
}
def _validate_country_report(parsed, tcs):
names = [tc["function"]["name"] for tc in tcs]
for required_tool in ("get_country_stats", "get_climate_info"):
if required_tool not in names:
return False, f"missing tool call {required_tool!r}: got {names}"
required = {
"country", "capital", "population",
"climate_summary", "highlights", "suitable_for_tourism",
}
missing = required - parsed.keys()
if missing:
return False, f"missing report fields: {missing}"
if "norway" not in parsed["country"].lower():
return False, f"country should reference Norway: {parsed['country']!r}"
if "oslo" not in parsed["capital"].lower():
return False, f"capital should be Oslo: {parsed['capital']!r}"
if not isinstance(parsed["population"], int) or parsed["population"] < 1_000_000:
return False, f"population implausible: {parsed['population']!r}"
if not isinstance(parsed["climate_summary"], str) or not parsed["climate_summary"]:
return False, "climate_summary must be a non-empty string"
hls = parsed["highlights"]
if not isinstance(hls, list) or not (2 <= len(hls) <= 5):
return False, f"highlights length out of range: {hls!r}"
if not all(isinstance(h, str) and h for h in hls):
return False, "each highlight must be a non-empty string"
if not isinstance(parsed["suitable_for_tourism"], bool):
return False, f"suitable_for_tourism must be bool: {parsed['suitable_for_tourism']!r}"
return True, (
f"tools={names}; report for {parsed['country']} "
f"(pop {parsed['population']}, {len(hls)} highlights)"
)
# ---------------------------------------------------------------------------
# All test cases
# ---------------------------------------------------------------------------
ALL_TEST_CASES = [
BOOK_TEST_CASE,
SENTIMENT_TEST_CASE,
RECIPE_TEST_CASE,
SHOP_COMPARISON_TEST_CASE,
COUNTRY_REPORT_TEST_CASE,
]
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Test llama-server structured-output capability."
)
parser.add_argument("--host", default="localhost")
parser.add_argument("--port", default=8080, type=int)
parser.add_argument(
"--no-stream", action="store_true", help="Disable streaming mode tests"
)
parser.add_argument(
"--stream-only", action="store_true", help="Only run streaming mode tests"
)
parser.add_argument(
"--test",
help="Run only the test whose name contains this substring (case-insensitive)",
)
args = parser.parse_args()
url = f"http://{args.host}:{args.port}/v1/chat/completions"
print_info(f"Testing server at {url}")
modes: list[bool] = []
if not args.stream_only:
modes.append(False)
if not args.no_stream:
modes.append(True)
cases: list[dict] = ALL_TEST_CASES
if args.test:
name_filter = args.test.lower()
cases = [c for c in cases if name_filter in str(c["name"]).lower()]
if not cases:
print_fail(f"No test cases matched '{args.test}'")
sys.exit(1)
total = 0
passed = 0
for stream in modes:
for case in cases:
total += 1
if run_test(url, case, stream=stream):
passed += 1
color = GREEN if passed == total else RED
_print(f"\n{BOLD}{color}{'' * 60}{RESET}")
_print(f"{BOLD}{color} Results: {passed}/{total} passed{RESET}")
_print(f"{BOLD}{color}{'' * 60}{RESET}\n")
sys.exit(0 if passed == total else 1)
if __name__ == "__main__":
main()

View File

@@ -207,6 +207,8 @@ struct cli_context {
auto meta = ctx_server.get_meta();
auto & chat_params = meta.chat_params;
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = {}; // TODO
@@ -214,7 +216,7 @@ struct cli_context {
inputs.json_schema = ""; // TODO
inputs.grammar = ""; // TODO
inputs.use_jinja = chat_params.use_jinja;
inputs.parallel_tool_calls = false;
inputs.parallel_tool_calls = caps["supports_parallel_tool_calls"];
inputs.add_generation_prompt = true;
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
inputs.force_pure_content = chat_params.force_pure_content;

View File

@@ -1027,6 +1027,8 @@ json oaicompat_chat_params_parse(
}
}
auto caps = common_chat_templates_get_caps(opt.tmpls.get());
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = common_chat_tools_parse_oaicompat(tools);
@@ -1034,7 +1036,7 @@ json oaicompat_chat_params_parse(
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
inputs.grammar = grammar;
inputs.use_jinja = opt.use_jinja;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]);
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.reasoning_format = opt.reasoning_format;
if (body.contains("reasoning_format")) {