Compare commits

...

3 Commits
b6119 ... b6122

Author SHA1 Message Date
Aman Gupta
34c9d765bf CUDA: add attention sinks for tile and wmma (#15178)
* CUDA: add attention sinks for tile and wmma

* Review: formatting changes + remove syncthreads from tile + remove warp_reduce_max from wmma
2025-08-09 20:00:24 +08:00
compilade
e54d41befc gguf-py : add Numpy MXFP4 de/quantization support (#15111)
* gguf-py : add MXFP4 de/quantization support

* ggml-quants : handle zero amax for MXFP4
2025-08-08 17:48:26 -04:00
Johannes Gäßler
4850b52aed server-bench: external OAI servers, sqlite (#15179)
* server-bench: external OAI servers, sqlite

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* raise_for_status

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-08-08 23:04:36 +02:00
8 changed files with 230 additions and 55 deletions

View File

@@ -49,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16(
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
@@ -242,6 +243,31 @@ static __global__ void flash_attn_tile_ext_f16(
__syncthreads();
}
//Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const half sink = __float2half(sinksf[head]);
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
kqmax[j0/nwarps] = kqmax_new_j;
const half val = hexp(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
}
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll

View File

@@ -60,10 +60,11 @@ static __global__ void flash_attn_tile_ext_f32(
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
@@ -252,6 +253,33 @@ static __global__ void flash_attn_tile_ext_f32(
__syncthreads();
}
//Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
kqmax[j0/nwarps] = kqmax_new_j;
const float val = expf(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps] += val;
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
}
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll

View File

@@ -82,11 +82,12 @@ static __global__ void flash_attn_ext_f16(
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half2 * mask2 = (const half2 *) maskh;
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half2 * mask2 = (const half2 *) maskh;
const float * sinksf = (const float *) sinks;
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
@@ -381,6 +382,53 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();
}
// Apply attention sinks
if (sinksf && blockIdx.y == 0) {
const float sinkf = sinksf[head];
const half sinkh = __float2half(sinkf);
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (std::is_same<KQ_acc_t, float>::value) {
float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
KQ_max_f[j0/nwarps] = kqmax_new;
KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D/2 && i >= D/2) break;
VKQ2[j*(D_padded/2) + i] *= scale_h2;
}
} else {
half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
half kqmax_new = fmaxf(kqmax_old, sinkh);
KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
const half2 KQ_max_scale = __half2half2(KQ_max_scale_h);
KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
const half val = hexp(sinkh - kqmax_new);
KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D/2 && i >= D/2) break;
VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
}
}
}
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j_VKQ = j0 + threadIdx.y;

View File

@@ -274,23 +274,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
// TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
if (sinks && !fp16_mma_available(cc)) {
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
}
return;
}
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);

View File

@@ -288,7 +288,7 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE
}
}
const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
const float d = GGML_E8M0_TO_FP32_HALF(e);

View File

@@ -228,8 +228,7 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
d = max / -8
with np.errstate(divide="ignore"):
id = np.where(d == 0, 0, 1 / d)
# FIXME: Q4_0's reference rounding is cursed and depends on FMA
qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
qs = np.trunc((blocks * id) + np.float32(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
@@ -300,8 +299,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
d = max / -16
with np.errstate(divide="ignore"):
id = np.where(d == 0, 0, 1 / d)
# FIXME: Q5_0's reference rounding is cursed and depends on FMA
q = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
q = np.trunc((blocks * id) + np.float32(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
qs = q.reshape((n_blocks, 2, cls.block_size // 2))
qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
@@ -655,6 +653,57 @@ class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0):
return (d * qs.astype(np.float32))
class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
# e2m1 values (doubled)
# ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12)
@staticmethod
# see ggml_e8m0_to_fp32_half in ggml-impl.h
def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray:
bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23))
return bits.view(np.float32)
@classmethod
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
n_blocks = blocks.shape[0]
d = abs(blocks).max(axis=-1, keepdims=True)
with np.errstate(divide="ignore"):
e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
d = cls.e8m0_to_fp32_half(e)
kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 16))
errs = np.abs(d.reshape((n_blocks, 1, 1)) * kvalues.astype(np.float32) - blocks.reshape((n_blocks, cls.block_size, 1)))
best = np.argmin(errs, axis=-1, keepdims=True)
qs = best.reshape(n_blocks, 2, cls.block_size // 2).astype(np.uint8)
qs = qs[:, 0] | (qs[:, 1] << np.uint8(4))
qs = qs.reshape((n_blocks, cls.block_size // 2))
return np.concatenate([e, qs], axis=-1)
@classmethod
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
n_blocks = blocks.shape[0]
e, qs = np.hsplit(blocks, [1])
d = cls.e8m0_to_fp32_half(e)
qs = qs.reshape((n_blocks, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
qs = (qs & np.uint8(0x0F)).view(np.int8)
kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
qs = np.take_along_axis(kvalues, qs, axis=-1).reshape((n_blocks, cls.block_size))
return (d * qs.astype(np.float32))
class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
ksigns: bytes = (
b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f"

View File

@@ -67,6 +67,7 @@ class GGMLQuants:
"q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
"q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
"tq1_0", "tq2_0",
"mxfp4",
"iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
"iq4_nl", "iq4_xs",
):
@@ -140,14 +141,21 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType)
return False
def do_test(libggml_path: Path, quick: bool = False):
def do_test(libggml_path: Path, quick: bool = False, user_type: GGMLQuantizationType | None = None):
ggml_quants = GGMLQuants(libggml_path)
np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})
r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)
# test zero blocks
r[0, 0, :] = 0
## Maybe test infinities? (can make NANs, not really useful in practice)
# r[0, 1, 0] = np.inf
# r[0, 2, 0] = -np.inf
# r[0, 3, 0] = np.inf
# r[0, 3, 1] = -np.inf
for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()):
for qtype in ((GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()) if user_type is None else (user_type,)):
has_dequantize = False
has_quantize = False
@@ -228,11 +236,12 @@ def do_test(libggml_path: Path, quick: bool = False):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so")
parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "bin" / "libggml.so", help="The path to libggml.so")
parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
parser.add_argument("--type", type=str, help="The quant type to test (all by default)")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG)
do_test(args.libggml, args.quick)
do_test(args.libggml, args.quick, GGMLQuantizationType[args.type.upper()] if args.type is not None else None)

