Compare commits

...

18 Commits

Author SHA1 Message Date
Francis Couture-Harpin
64b7d85891 llama : fix command-r inference 2024-03-28 06:22:24 -04:00
Ting Sun
cfc4d75df6 doc: fix outdated default value of batch size (#6336)
* doc: fix outdated default value of batch size

* doc: add doc for ubatch-size
2024-03-28 09:51:06 +01:00
Eric Zhang
6902cb7f2e server : stop gracefully on SIGTERM (#6348) 2024-03-28 09:50:48 +01:00
hutli
d2d8f38996 nix: removed unnessesary indentation 2024-03-28 07:48:27 +00:00
hutli
d39b308eaf nix: moved blas availability check to package inputs so it is still overridable 2024-03-28 07:48:27 +00:00
hutli
c873976649 using blas.meta.available to check host platform 2024-03-28 07:48:27 +00:00
hutli
dbb03e2b9c only using explicit blas if hostPlatform is allowed 2024-03-28 07:48:27 +00:00
Someone Serge
e9f17dc3bf nix: .#windows: proper cross-compilation set-up
Take all dependencies from the cross stage, rather tha only stdenv
2024-03-28 07:48:27 +00:00
Someone Serge
22a462cc1f nix: package: don't introduce the dependency on python
- The generic /usr/bin/env shebangs are good enough
- Python deps are provisioned in the devShells
- We need to be able to leave python out at least on windows (currently breaks eval)
2024-03-28 07:48:27 +00:00
hutli
f6a0f5c642 nix: .#widnows: init
initial nix build for windows using zig

mingwW64 build

removes nix zig windows build

removes nix zig windows build

removed unnessesary glibc.static

removed unnessesary import of pkgs in nix

fixed missing trailing newline on non-windows nix builds

overriding stdenv when building for crosscompiling to windows in nix

better variables when crosscompiling windows in nix

cross compile windows on macos

removed trailing whitespace

remove unnessesary overwrite of "CMAKE_SYSTEM_NAME" in nix windows build

nix: keep file extension when copying result files during cross compile for windows

nix: better checking for file extensions when using MinGW

nix: using hostPlatform instead of targetPlatform when cross compiling for Windows

using hostPlatform.extensions.executable to extract executable format
2024-03-28 07:48:27 +00:00
Ziang Wu
d0e2f6416b doc: fix typo in MobileVLM-README.md (#6181) 2024-03-28 13:03:30 +09:00
Neo Zhang Jianyu
25f4a613c4 [SYCL] fix set main gpu crash (#6339) 2024-03-28 08:55:24 +08:00
Pierrick Hymbert
a016026a3a server: continuous performance monitoring and PR comment (#6283)
* server: bench: init

* server: bench: reduce list of GPU nodes

* server: bench: fix graph, fix output artifact

* ci: bench: add mermaid in case of image cannot be uploaded

* ci: bench: more resilient, more metrics

* ci: bench: trigger build

* ci: bench: fix duration

* ci: bench: fix typo

* ci: bench: fix mermaid values, markdown generated

* typo on the step name

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>

* ci: bench: trailing spaces

* ci: bench: move images in a details section

* ci: bench: reduce bullet point size

---------

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
2024-03-27 20:26:49 +01:00
Someone Serge
53c7ec53d5 nix: ci: dont test cuda and rocm (for now)
Until https://github.com/ggerganov/llama.cpp/issues/6346 is resolved
2024-03-27 19:18:55 +00:00
slaren
e5b89a441a ggml : fix bounds checking of zero size views (#6347) 2024-03-27 15:07:50 +01:00
Georgi Gerganov
3a0345970e make : whitespace 2024-03-27 15:02:49 +02:00
howlger
1e13987fba embedding : show full embedding for single prompt (#6342)
* embedding : show full embedding for single prompt

To support the use case of creating an embedding for a given prompt, the entire embedding and not just the first part needed to be printed.

Also, show cosine similarity matrix only if there is more than one prompt, as the cosine similarity matrix for a single prompt is always `1.00`.

* Update examples/embedding/embedding.cpp

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-27 13:15:44 +02:00
AidanBeltonS
e82f9e2b83 [SYCL] Fix batched impl for NVidia GPU (#6164)
* Fix batched impl

* Maintain previous behaviour for igpu

* retrigger CI

---------

Co-authored-by: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com>
2024-03-27 13:46:40 +05:30
15 changed files with 689 additions and 41 deletions

View File

@@ -24,7 +24,7 @@
useOpenCL
useRocm
useVulkan
],
] && blas.meta.available,
useCuda ? config.cudaSupport,
useMetalKit ? stdenv.isAarch64 && stdenv.isDarwin && !useOpenCL,
useMpi ? false, # Increases the runtime closure size by ~700M
@@ -67,10 +67,15 @@ let
strings.optionalString (suffices != [ ])
", accelerated with ${strings.concatStringsSep ", " suffices}";
executableSuffix = effectiveStdenv.hostPlatform.extensions.executable;
# TODO: package the Python in this repository in a Nix-like way.
# It'd be nice to migrate to buildPythonPackage, as well as ensure this repo
# is PEP 517-compatible, and ensure the correct .dist-info is generated.
# https://peps.python.org/pep-0517/
#
# TODO: Package up each Python script or service appropriately, by making
# them into "entrypoints"
llama-python = python3.withPackages (
ps: [
ps.numpy
@@ -159,11 +164,6 @@ effectiveStdenv.mkDerivation (
--replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";"
substituteInPlace ./ggml-metal.m \
--replace '[bundle pathForResource:@"default" ofType:@"metallib"];' "@\"$out/bin/default.metallib\";"
# TODO: Package up each Python script or service appropriately.
# If we were to migrate to buildPythonPackage and prepare the `pyproject.toml`,
# we could make those *.py into setuptools' entrypoints
substituteInPlace ./*.py --replace "/usr/bin/env python" "${llama-python}/bin/python"
'';
# With PR#6015 https://github.com/ggerganov/llama.cpp/pull/6015,
@@ -244,8 +244,8 @@ effectiveStdenv.mkDerivation (
# TODO(SomeoneSerge): It's better to add proper install targets at the CMake level,
# if they haven't been added yet.
postInstall = ''
mv $out/bin/main $out/bin/llama
mv $out/bin/server $out/bin/llama-server
mv $out/bin/main${executableSuffix} $out/bin/llama${executableSuffix}
mv $out/bin/server${executableSuffix} $out/bin/llama-server${executableSuffix}
mkdir -p $out/include
cp $src/llama.h $out/include/
'';

279
.github/workflows/bench.yml vendored Normal file
View File

@@ -0,0 +1,279 @@
# Benchmark
name: Benchmark
on:
workflow_dispatch:
inputs:
gpu-series:
description: 'Azure GPU series to run with'
required: true
type: choice
options:
- Standard_NC4as_T4_v3
- Standard_NC24ads_A100_v4
- Standard_NC80adis_H100_v5
sha:
description: 'Commit SHA1 to build'
required: false
type: string
duration:
description: 'Duration of the bench'
type: string
default: 10m
push:
branches:
- master
paths: ['.github/workflows/bench.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/bench/**.*']
pull_request:
types: [opened, synchronize, reopened]
paths: ['.github/workflows/bench.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/bench/**.*']
schedule:
- cron: '04 2 * * *'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
bench-server-baseline:
runs-on: Standard_NC4as_T4_v3
env:
RUNNER_LABEL: Standard_NC4as_T4_v3 # FIXME Do not find a way to not duplicate it
N_USERS: 8
DURATION: 10m
if: ${{ github.event.inputs.gpu-series == 'Standard_NC4as_T4_v3' || github.event.schedule || github.event.pull_request || github.event.push.ref == 'refs/heads/master' }}
steps:
- name: Clone
id: checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Install python env
id: pipenv
run: |
cd examples/server/bench
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
- name: Prometheus
id: install_prometheus
run: |
wget --quiet https://github.com/prometheus/prometheus/releases/download/v2.51.0/prometheus-2.51.0.linux-amd64.tar.gz
tar xzf prometheus*.tar.gz --strip-components=1
./prometheus --config.file=examples/server/bench/prometheus.yml &
while ! nc -z localhost 9090; do
sleep 0.1
done
- name: Install k6
id: k6_installation
run: |
cd examples/server/bench
wget --quiet https://github.com/grafana/k6/releases/download/v0.49.0/k6-v0.49.0-linux-amd64.tar.gz
tar xzf k6*.tar.gz --strip-components=1
- name: Build
id: cmake_build
run: |
set -eux
mkdir build
cd build
cmake .. \
-DLLAMA_NATIVE=OFF \
-DLLAMA_BUILD_SERVER=ON \
-DLLAMA_CURL=ON \
-DLLAMA_CUBLAS=ON \
-DCUDAToolkit_ROOT=/usr/local/cuda \
-DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \
-DCMAKE_CUDA_ARCHITECTURES=75 \
-DLLAMA_FATAL_WARNINGS=OFF \
-DLLAMA_ALL_WARNINGS=OFF \
-DCMAKE_BUILD_TYPE=Release;
cmake --build . --config Release -j $(nproc) --target server
- name: Download the dataset
id: download_dataset
run: |
cd examples/server/bench
wget --quiet https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
- name: Server bench
id: server_bench
run: |
set -eux
cd examples/server/bench
source venv/bin/activate
BENCH_K6_BIN_PATH=./k6 python bench.py \
--runner-label ${{ env.RUNNER_LABEL }} \
--name ${{ github.job }} \
--branch ${{ github.head_ref || github.ref_name }} \
--commit ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha }} \
--scenario script.js \
--duration ${{ github.event.inputs.duration || env.DURATION }} \
--hf-repo ggml-org/models \
--hf-file phi-2/ggml-model-q4_0.gguf \
--model-path-prefix /models \
--parallel ${{ env.N_USERS }} \
-ngl 33 \
--batch-size 2048 \
--ubatch-size 256 \
--ctx-size 16384 \
--n-prompts 1000 \
--max-prompt-tokens 1024 \
--max-tokens 2048
cat results.github.env >> $GITHUB_ENV
# Remove dataset as we do not want it in the artefact
rm ShareGPT_V3_unfiltered_cleaned_split.json
- uses: actions/upload-artifact@v4
with:
name: benchmark-results
compression-level: 9
path: |
examples/server/bench/*.jpg
examples/server/bench/*.json
examples/server/bench/*.log
- name: Commit status
uses: Sibz/github-status-action@v1
with:
authToken: ${{secrets.GITHUB_TOKEN}}
sha: ${{ inputs.sha || github.event.pull_request.head.sha || github.sha }}
context: bench-server-baseline
description: |
${{ env.BENCH_RESULTS }}
state: 'success'
- name: Upload benchmark images
uses: devicons/public-upload-to-imgur@v2.2.2
continue-on-error: true # Important as it looks unstable: 503
id: imgur_step
with:
client_id: ${{secrets.IMGUR_CLIENT_ID}}
path: |
examples/server/bench/prompt_tokens_seconds.jpg
examples/server/bench/predicted_tokens_seconds.jpg
examples/server/bench/kv_cache_usage_ratio.jpg
examples/server/bench/requests_processing.jpg
- name: Extract mermaid
id: set_mermaid
run: |
set -eux
cd examples/server/bench
PROMPT_TOKENS_SECONDS=$(cat prompt_tokens_seconds.mermaid)
echo "PROMPT_TOKENS_SECONDS<<EOF" >> $GITHUB_ENV
echo "$PROMPT_TOKENS_SECONDS" >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV
PREDICTED_TOKENS_SECONDS=$(cat predicted_tokens_seconds.mermaid)
echo "PREDICTED_TOKENS_SECONDS<<EOF" >> $GITHUB_ENV
echo "$PREDICTED_TOKENS_SECONDS" >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV
KV_CACHE_USAGE_RATIO=$(cat kv_cache_usage_ratio.mermaid)
echo "KV_CACHE_USAGE_RATIO<<EOF" >> $GITHUB_ENV
echo "$KV_CACHE_USAGE_RATIO" >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV
REQUESTS_PROCESSING=$(cat requests_processing.mermaid)
echo "REQUESTS_PROCESSING<<EOF" >> $GITHUB_ENV
echo "$REQUESTS_PROCESSING" >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV
- name: Extract image url
id: extract_image_url
continue-on-error: true
run: |
set -eux
echo "IMAGE_O=${{ fromJSON(steps.imgur_step.outputs.imgur_urls)[0] }}" >> $GITHUB_ENV
echo "IMAGE_1=${{ fromJSON(steps.imgur_step.outputs.imgur_urls)[1] }}" >> $GITHUB_ENV
echo "IMAGE_2=${{ fromJSON(steps.imgur_step.outputs.imgur_urls)[2] }}" >> $GITHUB_ENV
echo "IMAGE_3=${{ fromJSON(steps.imgur_step.outputs.imgur_urls)[3] }}" >> $GITHUB_ENV
- name: Comment PR
uses: mshick/add-pr-comment@v2
id: comment_pr
if: ${{ github.event.pull_request != '' }}
with:
message-id: bench-${{ github.job }}-${{ env.RUNNER_LABEL }}
message: |
📈 **llama.cpp server** for _${{ github.job }}_ on _${{ env.RUNNER_LABEL }}_: **${{ env.BENCH_ITERATIONS}} iterations** 🚀
- Concurrent users: ${{ env.N_USERS }}, duration: ${{ github.event.inputs.duration || env.DURATION }}
- HTTP request : avg=${{ env.HTTP_REQ_DURATION_AVG }}ms p(90)=${{ env.HTTP_REQ_DURATION_P_90_ }}ms fails=${{ env.HTTP_REQ_FAILED_PASSES }}, finish reason: stop=${{ env.LLAMACPP_COMPLETIONS_STOP_RATE_PASSES }} truncated=${{ env.LLAMACPP_COMPLETIONS_TRUNCATED_RATE_PASSES }}
- Prompt processing (pp): avg=${{ env.LLAMACPP_PROMPT_TOKENS_AVG }}tk/s p(90)=${{ env.LLAMACPP_PROMPT_TOKENS_P_90_ }}tk/s **total=${{ env.LLAMACPP_PROMPT_TOKENS_TOTAL_COUNTER_RATE }}tk/s**
- Token generation (tg): avg=${{ env.LLAMACPP_TOKENS_SECOND_AVG }}tk/s p(90)=${{ env.LLAMACPP_TOKENS_SECOND_P_90_ }}tk/s **total=${{ env.LLAMACPP_COMPLETION_TOKENS_TOTAL_COUNTER_RATE }}tk/s**
- ${{ env.BENCH_GRAPH_XLABEL }}
<details>
<summary>Time series</summary>
<p align="center">
<img width="100%" height="100%" src="${{ env.IMAGE_O }}" alt="prompt_tokens_seconds" />
<details>
<summary>More</summary>
```mermaid
${{ env.PROMPT_TOKENS_SECONDS }}
```
</details>
<img width="100%" height="100%" src="${{ env.IMAGE_1 }}" alt="predicted_tokens_seconds"/>
<details>
<summary>More</summary>
```mermaid
${{ env.PREDICTED_TOKENS_SECONDS }}
```
</details>
</p>
<details>
<summary>Details</summary>
<p align="center">
<img width="100%" height="100%" src="${{ env.IMAGE_2 }}" alt="kv_cache_usage_ratio" />
<details>
<summary>More</summary>
```mermaid
${{ env.KV_CACHE_USAGE_RATIO }}
```
</details>
<img width="100%" height="100%" src="${{ env.IMAGE_3 }}" alt="requests_processing"/>
<details>
<summary>More</summary>
```mermaid
${{ env.REQUESTS_PROCESSING }}
```
</details>
</p>
</details>
</details>

View File

@@ -556,6 +556,7 @@ ifdef LLAMA_CUDA_NO_PEER_COPY
endif # LLAMA_CUDA_NO_PEER_COPY
OBJS += ggml-cuda.o
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<

View File

@@ -178,25 +178,27 @@ int main(int argc, char ** argv) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
// print the first part of the embeddings
// print the first part of the embeddings or for a single prompt, the full embedding
fprintf(stdout, "\n");
for (int j = 0; j < n_prompts; j++) {
fprintf(stdout, "embedding %d: ", j);
for (int i = 0; i < std::min(16, n_embd); i++) {
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
fprintf(stdout, "\n");
}
// print cosine similarity matrix
fprintf(stdout, "\n");
printf("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);
}
if (n_prompts > 1) {
fprintf(stdout, "\n");
printf("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);
}
fprintf(stdout, "\n");
}
}
// clean up

View File

@@ -6,7 +6,7 @@ for more information, please go to [Meituan-AutoML/MobileVLM](https://github.com
The implementation is based on llava, and is compatible with llava and mobileVLM. The usage is basically same as llava.
Notice: The overall process of model inference for both **MobileVLM** and **MobileVLM_V2** models is the same, but the process of model conversion is a little different. Therefore, using MobiVLM as an example, the different conversion step will be shown.
Notice: The overall process of model inference for both **MobileVLM** and **MobileVLM_V2** models is the same, but the process of model conversion is a little different. Therefore, using MobileVLM as an example, the different conversion step will be shown.
## Usage
Build with cmake or run `make llava-cli` to build it.

View File

@@ -296,7 +296,9 @@ These options help improve the performance and memory usage of the LLaMA models.
### Batch Size
- `-b N, --batch-size N`: Set the batch size for prompt processing (default: 512). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
- `-b N, --batch-size N`: Set the batch size for prompt processing (default: `2048`). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
- `-ub N`, `--ubatch-size N`: physical maximum batch size. This is for pipeline parallelization. Default: `512`.
### Prompt Caching

View File

@@ -0,0 +1,303 @@
import argparse
import json
import os
import re
import signal
import socket
import subprocess
import sys
import threading
import time
import traceback
from contextlib import closing
from datetime import datetime
import matplotlib
import matplotlib.dates
import matplotlib.pyplot as plt
import requests
def main(args_in: list[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="Start server benchmark scenario")
parser.add_argument("--name", type=str, help="Bench name", required=True)
parser.add_argument("--runner-label", type=str, help="Runner label", required=True)
parser.add_argument("--branch", type=str, help="Branch name", default="detached")
parser.add_argument("--commit", type=str, help="Commit name", default="dirty")
parser.add_argument("--host", type=str, help="Server listen host", default="0.0.0.0")
parser.add_argument("--port", type=int, help="Server listen host", default="8080")
parser.add_argument("--model-path-prefix", type=str, help="Prefix where to store the model files", default="models")
parser.add_argument("--n-prompts", type=int,
help="SERVER_BENCH_N_PROMPTS: total prompts to randomly select in the benchmark", required=True)
parser.add_argument("--max-prompt-tokens", type=int,
help="SERVER_BENCH_MAX_PROMPT_TOKENS: maximum prompt tokens to filter out in the dataset",
required=True)
parser.add_argument("--max-tokens", type=int,
help="SERVER_BENCH_MAX_CONTEXT: maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens",
required=True)
parser.add_argument("--hf-repo", type=str, help="Hugging Face model repository", required=True)
parser.add_argument("--hf-file", type=str, help="Hugging Face model file", required=True)
parser.add_argument("-ngl", "--n-gpu-layers", type=int, help="layers to the GPU for computation", required=True)
parser.add_argument("--ctx-size", type=int, help="Set the size of the prompt context", required=True)
parser.add_argument("--parallel", type=int, help="Set the number of slots for process requests", required=True)
parser.add_argument("--batch-size", type=int, help="Set the batch size for prompt processing", required=True)
parser.add_argument("--ubatch-size", type=int, help="physical maximum batch size", required=True)
parser.add_argument("--scenario", type=str, help="Scenario to run", required=True)
parser.add_argument("--duration", type=str, help="Bench scenario", required=True)
args = parser.parse_args(args_in)
start_time = time.time()
# Start the server and performance scenario
try:
server_process = start_server(args)
except Exception:
print("bench: server start error :")
traceback.print_exc(file=sys.stdout)
sys.exit(1)
# start the benchmark
try:
start_benchmark(args)
iterations = 0
with open("results.github.env", 'w') as github_env:
# parse output
with open('k6-results.json', 'r') as bench_results:
# Load JSON data from file
data = json.load(bench_results)
for metric_name in data['metrics']:
for metric_metric in data['metrics'][metric_name]:
value = data['metrics'][metric_name][metric_metric]
if isinstance(value, float) or isinstance(value, int):
value = round(value, 2)
data['metrics'][metric_name][metric_metric]=value
github_env.write(
f"{escape_metric_name(metric_name)}_{escape_metric_name(metric_metric)}={value}\n")
token_seconds = data['metrics']['llamacpp_tokens_second']['avg']
iterations = data['root_group']['checks']['success completion']['passes']
except Exception:
print("bench: error :")
traceback.print_exc(file=sys.stdout)
# Stop the server
if server_process:
try:
print(f"bench: shutting down server pid={server_process.pid} ...")
if os.name == 'nt':
interrupt = signal.CTRL_C_EVENT
else:
interrupt = signal.SIGINT
server_process.send_signal(interrupt)
server_process.wait(0.5)
except subprocess.TimeoutExpired:
print(f"server still alive after 500ms, force-killing pid={server_process.pid} ...")
server_process.kill() # SIGKILL
server_process.wait()
while is_server_listening(args.host, args.port):
time.sleep(0.1)
title = (f"llama.cpp {args.name} on {args.runner_label}\n "
f"duration={args.duration} {iterations} iterations")
xlabel = (f"{args.hf_repo}/{args.hf_file}\n"
f"parallel={args.parallel} ctx-size={args.ctx_size} ngl={args.n_gpu_layers} batch-size={args.batch_size} ubatch-size={args.ubatch_size} pp={args.max_prompt_tokens} pp+tg={args.max_tokens}\n"
f"branch={args.branch} commit={args.commit}")
# Prometheus
end_time = time.time()
if is_server_listening("0.0.0.0", 9090):
metrics = ['prompt_tokens_seconds', 'predicted_tokens_seconds',
'kv_cache_usage_ratio', 'requests_processing', 'requests_deferred']
for metric in metrics:
resp = requests.get(f"http://localhost:9090/api/v1/query_range",
params={'query': 'llamacpp:' + metric, 'start': start_time, 'end': end_time, 'step': 2})
with open(f"{metric}.json", 'w') as metric_json:
metric_json.write(resp.text)
if resp.status_code != 200:
print(f"bench: unable to extract prometheus metric {metric}: {resp.text}")
else:
metric_data = resp.json()
values = metric_data['data']['result'][0]['values']
timestamps, metric_values = zip(*values)
metric_values = [float(value) for value in metric_values]
timestamps_dt = [datetime.fromtimestamp(int(ts)) for ts in timestamps]
plt.figure(figsize=(16, 10), dpi=80)
plt.plot(timestamps_dt, metric_values, label=metric)
plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
plt.yticks(fontsize=12, alpha=.7)
ylabel = f"llamacpp:{metric}"
plt.title(title,
fontsize=14, wrap=True)
plt.grid(axis='both', alpha=.3)
plt.ylabel(ylabel, fontsize=22)
plt.xlabel(xlabel, fontsize=14, wrap=True)
plt.gca().xaxis.set_major_locator(matplotlib.dates.MinuteLocator())
plt.gca().xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d %H:%M:%S"))
plt.gcf().autofmt_xdate()
# Remove borders
plt.gca().spines["top"].set_alpha(0.0)
plt.gca().spines["bottom"].set_alpha(0.3)
plt.gca().spines["right"].set_alpha(0.0)
plt.gca().spines["left"].set_alpha(0.3)
# Save the plot as a jpg image
plt.savefig(f'{metric}.jpg', dpi=60)
plt.close()
# Mermaid format in case images upload failed
with (open(f"{metric}.mermaid", 'w') as mermaid_f):
mermaid = (
f"""---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "{title}"
y-axis "llamacpp:{metric}"
x-axis "llamacpp:{metric}" {int(min(timestamps))} --> {int(max(timestamps))}
line [{', '.join([str(round(float(value), 2)) for value in metric_values])}]
""")
mermaid_f.write(mermaid)
# 140 chars max for commit status description
bench_results = {
"req": {
"p90": data['metrics']["http_req_duration"]["p(90)"],
"avg": data['metrics']["http_req_duration"]["avg"],
},
"pp": {
"p90": data['metrics']["llamacpp_prompt_tokens"]["p(90)"],
"avg": data['metrics']["llamacpp_prompt_tokens"]["avg"],
},
"tg": {
"p90": data['metrics']["llamacpp_tokens_second"]["p(90)"],
"avg": data['metrics']["llamacpp_tokens_second"]["avg"],
},
}
with open("results.github.env", 'a') as github_env:
github_env.write(f"BENCH_RESULTS={json.dumps(bench_results, indent=None, separators=(',', ':') )}\n")
github_env.write(f"BENCH_ITERATIONS={iterations}\n")
title = title.replace('\n', ' ')
xlabel = xlabel.replace('\n', ' ')
github_env.write(f"BENCH_GRAPH_TITLE={title}\n")
github_env.write(f"BENCH_GRAPH_XLABEL={xlabel}\n")
def start_benchmark(args):
k6_path = 'k6'
if 'BENCH_K6_BIN_PATH' in os.environ:
k6_path = os.environ['BENCH_K6_BIN_PATH']
k6_args = [
'run', args.scenario,
'--no-color',
]
k6_args.extend(['--duration', args.duration])
k6_args.extend(['--iterations', args.n_prompts])
k6_args.extend(['--vus', args.parallel])
k6_args.extend(['--summary-export', 'k6-results.json'])
args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} "
args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]])
print(f"bench: starting k6 with: {args}")
k6_completed = subprocess.run(args, shell=True, stdout=sys.stdout, stderr=sys.stderr)
if k6_completed.returncode != 0:
raise Exception("bench: unable to run k6")
def start_server(args):
server_process = start_server_background(args)
attempts = 0
max_attempts = 20
if 'GITHUB_ACTIONS' in os.environ:
max_attempts *= 2
while not is_server_listening(args.host, args.port):
attempts += 1
if attempts > max_attempts:
assert False, "server not started"
print(f"bench: waiting for server to start ...")
time.sleep(0.5)
print("bench: server started.")
return server_process
def start_server_background(args):
# Start the server
server_path = '../../../build/bin/server'
if 'LLAMA_SERVER_BIN_PATH' in os.environ:
server_path = os.environ['LLAMA_SERVER_BIN_PATH']
server_args = [
'--host', args.host,
'--port', args.port,
]
model_file = args.model_path_prefix + os.path.sep + args.hf_file
model_dir = os.path.dirname(model_file)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
server_args.extend(['--model', model_file])
server_args.extend(['--hf-repo', args.hf_repo])
server_args.extend(['--hf-file', args.hf_file])
server_args.extend(['--n-gpu-layers', args.n_gpu_layers])
server_args.extend(['--ctx-size', args.ctx_size])
server_args.extend(['--parallel', args.parallel])
server_args.extend(['--batch-size', args.batch_size])
server_args.extend(['--ubatch-size', args.ubatch_size])
server_args.extend(['--n-predict', args.max_tokens * 2])
server_args.extend(['--defrag-thold', "0.1"])
server_args.append('--cont-batching')
server_args.append('--metrics')
server_args.extend(['--log-format', "text"])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
pkwargs = {
'stdout': subprocess.PIPE,
'stderr': subprocess.PIPE
}
server_process = subprocess.Popen(
args,
**pkwargs)
def server_log(in_stream, out_stream):
for line in iter(in_stream.readline, b''):
print(line.decode('utf-8'), end='', file=out_stream)
thread_stdout = threading.Thread(target=server_log, args=(server_process.stdout, sys.stdout))
thread_stdout.start()
thread_stderr = threading.Thread(target=server_log, args=(server_process.stderr, sys.stderr))
thread_stderr.start()
return server_process
def is_server_listening(server_fqdn, server_port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
result = sock.connect_ex((server_fqdn, server_port))
_is_server_listening = result == 0
if _is_server_listening:
print(f"server is listening on {server_fqdn}:{server_port}...")
return _is_server_listening
def escape_metric_name(metric_name):
return re.sub('[^A-Z0-9]', '_', metric_name.upper())
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,9 @@
global:
scrape_interval: 10s
external_labels:
llamacpp: 'server'
scrape_configs:
- job_name: 'llama.cpp server'
static_configs:
- targets: ['localhost:8080']

View File

@@ -0,0 +1,2 @@
matplotlib
requests

View File

@@ -3566,6 +3566,7 @@ int main(int argc, char ** argv) {
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;

View File

@@ -1114,7 +1114,10 @@ def start_server_background(context):
server_args.append('--verbose')
if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
server_args.extend(['--log-format', "text"])
print(f"starting server with: {context.server_path} {server_args}")
args = [str(arg) for arg in [context.server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
flags = 0
if 'nt' == os.name:
flags |= subprocess.DETACHED_PROCESS
@@ -1130,16 +1133,14 @@ def start_server_background(context):
[str(arg) for arg in [context.server_path, *server_args]],
**pkwargs)
def log_stdout(process):
for line in iter(process.stdout.readline, b''):
print(line.decode('utf-8'), end='')
thread_stdout = threading.Thread(target=log_stdout, args=(context.server_process,))
def server_log(in_stream, out_stream):
for line in iter(in_stream.readline, b''):
print(line.decode('utf-8'), end='', file=out_stream)
thread_stdout = threading.Thread(target=server_log, args=(context.server_process.stdout, sys.stdout))
thread_stdout.start()
def log_stderr(process):
for line in iter(process.stderr.readline, b''):
print(line.decode('utf-8'), end='', file=sys.stderr)
thread_stderr = threading.Thread(target=log_stderr, args=(context.server_process,))
thread_stderr = threading.Thread(target=server_log, args=(context.server_process.stderr, sys.stderr))
thread_stderr.start()
print(f"server pid={context.server_process.pid}, behave pid={os.getpid()}")

View File

@@ -145,6 +145,7 @@
# the same path you would with an overlay.
legacyPackages = {
llamaPackages = pkgs.callPackage .devops/nix/scope.nix { inherit llamaVersion; };
llamaPackagesWindows = pkgs.pkgsCross.mingwW64.callPackage .devops/nix/scope.nix { inherit llamaVersion; };
llamaPackagesCuda = pkgsCuda.callPackage .devops/nix/scope.nix { inherit llamaVersion; };
llamaPackagesRocm = pkgsRocm.callPackage .devops/nix/scope.nix { inherit llamaVersion; };
};
@@ -155,6 +156,7 @@
{
default = config.legacyPackages.llamaPackages.llama-cpp;
vulkan = config.packages.default.override { useVulkan = true; };
windows = config.legacyPackages.llamaPackagesWindows.llama-cpp;
}
// lib.optionalAttrs pkgs.stdenv.isLinux {
opencl = config.packages.default.override { useOpenCL = true; };
@@ -168,9 +170,14 @@
};
# Packages exposed in `.#checks` will be built by the CI and by
# `nix flake check`. Currently we expose all packages, but we could
# make more granular choices
checks = config.packages;
# `nix flake check`.
#
# We could test all outputs e.g. as `checks = confg.packages`.
#
# TODO: Build more once https://github.com/ggerganov/llama.cpp/issues/6346 has been addressed
checks = {
inherit (config.packages) default vulkan;
};
};
};
}

View File

@@ -2968,7 +2968,7 @@ namespace dpct
#include "ggml-common.h"
static int g_ggml_sycl_debug=0;
#define GGML_SYCL_DEBUG(...) do{if(g_ggml_sycl_debug) printf(__VA_ARGS__);}while(0)
#define GGML_SYCL_DEBUG(...) do{if(g_ggml_sycl_debug) fprintf(stderr, __VA_ARGS__);}while(0)
#define CHECK_TRY_ERROR(expr) \
[&]() { \
@@ -12868,6 +12868,7 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
}
void ggml_backend_sycl_print_sycl_devices() {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
int device_count = dpct::dev_mgr::instance().device_count();
std::map<std::string, size_t> DeviceNums;
fprintf(stderr, "found %d SYCL devices:\n", device_count);
@@ -12925,7 +12926,9 @@ static void ggml_init_sycl() try {
static bool initialized = false;
if (!initialized) {
fprintf(stderr, "[SYCL] call ggml_init_sycl\n");
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
fprintf(stderr, "%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug);
#if defined(GGML_SYCL_F16)
@@ -14986,6 +14989,9 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
SYCL_CHECK(ggml_sycl_set_device(g_main_device));
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda ||
main_stream->get_backend() == sycl::backend::ext_oneapi_hip;
SYCL_CHECK(
CHECK_TRY_ERROR(g_sycl_handles[g_main_device] = main_stream));
@@ -15016,24 +15022,38 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
if (no_mixed_dtypes) {
cu_compute_type = dpct::library_data_t::real_half;
cu_data_type = dpct::library_data_t::real_half;
}
// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];
const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;
const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;
const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;
const void * alpha = &alpha_f32;
const void * beta = &beta_f32;
if (no_mixed_dtypes) {
alpha = &alpha_f16;
beta = &beta_f16;
}
// TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
// oneMKL open source supports half, half, float, float: datatypes
// when oneMKL open source supports half, half, float, float: datatypes
dst_t = (char *) dst_ddf;
if (no_mixed_dtypes) {
dst_t = (char *) dst_f16.alloc(ne_dst);
nbd2 /= sizeof(float) / sizeof(sycl::half);
nbd3 /= sizeof(float) / sizeof(sycl::half);
}
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
@@ -15119,6 +15139,10 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
}
#endif
if (no_mixed_dtypes) {
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -16018,6 +16042,7 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_
}
GGML_API GGML_CALL void ggml_sycl_get_gpu_list(int *id_list, int max_len) try {
GGML_SYCL_DEBUG("[SYCL] call ggml_sycl_get_gpu_list\n");
for(int i=0;i<max_len;i++) id_list[i] = -1;
if (!g_sycl_gpu_mgr) {
@@ -16052,6 +16077,7 @@ catch (sycl::exception const &exc) {
GGML_API GGML_CALL void ggml_sycl_get_device_description(int device, char *description,
size_t description_size) try {
GGML_SYCL_DEBUG("[SYCL] call ggml_sycl_get_device_description\n");
dpct::device_info prop;
int device_id = g_sycl_gpu_mgr->gpus[device];
SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
@@ -16066,6 +16092,7 @@ catch (sycl::exception const &exc) {
GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free,
size_t *total) try {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
ggml_sycl_set_device(device);
/*
@@ -16417,7 +16444,8 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
};
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device_index) {
ggml_init_sycl();
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
if (device_index>=g_device_count or device_index<0) {
printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
device_index, g_device_count-1);
@@ -16787,6 +16815,7 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface
};
GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
ggml_init_sycl();
// FIXME: this is not thread safe
static std::map<std::array<float, GGML_SYCL_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
@@ -16859,6 +16888,7 @@ static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggm
}
ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n");
static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = {
/* .iface = */ {
/* .get_name = */ ggml_backend_sycl_host_buffer_type_name,
@@ -17155,6 +17185,7 @@ static ggml_guid_t ggml_backend_sycl_guid() {
}
GGML_CALL ggml_backend_t ggml_backend_sycl_init(int device) {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_init\n");
ggml_init_sycl();
check_allow_gpu_index(device);
@@ -17181,6 +17212,7 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
}
GGML_CALL int ggml_backend_sycl_get_device_count() {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
if (!g_sycl_gpu_mgr) g_sycl_gpu_mgr = new sycl_gpu_mgr();
return g_sycl_gpu_mgr->get_gpu_count();
}
@@ -17193,16 +17225,21 @@ GGML_CALL static ggml_backend_t ggml_backend_reg_sycl_init(const char * params,
}
GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id) {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_index\n");
return g_sycl_gpu_mgr->get_index(device_id);
}
GGML_API GGML_CALL int ggml_backend_sycl_get_device_id(int device_index) {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_id\n");
return g_sycl_gpu_mgr->gpus[device_index];
}
GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id) {
GGML_ASSERT(main_gpu_id<g_all_sycl_device_count);
ggml_init_sycl();
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_set_single_device_mode\n");
fprintf(stderr, "ggml_backend_sycl_set_single_device: use single device: [%d]\n", main_gpu_id);
GGML_ASSERT(main_gpu_id<g_all_sycl_device_count);
if (g_sycl_gpu_mgr) {
delete g_sycl_gpu_mgr;
}
@@ -17213,6 +17250,9 @@ GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id
}
GGML_API GGML_CALL void ggml_backend_sycl_set_mul_device_mode() {
ggml_init_sycl();
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_set_mul_device_mode\n");
if (g_ggml_sycl_backend_gpu_mode == SYCL_MUL_GPU_MODE) {
return;
}

2
ggml.c
View File

@@ -2938,7 +2938,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
data_size *= ne[i];
}
GGML_ASSERT(view_src == NULL || data_size + view_offs <= ggml_nbytes(view_src));
GGML_ASSERT(view_src == NULL || data_size == 0 || data_size + view_offs <= ggml_nbytes(view_src));
void * data = view_src != NULL ? view_src->data : NULL;
if (data != NULL) {

View File

@@ -9152,8 +9152,9 @@ struct llm_build_context {
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
}
struct ggml_tensor * attn_out = cur;