View File

@@ -4,6 +4,7 @@ import argparse
import json
import os
import random
import sqlite3
import subprocess
from time import sleep, time
from typing import Optional, Union
@@ -47,6 +48,8 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
def get_server(path_server: str, path_log: Optional[str]) -> dict:
if path_server.startswith("http://") or path_server.startswith("https://"):
return {"process": None, "address": path_server, "fout": None}
if os.environ.get("LLAMA_ARG_HOST") is None:
logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
@@ -89,15 +92,13 @@ def get_prompt_length(data: dict) -> int:
f"{server_address}/apply-template",
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
)
if response.status_code != 200:
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
response.raise_for_status()
prompt: str = json.loads(response.text)["prompt"]
response = session.post(
f"{server_address}/tokenize",
json={"content": prompt, "add_special": True}
)
if response.status_code != 200:
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
response.raise_for_status()
tokens: list[str] = json.loads(response.text)["tokens"]
return len(tokens)
@@ -107,7 +108,12 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
server_address: str = data["server_address"]
t_submit = time()
if data["synthetic_prompt"]:
if data["external_server"]:
json_data: dict = {
"prompt": data["prompt"], "ignore_eos": True,
"seed": data["seed"], "max_tokens": data["n_predict"], "stream": True}
response = session.post(f"{server_address}/v1/completions", json=json_data, stream=True)
elif data["synthetic_prompt"]:
json_data: dict = {
"prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
"seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
@@ -117,34 +123,38 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
f"{server_address}/apply-template",
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
)
if response.status_code != 200:
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
response.raise_for_status()
prompt: str = json.loads(response.text)["prompt"]
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
response.raise_for_status()
lines = []
token_arrival_times: list[float] = []
for line in response.iter_lines(decode_unicode=False):
if not line.startswith(b"data: "):
continue
lines.append(line)
token_arrival_times.append(time())
token_arrival_times = token_arrival_times[:-1]
if response.status_code != 200:
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]):
token_arrival_times = token_arrival_times[:-1]
return (t_submit, token_arrival_times)
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int, seed_offset: int):
def benchmark(
path_server: str, path_log: Optional[str], path_db: Optional[str], name: Optional[str], prompt_source: str, n_prompts: int,
n_predict: int, n_predict_min: int, seed_offset: int):
external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
if not external_server and os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
if not external_server and os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
@@ -165,7 +175,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
else:
n_predict_min = n_predict
if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
if not external_server and os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
context_total: int = context_per_slot * parallel
os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
@@ -176,6 +186,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
try:
server = get_server(path_server, path_log)
server_address: str = server["address"]
assert external_server == (server["process"] is None)
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
session = requests.Session()
@@ -188,8 +199,9 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
if seed_offset >= 0:
random.seed(3 * (seed_offset + 1000 * i) + 1)
data.append({
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
"n_predict": random.randint(n_predict_min, n_predict), "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
"session": session, "server_address": server_address, "external_server": external_server, "prompt": p,
"synthetic_prompt": synthetic_prompts, "n_predict": random.randint(n_predict_min, n_predict),
"seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
if not synthetic_prompts:
logger.info("Getting the prompt lengths...")
@@ -199,7 +211,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
t0 = time()
results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
finally:
if server is not None:
if server is not None and server["process"] is not None:
server["process"].terminate()
server["process"].wait()
if session is not None:
@@ -233,15 +245,24 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
logger.info("")
logger.info(
"The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
if path_db is not None:
con = sqlite3.connect(path_db)
cursor = con.cursor()
cursor.execute(
"CREATE TABLE IF NOT EXISTS server_bench"
"(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, "
"n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);")
cursor.execute(
"INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);",
[name, parallel, prompt_source, n_prompts, n_predict, n_predict_min, seed_offset, token_t_last])
con.commit()
plt.figure()
plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
plt.xlim(0, 1.05e0 * np.max(prompt_n))
plt.ylim(0, 1.05e3 * np.max(prompt_t))
plt.title(name or "")
plt.xlabel("Prompt length [tokens]")
plt.ylabel("Time to first token [ms]")
plt.savefig("prompt_time.png", dpi=240)
@@ -250,6 +271,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
plt.figure()
plt.hist(token_t, np.arange(0, bin_max))
plt.xlim(0, bin_max + 1)
plt.title(name or "")
plt.xlabel("Time [s]")
plt.ylabel("Num. tokens generated per second")
plt.savefig("gen_rate.png", dpi=240)
@@ -259,9 +281,13 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
"Results are printed to console and visualized as plots (saved to current working directory). "
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
"The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
parser.add_argument("--path_db", type=str, default=None, help="Path to an sqlite database to store the benchmark results in")
parser.add_argument("--name", type=str, default=None, help="Name to label plots and database entries with")
parser.add_argument(
"--prompt_source", type=str, default="rng-1024-2048",
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "