Compare commits

...

46 Commits
b5432 ... b5478

Author SHA1 Message Date
Olivier Chafik
f5cd27b71d server: streaming of tool calls and thoughts when --jinja is on (#12379)
* add common_json w/ support for truncated json healing

* add common_chat_msg_diff

* partial common_chat_parse

* refactor parser w/ optionals

* server: wire chat diffs in stream mode

* fix trigger of thinking models (must happen after thoughts are closed)

* fix functionary v3.2 raw python!

* rename: common_chat_syntax (now contains format)

* rm common_regex.at_start

* don't return empty <think></think>

* accommodate yet another deepseek r1 distill fantasy syntax (`<|tool▁calls|>`)

* fix QwQ 32B tool call parsing after thoughts (hermes2)

* better logs for grammar triggers

* consume spaces after parse_json_tool_calls

* fix required tool calls w/ thinking models that have pre-opened thinking tags

* fix thinking model's initial trigger + test qwq's template

* run most test_tool_call tests in stream + non-stream modes

* make functionary v3.2 parsing more strict (differentiate first match from others)

* send final diff from server, to close off raw python arguments

* support partial content streaming in Generic mode

* tool-call: allow content prelude before hermes2 tool calls (for Qwen2.5)

* Update function-calling.md

* Update tool_bench.py

* chat-parser: remove input from exception (llm output may contain PII)

---------

Co-authored-by: ochafik <ochafik@google.com>
Co-authored-by: Olivier Chafik <ochafik@users.noreply.github.com>
2025-05-25 01:48:08 +01:00
Diego Devesa
a2d02d5793 releases : bundle llvm omp library in windows release (#13763) 2025-05-25 00:55:16 +02:00
Diego Devesa
17fc817b58 releases : enable openmp in windows cpu backend build (#13756) 2025-05-24 22:27:03 +02:00
Diego Devesa
2bd1b30f69 ggml-cpu : set openmp wait time if not set (#13758) 2025-05-24 22:26:47 +02:00
0cc4m
259469c4b5 Move GLM4 f32 attention fix to the correct function (#13750) 2025-05-24 16:49:12 +02:00
Xuan-Son Nguyen
4c32832c59 ggml : add ggml_gelu_erf() CUDA kernel (#13719)
* ggml : add ggml_gelu_erf() CUDA kernel

* missing semicolon
2025-05-24 13:06:47 +02:00
Sigbjørn Skjæret
c3a2624339 vocab : fix ugm tokenizer precision (#13743) 2025-05-24 12:29:09 +02:00
Johannes Gäßler
ffd0eae60b CUDA: fix race condition in FA vector kernels (#13742) 2025-05-24 11:46:19 +02:00
Diego Devesa
b775345d78 ci : enable winget package updates (#13734) 2025-05-23 23:14:00 +03:00
Diego Devesa
a70a8a69c2 ci : add winget package updater (#13732) 2025-05-23 22:09:38 +02:00
Georgi Gerganov
d13d0f6135 hparams : initialize arrays (#13728)
ggml-ci
2025-05-23 20:16:13 +03:00
Xuan-Son Nguyen
8a2afb7520 llama : allow custom list of swa_layers (#13726) 2025-05-23 17:07:04 +02:00
Xuan-Son Nguyen
9ecf3e66a3 server : support audio input (#13714)
* server : support audio input

* add audio support on webui
2025-05-23 11:03:47 +02:00
Chenguang Li
faaaff5f94 CANN: Support MUL_MAT_ID for q8_0 and q4_0 (#13705)
* [CANN]Support MUL_MAT_ID Q8 && Q4

Signed-off-by: noemotiovon <757486878@qq.com>

* codestyle adjustment

Signed-off-by: noemotiovon <757486878@qq.com>

---------

Signed-off-by: noemotiovon <757486878@qq.com>
2025-05-23 16:47:53 +08:00
Xuan-Son Nguyen
e16c4731c7 ggml : fix the order of ggml_unary_op (#13718) 2025-05-23 08:12:48 +02:00
Jeff Bolz
1dcd01960c vulkan: support CPY from any type to itself (#13695)
Reuse the f16/f32 copy shaders, and just scale the number of elements
according to the type size.
2025-05-23 06:45:02 +02:00
Jeff Bolz
c10ed6cbcc vulkan: Disable coopmat/coopmat2/bfloat extensions if glslc doesn't support it (#13696) 2025-05-23 06:33:45 +02:00
Judd
a127ff1780 use LOG_WARN to replace std::cerr (#13657) 2025-05-23 06:33:08 +02:00
Diego Devesa
3079e9ac8e release : fix windows hip release (#13707)
* release : fix windows hip release

* make single hip release with multiple targets
2025-05-23 00:21:37 +02:00
Georgi Gerganov
8a1d206f1d tts : fix n_ubatch + make WavTokenizer cache-less (#13713)
ggml-ci
2025-05-22 22:21:07 +03:00
Xuan-Son Nguyen
797990c4bc mtmd : add ultravox audio input (#13623)
* convert ok, load ok

* warmup ok

* test

* still does not work?

* fix padding

* temporary give up

* fix merge conflict

* build_ultravox()

* rm test

* fix merge conflict

* add necessary mtmd APIs

* first working version (only 4s of audio)

* will this monster compile?

* fix compile

* please compile

* fPIC

* fix windows

* various fixes

* clean up audio_helpers

* fix conversion

* add some debug stuff

* long audio input ok

* adapt the api

* add --audio arg

* final touch UX

* add miniaudio to readme

* fix typo

* refactor kv metadata

* mtmd_default_marker()
2025-05-22 20:42:48 +02:00
Aaron Teo
ab86335760 common: Include torch package for s390x (#13699)
* common: update requirements.txt to include pytorch nightly for s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* common: fix torch installation via pip for s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

---------

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
2025-05-22 21:31:29 +03:00
Georgi Gerganov
cc74d5be99 server : pad small embedding batches (#13692)
ggml-ci
2025-05-22 16:33:39 +03:00
Sigbjørn Skjæret
5be24af73d gguf-py : correct charsmap parameter typing (#13701) 2025-05-22 14:25:05 +02:00
Nicolò Scipione
d394a9aedc sycl : Remove waits from function calls (#13702)
* removes the waits in async memcpy functions
2025-05-22 12:54:43 +01:00
Ewan Crawford
6b56a64690 SYCL: Avoid using with SYCL-Graph for unsupported nodes (#13587)
Currently on a CUDA backend to SYCL when running
`GGML_SYCL_DISABLE_GRAPH=0 ./bin/test-backend-ops -b SYCL0` there
are two operations that throw an exception from the blocking
waits during queue recording.

* `-o CONCAT` : Use of blocking waits on a queue that's being recorded https://github.com/ggml-org/llama.cpp/blob/master/ggml/src/ggml-sycl/concat.cpp#L185-L187
* `-o MUL_MAT_ID`: Blocking wait on a recording queue for a copy to host memory https://github.com/ggml-org/llama.cpp/blob/master/ggml/src/ggml-sycl/ggml-sycl.cpp#L3072-L3074

We've noticed that `ggml-cuda.cu` has the
[check_node_graph_compatibility_and_refresh_copy_ops](39e73ae0d6/ggml/src/ggml-cuda/ggml-cuda.cu (L2458-L2458))
method for checking if a graph can be used, even if enabled. I've taken a
similar approach in this PR by adding a method to `ggml-sycl.cpp` for checking
if a graph can be used for the operations even if a user has asked for it to be
enabled.
2025-05-22 16:24:09 +08:00
Henry Linjamäki
a4e8912dfd opencl: Add support for multiple devices (#12622)
* opencl: Add support for multiple devices

... but limited to one platform. A platform with a GPU will be preferred.

Additionally:

* Filter out devices that lack capabilities needed by the backend
  implementation (half support, OpenCL 2.0+, etc).

* Make ggml_backend_opencl_reg() thread-safe.

* fixup: fix an error in sync_with_other_backends

... when there is only one OpenCL device available.
2025-05-21 16:21:45 -07:00
Henry Linjamäki
edbf42edfd opencl: fix couple crashes (#12795)
* opencl: fix couple crashes

* fix kernel launches failed on devices which do not support
  non-uniform work-groups. When non-uniform work-groups are not
  supported, set `local_work_size` to NULL (= let driver choose the
  work-group sizes). This patch does not cover everything - just the
  cases tested by test-backend-ops.

* fix sub-buffer creation failed due to `cl_buffer_region::origin` not
  being aligned to `CL_DEVICE_MEM_BASE_ADDR_ALIGN`.

* OpenCL: query non-uniform WG sizes only on OpenCL 3.0+
2025-05-21 13:21:17 -07:00
Diego Devesa
d643bb2c79 releases : build CPU backend separately (windows) (#13642) 2025-05-21 22:09:57 +02:00
Georgi Gerganov
8e186ef0e7 hparams : support models for which all layers use SWA (#13682)
ggml-ci
2025-05-21 20:00:49 +03:00
Georgi Gerganov
5fbfe384d4 server : improve error reporting (#13680) 2025-05-21 19:46:56 +03:00
antichristHater
c76532e7ba convert : add qwen2vl support for unsloth merges (#13686) 2025-05-21 18:40:35 +02:00
Sigbjørn Skjæret
2aa777d86d examples : switch retrieval to llama_encode (#13685)
* switch retrieval to llama_encode

* enable --no-warmup for retrieval
2025-05-21 16:57:38 +02:00
Emmanuel Ferdman
eb0f5c28d3 gguf-py : display the invalid gguf type (#13687)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-05-21 16:33:54 +02:00
Xuan-Son Nguyen
cf4cb59e64 ggml : add ggml_gelu_erf() (#13667)
* ggml : add ggml_gelu_na (not approximated)

* fix naming order

* rename na --> erf

* apply review suggesions

* revert naming order
2025-05-21 16:26:33 +02:00
Robin Davidsson
0d5c742161 server : Add the endpoints /api/tags and /api/chat (#13659)
* Add the endpoints /api/tags and /api/chat

Add the endpoints /api/tags and /api/chat, and improved the model metadata response

* Remove trailing whitespaces

* Removed code that is not needed for copilot to work.
2025-05-21 15:15:27 +02:00
Dorin-Andrei Geman
42158ae2e8 server : fix first message identification (#13634)
* server : fix first message identification

When using the OpenAI SDK (https://github.com/openai/openai-node/blob/master/src/lib/ChatCompletionStream.ts#L623-L626) we noticed that the expected assistant role is missing in the first streaming message. Fix this by correctly checking for the first message.

Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
Signed-off-by: Dorin Geman <dorin.geman@docker.com>

* server : Fix checks for first role message for stream=True

Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
Signed-off-by: Dorin Geman <dorin.geman@docker.com>

---------

Signed-off-by: Dorin Geman <dorin.geman@docker.com>
Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
2025-05-21 15:07:57 +02:00
Georgi Gerganov
797f2ac062 kv-cache : simplify the interface (#13660)
* kv-cache : simplify the interface

ggml-ci

* context : revert llama_batch_allocr position change

ggml-ci
2025-05-21 15:11:13 +03:00
Georgi Gerganov
b44890df2e model : disable SWA for Phi models (#13676)
* model : disable SWA for Phi models

ggml-ci

* model : update warning message

* model : print warning only if n_swa > 0

* model : fix typo
2025-05-21 13:09:21 +03:00
R0CKSTAR
33983057d0 musa: Upgrade MUSA SDK version to rc4.0.1 and use mudnn::Unary::IDENTITY op to accelerate D2D memory copy (#13647)
* musa: fix build warning (unused parameter)

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* musa: upgrade MUSA SDK version to rc4.0.1

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* musa: use mudnn::Unary::IDENTITY op to accelerate D2D memory copy

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* Update ggml/src/ggml-cuda/cpy.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* musa: remove MUDNN_CHECK_GEN and use CUDA_CHECK_GEN instead in MUDNN_CHECK

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

---------

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2025-05-21 09:58:49 +08:00
Eve
fb1cab201c vulkan: fix warnings (#13626)
* small fixes

* remove ifdef
2025-05-20 21:35:16 +00:00
l3utterfly
b7a17463ec mtmd-helper : bug fix to token batching in mtmd (#13650)
* Update mtmd-helper.cpp

* Update tools/mtmd/mtmd-helper.cpp

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2025-05-20 18:55:30 +02:00
Georgi Gerganov
be0239693c model : fix llama4 graph (#13663)
ggml-ci
2025-05-20 19:21:04 +03:00
Georgi Gerganov
a4090d1174 llama : remove llama_kv_cache_view API + remove deprecated (#13653)
ggml-ci
2025-05-20 16:13:16 +03:00
Johannes Gäßler
b69f1647f9 CUDA: skip fully masked-out KV in FA vec kernel (#13584)
* CUDA: skip fully masked-out KV in FA vec kernel
2025-05-20 14:45:07 +02:00
Sigbjørn Skjæret
759e37b0d8 tests : avoid github urls due to throttling (#13654) 2025-05-20 12:03:17 +02:00
102 changed files with 100508 additions and 2484 deletions

View File

@@ -1,10 +1,10 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG MUSA_VERSION=rc3.1.1
ARG MUSA_VERSION=rc4.0.1
# Target the MUSA build image
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-devel-ubuntu${UBUNTU_VERSION}
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-runtime-ubuntu${UBUNTU_VERSION}
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
@@ -21,21 +21,14 @@ RUN apt-get update && \
libcurl4-openssl-dev \
libgomp1
COPY requirements.txt requirements.txt
COPY requirements requirements
RUN pip install --upgrade pip setuptools wheel \
&& pip install -r requirements.txt
WORKDIR /app
COPY . .
# Use the default MUSA archs if not specified
RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \
fi && \
cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
cmake --build build --config Release -j$(nproc)
RUN mkdir -p /app/lib && \

View File

@@ -48,3 +48,7 @@ end_of_line = unset
charset = unset
trim_trailing_whitespace = unset
insert_final_newline = unset
[tools/mtmd/miniaudio.h]
trim_trailing_whitespace = unset
insert_final_newline = unset

View File

@@ -351,7 +351,7 @@ jobs:
ubuntu-22-cmake-musa:
runs-on: ubuntu-22.04
container: mthreads/musa:rc3.1.1-devel-ubuntu22.04
container: mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04
steps:
- name: Clone

View File

@@ -1,4 +1,4 @@
name: Create Release
name: Release
on:
workflow_dispatch: # allows manual triggering
@@ -227,6 +227,69 @@ jobs:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip
name: llama-bin-ubuntu-vulkan-x64.zip
windows-cpu:
runs-on: windows-latest
strategy:
matrix:
include:
- arch: 'x64'
- arch: 'arm64'
steps:
- name: Clone
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: windows-latest-cmake-cpu-${{ matrix.arch }}
variant: ccache
evict-old-files: 1d
- name: Install Ninja
run: |
choco install ninja
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
with:
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
- name: Build
shell: cmd
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch }}
cmake -S . -B build -G "Ninja Multi-Config" ^
-D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake ^
-DGGML_NATIVE=OFF ^
-DGGML_BACKEND_DL=ON ^
-DGGML_CPU_ALL_VARIANTS=${{ matrix.arch == 'x64' && 'ON' || 'OFF' }} ^
-DGGML_OPENMP=ON ^
-DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" ^
${{ env.CMAKE_ARGS }}
cmake --build build --config Release
- name: Pack artifacts
id: pack_artifacts
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\
Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.42.34433\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\
7z a llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\*
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-bin-win-cpu-${{ matrix.arch }}.zip
name: llama-bin-win-cpu-${{ matrix.arch }}.zip
windows:
runs-on: windows-latest
@@ -237,52 +300,30 @@ jobs:
strategy:
matrix:
include:
- build: 'cpu-x64'
- backend: 'vulkan'
arch: 'x64'
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF'
#- build: 'openblas-x64'
# arch: 'x64'
# defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"'
- build: 'vulkan-x64'
arch: 'x64'
defines: '-DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON'
- build: 'cpu-arm64'
arch: 'arm64'
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF'
- build: 'opencl-adreno-arm64'
defines: '-DGGML_VULKAN=ON'
target: 'ggml-vulkan'
- backend: 'opencl-adreno'
arch: 'arm64'
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON'
target: 'ggml-opencl'
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: windows-latest-cmake-${{ matrix.build }}
key: windows-latest-cmake-${{ matrix.backend }}-${{ matrix.arch }}
variant: ccache
evict-old-files: 1d
- name: Download OpenBLAS
id: get_openblas
if: ${{ matrix.build == 'openblas-x64' }}
run: |
curl.exe -o $env:RUNNER_TEMP/openblas.zip -L "https://github.com/xianyi/OpenBLAS/releases/download/v${env:OPENBLAS_VERSION}/OpenBLAS-${env:OPENBLAS_VERSION}-x64.zip"
curl.exe -o $env:RUNNER_TEMP/OpenBLAS.LICENSE.txt -L "https://github.com/xianyi/OpenBLAS/raw/v${env:OPENBLAS_VERSION}/LICENSE"
mkdir $env:RUNNER_TEMP/openblas
tar.exe -xvf $env:RUNNER_TEMP/openblas.zip -C $env:RUNNER_TEMP/openblas
$vcdir = $(vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath)
$msvc = $(join-path $vcdir $('VC\Tools\MSVC\'+$(gc -raw $(join-path $vcdir 'VC\Auxiliary\Build\Microsoft.VCToolsVersion.default.txt')).Trim()))
$lib = $(join-path $msvc 'bin\Hostx64\x64\lib.exe')
& $lib /machine:x64 "/def:${env:RUNNER_TEMP}/openblas/lib/libopenblas.def" "/out:${env:RUNNER_TEMP}/openblas/lib/openblas.lib" /name:openblas.dll
- name: Install Vulkan SDK
id: get_vulkan
if: ${{ matrix.build == 'vulkan-x64' }}
if: ${{ matrix.backend == 'vulkan' }}
run: |
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe"
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
@@ -296,7 +337,7 @@ jobs:
- name: Install OpenCL Headers and Libs
id: install_opencl
if: ${{ matrix.build == 'opencl-adreno-arm64' }}
if: ${{ matrix.backend == 'opencl-adreno' && matrix.arch == 'arm64' }}
run: |
git clone https://github.com/KhronosGroup/OpenCL-Headers
cd OpenCL-Headers
@@ -314,46 +355,22 @@ jobs:
-DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release"
cmake --build build-arm64-release --target install --config release
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
with:
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
- name: Build
id: cmake_build
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
cmake -S . -B build ${{ matrix.defines }} `
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" `
${{ env.CMAKE_ARGS }}
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS}
- name: Add libopenblas.dll
id: add_libopenblas_dll
if: ${{ matrix.build == 'openblas-x64' }}
run: |
cp $env:RUNNER_TEMP/openblas/bin/libopenblas.dll ./build/bin/Release/openblas.dll
cp $env:RUNNER_TEMP/OpenBLAS.LICENSE.txt ./build/bin/Release/OpenBLAS-${env:OPENBLAS_VERSION}.txt
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
cmake -S . -B build ${{ matrix.defines }} -DGGML_NATIVE=OFF -DGGML_CPU=OFF -DGGML_BACKEND_DL=ON -DLLAMA_CURL=OFF
cmake --build build --config Release --target ${{ matrix.target }}
- name: Pack artifacts
id: pack_artifacts
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\
7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\*
7z a llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip
name: llama-bin-win-${{ matrix.build }}.zip
path: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
windows-cuda:
runs-on: windows-2019
@@ -366,8 +383,6 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Install ccache
uses: hendrikmuhs/ccache-action@v1.2.16
@@ -386,45 +401,30 @@ jobs:
run: |
choco install ninja
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
- name: Build
id: cmake_build
shell: cmd
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
cmake -S . -B build -G "Ninja Multi-Config" ^
-DGGML_NATIVE=OFF ^
-DGGML_BACKEND_DL=ON ^
-DGGML_CPU_ALL_VARIANTS=ON ^
-DGGML_NATIVE=OFF ^
-DGGML_CPU=OFF ^
-DGGML_CUDA=ON ^
-DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" ^
${{ env.CMAKE_ARGS }}
-DLLAMA_CURL=OFF
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
cmake --build build --config Release
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
cmake --build build --config Release -j %NINJA_JOBS% --target ggml-cuda
- name: Pack artifacts
id: pack_artifacts
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
cp $env:CURL_PATH\bin\libcurl-x64.dll .\build\bin\Release\libcurl-x64.dll
7z a llama-${{ steps.tag.outputs.name }}-bin-win-cuda${{ matrix.cuda }}-x64.zip .\build\bin\Release\*
7z a llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-win-cuda${{ matrix.cuda }}-x64.zip
name: llama-bin-win-cuda${{ matrix.cuda }}-x64.zip
path: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
- name: Copy and pack Cuda runtime
run: |
@@ -432,13 +432,13 @@ jobs:
$dst='.\build\bin\cudart\'
robocopy "${{env.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
robocopy "${{env.CUDA_PATH}}\lib" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
7z a cudart-llama-bin-win-cuda${{ matrix.cuda }}-x64.zip $dst\*
7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\*
- name: Upload Cuda runtime
uses: actions/upload-artifact@v4
with:
path: cudart-llama-bin-win-cuda${{ matrix.cuda }}-x64.zip
name: cudart-llama-bin-win-cuda${{ matrix.cuda }}-x64.zip
path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
windows-sycl:
runs-on: windows-latest
@@ -451,12 +451,11 @@ jobs:
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel
ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI"
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
@@ -469,15 +468,18 @@ jobs:
run: |
scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
# TODO: add libcurl support ; we will also need to modify win-build-sycl.bat to accept user-specified args
- name: Build
id: cmake_build
run: examples/sycl/win-build-sycl.bat
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
shell: cmd
run: |
call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
cmake -G "Ninja" -B build ^
-DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx ^
-DCMAKE_BUILD_TYPE=Release ^
-DGGML_BACKEND_DL=ON -DBUILD_SHARED_LIBS=ON ^
-DGGML_CPU=OFF -DGGML_SYCL=ON ^
-DLLAMA_CURL=OFF
cmake --build build --target ggml-sycl -j
- name: Build the release package
id: pack_artifacts
@@ -502,12 +504,12 @@ jobs:
cp "${{ env.ONEAPI_ROOT }}/tbb/latest/bin/tbb12.dll" ./build/bin
echo "cp oneAPI running time dll files to ./build/bin done"
7z a llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip ./build/bin/*
7z a llama-bin-win-sycl-x64.zip ./build/bin/*
- name: Upload the release package
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip
path: llama-bin-win-sycl-x64.zip
name: llama-bin-win-sycl-x64.zip
windows-hip:
@@ -515,14 +517,14 @@ jobs:
strategy:
matrix:
gpu_target: [gfx1100, gfx1101, gfx1030]
include:
- name: "radeon"
gpu_targets: "gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Clone rocWMMA repository
id: clone_rocwmma
@@ -532,7 +534,7 @@ jobs:
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: windows-latest-cmake-hip-release
key: windows-latest-cmake-hip-${{ matrix.name }}-x64
evict-old-files: 1d
- name: Install
@@ -550,50 +552,39 @@ jobs:
run: |
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
- name: Build
id: cmake_build
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
-DCMAKE_BUILD_TYPE=Release `
-DAMDGPU_TARGETS=${{ matrix.gpu_target }} `
-DGGML_BACKEND_DL=ON `
-DGGML_NATIVE=OFF `
-DGGML_CPU=OFF `
-DAMDGPU_TARGETS="${{ matrix.gpu_targets }}" `
-DGGML_HIP_ROCWMMA_FATTN=ON `
-DGGML_HIP=ON `
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" `
${{ env.CMAKE_ARGS }}
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
-DLLAMA_CURL=OFF
cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS}
md "build\bin\rocblas\library\"
cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\"
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
- name: Pack artifacts
id: pack_artifacts
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
cp $env:CURL_PATH\bin\libcurl-x64.dll .\build\bin\libcurl-x64.dll
7z a llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip .\build\bin\*
7z a llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\*
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip
name: llama-bin-win-hip-x64-${{ matrix.gpu_target }}.zip
path: llama-bin-win-hip-${{ matrix.name }}-x64.zip
name: llama-bin-win-hip-${{ matrix.name }}-x64.zip
ios-xcode-build:
runs-on: macos-latest
@@ -655,14 +646,16 @@ jobs:
runs-on: ubuntu-latest
needs:
- ubuntu-22-cpu
- ubuntu-22-vulkan
- windows
- windows-cpu
- windows-cuda
- windows-sycl
- windows-hip
- ubuntu-22-cpu
- ubuntu-22-vulkan
- macOS-arm64
- macOS-x64
- ios-xcode-build
steps:
- name: Clone
@@ -680,10 +673,43 @@ jobs:
uses: actions/download-artifact@v4
with:
path: ./artifact
merge-multiple: true
- name: Move artifacts
id: move_artifacts
run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release
run: |
mkdir -p release
echo "Adding CPU backend files to existing zips..."
for arch in x64 arm64; do
cpu_zip="artifact/llama-bin-win-cpu-${arch}.zip"
temp_dir=$(mktemp -d)
echo "Extracting CPU backend for $arch..."
unzip "$cpu_zip" -d "$temp_dir"
echo "Adding CPU files to $arch zips..."
for target_zip in artifact/llama-bin-win-*-${arch}.zip; do
if [[ "$target_zip" == "$cpu_zip" ]]; then
continue
fi
echo "Adding CPU backend to $(basename "$target_zip")"
realpath_target_zip=$(realpath "$target_zip")
(cd "$temp_dir" && zip -r "$realpath_target_zip" .)
done
rm -rf "$temp_dir"
done
echo "Renaming and moving zips to release..."
for zip_file in artifact/llama-bin-win-*.zip; do
base_name=$(basename "$zip_file" .zip)
zip_name="llama-${{ steps.tag.outputs.name }}-${base_name#llama-}.zip"
echo "Moving $zip_file to release/$zip_name"
mv "$zip_file" "release/$zip_name"
done
echo "Moving other artifacts..."
mv -v artifact/*.zip release
- name: Create release
id: create_release
@@ -702,7 +728,7 @@ jobs:
const path = require('path');
const fs = require('fs');
const release_id = '${{ steps.create_release.outputs.id }}';
for (let file of await fs.readdirSync('./artifact/release')) {
for (let file of await fs.readdirSync('./release')) {
if (path.extname(file) === '.zip') {
console.log('uploadReleaseAsset', file);
await github.repos.uploadReleaseAsset({
@@ -710,7 +736,7 @@ jobs:
repo: context.repo.repo,
release_id: release_id,
name: file,
data: await fs.readFileSync(`./artifact/release/${file}`)
data: await fs.readFileSync(`./release/${file}`)
});
}
}

42
.github/workflows/winget.yml vendored Normal file
View File

@@ -0,0 +1,42 @@
name: Update Winget Package
on:
workflow_dispatch: # allows manual triggering
schedule:
- cron: '28 5 * * *' # Update every day at 5:28 UTC
jobs:
update:
name: Update Winget Package
runs-on: ubuntu-latest
steps:
- name: Install cargo binstall
uses: cargo-bins/cargo-binstall@268643a6b5ea099f5718ee5cd3ff7dc89a5eb49b
- name: Install komac
run: |
cargo binstall komac@2.11.2 -y
- name: Find latest release
id: find_latest_release
uses: actions/github-script@v6
with:
script: |
const { data: releases } = await github.rest.repos.listReleases({
owner: context.repo.owner,
repo: context.repo.repo,
});
console.log("Latest release:", releases[0].tag_name);
return releases[0].tag_name;
- name: Update manifest
env:
VERSION: ${{ steps.find_latest_release.outputs.result }}
run: |
echo "Updating manifest..."
komac update --version ${{ env.VERSION }} \
--urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \
--token ${{ secrets.WINGET_GITHUB_TOKEN }} \
--submit \
ggml.llamacpp

View File

@@ -37,7 +37,7 @@ range of hardware - locally and in the cloud.
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
- AVX, AVX2, AVX512 and AMX support for x86 architectures
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads MTT GPUs via MUSA)
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
- Vulkan and SYCL backend support
- CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity
@@ -237,7 +237,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
| [BLAS](docs/build.md#blas-build) | All |
| [BLIS](docs/backend/BLIS.md) | All |
| [SYCL](docs/backend/SYCL.md) | Intel and Nvidia GPU |
| [MUSA](docs/build.md#musa) | Moore Threads MTT GPU |
| [MUSA](docs/build.md#musa) | Moore Threads GPU |
| [CUDA](docs/build.md#cuda) | Nvidia GPU |
| [HIP](docs/build.md#hip) | AMD GPU |
| [Vulkan](docs/build.md#vulkan) | GPU |
@@ -580,3 +580,4 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc
- [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License
- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License
- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html)
- [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain

View File

@@ -54,7 +54,7 @@ docker run --privileged -it \
-v $HOME/llama.cpp/ci-cache:/ci-cache \
-v $HOME/llama.cpp/ci-results:/ci-results \
-v $PWD:/ws -w /ws \
mthreads/musa:rc3.1.1-devel-ubuntu22.04
mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04
```
Inside the container, execute the following commands:

View File

@@ -60,12 +60,16 @@ add_library(${TARGET} STATIC
base64.hpp
chat.cpp
chat.h
chat-parser.cpp
chat-parser.h
common.cpp
common.h
console.cpp
console.h
json-schema-to-grammar.cpp
json.hpp
json-partial.h
json-partial.cpp
llguidance.cpp
log.cpp
log.h

View File

@@ -39,7 +39,7 @@
using json = nlohmann::ordered_json;
std::initializer_list<enum llama_example> mmproj_examples = {
LLAMA_EXAMPLE_LLAVA,
LLAMA_EXAMPLE_MTMD,
LLAMA_EXAMPLE_SERVER,
};
@@ -1452,7 +1452,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.swa_full = true;
}
));
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--no-context-shift"},
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
@@ -1678,7 +1678,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.warmup = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING}));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg(
{"--spm-infill"},
string_format(
@@ -2065,13 +2065,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.grp_attn_w = value;
}
).set_env("LLAMA_ARG_GRP_ATTN_W").set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
{"-dkvc", "--dump-kv-cache"},
"verbose print of the KV cache",
[](common_params & params) {
params.dump_kv_cache = true;
}
));
add_opt(common_arg(
{"-nkvo", "--no-kv-offload"},
"disable KV offload",
@@ -2240,12 +2233,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD"));
add_opt(common_arg(
{"--image"}, "FILE",
"path to an image file. use with multimodal models. Specify multiple times for batching",
{"--image", "--audio"}, "FILE",
"path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n",
[](common_params & params, const std::string & value) {
params.image.emplace_back(value);
}
).set_examples({LLAMA_EXAMPLE_LLAVA}));
).set_examples({LLAMA_EXAMPLE_MTMD}));
if (llama_supports_rpc()) {
add_opt(common_arg(
{"--rpc"}, "SERVERS",
@@ -2875,7 +2868,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.chat_template = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LLAVA}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(common_arg(
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
string_format(

376
common/chat-parser.cpp Normal file
View File

@@ -0,0 +1,376 @@
#include "chat-parser.h"
#include "common.h"
#include "log.h"
#include "regex-partial.h"
#include <optional>
#include <stdexcept>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
: input_(input), is_partial_(is_partial), syntax_(syntax)
{
result_.role = "assistant";
while (true) {
std::string id = std::to_string(std::rand());
if (input.find(id) == std::string::npos) {
healing_marker_ = id;
break;
}
}
}
std::string common_chat_msg_parser::str(const common_string_range & rng) const {
GGML_ASSERT(rng.begin <= rng.end);
return input_.substr(rng.begin, rng.end - rng.begin);
}
void common_chat_msg_parser::add_content(const std::string &content) {
result_.content += content;
}
void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
result_.reasoning_content += reasoning_content;
}
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
if (name.empty()) {
return false;
}
common_chat_tool_call tool_call;
tool_call.name = name;
tool_call.arguments = arguments;
tool_call.id = id;
// LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
result_.tool_calls.emplace_back(tool_call);
return true;
}
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
return add_tool_call(name, id, arguments);
}
bool common_chat_msg_parser::add_tool_calls(const json & arr) {
for (const auto & item : arr) {
if (!add_tool_call(item)) {
return false;
}
}
return true;
}
void common_chat_msg_parser::finish() {
if (!is_partial_ && pos_ != input_.size()) {
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
}
}
bool common_chat_msg_parser::consume_spaces() {
const auto length = input_.size();
auto consumed = false;
while (pos_ < length && std::isspace(input_[pos_])) {
++pos_;
consumed = true;
}
return consumed;
}
bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
auto pos = pos_;
for (auto i = 0u; i < literal.size(); ++i) {
if (pos >= input_.size()) {
return false;
}
if (input_[pos] != literal[i]) {
return false;
}
++pos;
}
pos_ = pos;
return true;
}
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
auto idx = input_.find(literal, pos_);
if (idx != std::string::npos) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = idx + literal.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
}
if (is_partial_) {
idx = string_find_partial_stop(input_, literal);
if (idx != std::string::npos && idx >= pos_) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = input_.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
}
}
return std::nullopt;
}
void common_chat_msg_parser::consume_literal(const std::string & literal) {
if (!try_consume_literal(literal)) {
throw common_chat_msg_partial_exception(literal);
}
}
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
auto stripped_reasoning = string_strip(reasoning);
if (stripped_reasoning.empty()) {
return;
}
if (syntax_.reasoning_in_content) {
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "<think>" : start_think);
add_content(stripped_reasoning);
if (closed) {
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
}
} else {
add_reasoning_content(stripped_reasoning);
}
};
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
if (auto res = try_find_literal(end_think)) {
handle_reasoning(res->prelude, /* closed */ true);
consume_spaces();
return true;
}
auto rest = consume_rest();
if (!rest.empty()) {
handle_reasoning(rest, /* closed */ !is_partial());
}
if (!syntax_.thinking_forced_open) {
throw common_chat_msg_partial_exception(end_think);
}
return true;
}
}
return false;
}
std::string common_chat_msg_parser::consume_rest() {
auto rest = input_.substr(pos_);
pos_ = input_.size();
return rest;
}
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
}
return std::nullopt;
}
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
pos_ = m.groups[0].end;
return find_regex_result{prelude, m.groups};
}
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
if (auto result = try_consume_regex(regex)) {
return *result;
}
throw common_chat_msg_partial_exception(regex.str());
}
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
auto m = regex.search(input_, pos_);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
}
return std::nullopt;
}
if (m.groups[0].begin != pos_) {
// Didn't match at the current position.
return std::nullopt;
}
pos_ = m.groups[0].end;
return find_regex_result {
/* .prelude = */ "",
m.groups,
};
}
std::optional<common_json> common_chat_msg_parser::try_consume_json() {
auto it = input_.cbegin() + pos_;
const auto end = input_.cend();
common_json result;
if (!common_json_parse(it, end, healing_marker_, result)) {
return std::nullopt;
}
pos_ = std::distance(input_.cbegin(), it);
if (result.healing_marker.marker.empty()) {
// No healing marker, just return the parsed json
return result;
}
if (!is_partial()) {
throw common_chat_msg_partial_exception("JSON");
}
return result;
}
common_json common_chat_msg_parser::consume_json() {
if (auto result = try_consume_json()) {
return *result;
}
throw common_chat_msg_partial_exception("JSON");
}
common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths,
const std::vector<std::vector<std::string>> & content_paths
) {
if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
return *result;
}
throw common_chat_msg_partial_exception("JSON");
}
std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths,
const std::vector<std::vector<std::string>> & content_paths
) {
auto partial = try_consume_json();
if (!partial) {
return std::nullopt;
}
auto is_arguments_path = [&](const std::vector<std::string> & path) {
return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end();
};
auto is_content_path = [&](const std::vector<std::string> & path) {
return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end();
};
if (partial->healing_marker.marker.empty()) {
if (args_paths.empty()) {
// No arguments to dump, and JSON was parsed fully.
return consume_json_result {
partial->json,
/* .is_partial = */ false,
};
}
if (is_arguments_path({})) {
// Entire JSON is the arguments and was parsed fully.
return consume_json_result {
partial->json.dump(),
/* .is_partial = */ false,
};
}
}
LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
auto found_healing_marker = false;
std::vector<std::string> path;
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
if (is_arguments_path(path)) {
auto arguments = j.dump();
if (is_partial() && !partial->healing_marker.marker.empty()) {
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
if (idx != std::string::npos) {
arguments.resize(idx);
found_healing_marker = true;
}
if (arguments == "\"") {
// This happens because of completing `:"$magic` after `"arguments"`
arguments = "";
}
}
return arguments;
}
if (is_content_path(path)) {
if (!j.is_string()) {
throw std::runtime_error("Content path must be a string");
}
std::string str = j;
auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
if (idx != std::string::npos) {
str.resize(idx);
found_healing_marker = true;
}
return str;
}
if (j.is_object()) {
auto obj = json::object();
for (const auto & p : j.items()) {
const auto & key = p.key();
const auto & value = p.value();
const std::string key_str = key; // NOLINT
auto idx = key_str.find(healing_marker_);
if (idx != std::string::npos) {
found_healing_marker = true;
break;
}
path.push_back(key_str);
if (value.is_string()) {
const std::string value_str = value;
if (value_str.find(healing_marker_) != std::string::npos) {
found_healing_marker = true;
if (is_content_path(path)) {
if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
// The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
obj[key] = remove_unsupported_healings_and_dump_args(value);
}
}
break;
}
obj[key] = value;
} else {
obj[key] = remove_unsupported_healings_and_dump_args(value);
}
path.pop_back();
}
return obj;
}
if (j.is_array()) {
auto arr = json::array();
for (const auto & value : j) {
if (value.is_string()) {
std::string str = value;
auto idx = str.find(healing_marker_);
if (idx != std::string::npos) {
// Don't heal array values that aren't in the arguments.
found_healing_marker = true;
break;
}
}
arr.push_back(remove_unsupported_healings_and_dump_args(value));
}
return arr;
}
return j;
};
auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
return consume_json_result {
cleaned,
/* .is_partial = */ found_healing_marker,
};
}

116
common/chat-parser.h Normal file
View File

@@ -0,0 +1,116 @@
#pragma once
#include "chat.h"
#include "json-partial.h"
#include "json.hpp"
#include "regex-partial.h"
#include <optional>
#include <string>
#include <vector>
class common_chat_msg_partial_exception : public std::runtime_error {
public:
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
};
class common_chat_msg_parser {
std::string input_;
bool is_partial_;
common_chat_syntax syntax_;
std::string healing_marker_;
size_t pos_ = 0;
common_chat_msg result_;
public:
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
const std::string & input() const { return input_; }
size_t pos() const { return pos_; }
const std::string & healing_marker() const { return healing_marker_; }
const bool & is_partial() const { return is_partial_; }
const common_chat_msg & result() const { return result_; }
void move_to(size_t pos) {
if (pos > input_.size()) {
throw std::runtime_error("Invalid position!");
}
pos_ = pos;
}
void move_back(size_t n) {
if (pos_ < n) {
throw std::runtime_error("Can't move back that far!");
}
pos_ -= n;
}
// Get the substring of the input at the given range
std::string str(const common_string_range & rng) const;
// Appends to the result.content field
void add_content(const std::string & content);
// Appends to the result.reasoning_content field
void add_reasoning_content(const std::string & reasoning_content);
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
bool add_tool_call(const nlohmann::ordered_json & tool_call);
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
bool add_tool_calls(const nlohmann::ordered_json & arr);
void finish();
bool consume_spaces();
void consume_literal(const std::string & literal);
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
std::string consume_rest();
struct find_regex_result {
std::string prelude;
std::vector<common_string_range> groups;
};
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
bool try_consume_literal(const std::string & literal);
std::optional<find_regex_result> try_find_literal(const std::string & literal);
find_regex_result consume_regex(const common_regex & regex);
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
std::optional<common_json> try_consume_json();
common_json consume_json();
struct consume_json_result {
nlohmann::ordered_json value;
bool is_partial;
};
/*
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
*/
consume_json_result consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths = {},
const std::vector<std::vector<std::string>> & content_paths = {}
);
std::optional<consume_json_result> try_consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths = {},
const std::vector<std::vector<std::string>> & content_paths = {}
);
};

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@
#pragma once
#include "common.h"
#include <functional>
#include <chrono>
#include <string>
#include <vector>
@@ -13,11 +14,19 @@ struct common_chat_tool_call {
std::string name;
std::string arguments;
std::string id;
bool operator==(const common_chat_tool_call & other) const {
return name == other.name && arguments == other.arguments && id == other.id;
}
};
struct common_chat_msg_content_part {
std::string type;
std::string text;
bool operator==(const common_chat_msg_content_part & other) const {
return type == other.type && text == other.text;
}
};
struct common_chat_msg {
@@ -28,6 +37,51 @@ struct common_chat_msg {
std::string reasoning_content;
std::string tool_name;
std::string tool_call_id;
template <class T> T to_json_oaicompat() const;
bool empty() const {
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
}
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
for (auto i = 0u; i < tool_calls.size(); i++) {
if (ids_cache.size() <= i) {
auto id = tool_calls[i].id;
if (id.empty()) {
id = gen_tool_call_id();
}
ids_cache.push_back(id);
}
tool_calls[i].id = ids_cache[i];
}
}
bool operator==(const common_chat_msg & other) const {
return role == other.role
&& content == other.content
&& content_parts == other.content_parts
&& tool_calls == other.tool_calls
&& reasoning_content == other.reasoning_content
&& tool_name == other.tool_name
&& tool_call_id == other.tool_call_id;
}
bool operator!=(const common_chat_msg & other) const {
return !(*this == other);
}
};
struct common_chat_msg_diff {
// std::string reasoning_content_delta;
std::string content_delta;
size_t tool_call_index = std::string::npos;
common_chat_tool_call tool_call_delta;
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
bool operator==(const common_chat_msg_diff & other) const {
return content_delta == other.content_delta
&& tool_call_index == other.tool_call_index
&& tool_call_delta == other.tool_call_delta;
}
};
struct common_chat_tool {
@@ -49,14 +103,11 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
@@ -71,7 +122,7 @@ struct common_chat_templates_inputs {
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
bool extract_reasoning = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
};
@@ -80,11 +131,20 @@ struct common_chat_params {
std::string prompt;
std::string grammar;
bool grammar_lazy = false;
bool thinking_forced_open = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};
struct common_chat_syntax {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
@@ -122,7 +182,7 @@ std::string common_chat_format_example(
bool use_jinja);
std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
@@ -135,3 +195,5 @@ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);

View File

@@ -1329,81 +1329,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
return text;
}
//
// KV cache utils
//
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
int seq_count = 0;
for (int j = 0; j < view.n_seq_max; j++) {
if (cs_curr[j] >= 0) { seq_count++; }
}
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
}
printf("\n=== Done dumping\n");
}
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
std::unordered_map<llama_seq_id, size_t> seqs;
llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
for (int j = 0; j < view.n_seq_max; j++) {
if (cs_curr[j] < 0) { continue; }
if (seqs.find(cs_curr[j]) == seqs.end()) {
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
const size_t sz = seqs.size();
seqs[cs_curr[j]] = sz;
}
}
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
}
printf("=== Sequence legend: ");
for (const auto & it : seqs) {
printf("%zu=%d, ", it.second, it.first);
}
printf("'+'=other sequence ids");
c_curr = view.cells;
cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
for (int j = 0; j < view.n_seq_max; j++) {
if (cs_curr[j] >= 0) {
const auto & it = seqs.find(cs_curr[j]);
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
} else {
putchar('.');
}
}
putchar(' ');
}
printf("\n=== Done dumping\n");
}
//
// Embedding utils
//

View File

@@ -76,7 +76,7 @@ enum llama_example {
LLAMA_EXAMPLE_SERVER,
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
LLAMA_EXAMPLE_EXPORT_LORA,
LLAMA_EXAMPLE_LLAVA,
LLAMA_EXAMPLE_MTMD,
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,
@@ -115,7 +115,7 @@ enum common_grammar_trigger_type {
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
};
struct common_grammar_trigger {
@@ -330,7 +330,6 @@ struct common_params {
bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation
bool display_prompt = true; // print prompt before generation
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data
@@ -622,16 +621,6 @@ std::string common_detokenize(
const std::vector<llama_token> & tokens,
bool special = true);
//
// KV cache utils
//
// Dump the KV cache view with the number of sequences per cell.
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
// Dump the KV cache view showing individual sequences in each cell (long output).
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
//
// Embedding utils
//

255
common/json-partial.cpp Normal file
View File

@@ -0,0 +1,255 @@
#include <json-partial.h>
#include "ggml.h"
#include "log.h"
#include <string>
#include <json.hpp>
using json = nlohmann::ordered_json;
enum common_json_stack_element_type {
COMMON_JSON_STACK_ELEMENT_OBJECT,
COMMON_JSON_STACK_ELEMENT_KEY,
COMMON_JSON_STACK_ELEMENT_ARRAY,
};
struct common_json_stack_element {
common_json_stack_element_type type;
std::string key;
};
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out)
{
std::string::const_iterator it = input.begin();
const auto end = input.end();
return common_json_parse(it, end, healing_marker, out);
}
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out)
{
// // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> {
std::size_t position;
bool found_error;
std::string last_token;
std::string exception_message;
std::vector<common_json_stack_element> stack;
json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
this->position = position - 1;
this->found_error = true;
this->last_token = last_token;
this->exception_message = ex.what();
return false;
}
void close_value() {
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
stack.pop_back();
}
}
bool null() override { // NOLINT
close_value();
return true;
}
bool boolean(bool) override { // NOLINT
close_value();
return true;
}
bool number_integer(number_integer_t) override { // NOLINT
close_value();
return true;
}
bool number_unsigned(number_unsigned_t) override { // NOLINT
close_value();
return true;
}
bool number_float(number_float_t, const string_t &) override { // NOLINT
close_value();
return true;
}
bool string(string_t &) override { // NOLINT
close_value();
return true;
}
bool binary(binary_t &) override { // NOLINT
close_value();
return true;
}
bool start_object(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
return true;
}
bool end_object() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
stack.pop_back();
close_value();
return true;
}
bool key(string_t & key) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
return true;
}
bool start_array(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
return true;
}
bool end_array() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
stack.pop_back();
close_value();
return true;
}
};
json_error_locator err_loc;
auto start = it;
json::sax_parse(it, end, &err_loc);
if (err_loc.found_error) {
it = start;
auto temptative_end = it + err_loc.position;
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
auto input = std::string(it, temptative_end);
try {
out.json = json::parse(input);
// out.json = json::parse(it, temptative_end);
it = temptative_end;
return true;
} catch (const std::exception & ex) {
// No, needs healing.
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
}
auto can_parse = [](const std::string & str) {
try {
auto _ = json::parse(str); // NOLINT
return true;
} catch (const std::exception &) {
return false;
}
};
if (!healing_marker.empty() && !err_loc.stack.empty()) {
std::string str(it, temptative_end);
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
if (last_non_sp_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
auto last_non_sp_char = str[last_non_sp_pos];
// Used to detect stops on a number, which may not be complete.
auto was_maybe_number = [&]() {
if (!str.empty() && std::isspace(str.back())) {
return false;
}
return std::isdigit(last_non_sp_char) ||
last_non_sp_char == '.' ||
last_non_sp_char == 'e' ||
last_non_sp_char == 'E' ||
last_non_sp_char == '-';
};
std::string closing;
for (size_t i = err_loc.stack.size(); i > 0; i--) {
auto & el = err_loc.stack[i - 1];
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
closing += "}";
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
closing += "]";
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
throw std::runtime_error("Unexpected stack element type");
}
}
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
// We're inside an object value
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
// Was about to create an object value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + ": 1" + closing)) {
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
// Was about to create an object
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an object value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an object value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else {
// find last :
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
// Cutting back to opening : for object value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
// Was about to create an array value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an array value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an array value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
// Had just finished a value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
} else {
auto last_pos = str.find_last_of("[,");
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
}
// Cutting back to last [ or , for array value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\": 1" + closing)) {
// Was inside an object key string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
// Was inside an object key string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
} else {
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "Cutting back to last : for object key+value\n");
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
out.json = json::parse(str);
it = temptative_end;
return true;
}
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
// fprintf(stderr, "Closing: TODO\n");
return false;
}
out.json = json::parse(it, end);
it = end;
return true;
}

37
common/json-partial.h Normal file
View File

@@ -0,0 +1,37 @@
#pragma once
#include <json.hpp>
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
struct common_healing_marker {
// Raw marker.
std::string marker;
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
std::string json_dump_marker;
};
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
struct common_json {
nlohmann::ordered_json json;
common_healing_marker healing_marker;
};
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
//
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
//
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out);
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out);

View File

@@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
} else {
std::vector<std::string> patterns_at_start;
std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) {
@@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
{
const auto & pattern = trigger.value;
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
patterns_anywhere.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
trigger_patterns.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
}
std::vector<std::string> trigger_patterns;
if (!patterns_at_start.empty()) {
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}

View File

@@ -45,7 +45,7 @@ class SentencePieceTokenTypes(IntEnum):
class ModelType(IntEnum):
TEXT = 1
VISION = 2
MMPROJ = 2
AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
@@ -54,7 +54,7 @@ AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
class ModelBase:
_model_classes: dict[ModelType, dict[str, type[ModelBase]]] = {
ModelType.TEXT: {},
ModelType.VISION: {},
ModelType.MMPROJ: {},
}
dir_model: Path
@@ -88,7 +88,7 @@ class ModelBase:
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is VisionModel:
type(self) is MmprojModel:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model
@@ -309,6 +309,7 @@ class ModelBase:
gguf.MODEL_TENSOR.POSNET_NORM1,
gguf.MODEL_TENSOR.POSNET_NORM2,
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
)
)
or not new_name.endswith(".weight")
@@ -438,7 +439,7 @@ class ModelBase:
assert names
def func(modelcls: AnyModel) -> AnyModel:
model_type = ModelType.VISION if modelcls.model_arch == gguf.MODEL_ARCH.CLIP_VISION else ModelType.TEXT
model_type = ModelType.MMPROJ if modelcls.model_arch == gguf.MODEL_ARCH.MMPROJ else ModelType.TEXT
for name in names:
cls._model_classes[model_type][name] = modelcls
return modelcls
@@ -1114,60 +1115,87 @@ class TextModel(ModelBase):
self.gguf_writer.add_pooling_type(pooling_type)
class VisionModel(ModelBase):
model_type = ModelType.VISION
model_arch = gguf.MODEL_ARCH.CLIP_VISION
class MmprojModel(ModelBase):
model_type = ModelType.MMPROJ
model_arch = gguf.MODEL_ARCH.MMPROJ
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]
has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
if self.model_arch != gguf.MODEL_ARCH.MMPROJ:
raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
if self.has_vision_encoder and self.has_audio_encoder:
raise NotImplementedError("both vision + audio not supported yet")
# get n_embd of the text model
if "text_config" not in self.hparams:
self.hparams["text_config"] = {}
if "audio_config" not in self.hparams:
self.hparams["audio_config"] = {}
text_config = {**self.hparams, **self.hparams["text_config"]}
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
assert self.n_embd_text > 0, "n_embd not found in hparams"
if "vision_config" not in self.hparams:
raise ValueError("vision_config not found in hparams")
# move vision config to the top level, while preserving the original hparams in global_config
self.global_config = self.hparams
self.hparams = self.hparams["vision_config"]
if "vision_config" in self.hparams:
self.hparams = self.hparams["vision_config"]
elif "audio_config" in self.hparams:
self.hparams = self.hparams["audio_config"]
else:
raise ValueError("vision_config / audio_config not found in hparams")
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
# load preprocessor config
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
def set_gguf_parameters(self):
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
self.gguf_writer.add_vision_has_vision_encoder(True)
# vision config
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.block_count)
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
if self.has_vision_encoder:
self.gguf_writer.add_clip_has_vision_encoder(True)
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
# preprocessor config
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
# vision config
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.block_count)
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
# preprocessor config
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
elif self.has_audio_encoder:
self.gguf_writer.add_clip_has_audio_encoder(True)
self.gguf_writer.add_audio_projection_dim(self.n_embd_text)
# audio config
self.gguf_writer.add_audio_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_audio_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_audio_block_count(self.block_count)
self.gguf_writer.add_audio_head_count(self.find_hparam(["num_attention_heads"]))
else:
raise ValueError("MmprojModel must have either vision or audio encoder")
def write_vocab(self):
raise ValueError("VisionModel does not support vocab writing")
raise ValueError("MmprojModel does not support vocab writing")
@ModelBase.register("GPTNeoXForCausalLM")
@@ -1951,7 +1979,7 @@ class LlamaModel(TextModel):
"LlavaForConditionalGeneration", # pixtral
"Mistral3ForConditionalGeneration", # mistral small 3.1
)
class LlavaVisionModel(VisionModel):
class LlavaVisionModel(MmprojModel):
img_break_tok_id = -1
def __init__(self, *args, **kwargs):
@@ -1977,7 +2005,7 @@ class LlavaVisionModel(VisionModel):
super().set_gguf_parameters()
hparams = self.hparams
if hparams["model_type"] == "pixtral":
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
# hidden_act
@@ -2016,7 +2044,7 @@ class LlavaVisionModel(VisionModel):
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
class SmolVLMModel(VisionModel):
class SmolVLMModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams["model_type"] == "smolvlm_vision":
@@ -2028,7 +2056,7 @@ class SmolVLMModel(VisionModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.IDEFICS3)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.IDEFICS3)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
self.gguf_writer.add_vision_use_gelu(True)
@@ -2094,10 +2122,10 @@ class Llama4Model(LlamaModel):
@ModelBase.register("Llama4ForConditionalGeneration")
class Llama4VisionModel(VisionModel):
class Llama4VisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LLAMA4)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"])
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"]))
assert self.hparams["hidden_act"] == "gelu"
@@ -2645,7 +2673,7 @@ class Qwen2Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
class Qwen2VLModel(TextModel):
model_arch = gguf.MODEL_ARCH.QWEN2VL
@@ -2669,8 +2697,8 @@ class Qwen2VLModel(TextModel):
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
class Qwen2VLVisionModel(VisionModel):
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
class Qwen2VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["image_size"] = self.hparams.get("image_size", 560)
@@ -2685,9 +2713,9 @@ class Qwen2VLVisionModel(VisionModel):
super().set_gguf_parameters()
hparams = self.hparams
if self.global_config['model_type'] == 'qwen2_vl':
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL)
elif self.global_config['model_type'] == 'qwen2_5_vl':
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
self.gguf_writer.add_vision_use_silu(True)
# find n_wa_pattern (window attention pattern)
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
@@ -2746,11 +2774,11 @@ class Qwen2VLVisionModel(VisionModel):
@ModelBase.register("InternVisionModel")
class InternVisionModel(VisionModel):
class InternVisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
# hidden_act
if hparams["hidden_act"] == "silu":
@@ -4008,11 +4036,11 @@ class Gemma3Model(TextModel):
@ModelBase.register("Gemma3ForConditionalGeneration")
class Gemma3VisionModel(VisionModel):
class Gemma3VisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.GEMMA3)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3)
# default values below are taken from HF tranformers code
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
self.gguf_writer.add_vision_use_gelu(True)
@@ -5959,6 +5987,52 @@ class ChameleonModel(TextModel):
return data_torch
@ModelBase.register("UltravoxModel")
class UltravoxModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA # dummy
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
raise NotImplementedError("Ultravox does not have text decoder. Please use --mmproj argument")
@ModelBase.register("UltravoxModel")
class UltravoxAudioModel(MmprojModel):
has_vision_encoder = False # no vision encoder
has_audio_encoder = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["hidden_size"] = self.hparams["d_model"]
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX)
self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
def tensor_force_quant(self, name, new_name, bid, n_dims):
del bid, new_name, n_dims # unused
if ".conv" in name and ".weight" in name:
return gguf.GGMLQuantizationType.F16
return False
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# prevent clash naming with vision tensors
if name.startswith("multi_modal_projector"):
name = "audio." + name
if "conv1.bias" in name or "conv2.bias" in name:
# transpose conv1 and conv2 bias
data_torch = data_torch.unsqueeze(-1)
return [(self.map_tensor_name(name), data_torch)]
###### CONVERSION LOGIC ######
@@ -6134,13 +6208,15 @@ def split_str_to_n_bytes(split_str: str) -> int:
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
# TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
# maybe we should fallback to text model's arch in that case, since not many models have both
text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0]
# if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0]
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0]
return arch
@@ -6203,7 +6279,7 @@ def main() -> None:
with torch.inference_mode():
output_type = ftype_map[args.outtype]
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT
hparams = ModelBase.load_hparams(dir_model)
model_architecture = get_model_architecture(hparams, model_type)
logger.info(f"Model architecture: {model_architecture}")

View File

@@ -107,7 +107,7 @@ You may want to pass in some different `ARGS`, depending on the MUSA environment
The defaults are:
- `MUSA_VERSION` set to `rc3.1.1`
- `MUSA_VERSION` set to `rc4.0.1`
The resulting images, are essentially the same as the non-MUSA images:

View File

@@ -325,36 +325,65 @@ To get the official template from original HuggingFace repos, you can use [scrip
> [!TIP]
> If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills)
> [!CAUTION]
> Beware of extreme KV quantizations (e.g. `-ctk q4_0`), they can substantially degrade the model's tool calling performance.
Test in CLI (or with any library / software that can use OpenAI-compatible API backends):
```bash
curl http://localhost:8080/v1/chat/completions -d '{
"model": "gpt-3.5-turbo",
"tools": [
{
"type":"function",
"function":{
"name":"python",
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters":{
"type":"object",
"properties":{
"code":{
"type":"string",
"description":"The code to run in the ipython interpreter."
"model": "gpt-3.5-turbo",
"tools": [
{
"type":"function",
"function":{
"name":"python",
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters":{
"type":"object",
"properties":{
"code":{
"type":"string",
"description":"The code to run in the ipython interpreter."
}
},
"required":["code"]
}
},
"required":["code"]
}
}
}
],
"messages": [
{
"role": "user",
"content": "Print a hello world message with python."
}
]
}
],
"messages": [
{
"role": "user",
"content": "Print a hello world message with python."
}
]
}'
curl http://localhost:8080/v1/chat/completions -d '{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
{"role": "user", "content": "What is the weather in Istanbul?"}
],
"tools": [{
"type":"function",
"function":{
"name":"get_current_weather",
"description":"Get the current weather in a given location",
"parameters":{
"type":"object",
"properties":{
"location":{
"type":"string",
"description":"The city and country/state, e.g. `San Francisco, CA`, or `Paris, France`"
}
},
"required":["location"]
}
}
}]
}'
```

View File

@@ -4,7 +4,9 @@ llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools
- [llama-mtmd-cli](../tools/mtmd/README.md)
- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API
To enable it, can use use one of the 2 methods below:
Currently, we support **image** and **audio** input. Audio is highly experimental and may have reduced quality.
To enable it, you can use one of the 2 methods below:
- Use `-hf` option with a supported model (see a list of pre-quantized model below)
- To load a model using `-hf` while disabling multimodal, use `--no-mmproj`
@@ -37,6 +39,8 @@ Replaces the `(tool_name)` with the name of binary you want to use. For example,
NOTE: some models may require large context window, for example: `-c 8192`
**Vision models**:
```sh
# Gemma 3
(tool_name) -hf ggml-org/gemma-3-4b-it-GGUF
@@ -78,3 +82,11 @@ NOTE: some models may require large context window, for example: `-c 8192`
# Llama 4 Scout
(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF
```
**Audio models**:
```sh
# Ultravox 0.5
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF
```

View File

@@ -50,8 +50,6 @@ int main(int argc, char ** argv) {
const int N = 5; // n-gram size
const int G = 15; // max verification n-grams
const bool dump_kv_cache = params.dump_kv_cache;
// init llama.cpp
llama_backend_init();
llama_numa_init(params.numa);
@@ -152,9 +150,6 @@ int main(int argc, char ** argv) {
// here we keep adding new n-grams as we go
ngram_container ngrams_observed(llama_vocab_n_tokens(vocab), N, G);
// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);
const auto t_dec_start = ggml_time_us();
// sample first token
@@ -172,12 +167,6 @@ int main(int argc, char ** argv) {
}
while (true) {
// debug
if (dump_kv_cache) {
llama_kv_cache_view_update(ctx, &kvc_view);
common_kv_cache_dump_view_seqs(kvc_view, 40);
}
// build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/
//
// Example for W = 5, N = 4, G = 2:
@@ -473,8 +462,6 @@ int main(int argc, char ** argv) {
common_sampler_free(smpl);
llama_kv_cache_view_free(&kvc_view);
llama_batch_free(batch);
llama_backend_free();

View File

@@ -24,8 +24,6 @@ int main(int argc, char ** argv){
// max. number of additional tokens to draft if match is found
const int n_draft = params.speculative.n_max;
const bool dump_kv_cache = params.dump_kv_cache;
// init llama.cpp
llama_backend_init();
llama_numa_init(params.numa);
@@ -110,18 +108,9 @@ int main(int argc, char ** argv){
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
const auto t_dec_start = ggml_time_us();
while (true) {
// debug
if (dump_kv_cache) {
llama_kv_cache_view_update(ctx, &kvc_view);
common_kv_cache_dump_view_seqs(kvc_view, 40);
}
// print current draft sequence
LOG_DBG("drafted %s\n", string_from(ctx, draft).c_str());

View File

@@ -178,8 +178,6 @@ int main(int argc, char ** argv) {
// insert new requests as soon as the previous one is done
const bool cont_batching = params.cont_batching;
const bool dump_kv_cache = params.dump_kv_cache;
// is the system prompt shared in the cache
const bool is_sp_shared = params.is_pp_shared;
@@ -241,8 +239,6 @@ int main(int argc, char ** argv) {
int32_t n_total_gen = 0;
int32_t n_cache_miss = 0;
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients);
const auto t_main_start = ggml_time_us();
LOG_INF("%s: Simulating parallel requests from clients:\n", __func__);
@@ -272,11 +268,6 @@ int main(int argc, char ** argv) {
LOG_INF("Processing requests ...\n\n");
while (true) {
if (dump_kv_cache) {
llama_kv_cache_view_update(ctx, &kvc_view);
common_kv_cache_dump_view_seqs(kvc_view, 40);
}
common_batch_clear(batch);
// decode any currently ongoing sequences

View File

@@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_self_clear(ctx);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_decode(ctx, batch) < 0) {
LOG_ERR("%s : failed to decode\n", __func__);
if (llama_encode(ctx, batch) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
}
for (int i = 0; i < batch.n_tokens; i++) {
@@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
batch_encode(ctx, batch, out, s, n_embd);
common_batch_clear(batch);
p += s;
s = 0;
@@ -246,7 +246,7 @@ int main(int argc, char ** argv) {
// final batch
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
batch_encode(ctx, batch, out, s, n_embd);
// save embeddings to chunks
for (int i = 0; i < n_chunks; i++) {
@@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
batch_add_seq(query_batch, query_tokens, 0);
std::vector<float> query_emb(n_embd, 0);
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd);
common_batch_clear(query_batch);

View File

@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
auto generate = [&](const std::string & prompt) {
std::string response;
const bool is_first = llama_kv_self_used_cells(ctx) == 0;
const bool is_first = llama_kv_self_seq_pos_max(ctx, 0) == 0;
// tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
@@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
while (true) {
// check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(ctx);
int n_ctx_used = llama_kv_self_used_cells(ctx);
int n_ctx_used = llama_kv_self_seq_pos_max(ctx, 0);
if (n_ctx_used + batch.n_tokens > n_ctx) {
printf("\033[0m\n");
fprintf(stderr, "context size exceeded\n");

View File

@@ -536,6 +536,7 @@ extern "C" {
GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_GELU_ERF,
GGML_UNARY_OP_COUNT,
};
@@ -1024,6 +1025,16 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
// GELU using erf (error function) when possible
// some backends may fallback to approximation based on Abramowitz and Stegun formula
GGML_API struct ggml_tensor * ggml_gelu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_gelu_erf_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_gelu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a);

View File

@@ -2697,14 +2697,10 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
}
}
// GroupedMatmulV2 required tensor_list.size < 128
size_t GROUP_SIZE = 128;
std::vector<std::vector<aclTensor*>> src0_tensor_vec_vec;
std::vector<std::vector<aclTensor*>> src1_tensor_vec_vec;
std::vector<std::vector<aclTensor*>> dst_tensor_vec_vec;
// split and call GroupedMatmulV2
// GroupedMatmulV2 required tensor_list.size < 128
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
// split and call GroupedMatmulV2
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
@@ -2722,6 +2718,133 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
return;
}
/**
* @brief Performs expert-specific matrix multiplication (MoE) with
* quantized precision using the CANN backend.
*
* This function executes a matrix multiplication operation tailored for
* Mixture of Experts (MoE) models, where the input tensor is multiplied
* with expert-specific quantized weight matrices. It leverages the CANN
* backend to perform efficient low-precision computations and stores the
* quantized result in the destination tensor `dst`.
*
* Quantization techniques reduce memory footprint and improve performance
* by using lower-bit representations (e.g., int8) instead of floating-point.
* This function is designed to work with such formats and may incorporate
* optimizations like identity-based fast paths or routing masks for sparse
* expert selection.
*
* @param ctx The context for executing CANN backend operations.
* @param dst The destination tensor where the quantized MoE multiplication result
* will be stored.
*
* @note This function assumes quantized data types and is designed for
* MoE architectures with potential sparse expert routing.
*/
static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: Use aclnnGroupedMatMul
//dst [M, K, N, 1]
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
ggml_tensor * ids = dst->src[2]; //ids [K, N]
GGML_TENSOR_BINARY_OP_LOCALS
// copy index from npu to cpu
int64_t n_as = ne02; // A
int64_t n_ids = ids->ne[0]; // K
std::vector<char> ids_host(ggml_nbytes(ids));
ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
ACL_MEMCPY_DEVICE_TO_HOST);
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
char * src0_original = (char *) src0->data;
char * src1_original = (char *) src1->data;
char * dst_original = (char *) dst->data;
ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
const enum ggml_type type = dst->src[0]->type;
float weight_elem_size;
if (type == GGML_TYPE_Q4_0) {
weight_elem_size = float(sizeof(uint8_t)) / 2;
} else if (type == GGML_TYPE_Q8_0) {
weight_elem_size = float(sizeof(uint8_t));
} else {
GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
}
// src0_row [D, M, 1, 1] weight without permute
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[0] = weight_elem_size;
src0_row.nb[1] = weight_elem_size * ne00;
src0_row.nb[2] = weight_elem_size * ne00;
src0_row.nb[3] = weight_elem_size * ne00;
size_t weight_stride = ne00 * ne01 * weight_elem_size;
size_t weight_size = weight_stride * ne02 * ne03;
// scale [D, M, 1, 1] -> scale && permute
size_t scale_elem_size = sizeof(uint16_t);
size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
// src1_row [D, 1, 1, 1] -> input
src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
src1_row.ne[3] = 1;
src1_row.nb[2] = nb11;
src1_row.nb[3] = nb11;
// dst_row [M, 1, 1, 1] -> out
dst_row.ne[1] = 1;
dst_row.ne[2] = 1;
dst_row.ne[3] = 1;
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
//create weight for one row
ggml_cann_pool_alloc weight_allocator(ctx.pool());
void* weight_buffer = weight_allocator.alloc(nb02);
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
// expert index
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
// If B = 1 (broadcast), always use 0; otherwise, use id.
int64_t i11 = (ne11 == 1 ? 0 : id);
int64_t i12 = iid1;
int64_t i1 = id;
int64_t i2 = i12;
void* src0_tmp_ptr = src0_original + i02*weight_stride;
void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride;
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
// mem cpy
ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride,
ACL_MEMCPY_DEVICE_TO_DEVICE);
void* scale_buffer = (char*)weight_buffer + weight_stride;
ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride,
ACL_MEMCPY_DEVICE_TO_DEVICE);
src0_row.data = weight_buffer;
src1_row.data = src1_tmp_ptr;
dst_row.data = dst_tmp_ptr;
dst_row.src[0] = &src0_row;
dst_row.src[1] = &src1_row;
ggml_cann_mul_mat(ctx, &dst_row);
}
}
return;
}
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
const enum ggml_type type = dst->src[0]->type;
switch (type) {
@@ -2729,6 +2852,10 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
case GGML_TYPE_F16:
ggml_cann_mul_mat_id_fp(ctx, dst);
break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
ggml_cann_mul_mat_id_quant(ctx, dst);
break;
default:
GGML_ABORT("Unsupported type for mul_mat_id");
break;

View File

@@ -2035,6 +2035,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
#ifdef ASCEND_310P
// Q4 && Q8 per group is not suppor on 310p device
return false;
#endif
// only support contiguous for quantized types.
return ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]);
default:
return false;
}

View File

@@ -2202,6 +2202,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
{
@@ -3483,6 +3484,19 @@ void ggml_cpu_init(void) {
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
#ifdef GGML_USE_OPENMP
//if (!getenv("OMP_WAIT_POLICY")) {
// // set the wait policy to active, so that OpenMP threads don't sleep
// putenv("OMP_WAIT_POLICY=active");
//}
if (!getenv("KMP_BLOCKTIME")) {
// set the time to wait before sleeping a thread
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
putenv("KMP_BLOCKTIME=200"); // 200ms
}
#endif
}
#if defined(__ARM_ARCH)

View File

@@ -2691,6 +2691,109 @@ static void ggml_compute_forward_gelu(
}
}
// ggml_compute_forward_gelu_erf
static void ggml_compute_forward_gelu_erf_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_are_same_shape(src0, dst));
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_gelu_erf_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_gelu_erf_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_are_same_shape(src0, dst));
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_gelu_erf_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
const float v = GGML_FP16_TO_FP32(x);
GGML_UNUSED(v);
assert(!isnan(v));
assert(!isinf(v));
}
#endif
}
}
static void ggml_compute_forward_gelu_erf(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_gelu_erf_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_gelu_erf_f16(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_gelu_quick
static void ggml_compute_forward_gelu_quick_f32(
@@ -7749,6 +7852,10 @@ void ggml_compute_forward_unary(
{
ggml_compute_forward_gelu(params, dst);
} break;
case GGML_UNARY_OP_GELU_ERF:
{
ggml_compute_forward_gelu_erf(params, dst);
} break;
case GGML_UNARY_OP_GELU_QUICK:
{
ggml_compute_forward_gelu_quick(params, dst);

View File

@@ -428,6 +428,7 @@ inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp
static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
static const float SQRT_2_INV = 0.70710678118654752440084436210484f;
inline static float ggml_gelu_f32(float x) {
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
@@ -440,6 +441,14 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
}
}
inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {
float xi = GGML_FP16_TO_FP32(x[i]);
float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
y[i] = GGML_FP32_TO_FP16(res);
}
}
#ifdef GGML_GELU_FP16
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
uint16_t t;
@@ -463,6 +472,13 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
}
#endif
inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
float xi = x[i];
y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
}
}
inline static float ggml_gelu_quick_f32(float x) {
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
}

View File

@@ -1,5 +1,8 @@
#include "cpy.cuh"
#include "dequantize.cuh"
#ifdef GGML_USE_MUSA
#include "ggml-musa/mudnn.cuh"
#endif // GGML_USE_MUSA
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
@@ -597,7 +600,14 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
#endif
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
#ifdef GGML_USE_MUSA
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
} else
#endif // GGML_USE_MUSA
{
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {

View File

@@ -772,7 +772,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
GGML_UNUSED(kb0);
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}

View File

@@ -2,9 +2,9 @@
#include "fattn-common.cuh"
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#ifndef GGML_USE_HIP
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // GGML_USE_HIP
static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
@@ -48,6 +48,12 @@ static __global__ void flash_attn_vec_ext_f16(
NO_DEVICE_CODE;
return;
}
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
if (ncols > 1) {
NO_DEVICE_CODE;
return;
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
@@ -91,6 +97,13 @@ static __global__ void flash_attn_vec_ext_f16(
kqsum_shared[j][threadIdx.x] = 0.0f;
}
}
__shared__ half maskh_shared[ncols*D];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskh_shared[j*D + tid] = 0.0f;
}
__syncthreads();
// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
@@ -175,6 +188,36 @@ static __global__ void flash_attn_vec_ext_f16(
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
if (mask) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
}
__syncthreads();
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
// In such cases, skip the KV slice.
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
#ifndef GGML_USE_HIP
bool skip = true;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
skip = skip && isinf(tmp.x) && isinf(tmp.y);
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP
}
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
@@ -202,7 +245,7 @@ static __global__ void flash_attn_vec_ext_f16(
sum = logit_softcap*tanhf(sum);
}
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
sum += maskh_shared[j*D + i_KQ];
if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
@@ -335,7 +378,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;

View File

@@ -2,9 +2,9 @@
#include "fattn-common.cuh"
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#ifndef GGML_USE_HIP
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // GGML_USE_HIP
static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ Q,
const char * __restrict__ K,
@@ -60,6 +60,12 @@ static __global__ void flash_attn_vec_ext_f32(
NO_DEVICE_CODE;
return;
}
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
if (ncols > 1) {
NO_DEVICE_CODE;
return;
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
@@ -104,6 +110,13 @@ static __global__ void flash_attn_vec_ext_f32(
kqsum_shared[j][threadIdx.x] = 0.0f;
}
}
__shared__ float maskf_shared[ncols*D];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskf_shared[j*D + tid] = 0.0f;
}
__syncthreads();
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
@@ -181,6 +194,35 @@ static __global__ void flash_attn_vec_ext_f32(
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
if (mask) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
}
__syncthreads();
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
// In such cases, skip the KV slice.
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
#ifndef GGML_USE_HIP
bool skip = true;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
skip = skip && isinf(maskf_shared[j*D + i]);
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP
}
float kqmax_new_arr[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@@ -204,7 +246,7 @@ static __global__ void flash_attn_vec_ext_f32(
sum = logit_softcap*tanhf(sum);
}
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
sum += maskf_shared[j*D + i_KQ];
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
@@ -326,7 +368,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;

View File

@@ -2192,6 +2192,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_SILU:
ggml_cuda_op_silu(ctx, dst);
break;
case GGML_UNARY_OP_GELU_ERF:
ggml_cuda_op_gelu_erf(ctx, dst);
break;
case GGML_UNARY_OP_GELU_QUICK:
ggml_cuda_op_gelu_quick(ctx, dst);
break;
@@ -2977,6 +2980,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:

View File

@@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) {
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
static __device__ __forceinline__ float op_gelu_erf(float x) {
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));
}
static __device__ __forceinline__ float op_gelu_quick(float x) {
const float GELU_QUICK_COEF = -1.702f;
@@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_gelu>(ctx, dst);
}
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);
}
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
}

View File

@@ -30,6 +30,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -149,6 +149,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SIGMOID,
GGML_METAL_KERNEL_TYPE_GELU,
GGML_METAL_KERNEL_TYPE_GELU_4,
GGML_METAL_KERNEL_TYPE_GELU_ERF,
GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
GGML_METAL_KERNEL_TYPE_SILU,
@@ -1103,6 +1105,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
@@ -1613,6 +1617,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_ELU:
@@ -2251,6 +2256,25 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_GELU_ERF:
{
int64_t n = ggml_nelements(dst);
id<MTLComputePipelineState> pipeline = nil;
if (n % 4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
n /= 4;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_GELU_QUICK:
{
int64_t n = ggml_nelements(dst);

View File

@@ -856,6 +856,7 @@ kernel void kernel_tanh(
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
kernel void kernel_gelu(
device const float * src0,
@@ -897,6 +898,42 @@ kernel void kernel_gelu_quick_4(
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
kernel void kernel_gelu_erf(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
}
kernel void kernel_gelu_erf_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
}
kernel void kernel_silu(
device const float * src0,
device float * dst,

View File

@@ -27,12 +27,15 @@ if (MUSAToolkit_FOUND)
file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-musa/*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
@@ -62,7 +65,9 @@ if (MUSAToolkit_FOUND)
)
# TODO: do not use CUDA definitions for MUSA
target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
if (NOT GGML_BACKEND_DL)
target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
endif()
add_compile_definitions(GGML_USE_MUSA)
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
@@ -92,9 +97,10 @@ if (MUSAToolkit_FOUND)
endif()
if (GGML_STATIC)
# TODO: mudnn has not provided static libraries yet
target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
else()
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
endif()
if (GGML_CUDA_NO_VMM)

112
ggml/src/ggml-musa/mudnn.cu Normal file
View File

@@ -0,0 +1,112 @@
#include <mutex>
#include <mudnn.h>
#include "mudnn.cuh"
namespace mudnn = musa::dnn;
// Returns a human-readable error string for mudnn::Status
const char* mudnnGetErrorString(mudnn::Status err) {
switch (err) {
case mudnn::Status::SUCCESS:
return "Success";
case mudnn::Status::INVALID_PARAMETER:
return "Invalid parameter";
case mudnn::Status::NOT_INITIALIZED:
return "Not initialized";
case mudnn::Status::ALLOC_FAILED:
return "Allocation failed";
case mudnn::Status::NOT_SUPPORTED:
return "Not supported";
case mudnn::Status::INTERNAL_ERROR:
return "Internal error";
case mudnn::Status::ARCH_MISMATCH:
return "Architecture mismatch";
case mudnn::Status::EXECUTION_FAILED:
return "Execution failed";
default:
return "Unknown mudnn status";
}
}
// Error checking macro for MUDNN calls
#define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString)
namespace {
// Thread-safe cache for mudnn::Handle objects per device
std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache;
std::mutex handle_cache_mutex;
mudnn::Handle* get_cached_handle(int device_id) {
std::lock_guard<std::mutex> lock(handle_cache_mutex);
auto it = handle_cache.find(device_id);
if (it != handle_cache.end()) {
return it->second.get();
}
auto handle = std::make_unique<mudnn::Handle>(device_id);
mudnn::Handle* handle_ptr = handle.get();
handle_cache[device_id] = std::move(handle);
return handle_ptr;
}
}
// Extracts dimensions and strides from a ggml_tensor
int get_ggml_dims_and_strides(const ggml_tensor* tensor,
std::vector<int64_t>& dims,
std::vector<int64_t>& strides) {
const int ndims = ggml_n_dims(tensor);
const size_t element_size = ggml_element_size(tensor);
dims.resize(ndims);
strides.resize(ndims);
for (int i = 0; i < ndims; ++i) {
dims[i] = tensor->ne[i];
strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size);
}
return ndims;
}
// Converts ggml_type to mudnn::Tensor::Type
mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return mudnn::Tensor::Type::FLOAT;
case GGML_TYPE_F16:
return mudnn::Tensor::Type::HALF;
// TODO: Add support for other types
default:
MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED);
}
return mudnn::Tensor::Type::FLOAT; // Default fallback
}
// Asynchronous memory copy using mudnn::Unary::IDENTITY
musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) {
mudnn::Tensor tensor_dst, tensor_src;
MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type)));
MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type)));
std::vector<int64_t> dims, strides;
const int ndims = get_ggml_dims_and_strides(src, dims, strides);
MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data()));
MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data()));
MUDNN_CHECK(tensor_dst.SetAddr(dst->data));
MUDNN_CHECK(tensor_src.SetAddr(src->data));
mudnn::Unary op;
MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY));
MUDNN_CHECK(op.SetAlpha(0.0f));
MUDNN_CHECK(op.SetBeta(0.0f));
mudnn::Handle* handle = get_cached_handle(ctx.device);
MUDNN_CHECK(handle->SetStream(ctx.stream()));
MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src));
return musaSuccess;
}

View File

@@ -0,0 +1,12 @@
#pragma once
#include "../include/ggml.h"
#include "../ggml-cuda/common.cuh"
// Asynchronously copies data from src tensor to dst tensor using the provided context.
// Returns a musaError_t indicating success or failure.
musaError_t mudnnMemcpyAsync(
ggml_backend_cuda_context &ctx,
const ggml_tensor *dst,
const ggml_tensor *src
);

View File

@@ -27,6 +27,7 @@
#include <cmath>
#include <memory>
#include <charconv>
#include <mutex>
#undef MIN
#undef MAX
@@ -74,6 +75,7 @@ struct ggml_cl_version {
cl_uint minor = 0;
};
struct ggml_cl_compiler_version {
ADRENO_CL_COMPILER_TYPE type;
int major = -1;
@@ -91,6 +93,14 @@ struct ggml_cl_compiler_version {
}
};
static size_t align_to(size_t value, size_t to_alignment) {
GGML_ASSERT(to_alignment && "Invalid alignment (must be non-zero)");
GGML_ASSERT((to_alignment & (to_alignment - 1)) == 0 && "to_alignment must be power-of-two");
return ((value + to_alignment - 1) / to_alignment) * to_alignment;
}
// Parses a version string of form "XX.YY ". On an error returns ggml_cl_version with all zeroes.
static ggml_cl_version parse_cl_version(std::string_view str) {
size_t major_str_begin = 0;
@@ -221,13 +231,25 @@ static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *drive
return { type, major, minor, patch };
}
struct ggml_backend_opencl_context;
// backend device context
struct ggml_backend_opencl_device_context {
cl_platform_id platform;
std::string platform_name;
cl_device_id device;
std::string device_name;
cl_device_id device;
std::string device_name;
cl_device_type device_type;
std::string device_version;
// Initialized by ggml_cl2_init().
ggml_backend_opencl_context * backend_ctx = nullptr;
// Initialized by ggml_backend_opencl_device_get_buffer_type()
ggml_backend_buffer_type buffer_type;
cl_context context = nullptr;
};
// backend context
@@ -248,6 +270,8 @@ struct ggml_backend_opencl_context {
int adreno_wave_size;
cl_bool non_uniform_workgroups;
cl_context context;
cl_command_queue queue;
@@ -344,15 +368,8 @@ struct ggml_backend_opencl_context {
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
};
static ggml_backend_device g_ggml_backend_opencl_device;
static ggml_backend_opencl_device_context g_ggml_ctx_dev_main {
/*.platform =*/ nullptr,
/*.platform_nane =*/ "",
/*.device =*/ nullptr,
/*.device_name =*/ "",
};
static int ggml_backend_opencl_n_devices = 0;
// All registered devices with a default device in the front.
static std::vector<ggml_backend_device> g_ggml_backend_opencl_devices;
// Profiling
#ifdef GGML_OPENCL_PROFILING
@@ -1107,25 +1124,19 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT("\n");
}
static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
static bool initialized = false;
static ggml_backend_opencl_context *backend_ctx = nullptr;
// XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
// XXX static bool initialized = false;
// XXX static ggml_backend_opencl_context *backend_ctx = nullptr;
if (initialized) {
return backend_ctx;
}
static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev);
ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *)dev->context;
GGML_ASSERT(dev_ctx);
GGML_ASSERT(dev_ctx->platform == nullptr);
GGML_ASSERT(dev_ctx->device == nullptr);
GGML_ASSERT(backend_ctx == nullptr);
namespace /* anonymous */ {
extern struct ggml_backend_device_i ggml_backend_opencl_device_i;
}
initialized = true;
backend_ctx = new ggml_backend_opencl_context();
backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;
cl_int err;
// Look for available and suitable devices.
static std::vector<ggml_backend_device> ggml_opencl_probe_devices(ggml_backend_reg * reg) {
std::vector<ggml_backend_device> found_devices;
#ifdef GGML_OPENCL_PROFILING
GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n");
@@ -1158,11 +1169,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
struct cl_device devices[NDEV];
unsigned n_devices = 0;
struct cl_device * default_device = NULL;
unsigned default_platform_number = 0;
cl_platform_id platform_ids[NPLAT];
if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) {
GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n");
return backend_ctx;
return found_devices;
}
for (unsigned i = 0; i < n_platforms; i++) {
@@ -1197,19 +1209,22 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
}
if (default_device == NULL && p->default_device != NULL) {
default_device = p->default_device;
default_device = p->default_device;
default_platform_number = i;
}
}
if (n_devices == 0) {
GGML_LOG_ERROR("ggml_opencl: could find any OpenCL devices.\n");
return backend_ctx;
return found_devices;
}
char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
char * user_device_string = getenv("GGML_OPENCL_DEVICE");
int user_platform_number = -1;
int user_device_number = -1;
char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
char * user_device_string = getenv("GGML_OPENCL_DEVICE");
int user_platform_number = -1;
int user_device_number = -1;
cl_device * candidate_devices = nullptr;
unsigned n_candidate_devices = 0;
unsigned n;
if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) {
@@ -1224,12 +1239,11 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
GGML_LOG_ERROR("ggml_opencl: invalid device number %d\n", user_device_number);
exit(1);
}
default_device = &platform->devices[user_device_number];
default_device = &platform->devices[user_device_number];
candidate_devices = platform->devices;
n_candidate_devices = platform->n_devices;
} else {
struct cl_device * selected_devices = devices;
unsigned n_selected_devices = n_devices;
// Choose a platform by matching a substring.
if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {
for (unsigned i = 0; i < n_platforms; i++) {
struct cl_platform * p = &platforms[i];
@@ -1244,20 +1258,20 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
exit(1);
}
}
if (user_platform_number != -1) {
struct cl_platform * p = &platforms[user_platform_number];
selected_devices = p->devices;
n_selected_devices = p->n_devices;
default_device = p->default_device;
if (n_selected_devices == 0) {
GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
exit(1);
}
int platform_idx = user_platform_number != -1 ? user_platform_number : default_platform_number;
struct cl_platform * p = &platforms[platform_idx];
candidate_devices = p->devices;
n_candidate_devices = p->n_devices;
default_device = p->default_device;
if (n_candidate_devices == 0) {
GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
exit(1);
}
if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {
for (unsigned i = 0; i < n_selected_devices; i++) {
struct cl_device * d = &selected_devices[i];
for (unsigned i = 0; i < n_candidate_devices; i++) {
struct cl_device * d = &candidate_devices[i];
if (strstr(d->name, user_device_string) != NULL) {
user_device_number = d->number;
break;
@@ -1269,71 +1283,145 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
}
}
if (user_device_number != -1) {
selected_devices = &devices[user_device_number];
n_selected_devices = 1;
default_device = &selected_devices[0];
candidate_devices = &devices[user_device_number];
n_candidate_devices = 1;
default_device = &candidate_devices[0];
}
GGML_ASSERT(n_selected_devices > 0);
GGML_ASSERT(n_candidate_devices > 0);
if (default_device == NULL) {
default_device = &selected_devices[0];
default_device = &candidate_devices[0];
}
}
GGML_LOG_INFO("ggml_opencl: selecting platform: '%s'\n", default_device->platform->name);
GGML_LOG_INFO("ggml_opencl: selecting device: '%s (%s)'\n", default_device->name, default_device->version);
if (default_device->type != CL_DEVICE_TYPE_GPU) {
GGML_LOG_WARN("ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name);
GGML_ASSERT(n_candidate_devices != 0 && candidate_devices);
// Put the default device in front.
for (unsigned i = 1; i < n_candidate_devices; i++) {
if (&candidate_devices[i] == default_device) {
std::swap(candidate_devices[0], candidate_devices[i]);
default_device = &candidate_devices[0];
break;
}
}
dev_ctx->platform = default_device->platform->id;
dev_ctx->device = default_device->id;
backend_ctx->device = default_device->id;
GGML_LOG_INFO("ggml_opencl: selected platform: '%s'\n", default_device->platform->name);
if (strstr(default_device->name, "Adreno") ||
strstr(default_device->name, "Qualcomm") ||
strstr(default_device->version, "Adreno")) {
std::vector<cl_device_id> device_ids;
for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) {
device_ids.push_back(dev->id);
}
cl_int err;
cl_context shared_context;
cl_context_properties properties[] = { (intptr_t) CL_CONTEXT_PLATFORM, (intptr_t) default_device->platform->id, 0 };
CL_CHECK(
(shared_context = clCreateContext(properties, device_ids.size(), device_ids.data(), NULL, NULL, &err), err));
for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) {
GGML_LOG_INFO("\nggml_opencl: device: '%s (%s)'\n", dev->name, dev->version);
auto dev_ctx = std::unique_ptr<ggml_backend_opencl_device_context>(new ggml_backend_opencl_device_context{
/*.platform =*/dev->platform->id,
/*.platform_nane =*/dev->platform->name,
/*.device =*/dev->id,
/*.device_name =*/dev->name,
/*.device_type =*/dev->type,
/*.device_version =*/dev->version,
/*.backend_ctx =*/nullptr,
/*.buffer_type =*/{},
/*.context =*/shared_context,
});
found_devices.push_back(ggml_backend_device{
/* .iface = */ ggml_backend_opencl_device_i,
/* .reg = */ reg,
/* .context = */ dev_ctx.get(),
});
if (!ggml_cl2_init(&found_devices.back())) {
found_devices.pop_back();
GGML_LOG_INFO("ggml_opencl: drop unsupported device.\n");
continue;
}
dev_ctx.release();
}
if (found_devices.size()) {
auto * dev_ctx = static_cast<ggml_backend_opencl_device_context *>(found_devices.front().context);
GGML_LOG_INFO("ggml_opencl: default device: '%s (%s)'\n", dev_ctx->device_name.c_str(),
dev_ctx->device_version.c_str());
if (dev_ctx->device_type != CL_DEVICE_TYPE_GPU) {
GGML_LOG_WARN("ggml_opencl: warning, the default device is not a GPU: '%s'.\n",
dev_ctx->device_name.c_str());
}
}
return found_devices;
}
// Initialize device if it is supported (returns nullptr if it is not).
static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
GGML_ASSERT(dev);
GGML_ASSERT(dev->context);
ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context;
GGML_ASSERT(dev_ctx->platform);
GGML_ASSERT(dev_ctx->device);
if (dev_ctx->backend_ctx) {
return dev_ctx->backend_ctx;
}
auto backend_ctx = std::make_unique<ggml_backend_opencl_context>();
backend_ctx->device = dev_ctx->device;
backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;
if (strstr(dev_ctx->device_name.c_str(), "Adreno") ||
strstr(dev_ctx->device_name.c_str(), "Qualcomm") ||
strstr(dev_ctx->device_version.c_str(), "Adreno")) {
backend_ctx->gpu_family = GPU_FAMILY::ADRENO;
// Usually device version contains the detailed device name
backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->version);
backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str());
if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) {
backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name);
backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str());
}
// Use wave size of 64 for all Adreno GPUs.
backend_ctx->adreno_wave_size = 64;
} else if (strstr(default_device->name, "Intel")) {
} else if (strstr(dev_ctx->device_name.c_str(), "Intel")) {
backend_ctx->gpu_family = GPU_FAMILY::INTEL;
} else {
GGML_LOG_ERROR("Unsupported GPU: %s\n", default_device->name);
GGML_LOG_ERROR("Unsupported GPU: %s\n", dev_ctx->device_name.c_str());
backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;
return backend_ctx;
return nullptr;
}
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) {
GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; "
"run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n");
return backend_ctx;
return nullptr;
}
#endif
// Populate backend device name
dev_ctx->platform_name = default_device->platform->name;
dev_ctx->device_name = default_device->name;
backend_ctx->device_name = default_device->name;
backend_ctx->device_name = dev_ctx->device_name;
// A local ref of cl_device_id for convenience
cl_device_id device = backend_ctx->device;
ggml_cl_version platform_version = get_opencl_platform_version(default_device->platform->id);
ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform);
// Check device OpenCL version, OpenCL 2.0 or above is required
ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device);
if (opencl_c_version.major < 2) {
GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n");
return backend_ctx;
return nullptr;
}
// Check driver version
@@ -1364,7 +1452,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
// fp16 is required
if (!backend_ctx->fp16_support) {
GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n");
return backend_ctx;
return nullptr;
}
// If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes
@@ -1373,7 +1461,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
strstr(ext_buffer, "cl_intel_subgroups") == NULL) {
GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) "
"(note that subgroups is an optional feature in OpenCL 3.0)\n");
return backend_ctx;
return nullptr;
}
cl_uint base_align_in_bits;
@@ -1397,6 +1485,15 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n",
svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false");
if (opencl_c_version.major >= 3) {
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool),
&backend_ctx->non_uniform_workgroups, 0));
} else {
GGML_ASSERT(opencl_c_version.major == 2);
// Non-uniform workgroup sizes is mandatory feature in v2.x.
backend_ctx->non_uniform_workgroups = true;
}
// Print out configurations
#ifdef GGML_OPENCL_SOA_Q
GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n");
@@ -1406,14 +1503,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n");
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
cl_context_properties properties[] = {
(intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)dev_ctx->platform, 0
};
CL_CHECK((backend_ctx->context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err));
cl_int err;
// A local ref of cl_context for convenience
cl_context context = backend_ctx->context;
cl_context context = backend_ctx->context = dev_ctx->context;
//CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),
// (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err :
@@ -1426,7 +1519,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err));
// Load kernels
load_cl_kernels(backend_ctx, opencl_c_version);
load_cl_kernels(backend_ctx.get(), opencl_c_version);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
// Allocate intermediate buffers and images
@@ -1456,10 +1549,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
// For now we support a single devices
ggml_backend_opencl_n_devices = 1;
return backend_ctx;
dev_ctx->backend_ctx = backend_ctx.release();
return dev_ctx->backend_ctx;
}
static void ggml_cl2_free(void) {
@@ -1664,10 +1755,46 @@ static void ggml_backend_opencl_synchronize(ggml_backend_t backend) {
GGML_UNUSED(backend);
}
// Syncronizes the 'backend_ctx's device with others so that commands
// enqueued to it won't start until commands in the other devices have
// completed.
static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) {
if (g_ggml_backend_opencl_devices.size() < 2)
return; // No other devices to synchronize with.
std::vector<cl_event> events;
events.reserve(g_ggml_backend_opencl_devices.size());
for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) {
auto * other_backend_ctx = ggml_cl2_init(&backend_dev);
if (backend_ctx != other_backend_ctx) {
cl_event ev;
CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev));
CL_CHECK(clFlush(other_backend_ctx->queue));
events.push_back(ev);
}
}
CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, events.size(), events.data(), nullptr));
for (auto ev : events) {
CL_CHECK(clReleaseEvent(ev));
}
}
static void sync_with_other_backends(ggml_backend_t backend) {
auto * backend_ctx = static_cast<ggml_backend_opencl_context *>(backend->context);
sync_with_other_backends(backend_ctx);
}
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
// NOTE: this may oversynchronize by synchronizing with
// backends/devices which don't compute 'cgraph's
// dependencies.
sync_with_other_backends(backend);
if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
@@ -2058,15 +2185,16 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
// The original tensor memory is divided into scales and quants, i.e.,
// we first store scales, then quants.
// Create subbuffer for scales.
region.origin = extra_orig->offset + tensor->view_offs + offset;
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
region.size = size_d;
extra->d = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
auto previous_origin = region.origin;
// Create subbuffer for quants.
region.origin = extra_orig->offset + tensor->view_offs + offset + size_d;
region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);
region.size = size_q;
extra->q = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
@@ -2271,8 +2399,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
cl_context context = backend_ctx->context;
cl_command_queue queue = backend_ctx->queue;
// Make sure all previously submitted commands are finished.
CL_CHECK(clFinish(queue));
// Make sure all previously submitted commands in other devices are finished.
sync_with_other_backends(backend_ctx);
#ifdef GGML_OPENCL_SOA_Q
// In end-to-end runs, get_tensor is usually used to get back the logits,
@@ -2376,13 +2504,8 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b
}
static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
// FIXME: not thread safe, device may not be initialized yet
static cl_uint alignment = -1;
if (alignment == (cl_uint)-1) {
ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device);
alignment = backend_ctx->alignment;
}
return alignment;
ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device);
return backend_ctx->alignment;
}
static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
@@ -2409,16 +2532,6 @@ static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = {
/* .is_host = */ NULL,
};
ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() {
static ggml_backend_buffer_type buffer_type = {
/* .iface = */ ggml_backend_opencl_buffer_type_interface,
/* .device = */ &g_ggml_backend_opencl_device,
/* .context = */ nullptr,
};
return &buffer_type;
}
//
// backend device
//
@@ -2476,9 +2589,15 @@ static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, co
}
static ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) {
return ggml_backend_opencl_buffer_type();
auto * dev_ctx = static_cast<ggml_backend_opencl_device_context *>(dev->context);
GGML_UNUSED(dev);
dev_ctx->buffer_type = ggml_backend_buffer_type{
/* .iface = */ ggml_backend_opencl_buffer_type_interface,
/* .device = */ dev,
/* .context = */ nullptr,
};
return &dev_ctx->buffer_type;
}
static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
@@ -2494,12 +2613,21 @@ static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const
}
static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_opencl_buffer_type_get_name;
// Check 'dev' and 'buffer_type' are not objects belonging to this backend.
if (dev->iface.get_name != ggml_backend_opencl_device_get_name ||
buft->iface.get_name != ggml_backend_opencl_buffer_type_get_name) {
return false;
}
GGML_UNUSED(dev);
// Check cl_context is the same. clEnqueue* commands may not use
// buffers from another cl_context.
ggml_backend_opencl_context * backend_ctx0 = ggml_cl2_init(dev);
ggml_backend_opencl_context * backend_ctx1 = ggml_cl2_init(buft->device);
return backend_ctx0->context == backend_ctx1->context;
}
static struct ggml_backend_device_i ggml_backend_opencl_device_i = {
namespace /* anonymous */ {
struct ggml_backend_device_i ggml_backend_opencl_device_i = {
/* .get_name = */ ggml_backend_opencl_device_get_name,
/* .get_description = */ ggml_backend_opencl_device_get_description,
/* .get_memory = */ ggml_backend_opencl_device_get_memory,
@@ -2516,6 +2644,7 @@ static struct ggml_backend_device_i ggml_backend_opencl_device_i = {
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
};
}
// Backend registry
@@ -2526,15 +2655,15 @@ static const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) {
}
static size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) {
return ggml_backend_opencl_n_devices;
return g_ggml_backend_opencl_devices.size();
GGML_UNUSED(reg);
}
static ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) {
GGML_ASSERT(index == 0);
GGML_ASSERT(index < ggml_backend_opencl_reg_device_count(reg));
return &g_ggml_backend_opencl_device;
return &g_ggml_backend_opencl_devices[index];
GGML_UNUSED(reg);
GGML_UNUSED(index);
@@ -2548,27 +2677,23 @@ static struct ggml_backend_reg_i ggml_backend_opencl_reg_i = {
};
ggml_backend_reg_t ggml_backend_opencl_reg(void) {
// TODO: make this thread-safe somehow?
static std::mutex mutex;
static ggml_backend_reg reg;
static bool initialized = false;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
reg = ggml_backend_reg {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_opencl_reg_i,
/* .context = */ NULL,
};
g_ggml_backend_opencl_device = ggml_backend_device {
/* .iface = */ ggml_backend_opencl_device_i,
/* .reg = */ &reg,
/* .context = */ &g_ggml_ctx_dev_main,
};
ggml_cl2_init(&g_ggml_backend_opencl_device);
initialized = true;
if (initialized) {
return &reg;
}
initialized = true;
g_ggml_backend_opencl_devices = ggml_opencl_probe_devices(&reg);
reg = ggml_backend_reg{
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_opencl_reg_i,
/* .context = */ NULL,
};
return &reg;
}
@@ -2942,14 +3067,19 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
} else {
unsigned int nth = MIN(64, ne0);
@@ -3077,14 +3207,19 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
} else {
unsigned int nth = MIN(64, ne0);
@@ -3233,14 +3368,19 @@ static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
}
@@ -3273,14 +3413,19 @@ static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
}
@@ -3320,14 +3465,19 @@ static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, cons
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
}
@@ -4230,14 +4380,19 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
}
@@ -4418,14 +4573,19 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr
size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (ne00 % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
}
}

View File

@@ -3740,7 +3740,7 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
data, (const char *)tensor->data + offset, size).wait()));
data, (const char *)tensor->data + offset, size)));
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -3760,7 +3760,7 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
*/
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
dst->data, src->data, ggml_nbytes(dst)).wait()));
dst->data, src->data, ggml_nbytes(dst))));
return true;
}
@@ -3809,11 +3809,43 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc
}
}
#ifdef GGML_SYCL_GRAPH
static bool check_graph_compatibility(ggml_cgraph * cgraph) {
if (ggml_sycl_info().device_count > 1) {
// A sycl_ex::command_graph object can only be created for a single device
GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__);
return false;
}
for (int i = 0; i < cgraph->n_nodes; i++) {
const ggml_op node_op = cgraph->nodes[i]->op;
switch (node_op) {
default:
break;
case GGML_OP_CONCAT:
// ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
// but wait() can't be called on the events returned by a queue recording
// to a graph.
[[fallthrough]];
case GGML_OP_MUL_MAT_ID:
// ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
// submitting a memcpy operation, but wait() can't be called on a queue that
// is recording to a graph.
GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
ggml_op_name(node_op));
return false;
}
}
return true;
}
#endif
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
#ifdef GGML_SYCL_GRAPH
if (!g_ggml_sycl_disable_graph) {
bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);
if (use_sycl_graph) {
const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
if (!graph_support) {
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);

View File

@@ -2804,23 +2804,29 @@ static vk_device ggml_vk_get_device(size_t idx) {
pipeline_robustness = true;
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
device->subgroup_size_control = true;
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT")) {
device->coopmat_support = true;
device->coopmat_m = 0;
device->coopmat_n = 0;
device->coopmat_k = 0;
#endif
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
coopmat2_support = true;
#endif
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
device->integer_dot_product = true;
#endif
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
bfloat16_support = true;
#endif
}
}
@@ -4513,6 +4519,8 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
GGML_UNUSED(src1_type);
}
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
@@ -4668,6 +4676,19 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
}
}
if (src->type == to) {
// Copy two or four bytes at a time, depending on block size.
// For quantized types, we scale by block size/type size. But
// this path is also used for bf16->bf16 for example, where the
// type size must be exactly 2 or 4.
GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
if ((ggml_type_size(src->type) % 4) == 0) {
return ctx->device->pipeline_contig_cpy_f32_f32;
} else {
return ctx->device->pipeline_contig_cpy_f16_f16;
}
}
std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
GGML_ABORT("fatal error");
}
@@ -6729,7 +6750,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_UNARY:
case GGML_OP_CONV_2D_DW:
{
const uint32_t ne = ggml_nelements(dst);
uint32_t ne = ggml_nelements(dst);
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
// Convert from number of logical elements to 2- or 4-byte units.
ne /= ggml_blck_size(src0->type);
if ((ggml_type_size(src0->type) % 4) == 0) {
ne *= ggml_type_size(src0->type) / 4;
} else {
ne *= ggml_type_size(src0->type) / 2;
}
}
if (ne > 262144) {
elements = { 512, 512, CEIL_DIV(ne, 262144) };
} else if (ne > 512) {
@@ -7279,8 +7309,19 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
uint32_t ne = (uint32_t)ggml_nelements(src0);
if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
// Convert from number of logical elements to 2- or 4-byte units.
ne /= ggml_blck_size(src0->type);
if ((ggml_type_size(src0->type) % 4) == 0) {
ne *= ggml_type_size(src0->type) / 4;
} else {
ne *= ggml_type_size(src0->type) / 2;
}
}
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
(uint32_t)ggml_nelements(src0),
ne,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
@@ -9262,8 +9303,7 @@ static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_
try {
ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
} catch (vk::SystemError& e) {
std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl;
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what());
// fallback to cpu buffer
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
}
@@ -9865,6 +9905,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
return true;
}
// We can handle copying from a type to the same type if it's
// contiguous (memcpy). We use f16 or f32 shaders to do the copy,
// so the type/block size must be a multiple of 4.
if (src0_type == src1_type &&
ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) &&
(ggml_type_size(src0_type) % 2) == 0) {
return true;
}
return false;
} break;
case GGML_OP_REPEAT:

View File

@@ -1,6 +1,6 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#include "dequant_head.comp"

View File

@@ -7,7 +7,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#if defined(DATA_A_IQ1_M)
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#endif
#if defined(DATA_A_BF16) && defined(COOPMAT)

View File

@@ -1099,9 +1099,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSWISH",
"HARDSIGMOID",
"EXP",
"GELU_ERF",
};
static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2501,6 +2502,20 @@ struct ggml_tensor * ggml_gelu_inplace(
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
}
// ggml_gelu_erf
struct ggml_tensor * ggml_gelu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);
}
struct ggml_tensor * ggml_gelu_erf_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);
}
// ggml_gelu_quick
struct ggml_tensor * ggml_gelu_quick(

View File

@@ -219,10 +219,13 @@ class Keys:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
class ClipVision:
class Clip:
PROJECTOR_TYPE = "clip.projector_type"
HAS_VISION_ENCODER = "clip.has_vision_encoder"
HAS_AUDIO_ENCODER = "clip.has_audio_encoder"
HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
class ClipVision:
IMAGE_SIZE = "clip.vision.image_size"
PATCH_SIZE = "clip.vision.patch_size"
EMBEDDING_LENGTH = "clip.vision.embedding_length"
@@ -243,19 +246,33 @@ class Keys:
class Projector:
SCALE_FACTOR = "clip.vision.projector.scale_factor"
class ClipAudio:
NUM_MEL_BINS = "clip.audio.num_mel_bins"
EMBEDDING_LENGTH = "clip.audio.embedding_length"
FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length"
PROJECTION_DIM = "clip.audio.projection_dim"
BLOCK_COUNT = "clip.audio.block_count"
class Attention:
HEAD_COUNT = "clip.audio.attention.head_count"
LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon"
class Projector:
STACK_FACTOR = "clip.audio.projector.stack_factor"
#
# recommended mapping of model tensor names for storage in gguf
#
class GGUFType:
MODEL = "model"
ADAPTER = "adapter"
CLIP_VISION = "clip-vision"
MODEL = "model"
ADAPTER = "adapter"
MMPROJ = "mmproj" # dummy, unused for now
class MODEL_ARCH(IntEnum):
CLIP_VISION = auto() # dummy arch for clip.cpp
MMPROJ = auto() # dummy arch for clip.cpp
LLAMA = auto()
LLAMA4 = auto()
DECI = auto()
@@ -514,10 +531,27 @@ class MODEL_TENSOR(IntEnum):
V_RESMPL_QUERY = auto() # minicpmv
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
V_MM_PATCH_MERGER = auto() # mistral small 3.1
# audio (mtmd)
A_ENC_EMBD_POS = auto()
A_ENC_CONV1D = auto()
A_PRE_NORM = auto()
A_POST_NORM = auto()
A_ENC_ATTN_Q = auto()
A_ENC_ATTN_K = auto()
A_ENC_ATTN_V = auto()
A_ENC_INPUT_NORM = auto()
A_ENC_OUTPUT = auto()
A_ENC_OUTPUT_NORM = auto()
A_ENC_FFN_UP = auto()
A_ENC_FFN_GATE = auto()
A_ENC_FFN_DOWN = auto()
A_MMPROJ = auto()
A_MM_NORM_PRE = auto()
A_MM_NORM_MID = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp
MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.LLAMA4: "llama4",
MODEL_ARCH.DECI: "deci",
@@ -776,10 +810,27 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
# audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q",
MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k",
MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v",
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2",
MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up",
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_ARCH.CLIP_VISION: [
MODEL_ARCH.MMPROJ: [
MODEL_TENSOR.V_MMPROJ,
MODEL_TENSOR.V_MMPROJ_FC,
MODEL_TENSOR.V_MMPROJ_MLP,
@@ -819,6 +870,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_RESMPL_QUERY,
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
MODEL_TENSOR.V_MM_PATCH_MERGER,
# audio
MODEL_TENSOR.A_ENC_EMBD_POS,
MODEL_TENSOR.A_ENC_CONV1D,
MODEL_TENSOR.A_PRE_NORM,
MODEL_TENSOR.A_POST_NORM,
MODEL_TENSOR.A_ENC_ATTN_Q,
MODEL_TENSOR.A_ENC_ATTN_K,
MODEL_TENSOR.A_ENC_ATTN_V,
MODEL_TENSOR.A_ENC_INPUT_NORM,
MODEL_TENSOR.A_ENC_OUTPUT,
MODEL_TENSOR.A_ENC_OUTPUT_NORM,
MODEL_TENSOR.A_ENC_FFN_UP,
MODEL_TENSOR.A_ENC_FFN_GATE,
MODEL_TENSOR.A_ENC_FFN_DOWN,
MODEL_TENSOR.A_MMPROJ,
MODEL_TENSOR.A_MM_NORM_PRE,
MODEL_TENSOR.A_MM_NORM_MID,
],
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD,
@@ -2186,6 +2254,7 @@ class VisionProjectorType:
LLAMA4 = "llama4"
QWEN2VL = "qwen2vl_merger"
QWEN25VL = "qwen2.5vl_merger"
ULTRAVOX = "ultravox"
INTERNVL = "internvl"

View File

@@ -251,7 +251,7 @@ class GGUFReader:
offs += curr_size
return offs - orig_offs, aparts, data_idxs, types
# We can't deal with this one.
raise ValueError('Unknown/unhandled field type {gtype}')
raise ValueError(f'Unknown/unhandled field type {gtype}')
def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
offs = orig_offs

View File

@@ -896,7 +896,7 @@ class GGUFWriter:
def add_remove_extra_whitespaces(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value)
def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
def add_precompiled_charsmap(self, charsmap: bytes) -> None:
self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
@@ -936,12 +936,18 @@ class GGUFWriter:
# for vision models
def add_clip_has_vision_encoder(self, value: bool) -> None:
self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value)
def add_clip_has_audio_encoder(self, value: bool) -> None:
self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value)
def add_clip_projector_type(self, value: str) -> None:
self.add_string(Keys.Clip.PROJECTOR_TYPE, value)
def add_vision_projection_dim(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
def add_vision_has_vision_encoder(self, value: bool) -> None:
self.add_bool(Keys.ClipVision.HAS_VISION_ENCODER, value)
def add_vision_patch_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
@@ -957,9 +963,6 @@ class GGUFWriter:
def add_vision_head_count(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
def add_vision_projector_type(self, value: str) -> None:
self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value)
def add_vision_attention_layernorm_eps(self, value: float) -> None:
self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
@@ -987,6 +990,32 @@ class GGUFWriter:
def add_vision_n_wa_pattern(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
# audio models
def add_audio_projection_dim(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value)
def add_audio_embedding_length(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value)
def add_audio_feed_forward_length(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value)
def add_audio_block_count(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value)
def add_audio_head_count(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value)
def add_audio_attention_layernorm_eps(self, value: float) -> None:
self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value)
def add_audio_num_mel_bins(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value)
def add_audio_stack_factor(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = ''
if not skip_pack_prefix:

View File

@@ -1110,6 +1110,68 @@ class TensorNameMap:
MODEL_TENSOR.V_MM_PATCH_MERGER: (
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
),
# audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: (
"audio_tower.embed_positions", # ultravox
),
MODEL_TENSOR.A_ENC_CONV1D: (
"audio_tower.conv{bid}", # ultravox
),
MODEL_TENSOR.A_PRE_NORM: (),
MODEL_TENSOR.A_POST_NORM: (
"audio_tower.layer_norm", # ultravox
),
MODEL_TENSOR.A_ENC_ATTN_Q: (
"audio_tower.layers.{bid}.self_attn.q_proj", # ultravox
),
MODEL_TENSOR.A_ENC_ATTN_K: (
"audio_tower.layers.{bid}.self_attn.k_proj", # ultravox
),
MODEL_TENSOR.A_ENC_ATTN_V: (
"audio_tower.layers.{bid}.self_attn.v_proj", # ultravox
),
MODEL_TENSOR.A_ENC_INPUT_NORM: (
"audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox
),
MODEL_TENSOR.A_ENC_OUTPUT: (
"audio_tower.layers.{bid}.self_attn.out_proj", # ultravox
),
MODEL_TENSOR.A_ENC_OUTPUT_NORM: (
"audio_tower.layers.{bid}.final_layer_norm", # ultravox
),
MODEL_TENSOR.A_ENC_FFN_UP: (
"audio_tower.layers.{bid}.fc1", # ultravox
),
MODEL_TENSOR.A_ENC_FFN_GATE: (),
MODEL_TENSOR.A_ENC_FFN_DOWN: (
"audio_tower.layers.{bid}.fc2", # ultravox
),
MODEL_TENSOR.A_MMPROJ: (
"audio.multi_modal_projector.linear_{bid}", # ultravox
),
MODEL_TENSOR.A_MM_NORM_PRE: (
"audio.multi_modal_projector.ln_pre", # ultravox
),
MODEL_TENSOR.A_MM_NORM_MID: (
"audio.multi_modal_projector.ln_mid", # ultravox
),
}
# architecture-specific block mappings

View File

@@ -608,71 +608,14 @@ extern "C" {
// KV cache
//
// TODO: start using struct llama_kv_cache
// Information associated with an individual cell in the KV cache view.
struct llama_kv_cache_view_cell {
// The position for this cell. Takes KV cache shifts into account.
// May be negative if the cell is not populated.
llama_pos pos;
};
// An updateable view of the KV cache.
struct llama_kv_cache_view {
// Number of KV cache cells. This will be the same as the context size.
int32_t n_cells;
// Maximum number of sequences that can exist in a cell. It's not an error
// if there are more sequences in a cell than this value, however they will
// not be visible in the view cells_sequences.
int32_t n_seq_max;
// Number of tokens in the cache. For example, if there are two populated
// cells, the first with 1 sequence id in it and the second with 2 sequence
// ids then you'll have 3 tokens.
int32_t token_count;
// Number of populated cache cells.
int32_t used_cells;
// Maximum contiguous empty slots in the cache.
int32_t max_contiguous;
// Index to the start of the max_contiguous slot range. Can be negative
// when cache is full.
int32_t max_contiguous_idx;
// Information for an individual cell.
struct llama_kv_cache_view_cell * cells;
// The sequences for each cell. There will be n_seq_max items per cell.
llama_seq_id * cells_sequences;
};
// Create an empty KV cache view. (use only for debugging purposes)
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
// Free a KV cache view. (use only for debugging purposes)
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
// TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
///
// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
"use llama_kv_self_n_tokens instead");
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
"Use llama_kv_self_seq_pos_max() instead");
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
"use llama_kv_self_used_cells instead");
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
"Use llama_kv_self_seq_pos_max() instead");
// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_self_clear(
@@ -756,61 +699,6 @@ extern "C" {
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx),
"use llama_kv_self_clear instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_rm instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_cp instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_keep instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta),
"use llama_kv_self_seq_add instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d),
"use llama_kv_self_seq_div instead");
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_pos_max instead");
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
"use llama_kv_self_defrag instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
"use llama_kv_self_can_shift instead");
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
"use llama_kv_self_update instead");
//
// State / sessions
//

View File

@@ -0,0 +1,62 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0]['role'] == 'system' %}
{{- messages[0]['content'] }}
{%- else %}
{{- '' }}
{%- endif %}
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0]['role'] == 'system' %}
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" and not message.tool_calls %}
{%- set content = message.content %}
{%- if not loop.last %}
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set content = message.content %}
{%- if not loop.last %}
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{{- '<|im_start|>' + message.role }}
{%- if message.content %}
{{- '\n' + content }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '\n<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{{- tool_call.arguments | tojson }}
{{- '}\n</tool_call>' }}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n<think>\n' }}
{%- endif %}

View File

@@ -19,4 +19,5 @@ These templates can be updated with the following commands:
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
```

View File

@@ -1,3 +1,7 @@
-r ./requirements-convert_legacy_llama.txt
--extra-index-url https://download.pytorch.org/whl/cpu
torch~=2.2.1
torch~=2.2.1; platform_machine != "s390x"
# torch s390x packages can only be found from nightly builds
--extra-index-url https://download.pytorch.org/whl/nightly
torch>=0.0.0.dev0; platform_machine == "s390x"

View File

@@ -1,3 +1,7 @@
-r ./requirements-convert_legacy_llama.txt
--extra-index-url https://download.pytorch.org/whl/cpu
torch~=2.2.1
torch~=2.2.1; platform_machine != "s390x"
# torch s390x packages can only be found from nightly builds
--extra-index-url https://download.pytorch.org/whl/nightly
torch>=0.0.0.dev0; platform_machine == "s390x"

View File

@@ -1,2 +1,4 @@
-r ./requirements-convert_hf_to_gguf.txt
--extra-index-url https://download.pytorch.org/whl/cpu
# torch s390x packages can only be found from nightly builds
--extra-index-url https://download.pytorch.org/whl/nightly

View File

@@ -12,6 +12,7 @@
export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b
@@ -205,6 +206,7 @@ def run(
model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None,
ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
@@ -229,6 +231,12 @@ def run(
# n_ctx = 8192
n_ctx = 2048
if model is None:
if hf is not None:
model = hf.split("/")[-1]
elif ollama is not None:
model = ollama
assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
with output.open('a' if append else 'w') as output_file:
@@ -320,6 +328,7 @@ def run(
server.model_hf_repo = hf
server.model_hf_file = None
server.chat_template = chat_template
server.chat_template_file = chat_template_file
server.server_path = server_path
if port is not None:
server.server_port = port
@@ -335,6 +344,7 @@ def run(
temp=t,
output_kwargs=dict(
chat_template=chat_template,
chat_template_file=chat_template_file,
),
request_kwargs=dict(
ignore_chat_grammar=ignore_chat_grammar,
@@ -355,6 +365,7 @@ def run(
temp=t,
output_kwargs=dict(
chat_template=None,
chat_template_file=None,
),
request_kwargs=dict(
model=ollama,

View File

@@ -1,5 +1,6 @@
#include "llama-batch.h"
#include <cassert>
#include <cstring>
#include <algorithm>
@@ -281,9 +282,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
batch = in_batch;
GGML_ASSERT(batch.n_tokens > 0);
if (!batch.pos) {
assert(p0 >= 0);
pos.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i] = i + p0;
pos[i] = p0 + i;
}
batch.pos = pos.data();
}

View File

@@ -857,11 +857,17 @@ int llama_context::decode(llama_batch & inp_batch) {
return -1;
}
if (!inp_batch.pos) {
if (inp_batch.seq_id) {
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
return -1;
}
}
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
const llama_batch & batch = batch_allocr.batch;
@@ -2288,65 +2294,51 @@ int32_t llama_apply_adapter_cvec(
return res ? 0 : -1;
}
//
// kv cache view
//
llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
const auto * kv = ctx->get_kv_self();
if (kv == nullptr) {
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
return {};
}
return llama_kv_cache_view_init(*kv, n_seq_max);
}
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
const auto * kv = ctx->get_kv_self();
if (kv == nullptr) {
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
return;
}
llama_kv_cache_view_update(view, kv);
}
//
// kv cache
//
// deprecated
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
return llama_kv_self_n_tokens(ctx);
}
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
}
return kv->get_n_tokens();
int32_t res = 0;
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
const llama_pos p0 = kv->seq_pos_min(s);
const llama_pos p1 = kv->seq_pos_max(s);
if (p0 >= 0) {
res += (p1 - p0) + 1;
}
}
return res;
}
// deprecated
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
return llama_kv_self_used_cells(ctx);
}
// note: this is the same as above - will be removed anyway, so it's ok
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
}
return kv->get_used_cells();
}
int32_t res = 0;
// deprecated
void llama_kv_cache_clear(llama_context * ctx) {
llama_kv_self_clear(ctx);
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
const llama_pos p0 = kv->seq_pos_min(s);
const llama_pos p1 = kv->seq_pos_max(s);
if (p0 >= 0) {
res += (p1 - p0) + 1;
}
}
return res;
}
void llama_kv_self_clear(llama_context * ctx) {
@@ -2358,15 +2350,6 @@ void llama_kv_self_clear(llama_context * ctx) {
kv->clear();
}
// deprecated
bool llama_kv_cache_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
}
bool llama_kv_self_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
@@ -2380,16 +2363,6 @@ bool llama_kv_self_seq_rm(
return kv->seq_rm(seq_id, p0, p1);
}
// deprecated
void llama_kv_cache_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_self_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
@@ -2404,13 +2377,6 @@ void llama_kv_self_seq_cp(
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
// deprecated
void llama_kv_cache_seq_keep(
llama_context * ctx,
llama_seq_id seq_id) {
llama_kv_self_seq_keep(ctx, seq_id);
}
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
auto * kv = ctx->get_kv_self();
if (!kv) {
@@ -2420,16 +2386,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
kv->seq_keep(seq_id);
}
// deprecated
void llama_kv_cache_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
}
void llama_kv_self_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
@@ -2444,16 +2400,6 @@ void llama_kv_self_seq_add(
kv->seq_add(seq_id, p0, p1, delta);
}
// deprecated
void llama_kv_cache_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
}
void llama_kv_self_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
@@ -2477,11 +2423,6 @@ llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
return kv->seq_pos_min(seq_id);
}
// deprecated
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_self_seq_pos_max(ctx, seq_id);
}
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self();
if (!kv) {
@@ -2491,11 +2432,6 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return kv->seq_pos_max(seq_id);
}
// deprecated
void llama_kv_cache_defrag(llama_context * ctx) {
llama_kv_self_defrag(ctx);
}
void llama_kv_self_defrag(llama_context * ctx) {
auto * kv = ctx->get_kv_self();
if (!kv) {
@@ -2506,11 +2442,6 @@ void llama_kv_self_defrag(llama_context * ctx) {
kv->defrag_sched(-1.0f);
}
// deprecated
bool llama_kv_cache_can_shift(const llama_context * ctx) {
return llama_kv_self_can_shift(ctx);
}
bool llama_kv_self_can_shift(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self();
if (!kv) {
@@ -2520,11 +2451,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
return kv->get_can_shift();
}
// deprecated
void llama_kv_cache_update(llama_context * ctx) {
llama_kv_self_update(ctx);
}
// llama state API
// deprecated

View File

@@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
for (const auto & trigger_pattern : grammar.trigger_patterns) {
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
grammar.awaiting_trigger = false;
// get from the first match to the end of the string
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
// get from the first matched capturing group to the end of the string
size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}
auto constrained_str = grammar.trigger_buffer.substr(start);
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, constrained_str);

View File

@@ -1236,8 +1236,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
{
GGML_ASSERT(hparams.n_swa_pattern == 1 && "Use llama_kv_cache_unified_iswa for SWA");
GGML_ASSERT(hparams.n_swa == 0 && "Use llama_kv_cache_unified_iswa for SWA");
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
const auto n_kv = kv_self->get_n();
@@ -1288,6 +1287,10 @@ ggml_tensor * llm_graph_context::build_attn(
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
@@ -1312,8 +1315,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
if (hparams.n_swa_pattern > 1) {
GGML_ASSERT(hparams.n_swa > 0 && "Use llama_kv_cache_unified for non-SWA");
{
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
const auto n_kv = kv_self->get_kv_swa()->get_n();
@@ -1368,10 +1371,6 @@ ggml_tensor * llm_graph_context::build_attn(
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {

View File

@@ -2,6 +2,22 @@
#include "ggml.h"
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
for (uint32_t il = 0; il < n_layer; ++il) {
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
}
}
bool llama_hparams::is_swa_any() const {
for (uint32_t il = 0; il < n_layer; ++il) {
if (swa_layers[il]) {
return true;
}
}
return false;
}
uint32_t llama_hparams::n_head(uint32_t il) const {
if (il < n_layer) {
return n_head_arr[il];
@@ -72,7 +88,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) {
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
return swa_layers[il];
}
GGML_ABORT("fatal error");

View File

@@ -102,9 +102,12 @@ struct llama_hparams {
// Sliding Window Attention (SWA)
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
// the size of the sliding window (0 - no SWA)
uint32_t n_swa = 0;
// if swa_layers[il] == true, then layer il is SWA
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
// by default, all layers are dense
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
// for State Space Models
uint32_t ssm_d_conv = 0;
@@ -142,6 +145,23 @@ struct llama_hparams {
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
// note that if n_pattern == 0, all layers are SWA
// if n_pattern == 1, all layers are dense
// example: n_pattern = 3
// il == 0: swa
// il == 1: swa
// il == 2: dense
// il == 3: swa
// il == 4: swa
// il == 5: dense
// il == 6: swa
// etc ...
void set_swa_pattern(uint32_t n_pattern);
// return true if one of the layers is SWA
bool is_swa_any() const;
uint32_t n_head(uint32_t il = 0) const;
uint32_t n_head_kv(uint32_t il = 0) const;

View File

@@ -30,13 +30,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
bool v_trans,
bool offload,
uint32_t kv_size,
uint32_t padding,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
llama_swa_type swa_type) :
model(model), hparams(model.hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
this->type_k = type_k;
this->type_v = type_v;
GGML_ASSERT(kv_size % n_pad == 0);
// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -129,8 +130,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
const size_t memory_size_k = size_k_bytes();
const size_t memory_size_v = size_v_bytes();
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6d cells, %3d layers), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(),
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
@@ -442,7 +443,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
void llama_kv_cache_unified::defrag_sched(float thold) {
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > thold) {
@@ -558,7 +559,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
#ifdef FIND_SLOT_DEBUG
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
@@ -567,20 +568,6 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
return true;
}
int32_t llama_kv_cache_unified::get_n_tokens() const {
int32_t result = 0;
for (uint32_t i = 0; i < size; i++) {
result += cells[i].seq_id.size();
}
return result;
}
int32_t llama_kv_cache_unified::get_used_cells() const {
return used;
}
bool llama_kv_cache_unified::get_can_shift() const {
return true;
}
@@ -802,16 +789,6 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
}
}
llama_pos llama_kv_cache_unified::get_pos_max() const {
llama_pos pos_max = -1;
for (const auto & cell : cells) {
pos_max = std::max(pos_max, cell.pos);
}
return pos_max;
}
size_t llama_kv_cache_unified::total_size() const {
size_t size = 0;
@@ -1501,11 +1478,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id));
// TODO: llama_kv_cache_unified should have a notion of max sequences
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
if (seq_id < 0) {
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
return false;
}
@@ -1655,17 +1629,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
ggml_type type_v,
bool v_trans,
bool offload,
uint32_t kv_size,
bool swa_full,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_batch,
uint32_t padding) : hparams(model.hparams) {
uint32_t n_pad) : hparams(model.hparams) {
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
const uint32_t size_base = kv_size;
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
if (swa_full) {
@@ -1680,14 +1654,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
kv_base = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_base), type_k, type_v,
v_trans, offload, size_base, padding,
v_trans, offload, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_swa), type_k, type_v,
v_trans, offload, size_swa, padding,
v_trans, offload, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type);
}
@@ -1810,18 +1784,6 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
return res;
}
int32_t llama_kv_cache_unified_iswa::get_n_tokens() const {
return kv_base->get_n_tokens();
}
int32_t llama_kv_cache_unified_iswa::get_used_cells() const {
return kv_base->get_used_cells();
}
llama_pos llama_kv_cache_unified_iswa::get_pos_max() const {
return kv_base->get_pos_max();
}
bool llama_kv_cache_unified_iswa::get_can_shift() const {
return kv_base->get_size() == kv_swa->get_size();
}
@@ -1853,19 +1815,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size) : hparams(model.hparams) {
uint32_t kv_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer;
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
__func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
head = 0;
size = kv_size;
used = 0;
this->type_k = type_k;
this->type_v = type_v;
cells.clear();
cells.resize(kv_size);
@@ -2203,8 +2163,8 @@ void llama_kv_cache_recurrent::commit() {
pending.ranges.clear();
}
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
GGML_UNUSED(lctx);
bool llama_kv_cache_recurrent::update(llama_context & ctx) {
GGML_UNUSED(ctx);
return false;
}
@@ -2265,7 +2225,7 @@ bool llama_kv_cache_recurrent::find_slot(
if (seq_id < 0 || (uint32_t) seq_id >= size) {
// too big seq_id
// TODO: would it be possible to resize the cache instead?
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
return false;
}
if (j > 0) {
@@ -2408,29 +2368,6 @@ bool llama_kv_cache_recurrent::find_slot(
return n >= n_seqs;
}
int32_t llama_kv_cache_recurrent::get_n_tokens() const {
int32_t result = 0;
for (uint32_t i = 0; i < size; i++) {
result += cells[i].seq_id.size();
}
return result;
}
int32_t llama_kv_cache_recurrent::get_used_cells() const {
return used;
}
llama_pos llama_kv_cache_recurrent::get_pos_max() const {
llama_pos pos_max = -1;
for (const auto & cell : cells) {
pos_max = std::max(pos_max, cell.pos);
}
return pos_max;
}
bool llama_kv_cache_recurrent::get_can_shift() const {
return false;
}
@@ -2888,38 +2825,3 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
return true;
}
//
// kv cache view
//
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) {
llama_kv_cache_view result = {
/*.n_cells = */ 0,
/*.n_seq_max = */ n_seq_max,
/*.token_count = */ 0,
/*.used_cells = */ kv.get_used_cells(),
/*.max_contiguous = */ 0,
/*.max_contiguous_idx = */ -1,
/*.cells = */ nullptr,
/*.cells_sequences = */ nullptr,
};
return result;
}
void llama_kv_cache_view_free(llama_kv_cache_view * view) {
if (view->cells != nullptr) {
free(view->cells);
view->cells = nullptr;
}
if (view->cells_sequences != nullptr) {
free(view->cells_sequences);
view->cells_sequences = nullptr;
}
}
void llama_kv_cache_view_update(llama_kv_cache_view * , const llama_kv_cache * ) {
// TODO: will be removed soon, keep this for now to avoid too many changes in
// https://github.com/ggml-org/llama.cpp/pull/13194
}

View File

@@ -55,10 +55,7 @@ struct llama_kv_cache : public llama_memory_i {
// =============================================================================================================
// getters
virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual llama_pos get_pos_max() const = 0;
virtual bool get_can_shift() const = 0;
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); }
@@ -108,7 +105,8 @@ public:
bool v_trans,
bool offload,
uint32_t kv_size,
uint32_t padding,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type);
@@ -150,12 +148,6 @@ public:
// to the first cell of the slot.
bool find_slot(const llama_ubatch & batch) override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
bool get_can_shift() const override;
// state write/load
@@ -228,16 +220,15 @@ private:
// computed before each graph build
uint32_t n = 0;
// required padding
uint32_t padding = 1;
const uint32_t n_seq_max = 1;
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
// required padding
const uint32_t n_pad = 1;
// SWA
uint32_t n_swa = 0;
const uint32_t n_swa = 0;
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
@@ -317,11 +308,11 @@ public:
ggml_type type_v,
bool v_trans,
bool offload,
uint32_t kv_size,
bool swa_full,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_batch,
uint32_t padding);
uint32_t n_pad);
~llama_kv_cache_unified_iswa() = default;
@@ -358,12 +349,6 @@ public:
bool find_slot(const llama_ubatch & batch) override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
bool get_can_shift() const override;
// state write/load
@@ -432,7 +417,8 @@ public:
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size);
uint32_t kv_size,
uint32_t n_seq_max);
~llama_kv_cache_recurrent() = default;
@@ -444,7 +430,7 @@ public:
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
@@ -458,7 +444,7 @@ public:
void restore() override;
void commit() override;
bool update(llama_context & lctx) override;
bool update(llama_context & ctx) override;
void defrag_sched(float thold) override;
@@ -469,12 +455,6 @@ public:
bool find_slot(const llama_ubatch & batch) override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
bool get_can_shift() const override;
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
@@ -514,8 +494,7 @@ private:
std::vector<slot_range> ranges;
} pending;
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
const uint32_t n_seq_max = 1;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
@@ -534,12 +513,3 @@ private:
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
//
// kv cache view
//
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);

View File

@@ -463,11 +463,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
GGML_ASSERT(hparams.n_expert_used == 0);
}
// zero-out the array hparams
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0);
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false);
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
@@ -574,7 +577,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
switch (hparams.n_expert) {
case 16: type = LLM_TYPE_17B_16E; break;
@@ -853,44 +856,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
// for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931
if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) {
// default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
LLAMA_LOG_WARN("%s: assuming n_swa = 2047 for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct\n", __func__);
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa = 2047;
} else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
// default value for Phi-3-mini-128k-instruct
LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-mini-128k-instruct\n", __func__);
if (found_swa && hparams.n_swa > 0) {
LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/13676");
// TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern`
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_swa = hparams.n_ctx_train;
hparams.n_swa_pattern = 1;
} else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
// default value for Phi-3-medium-128k-instruct
LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-medium-128k-instruct\n", __func__);
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_swa = hparams.n_ctx_train;
hparams.n_swa_pattern = 1;
}
bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
if (!found_swa && hparams.n_swa == 0) {
throw std::runtime_error("invalid value for sliding_window");
}
if (hparams.n_swa > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: unexpected n_swa: %d >= %d, disabling SWA\n", __func__, hparams.n_swa, hparams.n_ctx_train);
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_swa = hparams.n_ctx_train;
hparams.n_swa_pattern = 1;
hparams.n_swa = 0;
hparams.set_swa_pattern(1);
}
} break;
case LLM_ARCH_PHIMOE:
@@ -962,7 +938,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa = 4096; // default value of gemma 2
hparams.n_swa_pattern = 2;
hparams.set_swa_pattern(2);
hparams.attn_soft_cap = true;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
@@ -980,7 +956,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_GEMMA3:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 6;
hparams.set_swa_pattern(6);
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
@@ -1065,7 +1041,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_COHERE2:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 4;
hparams.set_swa_pattern(4);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
@@ -4347,7 +4323,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern);
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
@@ -4803,8 +4779,21 @@ struct llm_build_llama_iswa : public llm_graph_context {
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
{
// llama4 MoE
// feed-forward network (non-MoE)
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
@@ -7355,8 +7344,9 @@ struct llm_build_phi2 : public llm_graph_context {
}
};
struct llm_build_phi3_iswa : public llm_graph_context {
llm_build_phi3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
template<bool iswa>
struct llm_build_phi3 : public llm_graph_context {
llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -7370,7 +7360,14 @@ struct llm_build_phi3_iswa : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified_iswa();
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
inp_attn_type * inp_attn = nullptr;
if constexpr (iswa) {
inp_attn = build_attn_inp_kv_unified_iswa();
} else {
inp_attn = build_attn_inp_kv_unified();
}
for (int il = 0; il < n_layer; ++il) {
auto * residual = inpL;
@@ -13195,6 +13192,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_WAVTOKENIZER_DEC:
{
res = nullptr;
} break;
@@ -13209,7 +13207,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
GGML_TYPE_F32,
GGML_TYPE_F32,
cparams.offload_kqv,
std::max((uint32_t) 1, cparams.n_seq_max));
std::max((uint32_t) 1, cparams.n_seq_max),
cparams.n_seq_max);
} break;
default:
{
@@ -13219,19 +13218,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
if (hparams.n_swa > 0) {
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
res = new llama_kv_cache_unified_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.n_ctx,
params.swa_full,
cparams.n_ctx,
cparams.n_seq_max,
cparams.n_batch,
padding);
} else {
GGML_ASSERT(!hparams.is_swa_any());
res = new llama_kv_cache_unified(
*this,
nullptr,
@@ -13240,6 +13243,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.n_ctx,
cparams.n_seq_max,
padding,
hparams.n_swa,
hparams.swa_type);
@@ -13340,7 +13344,11 @@ llm_graph_result_ptr llama_model::build_graph(
case LLM_ARCH_PHI3:
case LLM_ARCH_PHIMOE:
{
llm = std::make_unique<llm_build_phi3_iswa>(*this, params, gf);
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
llm = std::make_unique<llm_build_phi3<true>> (*this, params, gf);
} else {
llm = std::make_unique<llm_build_phi3<false>>(*this, params, gf);
}
} break;
case LLM_ARCH_PLAMO:
{

View File

@@ -835,7 +835,7 @@ struct llm_tokenizer_ugm_session {
}
// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX});
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX});
// at the beginning tokenization score is zero
tokenization_results[0] = { vocab.token_unk(), 0, 0 };
@@ -867,7 +867,7 @@ struct llm_tokenizer_ugm_session {
const double challenger_score = current_best.score_sum + token_score;
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
if (challenger_score > current_champ.score_sum) {
struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
struct best_tokenization challenger = { token_id, input_offset, challenger_score };
current_champ = challenger;
}
}
@@ -881,7 +881,7 @@ struct llm_tokenizer_ugm_session {
prefix_offset = input_offset + n_utf8_code_units;
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
if (challenger_score > current_champ.score_sum) {
struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score };
struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score };
current_champ = challenger;
}
}
@@ -1007,7 +1007,7 @@ private:
struct best_tokenization {
llama_token token_id;
size_t input_offset;
float score_sum;
double score_sum;
};
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {

View File

@@ -142,8 +142,10 @@ if (NOT WIN32)
# llama_build_and_test(test-double-float.cpp) # SLOW
endif()
llama_build_and_test(test-log.cpp)
llama_build_and_test(test-chat-parser.cpp)
llama_build_and_test(test-chat-template.cpp)
llama_build_and_test(test-json-partial.cpp)
llama_build_and_test(test-log.cpp)
llama_build_and_test(test-regex-partial.cpp)
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)

View File

@@ -128,7 +128,7 @@ int main(void) {
if (common_has_curl()) {
printf("test-arg-parser: test curl-related functions\n\n");
const char * GOOD_URL = "https://raw.githubusercontent.com/ggml-org/llama.cpp/refs/heads/master/README.md";
const char * GOOD_URL = "https://ggml.ai/";
const char * BAD_URL = "https://www.google.com/404";
const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin";

355
tests/test-chat-parser.cpp Normal file
View File

@@ -0,0 +1,355 @@
// Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
//
// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
// e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
//
// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
//
#include <exception>
#include <iostream>
#include <json.hpp>
#include <string>
#include "chat-parser.h"
#include "common.h"
#include "log.h"
#include "regex-partial.h"
using json = nlohmann::ordered_json;
template <class T>
static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
static void assert_equals(const char * expected, const std::string & actual) {
return assert_equals<std::string>(expected, actual);
}
static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
try {
fn();
} catch (const std::exception & e) {
if (expected_exception_pattern.empty()) {
return;
}
std::regex expected_exception_regex(expected_exception_pattern);
std::string actual_message = e.what();
if (std::regex_search(actual_message, expected_exception_regex)) {
return;
}
throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
}
throw std::runtime_error("Exception was expected but not thrown");
}
static void test_reasoning() {
{
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
assert_equals("Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
assert_equals("Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ true,
/* .thinking_forced_open = */ true,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("<think>Cogito</think>", builder.result().content);
assert_equals("Ergo sum", builder.consume_rest());
}
}
static void test_regex() {
auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
common_chat_msg_parser builder(input, /* is_partial= */ false, {});
assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
};
test_throws("Hello, world!", "abc", "^abc$");
test_throws("Hello, world!", "e", "^e$");
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
builder.consume_regex(common_regex("Hello"));
assert_equals(", world!", builder.consume_rest());
}
{
// When in non partial mode, we can say whether the regex was consumed or not.
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
}
{
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
assert_equals(true, res.has_value());
// Verify captures
assert_equals<size_t>(2, res->groups.size());
assert_equals("Hell", builder.str(res->groups[0]));
assert_equals("el", builder.str(res->groups[1]));
// Verify position is after the match
assert_equals<size_t>(4, builder.pos());
assert_equals("o,", builder.consume_rest());
}
{
// But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
assert_throws([&]() {
builder.try_consume_regex(common_regex("Hello, world!"));
}, "^Hello, world!$");
}
// Now regardless of the mode, we can tell these aren't a match.
for (const auto is_partial : {false, true}) {
common_chat_msg_parser builder("Hello,", is_partial, {});
assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
}
for (const auto is_partial : {false, true}) {
common_chat_msg_parser builder("Hello,", is_partial, {});
assert_equals(false, builder.try_consume_literal("Oh"));
}
}
const std::vector<std::string> barely_healable_jsons = {
"{",
"{\"",
"{\"\\",
"{\"n",
"{\"name\"",
"{\"name\":",
"{\"name\":\"",
"{\"name\":\"\\",
"{\"name\":\"python",
"{\"name\":\"python\\",
"{\",",
"{\":",
"{\"[",
"{\"]",
"{\"{",
"{\"}",
"{\"1",
"{\"name\":\",",
"{\"name\":\":",
"{\"name\":\"[",
"{\"name\":\"]",
"{\"name\":\"{",
"{\"name\":\"}",
"{\"name\":\"1",
};
static void test(const std::string & input, bool is_partial, const std::vector<std::vector<std::string>> & args_paths, const std::vector<std::vector<std::string>> & content_paths, const std::string & expected) {
common_chat_msg_parser builder(input, is_partial, {});
auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
assert_equals(true, js.has_value());
assert_equals(is_partial, js->is_partial);
assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get<std::string>() : js->value.dump());
}
static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
common_chat_msg_parser builder(input, parse_as_partial, {});
auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
assert_equals(true, js.has_value());
assert_equals(is_partial, js->is_partial);
assert_equals(expected, js->value.dump());
}
static void test_json_with_dumped_args_no_args() {
// Normal JSON, nothing to heal, nothing to dump
test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
// Full json is args
test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
// If the arguments are further down, don't heal partial content.
for (const auto & src : barely_healable_jsons) {
test(src, true, {{"arguments"}}, {}, "{}");
}
// But heal content that isn't partial.
test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
}
static void test_json_with_dumped_args() {
// Partial content.
test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
test("{\"content\": ", true, {}, {{"content"}}, "{}");
// If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
for (const auto & src : barely_healable_jsons) {
test(src, true, {{}}, {}, src);
}
// Full JSON w/ args
for (auto parse_as_partial : {true, false}) {
test_with_args(
R"({"name": "python", "args": {"arg1": 1}})",
R"({"name":"python","args":"{\"arg1\":1}"})",
parse_as_partial,
/* is_partial= */ false
);
}
// Partial JSON w/ partial args
test_with_args(
R"({"foo": "bar", "args": {")",
R"({"foo":"bar","args":"{\""})"
);
// Partial args broken in object key
test_with_args(
R"({"foo": "bar", "args": {"ar)",
R"({"foo":"bar","args":"{\"ar"})"
);
// Partial args broken after object key
test_with_args(
R"({"foo": "bar", "args": {"arg1")",
R"({"foo":"bar","args":"{\"arg1\""})"
);
// Partial args broken before object value
test_with_args(
R"({"foo": "bar", "args": {"arg1":)",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken before object value (space)
test_with_args(
R"({"foo": "bar", "args": {"arg1": )",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken in object value that may not be complete (int)
test_with_args(
R"({"foo": "bar", "args": {"arg1": 1)",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken in object value that is complete (int)
test_with_args(
R"({"foo": "bar", "args": {"arg1": 1 )",
R"({"foo":"bar","args":"{\"arg1\":1"})"
);
// Partial args broken in object value that is incomplete (string)
test_with_args(
R"({"foo": "bar", "args": {"arg1": ")",
R"({"foo":"bar","args":"{\"arg1\":\""})"
);
// Partial args broken in object value that is complete (string)
test_with_args(
R"({"foo": "bar", "args": {"arg1": "1")",
R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
);
// Partial args broken on array opening
test_with_args(
R"({"foo": "bar", "args": [)",
R"({"foo":"bar","args":"["})"
);
// Partial args broken on array value that is incomplete (int)
test_with_args(
R"({"foo": "bar", "args": [1)",
R"({"foo":"bar","args":"["})"
);
// Partial args broken on array value that is complete (int)
test_with_args(
R"({"foo": "bar", "args": [1 )",
R"({"foo":"bar","args":"[1"})"
);
// Partial args broken on array value that is complete (string)
test_with_args(
R"({"foo": "bar", "args": ["1")",
R"({"foo":"bar","args":"[\"1\""})"
);
// Partial args broken after array value
test_with_args(
R"({"foo": "bar", "args": [1,)",
R"({"foo":"bar","args":"[1,"})"
);
// Partial args broken on nested array
test_with_args(
R"({"foo": "bar", "args": {"arg1": [)",
R"({"foo":"bar","args":"{\"arg1\":["})"
);
}
static void test_positions() {
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.move_to(100); });
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.move_back(1); });
assert_equals<size_t>(0, builder.pos());
builder.move_to(8);
assert_equals<size_t>(8, builder.pos());
builder.move_back(1);
assert_equals<size_t>(7, builder.pos());
assert_equals("world!", builder.consume_rest());
builder.move_to(0);
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.finish(); });
assert_equals<size_t>(0, builder.pos());
builder.move_to(builder.input().size());
builder.finish();
}
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
builder.move_to(builder.input().size());
assert_equals<size_t>(builder.input().size(), builder.pos());
builder.finish();
}
}
int main() {
test_positions();
test_json_with_dumped_args_no_args();
test_json_with_dumped_args();
test_reasoning();
test_regex();
std::cout << "All tests passed!\n";
return 0;
}

File diff suppressed because it is too large Load Diff

237
tests/test-json-partial.cpp Normal file
View File

@@ -0,0 +1,237 @@
#include "common.h"
#include "json-partial.h"
#include <exception>
#include <iostream>
#include <stdexcept>
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
static void test_json_healing() {
auto parse = [](const std::string & str) {
std::cerr << "# Parsing: " << str << '\n';
std::string::const_iterator it = str.begin();
const auto end = str.end();
common_json out;
std::string healing_marker = "$llama.cpp.json$";
if (common_json_parse(it, end, healing_marker, out)) {
auto dump = out.json.dump();
std::cerr << "Parsed: " << dump << '\n';
std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
std::string result;
if (!out.healing_marker.json_dump_marker.empty()) {
auto i = dump.find(out.healing_marker.json_dump_marker);
if (i == std::string::npos) {
throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
}
result = dump.substr(0, i);
} else {
result = dump;
}
std::cerr << "Result: " << result << '\n';
if (string_starts_with(str, result)) {
std::cerr << "Failure!\n";
}
// return dump;
} else {
throw std::runtime_error("Failed to parse: " + str);
}
};
auto parse_all = [&](const std::string & str) {
for (size_t i = 1; i < str.size(); i++) {
parse(str.substr(0, i));
}
};
parse_all("{\"a\": \"b\"}");
parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
parse_all("[{\"a\": \"b\"}]");
auto test = [&](const std::vector<std::string> & inputs, const std::string & expected, const std::string & expected_marker) {
for (const auto & input : inputs) {
common_json out;
assert_equals(true, common_json_parse(input, "$foo", out));
assert_equals<std::string>(expected, out.json.dump());
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
}
};
// No healing needed:
test(
{
R"([{"a":"b"}, "y"])",
},
R"([{"a":"b"},"y"])",
""
);
// Partial literals can't be healed:
test(
{
R"([1)",
R"([tru)",
R"([n)",
R"([nul)",
R"([23.2)",
},
R"(["$foo"])",
R"("$foo)"
);
test(
{
R"({"a": 1)",
R"({"a": tru)",
R"({"a": n)",
R"({"a": nul)",
R"({"a": 23.2)",
},
R"({"a":"$foo"})",
R"("$foo)"
);
test(
{
R"({)",
},
R"({"$foo":1})",
R"("$foo)"
);
test(
{
R"([)",
},
R"(["$foo"])",
R"("$foo)"
);
// Healing right after a full literal
test(
{
R"(1 )",
},
R"(1)",
""
);
test(
{
R"(true)",
R"(true )",
},
R"(true)",
""
);
test(
{
R"(null)",
R"(null )",
},
R"(null)",
""
);
test(
{
R"([1 )",
},
R"([1,"$foo"])",
R"(,"$foo)"
);
test(
{
R"([{})",
R"([{} )",
},
R"([{},"$foo"])",
R"(,"$foo)"
);
test(
{
R"([true)",
},
// TODO: detect the true/false/null literal was complete
R"(["$foo"])",
R"("$foo)"
);
test(
{
R"([true )",
},
R"([true,"$foo"])",
R"(,"$foo)"
);
test(
{
R"([true,)",
},
R"([true,"$foo"])",
R"("$foo)"
);
// Test nesting
test(
{
R"([{"a": [{"b": [{)",
},
R"([{"a":[{"b":[{"$foo":1}]}]}])",
R"("$foo)"
);
test(
{
R"([{"a": [{"b": [)",
},
R"([{"a":[{"b":["$foo"]}]}])",
R"("$foo)"
);
test(
{
R"([{"a": "b"})",
R"([{"a": "b"} )",
},
R"([{"a":"b"},"$foo"])",
R"(,"$foo)"
);
test(
{
R"([{"a": "b"},)",
R"([{"a": "b"}, )",
},
R"([{"a":"b"},"$foo"])",
R"("$foo)"
);
test(
{
R"({ "code)",
},
R"({"code$foo":1})",
R"($foo)"
);
test(
{
R"({ "code\)",
},
R"({"code\\$foo":1})",
R"(\$foo)"
);
test(
{
R"({ "code")",
},
R"({"code":"$foo"})",
R"(:"$foo)"
);
test(
{
R"({ "key")",
},
R"({"key":"$foo"})",
R"(:"$foo)"
);
}
int main() {
test_json_healing();
std::cerr << "All tests passed.\n";
return 0;
}

View File

@@ -1,5 +1,15 @@
# mtmd
# compile mtmd-audio separately to avoid long compile times with miniaudio.h
# TODO @ngxson : move miniaudio.h and stb_image.h to mtmd-helper.cpp, then compile the helper as a separate library
add_library(mtmd_audio STATIC mtmd-audio.cpp mtmd-audio.h)
if (BUILD_SHARED_LIBS)
set_target_properties(mtmd_audio PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
target_link_libraries(mtmd_audio PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(mtmd_audio PRIVATE cxx_std_17)
target_include_directories(mtmd_audio PRIVATE .)
add_library(mtmd OBJECT
mtmd.cpp
mtmd-helper.cpp
@@ -9,7 +19,7 @@ add_library(mtmd OBJECT
clip-impl.h
)
target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(mtmd PRIVATE ggml llama mtmd_audio ${CMAKE_THREAD_LIBS_INIT})
target_include_directories(mtmd PUBLIC .)
target_include_directories(mtmd PRIVATE ../..)
@@ -22,12 +32,13 @@ if (BUILD_SHARED_LIBS)
set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD)
add_library(mtmd_shared SHARED $<TARGET_OBJECTS:mtmd>)
target_link_libraries(mtmd_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(mtmd_shared PRIVATE ggml llama mtmd_audio ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS mtmd_shared LIBRARY)
endif()
if (NOT MSVC)
target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
target_compile_options(mtmd_audio PRIVATE -Wno-cast-qual) # miniaudio.h
endif()
if(TARGET BUILD_INFO)

View File

@@ -16,22 +16,26 @@
#define KEY_FTYPE "general.file_type"
#define KEY_NAME "general.name"
#define KEY_DESCRIPTION "general.description"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_HAS_AUDIO_ENC "clip.has_audio_encoder"
#define KEY_HAS_VISION_ENC "clip.has_vision_encoder"
#define KEY_USE_GELU "clip.use_gelu"
#define KEY_USE_SILU "clip.use_silu"
#define KEY_N_EMBD "clip.vision.embedding_length"
#define KEY_N_FF "clip.vision.feed_forward_length"
#define KEY_N_BLOCK "clip.vision.block_count"
#define KEY_N_HEAD "clip.vision.attention.head_count"
#define KEY_LAYER_NORM_EPS "clip.vision.attention.layer_norm_epsilon"
#define KEY_PROJ_DIM "clip.vision.projection_dim"
#define KEY_N_EMBD "clip.%s.embedding_length"
#define KEY_N_FF "clip.%s.feed_forward_length"
#define KEY_N_BLOCK "clip.%s.block_count"
#define KEY_PROJ_DIM "clip.%s.projection_dim"
#define KEY_N_HEAD "clip.%s.attention.head_count"
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
// vision-specific
#define KEY_IMAGE_SIZE "clip.vision.image_size"
#define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
@@ -39,13 +43,18 @@
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
// audio-specific
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
//
// tensor name constants
//
#define TN_POS_EMBD "v.position_embd.weight"
#define TN_POS_EMBD "%s.position_embd.weight"
#define TN_CLASS_EMBD "v.class_embd"
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
@@ -95,6 +104,12 @@
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
// ultravox
#define TN_CONV1D "a.conv1d.%d.%s"
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
#define TN_MM_NORM_MID "mm.a.norm_mid.%s"
// align x to upper multiple of n
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
@@ -110,6 +125,7 @@ enum projector_type {
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
PROJECTOR_TYPE_QWEN25VL,
PROJECTOR_TYPE_ULTRAVOX,
PROJECTOR_TYPE_INTERNVL,
PROJECTOR_TYPE_LLAMA4,
PROJECTOR_TYPE_UNKNOWN,
@@ -126,6 +142,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
};
@@ -147,8 +164,10 @@ struct clip_image_u8 {
std::vector<uint8_t> buf;
};
// RGB float32 image (NHWC)
// Memory layout: RGBRGBRGB...
// For images, buf.size() == nx*ny*3
// Memory layout: RGBRGBRGB...
// For audio, only one channel is used, buf.size() == nx*ny
// nx will be n_frames and ny will be n_mel
struct clip_image_f32 {
int nx;
int ny;
@@ -242,6 +261,7 @@ struct clip_image_u8_batch {
struct clip_image_f32_batch {
std::vector<clip_image_f32_ptr> entries;
bool is_audio = false;
// for llava-uhd style models, we need to know the grid size
// note: entries.size() == grid_x * grid_y + 1 (one overview image)
@@ -249,7 +269,12 @@ struct clip_image_f32_batch {
int grid_y = 0;
clip_image_f32_batch clone() const {
clip_image_f32_batch new_batch;
clip_image_f32_batch new_batch{
/* entries */ {},
/* is_audio */ is_audio,
/* grid_x */ grid_x,
/* grid_y */ grid_y,
};
new_batch.entries.reserve(entries.size());
for (const auto & entry : entries) {
new_batch.entries.emplace_back(new clip_image_f32(*entry));

View File

@@ -35,6 +35,7 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac
enum ffn_op_type {
FFN_GELU,
FFN_GELU_ERF,
FFN_SILU,
FFN_GELU_QUICK,
};
@@ -165,6 +166,9 @@ enum patch_merge_type {
};
struct clip_hparams {
bool has_vision = false;
bool has_audio = false;
int32_t image_size;
int32_t patch_size;
int32_t n_embd;
@@ -191,6 +195,10 @@ struct clip_hparams {
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
int32_t spatial_merge_size = 0;
// audio
int32_t n_mel_bins = 0; // whisper preprocessor
int32_t proj_stack_factor = 0; // ultravox
};
struct clip_layer {
@@ -332,6 +340,14 @@ struct clip_vision_model {
// pixtral
ggml_tensor * token_embd_img_break = nullptr;
ggml_tensor * mm_patch_merger_w = nullptr;
// ultravox / whisper encoder
ggml_tensor * conv1d_1_w = nullptr;
ggml_tensor * conv1d_1_b = nullptr;
ggml_tensor * conv1d_2_w = nullptr;
ggml_tensor * conv1d_2_b = nullptr;
ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr;
};
struct clip_ctx {
@@ -1408,6 +1424,104 @@ struct clip_graph {
return gf;
}
// whisper encoder with custom projector
ggml_cgraph * build_whisper_enc() {
const int n_frames = img.nx;
const int n_pos = n_frames / 2;
GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
ggml_tensor * inp = build_inp_raw(1);
// conv1d block
{
// convolution + gelu
ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1);
cur = ggml_add(ctx0, cur, model.conv1d_1_b);
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1);
cur = ggml_add(ctx0, cur, model.conv1d_2_b);
cur = ggml_gelu_erf(ctx0, cur);
// transpose
inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
cb(inp, "after_conv1d", -1);
}
// sanity check (only check one layer, but it should be the same for all)
GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b);
GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b);
GGML_ASSERT(model.layers[0].q_b);
GGML_ASSERT(model.layers[0].v_b);
GGML_ASSERT(!model.layers[0].k_b); // no bias for k
GGML_ASSERT(model.post_ln_w && model.post_ln_b);
ggml_tensor * pos_embd_selected = ggml_view_2d(
ctx0, model.position_embeddings,
model.position_embeddings->ne[0], n_pos,
model.position_embeddings->nb[1], 0
);
ggml_tensor * cur = build_vit(
inp, n_pos,
NORM_TYPE_NORMAL,
hparams.ffn_op,
pos_embd_selected,
nullptr);
cb(cur, "after_transformer", -1);
// StackAudioFrames
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
{
int64_t stride = n_embd * hparams.proj_stack_factor;
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
int64_t pad = padded_len - ggml_nelements(cur);
if (pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
ggml_row_size(cur->type, stride), 0);
}
cb(cur, "after_stacked", -1);
// UltravoxProjector
{
// pre-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
// ffn in
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
// swiglu
{
int64_t split_point = cur->ne[0] / 2;
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
x1 = ggml_silu(ctx0, x1);
cur = ggml_mul(ctx0, x0, x1);
}
// mid-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
// ffn out
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
}
cb(cur, "projected", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
private:
//
// utility functions
@@ -1562,8 +1676,8 @@ private:
return inp;
}
ggml_tensor * build_inp_raw() {
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3);
ggml_tensor * build_inp_raw(int channels = 3) {
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
return inp_raw;
@@ -1641,6 +1755,11 @@ private:
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_gelu", il);
} break;
case FFN_GELU_ERF:
{
cur = ggml_gelu_erf(ctx0, cur);
cb(cur, "ggml_gelu_erf", il);
} break;
case FFN_GELU_QUICK:
{
cur = ggml_gelu_quick(ctx0, cur);
@@ -1832,6 +1951,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_llama4();
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
res = graph.build_whisper_enc();
} break;
default:
{
res = graph.build_llava();
@@ -1915,18 +2038,30 @@ struct clip_model_loader {
// other hparams
{
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
get_bool(KEY_HAS_AUDIO_ENC, hparams.has_audio, false);
get_bool(KEY_HAS_VISION_ENC, hparams.has_vision, false);
get_u32(KEY_N_EMBD, hparams.n_embd);
get_u32(KEY_N_HEAD, hparams.n_head);
get_u32(KEY_N_FF, hparams.n_ff);
get_u32(KEY_N_BLOCK, hparams.n_layer);
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
const char * prefix = hparams.has_vision ? "vision" : "audio";
get_u32(string_format(KEY_N_EMBD, prefix), hparams.n_embd);
get_u32(string_format(KEY_N_HEAD, prefix), hparams.n_head);
get_u32(string_format(KEY_N_FF, prefix), hparams.n_ff);
get_u32(string_format(KEY_N_BLOCK, prefix), hparams.n_layer);
get_u32(string_format(KEY_PROJ_DIM, prefix), hparams.projection_dim);
get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps);
if (hparams.has_vision) {
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
} else if (hparams.has_audio) {
get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
} else {
throw std::runtime_error(string_format("%s: neither vision nor audio encoder is present\n", __func__));
}
// default warmup value
hparams.warmup_image_size = hparams.image_size;
@@ -1964,7 +2099,7 @@ struct clip_model_loader {
}
}
{
if (hparams.has_vision) {
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
@@ -2050,30 +2185,43 @@ struct clip_model_loader {
isize, isize*3, // 336, 1008
};
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor);
if (hparams.n_mel_bins != 128) {
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
}
hparams.ffn_op = FFN_GELU_ERF;
log_ffn_op = "gelu_erf"; // temporary solution for logging
} break;
default:
break;
}
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
LOG_INF("%s: has_vision_encoder: %d\n", __func__, hparams.has_vision);
LOG_INF("%s: has_audio_encoder: %d\n", __func__, hparams.has_audio);
LOG_INF("%s: n_embd: %d\n", __func__, hparams.n_embd);
LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head);
LOG_INF("%s: n_ff: %d\n", __func__, hparams.n_ff);
LOG_INF("%s: n_layer: %d\n", __func__, hparams.n_layer);
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
LOG_INF("\n");
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str());
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
LOG_INF("\n");
if (hparams.has_vision) {
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
} else if (hparams.has_audio) {
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor);
}
LOG_INF("\n");
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) {
LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
}
}
}
@@ -2082,6 +2230,9 @@ struct clip_model_loader {
std::map<std::string, size_t> tensor_offset;
std::vector<ggml_tensor *> tensors_to_load;
// TODO @ngxson : support both audio and video in the future
const char * prefix = hparams.has_audio ? "a" : "v";
// get offsets
for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
@@ -2119,47 +2270,47 @@ struct clip_model_loader {
vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, "v", "weight"), false);
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, "v", "bias"), false);
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"), false);
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, "v", "weight"), false);
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, "v", "bias"), false);
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"), false);
vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false);
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
// layers
vision_model.layers.resize(hparams.n_layer);
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = vision_model.layers[il];
layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight"));
layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight"));
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, "v", il, "weight"), false);
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, "v", il, "weight"), false);
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias
layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"));
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"));
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, prefix, il, "weight"), false);
layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias
layer.ls_2_w = get_tensor(string_format(TN_LS_2, prefix, il, "weight"), false); // no bias
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false);
// ffn
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, prefix, il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, prefix, il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
@@ -2301,6 +2452,17 @@ struct clip_model_loader {
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
vision_model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
vision_model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
} break;
case PROJECTOR_TYPE_INTERNVL:
{
vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
@@ -2358,13 +2520,19 @@ struct clip_model_loader {
}
void alloc_compute_meta() {
const auto & hparams = ctx_clip.vision_model.hparams;
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
// create a fake batch
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
img->nx = ctx_clip.vision_model.hparams.warmup_image_size;
img->ny = ctx_clip.vision_model.hparams.warmup_image_size;
if (hparams.has_vision) {
img->nx = hparams.warmup_image_size;
img->ny = hparams.warmup_image_size;
} else {
img->nx = 1024; // TODO @ngxson : use a better default
img->ny = hparams.n_mel_bins;
}
img->buf.resize(img->nx * img->ny * 3);
batch.entries.push_back(std::move(img));
@@ -3278,6 +3446,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
} else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
n_patches /= (scale_factor * scale_factor);
} else if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
n_patches = n_len / proj_stack_factor / 2;
}
return n_patches;
@@ -3435,7 +3607,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
};
// set input pixel values
{
if (!imgs.is_audio) {
size_t nelem = 0;
for (const auto & img : imgs.entries) {
nelem += img->nx * img->ny * 3;
@@ -3472,6 +3644,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
}
set_input_f32("inp_raw", inp_raw);
} else {
// audio input
GGML_ASSERT(imgs.entries.size() == 1);
const auto & mel_inp = imgs.entries[0];
const int n_step = mel_inp->nx;
const int n_mel = mel_inp->ny;
std::vector<float> inp_raw(n_step * n_mel);
std::memcpy(inp_raw.data(), mel_inp->buf.data(), n_step * n_mel * sizeof(float));
set_input_f32("inp_raw", inp_raw);
}
// set input per projector
@@ -3668,6 +3850,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_ULTRAVOX:
{
// do nothing
} break;
@@ -3766,6 +3949,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3:
return ctx->vision_model.projection->ne[1];
case PROJECTOR_TYPE_ULTRAVOX:
return ctx->vision_model.mm_2_w->ne[1];
case PROJECTOR_TYPE_INTERNVL:
return ctx->vision_model.mm_3_w->ne[1];
case PROJECTOR_TYPE_LLAMA4:
@@ -3798,6 +3983,14 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
}
bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.has_vision;
}
bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.has_audio;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;
clip_img.buf.resize(h * w * 3);
@@ -3818,3 +4011,14 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
return ctx->proj_type;
}
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) {
clip_image_f32 * audio = new clip_image_f32;
audio->nx = n_frames;
audio->ny = n_mel;
audio->buf.resize(n_frames * n_mel);
std::memcpy(audio->buf.data(), mel, n_frames * n_mel * sizeof(float));
batch->entries.push_back(clip_image_f32_ptr(audio));
batch->is_audio = true;
}

View File

@@ -93,3 +93,9 @@ bool clip_is_llava(const struct clip_ctx * ctx);
bool clip_is_gemma3(const struct clip_ctx * ctx);
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
// use by audio input
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel);
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
bool clip_has_audio_encoder(const struct clip_ctx * ctx);

93468
tools/mtmd/miniaudio.h Normal file

File diff suppressed because it is too large Load Diff

855
tools/mtmd/mtmd-audio.cpp Normal file
View File

@@ -0,0 +1,855 @@
// fix problem with std::min and std::max
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#endif
#include "mtmd-audio.h"
//#define MTMD_AUDIO_DEBUG
#define MINIAUDIO_IMPLEMENTATION
#ifndef MTMD_AUDIO_DEBUG
# define MA_NO_ENCODING
#endif
#define MA_NO_DEVICE_IO
#define MA_NO_RESOURCE_MANAGER
#define MA_NO_NODE_GRAPH
#define MA_NO_ENGINE
#define MA_NO_GENERATION
#define MA_API static
#include "miniaudio.h"
#define _USE_MATH_DEFINES // for M_PI
#include <cmath>
#include <cstdint>
#include <cstring>
#include <thread>
#include <vector>
#include <fstream>
#include <algorithm>
// most of the code here is copied from whisper.cpp
// align x to upper multiple of n
#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
namespace whisper_preprocessor {
#define SIN_COS_N_COUNT WHISPER_N_FFT
namespace {
struct whisper_global_cache {
// In FFT, we frequently use sine and cosine operations with the same values.
// We can use precalculated values to speed up the process.
float sin_vals[SIN_COS_N_COUNT];
float cos_vals[SIN_COS_N_COUNT];
// Hann window (Use cosf to eliminate difference)
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
float hann_window[WHISPER_N_FFT];
whisper_global_cache() {
fill_sin_cos_table();
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
}
void fill_sin_cos_table() {
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
}
void fill_hann_window(int length, bool periodic, float * output) {
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
}
}
} global_cache;
}
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
static void dft(const float* in, int N, float* out) {
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N; k++) {
float re = 0;
float im = 0;
for (int n = 0; n < N; n++) {
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
}
out[k*2 + 0] = re;
out[k*2 + 1] = im;
}
}
// Cooley-Tukey FFT
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
static void fft(float* in, int N, float* out) {
if (N == 1) {
out[0] = in[0];
out[1] = 0;
return;
}
const int half_N = N / 2;
if (N - half_N*2 == 1) {
dft(in, N, out);
return;
}
float* even = in + N;
for (int i = 0; i < half_N; ++i) {
even[i]= in[2*i];
}
float* even_fft = out + 2 * N;
fft(even, half_N, even_fft);
float* odd = even;
for (int i = 0; i < half_N; ++i) {
odd[i] = in[2*i + 1];
}
float* odd_fft = even_fft + N;
fft(odd, half_N, odd_fft);
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < half_N; k++) {
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = global_cache.cos_vals[idx]; // cos(t)
float im = -global_cache.sin_vals[idx]; // sin(t)
float re_odd = odd_fft[2*k + 0];
float im_odd = odd_fft[2*k + 1];
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
}
}
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
int n_samples, int frame_size, int frame_step, int n_threads,
const whisper_filters & filters, whisper_mel & mel) {
std::vector<float> fft_in(frame_size * 2, 0.0);
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
int n_fft = filters.n_fft;
int i = ith;
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
WHISPER_ASSERT(n_fft == 1 + (frame_size / 2));
// calculate FFT only when fft_in are not all zero
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
const int offset = i * frame_step;
// apply Hann window (~10% faster)
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
fft_in[j] = hann[j] * samples[offset + j];
}
// fill the rest with zeros
if (n_samples - offset < frame_size) {
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
}
// FFT
fft(fft_in.data(), frame_size, fft_out.data());
// Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
for (int j = 0; j < n_fft; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
}
// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;
// unroll loop (suggested by GH user @lunixbochs)
int k = 0;
for (k = 0; k < n_fft - 3; k += 4) {
sum +=
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
}
// handle n_fft remainder
for (; k < n_fft; k++) {
sum += fft_out[k] * filters.data[j * n_fft + k];
}
sum = log10(std::max(sum, 1e-10));
mel.data[j * mel.n_len + i] = sum;
}
}
// Otherwise fft_out are all zero
double sum = log10(1e-10);
for (; i < mel.n_len; i += n_threads) {
for (int j = 0; j < mel.n_mel; j++) {
mel.data[j * mel.n_len + i] = sum;
}
}
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
static bool log_mel_spectrogram(
const float * samples,
const int n_samples,
const int /*sample_rate*/,
const int frame_size,
const int frame_step,
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool debug,
whisper_mel & mel) {
//const int64_t t_start_us = ggml_time_us();
// Hann window
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
const float * hann = global_cache.hann_window;
// Calculate the length of padding
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
int64_t stage_2_pad = frame_size / 2;
// Initialize a vector and copy data from C array to it.
std::vector<float> samples_padded;
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
// reflective pad 200 samples at the beginning of audio
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
mel.n_mel = n_mel;
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
// Calculate number of frames + remove the last frame
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
// Calculate semi-padded sample length to ensure compatibility
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
mel.data.resize(mel.n_mel * mel.n_len);
{
std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
std::cref(filters), std::ref(mel));
}
// main thread
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join();
}
}
// clamping and normalization
double mmax = -1e20;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] > mmax) {
mmax = mel.data[i];
}
}
mmax -= 8.0;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
// Dump log_mel_spectrogram
if (debug) {
std::ofstream outFile("log_mel_spectrogram.json");
outFile << "[";
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
outFile << mel.data[i] << ", ";
}
outFile << mel.data[mel.data.size() - 1] << "]";
outFile.close();
}
return true;
}
bool preprocess_audio(
const float * samples,
size_t n_samples,
const whisper_filters & filters,
std::vector<whisper_mel> & output) {
if (n_samples == 0) {
// empty audio
return false;
}
whisper_mel out_full;
bool ok = log_mel_spectrogram(
samples,
n_samples,
COMMON_SAMPLE_RATE,
WHISPER_N_FFT,
WHISPER_HOP_LENGTH,
filters.n_mel,
4, // n_threads
filters,
false, // debug
out_full);
if (!ok) {
return false;
}
// because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
// we always expect the mel to have 3000 silent frames at the end
// printf("n_len %d\n", out_full.n_len);
const size_t frames_per_chunk = 3000;
GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
if ((size_t)n_len < frames_per_chunk) {
break; // last uncomplete chunk will always be a padded chunk, safe to ignore
}
whisper_mel out_chunk;
out_chunk.n_len = n_len;
out_chunk.n_mel = out_full.n_mel;
out_chunk.n_len_org = out_full.n_mel; // unused
out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
for (int i = 0; i < out_full.n_mel; i++) {
auto src = out_full.data.begin() + i*out_full.n_len + off;
out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
}
output.push_back(std::move(out_chunk));
}
return true;
}
} // namespace whisper_preprocessor
namespace audio_helpers {
bool is_audio_file(const char * buf, size_t len) {
if (len < 12) {
return false;
}
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
bool is_wav = memcmp(buf, "RIFF", 4) == 0 && memcmp(buf + 8, "WAVE", 4) == 0;
bool is_mp3 = len >= 3 && (
memcmp(buf, "ID3", 3) == 0 ||
// Check for MPEG sync word (simplified check)
((unsigned char)buf[0] == 0xFF && ((unsigned char)buf[1] & 0xE0) == 0xE0)
);
bool is_flac = memcmp(buf, "fLaC", 4) == 0;
return is_wav || is_mp3 || is_flac;
}
// returns true if the buffer is a valid audio file
bool decode_audio_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector<float> & pcmf32_mono) {
ma_result result;
const int channels = 1;
ma_decoder_config decoder_config = ma_decoder_config_init(ma_format_f32, channels, target_sampler_rate);
ma_decoder decoder;
result = ma_decoder_init_memory(buf_in, len, &decoder_config, &decoder);
if (result != MA_SUCCESS) {
return false;
}
ma_uint64 frame_count;
ma_uint64 frames_read;
result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count);
if (result != MA_SUCCESS) {
ma_decoder_uninit(&decoder);
return false;
}
pcmf32_mono.resize(frame_count);
result = ma_decoder_read_pcm_frames(&decoder, pcmf32_mono.data(), frame_count, &frames_read);
if (result != MA_SUCCESS) {
ma_decoder_uninit(&decoder);
return false;
}
#ifdef MTMD_AUDIO_DEBUG
// save audio to wav file
ma_encoder_config config = ma_encoder_config_init(ma_encoding_format_wav, ma_format_f32, 1, target_sampler_rate);
ma_encoder encoder;
ma_encoder_init_file("output.wav", &config, &encoder);
ma_encoder_write_pcm_frames(&encoder, pcmf32_mono.data(), pcmf32_mono.size(), &frames_read);
ma_encoder_uninit(&encoder);
#endif
ma_decoder_uninit(&decoder);
return true;
}
} // namespace wav_utils
// precalculated mel filter banks
// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function
//
// generated from python code:
//
// from numpy import load
// data = load('mel_filters.npz')
// lst = data.files
// for item in lst:
// print(item)
// print(data[item].shape)
// n_mel = data[item].shape[0]
// n_fft = data[item].shape[1]
// for i, row in enumerate(data[item]):
// for j, val in enumerate(row):
// val = val * 1000.0
// if val != 0:
// print(f"data[{i*n_fft + j}] = {val:.6f};")
namespace whisper_precalc_filters {
whisper_preprocessor::whisper_filters get_128_bins() {
whisper_preprocessor::whisper_filters filters;
filters.n_mel = 128;
filters.n_fft = 201;
std::vector data(filters.n_mel * filters.n_fft, 0.0f);
data[1] = 12.37398665;
data[202] = 30.39256483;
data[404] = 24.74797331;
data[605] = 18.01857911;
data[807] = 37.12195903;
data[1008] = 5.64459199;
data[1009] = 6.72939420;
data[1210] = 36.03715822;
data[1412] = 19.10337992;
data[1613] = 23.66316877;
data[1815] = 31.47736564;
data[2016] = 11.28918398;
data[2017] = 1.08480197;
data[2218] = 41.68175161;
data[2420] = 13.45878839;
data[2621] = 29.30776216;
data[2823] = 25.83277412;
data[3024] = 16.93377644;
data[3226] = 38.20675984;
data[3427] = 4.55979025;
data[3428] = 7.81419594;
data[3629] = 34.95235741;
data[3831] = 20.18818259;
data[4032] = 22.57836796;
data[4234] = 32.56217018;
data[4435] = 10.20438317;
data[4436] = 2.16960395;
data[4637] = 40.59694707;
data[4839] = 14.54358920;
data[5040] = 28.22295949;
data[5242] = 26.91757679;
data[5443] = 15.84897563;
data[5645] = 39.29156065;
data[5846] = 3.47498828;
data[5847] = 8.89899861;
data[6048] = 33.86755288;
data[6250] = 21.27298526;
data[6451] = 21.49356715;
data[6653] = 33.64697099;
data[6854] = 9.11958050;
data[6855] = 3.25440569;
data[7056] = 39.51214626;
data[7258] = 15.62839188;
data[7459] = 27.13815868;
data[7661] = 28.00237760;
data[7862] = 14.76417296;
data[8064] = 40.37636518;
data[8265] = 2.38068704;
data[8266] = 10.20263787;
data[8467] = 31.61146119;
data[8669] = 24.54700135;
data[8870] = 15.32919332;
data[8871] = 1.66583748;
data[9072] = 36.72905266;
data[9274] = 20.09709924;
data[9475] = 16.93102531;
data[9476] = 2.90265540;
data[9677] = 32.84499049;
data[9879] = 23.52004871;
data[10080] = 11.03894413;
data[10081] = 10.72582975;
data[10282] = 22.71829173;
data[10484] = 32.27872774;
data[10685] = 0.11626833;
data[10686] = 22.85348251;
data[10887] = 8.56344029;
data[10888] = 14.97978810;
data[11089] = 15.51398356;
data[11090] = 8.51490628;
data[11291] = 21.10680379;
data[11292] = 3.32652032;
data[11493] = 25.47064796;
data[11695] = 27.35907957;
data[11896] = 0.65853616;
data[11897] = 23.83812517;
data[12098] = 3.44359246;
data[12099] = 21.22455277;
data[12300] = 5.35842171;
data[12301] = 19.42555793;
data[12502] = 6.49324711;
data[12503] = 18.35542172;
data[12704] = 6.93138083;
data[12705] = 17.93504693;
data[12906] = 6.74968259;
data[12907] = 18.09151843;
data[13108] = 6.01899112;
data[13109] = 18.75767298;
data[13310] = 4.80452832;
data[13311] = 19.87172849;
data[13512] = 3.16627859;
data[13513] = 21.37690969;
data[13514] = 1.25317345;
data[13714] = 1.15934468;
data[13715] = 20.80361731;
data[13716] = 4.04486805;
data[13917] = 17.55363122;
data[13918] = 7.08320038;
data[14119] = 14.07538634;
data[14120] = 10.32655034;
data[14321] = 10.40921453;
data[14322] = 13.73696327;
data[14523] = 6.59187697;
data[14524] = 17.27988198;
data[14525] = 1.46804214;
data[14725] = 2.65681883;
data[14726] = 18.09193194;
data[14727] = 5.85655728;
data[14928] = 13.34277913;
data[14929] = 10.28267574;
data[15130] = 8.56800377;
data[15131] = 14.72230814;
data[15132] = 1.04039861;
data[15332] = 3.79085587;
data[15333] = 17.14678481;
data[15334] = 6.11609267;
data[15535] = 11.75929047;
data[15536] = 11.13393717;
data[15737] = 6.43857848;
data[15738] = 16.07806236;
data[15739] = 4.23917221;
data[15939] = 1.19989377;
data[15940] = 12.75671553;
data[15941] = 9.65298992;
data[16142] = 7.06935255;
data[16143] = 14.94054683;
data[16144] = 4.19024844;
data[16344] = 1.51483389;
data[16345] = 12.00899947;
data[16346] = 9.84823331;
data[16547] = 6.10224018;
data[16548] = 15.33857174;
data[16549] = 5.57676842;
data[16749] = 0.36827257;
data[16750] = 9.89749376;
data[16751] = 11.35340426;
data[16752] = 2.05122307;
data[16952] = 3.89297144;
data[16953] = 12.97352277;
data[16954] = 8.06631614;
data[17155] = 6.74493238;
data[17156] = 13.85874674;
data[17157] = 5.41190524;
data[17357] = 0.74220158;
data[17358] = 8.98779090;
data[17359] = 11.37871388;
data[17360] = 3.32958088;
data[17560] = 2.82313535;
data[17561] = 10.68049297;
data[17562] = 9.43340641;
data[17563] = 1.76325557;
data[17763] = 4.39018616;
data[17764] = 11.87758986;
data[17765] = 7.97005836;
data[17766] = 0.66104700;
data[17966] = 5.49466675;
data[17967] = 12.62953598;
data[17968] = 6.93987962;
data[18169] = 6.18401915;
data[18170] = 12.93473132;
data[18171] = 6.29778765;
data[18371] = 0.02325210;
data[18372] = 6.50206627;
data[18373] = 12.32661773;
data[18374] = 6.00216538;
data[18574] = 0.31548753;
data[18575] = 6.48925547;
data[18576] = 12.04130240;
data[18577] = 6.01462880;
data[18777] = 0.29979556;
data[18778] = 6.18288014;
data[18779] = 12.04272825;
data[18780] = 6.29981188;
data[18781] = 0.55689598;
data[18980] = 0.01120471;
data[18981] = 5.61729167;
data[18982] = 11.22337859;
data[18983] = 6.82516303;
data[18984] = 1.35264499;
data[19184] = 4.82410006;
data[19185] = 10.16623247;
data[19186] = 7.56075513;
data[19187] = 2.34590308;
data[19387] = 3.83235747;
data[19388] = 8.92296247;
data[19389] = 8.47910438;
data[19390] = 3.50978645;
data[19590] = 2.66873185;
data[19591] = 7.51965167;
data[19592] = 9.55500547;
data[19593] = 4.81966138;
data[19594] = 0.08431751;
data[19793] = 1.35767367;
data[19794] = 5.98019501;
data[19795] = 10.60271543;
data[19796] = 6.25298498;
data[19797] = 1.74059917;
data[19997] = 4.32644226;
data[19998] = 8.73131864;
data[19999] = 7.78916525;
data[20000] = 3.48923868;
data[20200] = 2.57835095;
data[20201] = 6.77582854;
data[20202] = 9.40941647;
data[20203] = 5.31194592;
data[20204] = 1.21447595;
data[20403] = 0.75411191;
data[20404] = 4.75395704;
data[20405] = 8.75380263;
data[20406] = 7.19209015;
data[20407] = 3.28754401;
data[20607] = 2.68179690;
data[20608] = 6.49331464;
data[20609] = 9.11457930;
data[20610] = 5.39387390;
data[20611] = 1.67316827;
data[20810] = 0.57394296;
data[20811] = 4.20600036;
data[20812] = 7.83805829;
data[20813] = 7.52023002;
data[20814] = 3.97470826;
data[20815] = 0.42918732;
data[21014] = 1.90464477;
data[21015] = 5.36569161;
data[21016] = 8.82673822;
data[21017] = 6.27609482;
data[21018] = 2.89750961;
data[21218] = 2.89885257;
data[21219] = 6.19694078;
data[21220] = 8.56699049;
data[21221] = 5.34748193;
data[21222] = 2.12797290;
data[21421] = 0.44750227;
data[21422] = 3.59030394;
data[21423] = 6.73310598;
data[21424] = 7.77023612;
data[21425] = 4.70231380;
data[21426] = 1.63439126;
data[21625] = 1.01536023;
data[21626] = 4.01018746;
data[21627] = 7.00501446;
data[21628] = 7.23442994;
data[21629] = 4.31095669;
data[21630] = 1.38748321;
data[21829] = 1.33348850;
data[21830] = 4.18730825;
data[21831] = 7.04112789;
data[21832] = 6.93188375;
data[21833] = 4.14605811;
data[21834] = 1.36023236;
data[22033] = 1.42879714;
data[22034] = 4.14824858;
data[22035] = 6.86769979;
data[22036] = 6.83705276;
data[22037] = 4.18239459;
data[22038] = 1.52773573;
data[22237] = 1.32610439;
data[22238] = 3.91751388;
data[22239] = 6.50892360;
data[22240] = 6.92639686;
data[22241] = 4.39672917;
data[22242] = 1.86706171;
data[22441] = 1.04827771;
data[22442] = 3.51767405;
data[22443] = 5.98707050;
data[22444] = 7.17824046;
data[22445] = 4.76767914;
data[22446] = 2.35711760;
data[22645] = 0.61636406;
data[22646] = 2.96949223;
data[22647] = 5.32262027;
data[22648] = 7.57265091;
data[22649] = 5.27558755;
data[22650] = 2.97852419;
data[22651] = 0.68146095;
data[22849] = 0.04971400;
data[22850] = 2.29204819;
data[22851] = 4.53438237;
data[22852] = 6.77671656;
data[22853] = 5.90240723;
data[22854] = 3.71349836;
data[22855] = 1.52458926;
data[23054] = 1.50285335;
data[23055] = 3.63961048;
data[23056] = 5.77636715;
data[23057] = 6.63159089;
data[23058] = 4.54574358;
data[23059] = 2.45989650;
data[23060] = 0.37404924;
data[23258] = 0.61795861;
data[23259] = 2.65410915;
data[23260] = 4.69025923;
data[23261] = 6.72641024;
data[23262] = 5.46034705;
data[23263] = 3.47270933;
data[23264] = 1.48507138;
data[23463] = 1.59233576;
data[23464] = 3.53261665;
data[23465] = 5.47289755;
data[23466] = 6.44368259;
data[23467] = 4.54962999;
data[23468] = 2.65557761;
data[23469] = 0.76152512;
data[23667] = 0.46749352;
data[23668] = 2.31641904;
data[23669] = 4.16534441;
data[23670] = 6.01426978;
data[23671] = 5.67844696;
data[23672] = 3.87357362;
data[23673] = 2.06870004;
data[23674] = 0.26382666;
data[23872] = 1.05349103;
data[23873] = 2.81536230;
data[23874] = 4.57723346;
data[23875] = 6.33910485;
data[23876] = 5.12815686;
data[23877] = 3.40826320;
data[23878] = 1.68837002;
data[24077] = 1.43350090;
data[24078] = 3.11241671;
data[24079] = 4.79133241;
data[24080] = 6.40943693;
data[24081] = 4.77052201;
data[24082] = 3.13160778;
data[24083] = 1.49269309;
data[24281] = 0.02932359;
data[24282] = 1.62918994;
data[24283] = 3.22905602;
data[24284] = 4.82892245;
data[24285] = 6.14671456;
data[24286] = 4.58496623;
data[24287] = 3.02321767;
data[24288] = 1.46146910;
data[24486] = 0.13601698;
data[24487] = 1.66055572;
data[24488] = 3.18509457;
data[24489] = 4.70963307;
data[24490] = 6.04072399;
data[24491] = 4.55250870;
data[24492] = 3.06429295;
data[24493] = 1.57607743;
data[24494] = 0.08786193;
data[24691] = 0.09328097;
data[24692] = 1.54603878;
data[24693] = 2.99879676;
data[24694] = 4.45155473;
data[24695] = 5.90431225;
data[24696] = 4.65566106;
data[24697] = 3.23751615;
data[24698] = 1.81937125;
data[24699] = 0.40122634;
data[24897] = 1.30262633;
data[24898] = 2.68698297;
data[24899] = 4.07133950;
data[24900] = 5.45569602;
data[24901] = 4.87832492;
data[24902] = 3.52695142;
data[24903] = 2.17557792;
data[24904] = 0.82420459;
data[25102] = 0.94595028;
data[25103] = 2.26512621;
data[25104] = 3.58430226;
data[25105] = 4.90347855;
data[25106] = 5.20569785;
data[25107] = 3.91795207;
data[25108] = 2.63020652;
data[25109] = 1.34246063;
data[25110] = 0.05471494;
data[25307] = 0.49037894;
data[25308] = 1.74744334;
data[25309] = 3.00450763;
data[25310] = 4.26157191;
data[25311] = 5.51863620;
data[25312] = 4.39707236;
data[25313] = 3.16995848;
data[25314] = 1.94284460;
data[25315] = 0.71573065;
data[25513] = 1.14698056;
data[25514] = 2.34485767;
data[25515] = 3.54273478;
data[25516] = 4.74061165;
data[25517] = 4.95198462;
data[25518] = 3.78264743;
data[25519] = 2.61331047;
data[25520] = 1.44397374;
data[25521] = 0.27463681;
data[25718] = 0.47569509;
data[25719] = 1.61717169;
data[25720] = 2.75864848;
data[25721] = 3.90012516;
data[25722] = 5.04160160;
data[25723] = 4.45712078;
data[25724] = 3.34284059;
data[25725] = 2.22856039;
data[25726] = 1.11428020;
for (auto & val : data) {
val /= 1000.0f;
}
filters.data = std::move(data);
return filters;
}
} // namespace whisper_precalc_filters

62
tools/mtmd/mtmd-audio.h Normal file
View File

@@ -0,0 +1,62 @@
#pragma once
#include "ggml.h"
#include <cstdint>
#include <vector>
#include <string>
#define WHISPER_ASSERT GGML_ASSERT
#define WHISPER_SAMPLE_RATE 16000
#define WHISPER_N_FFT 400
#define WHISPER_HOP_LENGTH 160
#define WHISPER_CHUNK_SIZE 30
#define COMMON_SAMPLE_RATE 16000
namespace whisper_preprocessor {
struct whisper_mel {
int n_len;
int n_len_org;
int n_mel;
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
extern bool preprocess_audio(
const float * samples,
size_t n_samples,
const whisper_filters & filters,
std::vector<whisper_mel> & output);
} // namespace whisper_preprocessor
// TODO @ngxson : move this helper to mtmd-helpers.cpp
namespace audio_helpers {
extern bool is_audio_file(const char * buf, size_t len);
extern bool decode_audio_from_buf(
const unsigned char * buf_in,
size_t len,
int target_sampler_rate,
std::vector<float> & pcmf32_mono);
} // namespace audio_helpers
namespace whisper_precalc_filters {
extern whisper_preprocessor::whisper_filters get_128_bins();
} // namespace whisper_precalc_filters

View File

@@ -37,10 +37,10 @@ static volatile bool g_is_interrupted = false;
static void show_additional_info(int /*argc*/, char ** argv) {
LOG(
"Experimental CLI for multimodal\n\n"
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> --audio <audio> -p <prompt>\n\n"
" -m and --mmproj are required\n"
" -hf user/repo can replace both -m and --mmproj in most cases\n"
" --image and -p are optional, if NOT provided, the CLI will run in chat mode\n"
" --image, --audio and -p are optional, if NOT provided, the CLI will run in chat mode\n"
" to disable using GPU for mmproj model, add --no-mmproj-offload\n",
argv[0]
);
@@ -142,7 +142,7 @@ struct mtmd_cli_context {
);
}
bool load_image(const std::string & fname) {
bool load_media(const std::string & fname) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(fname.c_str()));
if (!bmp.ptr) {
return false;
@@ -243,7 +243,7 @@ int main(int argc, char ** argv) {
common_params params;
params.sampling.temp = 0.2; // lower temp by default for better quality
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
return 1;
}
@@ -283,14 +283,14 @@ int main(int argc, char ** argv) {
if (is_single_turn) {
g_is_generating = true;
if (params.prompt.find("<__image__>") == std::string::npos) {
params.prompt += " <__image__>";
if (params.prompt.find(mtmd_default_marker()) == std::string::npos) {
params.prompt += mtmd_default_marker();
}
common_chat_msg msg;
msg.role = "user";
msg.content = params.prompt;
for (const auto & image : params.image) {
if (!ctx.load_image(image)) {
if (!ctx.load_media(image)) {
return 1; // error is already printed by libmtmd
}
}
@@ -303,7 +303,12 @@ int main(int argc, char ** argv) {
} else {
LOG("\n Running in chat mode, available commands:");
LOG("\n /image <path> load an image");
if (mtmd_support_vision(ctx.ctx_vision.get())) {
LOG("\n /image <path> load an image");
}
if (mtmd_support_audio(ctx.ctx_vision.get())) {
LOG("\n /audio <path> load an audio");
}
LOG("\n /clear clear the chat history");
LOG("\n /quit or /exit exit the program");
LOG("\n");
@@ -333,15 +338,17 @@ int main(int argc, char ** argv) {
continue;
}
g_is_generating = true;
if (line == "/image" || line.find("/image ") == 0) {
bool is_image = line == "/image" || line.find("/image ") == 0;
bool is_audio = line == "/audio" || line.find("/audio ") == 0;
if (is_image || is_audio) {
if (line.size() < 8) {
LOG_ERR("ERR: Missing image filename\n");
LOG_ERR("ERR: Missing media filename\n");
continue;
}
std::string image = line.substr(7);
if (ctx.load_image(image)) {
LOG("Image %s loaded\n", image.c_str());
content += "<__image__>";
std::string media_path = line.substr(7);
if (ctx.load_media(media_path)) {
LOG("%s %s loaded\n", media_path.c_str(), is_image ? "image" : "audio");
content += mtmd_default_marker();
}
// else, error is already printed by libmtmd
continue;

View File

@@ -12,17 +12,7 @@ size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
size_t n_tokens = 0;
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
auto chunk = mtmd_input_chunks_get(chunks, i);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens_text;
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
n_tokens += n_tokens_text;
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image);
} else {
GGML_ASSERT(false && "chunk type not supported");
}
n_tokens += mtmd_input_chunk_get_n_tokens(chunk);
}
return n_tokens;
}
@@ -31,17 +21,7 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
llama_pos n_pos = 0;
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
auto chunk = mtmd_input_chunks_get(chunks, i);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens_text;
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
n_pos += n_tokens_text;
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
n_pos += mtmd_image_tokens_get_n_pos(tokens_image);
} else {
GGML_ASSERT(false && "chunk type not supported");
}
n_pos += mtmd_input_chunk_get_n_pos(chunk);
}
return n_pos;
}
@@ -149,13 +129,10 @@ int32_t mtmd_helper_decode_image_chunk(
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past) {
if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
return -1;
}
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
if (!image_tokens) {
LOG_ERR("failed to decode image chunk: image tokens are null\n");
auto chunk_type = mtmd_input_chunk_get_type(chunk);
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
LOG_ERR("failed to decode chunk: input chunk not of image/audio type\n");
return -1;
}
@@ -163,15 +140,23 @@ int32_t mtmd_helper_decode_image_chunk(
int n_mmproj_embd = llama_model_n_embd(model);
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
int32_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
int32_t i_batch = 0;
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
const int nx = mtmd_image_tokens_get_nx(image_tokens);
const int ny = mtmd_image_tokens_get_ny(image_tokens);
if (mtmd_decode_use_mrope(ctx)) {
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
if (chunk_type != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
LOG_ERR("failed to decode chunk: M-RoPE only accepts image chunk\n");
return -1;
}
if (!image_tokens) {
LOG_ERR("failed to decode chunk: image tokens are null\n");
return -1;
}
const int nx = mtmd_image_tokens_get_nx(image_tokens);
const int ny = mtmd_image_tokens_get_ny(image_tokens);
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
} else {
batch_embd.set_position_normal(n_past, seq_id);
@@ -187,22 +172,22 @@ int32_t mtmd_helper_decode_image_chunk(
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
LOG_INF("decoding %s batch %d/%d, n_tokens_batch = %d\n", name, i_batch+1, n_img_batches, n_tokens_batch);
int64_t t1 = ggml_time_ms();
int32_t ret = llama_decode(lctx, batch_embd_view);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
LOG_ERR("failed to decode %s\n", name);
llama_set_causal_attn(lctx, true); // restore causal attn
return ret;
}
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1);
i_batch++;
}
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
n_past += mtmd_input_chunk_get_n_pos(chunk);
*new_n_past = n_past;
if (mtmd_decode_use_non_causal(ctx)) {
@@ -231,12 +216,14 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
while (i < n_tokens) { // split into batches
text_batch.n_tokens = 0; // clear the batch
for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) {
int32_t j = text_batch.n_tokens;
text_batch.token [j] = tokens[i];
text_batch.pos [j] = n_past++;
text_batch.n_seq_id[j] = 1;
text_batch.seq_id [j][0] = seq_id;
text_batch.logits [j] = false;
text_batch.n_tokens++;
text_batch.token [i] = tokens[i];
text_batch.pos [i] = n_past++;
text_batch.n_seq_id[i] = 1;
text_batch.seq_id [i][0] = seq_id;
text_batch.logits [i] = false;
}
bool is_last_token = (i == n_tokens);
if (logits_last && is_last_token) {
@@ -251,25 +238,25 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
*new_n_past += text_batch.n_tokens;
}
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
int64_t t0 = ggml_time_ms();
LOG_INF("encoding image or slice...\n");
LOG_INF("encoding %s slice...\n", name);
ret = mtmd_encode(ctx, image_tokens);
ret = mtmd_encode_chunk(ctx, chunk);
if (ret != 0) {
LOG_ERR("failed to encode image\n");
LOG_ERR("failed to encode %s slice\n", name);
llama_batch_free(text_batch);
return ret;
}
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
float * embd = mtmd_get_output_embd(ctx);
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
LOG_ERR("failed to decode %s\n", name);
llama_batch_free(text_batch);
return ret;
}

View File

@@ -1,6 +1,7 @@
#include "clip.h"
#include "clip-impl.h"
#include "mtmd.h"
#include "mtmd-audio.h"
#include "llama.h"
@@ -19,17 +20,49 @@ struct mtmd_bitmap {
uint32_t ny;
std::vector<unsigned char> data;
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
bool is_audio = false; // true if the bitmap is audio
};
struct mtmd_image_tokens_deleter {
void operator()(mtmd_image_tokens * val); // forward declaration
struct mtmd_image_tokens {
uint32_t nx; // number of tokens in x direction
uint32_t ny; // number of tokens in y direction
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
uint32_t n_tokens() const { return nx * ny; }
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
mtmd_image_tokens clone() {
return mtmd_image_tokens{
nx,
ny,
use_mrope_pos,
batch_f32.clone(),
id
};
}
};
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens>;
struct mtmd_audio_tokens {
uint32_t n_tokens; // number of tokens
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
mtmd_audio_tokens clone() {
return mtmd_audio_tokens{
n_tokens,
batch_f32.clone(),
id
};
}
};
using mtmd_audio_tokens_ptr = std::unique_ptr<mtmd_audio_tokens>;
struct mtmd_input_chunk {
mtmd_input_chunk_type type;
std::vector<llama_token> tokens_text;
mtmd_image_tokens_ptr tokens_image;
mtmd_audio_tokens_ptr tokens_audio;
};
struct mtmd_input_chunks {
@@ -46,6 +79,10 @@ enum mtmd_slice_tmpl {
// TODO @ngxson : add support for idefics (SmolVLM)
};
const char * mtmd_default_marker() {
return "<__media__>";
}
mtmd_context_params mtmd_context_params_default() {
mtmd_context_params params;
params.use_gpu = true;
@@ -53,6 +90,7 @@ mtmd_context_params mtmd_context_params_default() {
params.n_threads = 4;
params.verbosity = GGML_LOG_LEVEL_INFO;
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
params.media_marker = mtmd_default_marker();
return params;
}
@@ -63,7 +101,9 @@ struct mtmd_context {
bool print_timings;
int n_threads;
std::string image_marker;
std::string media_marker;
bool has_vision;
bool has_audio;
// for llava-uhd style models, we need special tokens in-between slices
// minicpmv calls them "slices", llama 4 calls them "tiles"
@@ -81,6 +121,9 @@ struct mtmd_context {
bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
// for whisper, we pre-calculate the mel filter bank
whisper_preprocessor::whisper_filters w_filters;
// TODO @ngxson : add timings
mtmd_context(const char * mmproj_fname,
@@ -89,8 +132,12 @@ struct mtmd_context {
text_model (text_model),
print_timings(ctx_params.print_timings),
n_threads (ctx_params.n_threads),
image_marker (ctx_params.image_marker)
media_marker (ctx_params.media_marker)
{
if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) {
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
}
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
@@ -99,7 +146,9 @@ struct mtmd_context {
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
}
use_mrope = clip_is_qwen2vl(ctx_clip);
has_vision = clip_has_vision_encoder(ctx_clip);
has_audio = clip_has_audio_encoder(ctx_clip);
use_mrope = clip_is_qwen2vl(ctx_clip);
projector_type proj = clip_get_projector_type(ctx_clip);
int minicpmv_version = clip_is_minicpmv(ctx_clip);
@@ -146,6 +195,21 @@ struct mtmd_context {
tok_row_end_trail = true; // add trailing end-of-row token
ov_img_first = false; // overview image is last
}
if (proj == PROJECTOR_TYPE_ULTRAVOX) {
// TODO @ngxson : check if model n_mel is 128 or 80
w_filters = whisper_precalc_filters::get_128_bins();
}
// warning messages
if (proj == PROJECTOR_TYPE_LLAMA4) {
LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
" https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
}
if (has_audio) {
LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
" https://github.com/ggml-org/llama.cpp/pull/13623\n", __func__);
}
}
~mtmd_context() {
@@ -179,29 +243,6 @@ private:
}
};
struct mtmd_image_tokens_data {
clip_image_f32_batch batch_f32; // preprocessed image patches
};
struct mtmd_image_tokens {
uint32_t nx; // number of tokens in x direction
uint32_t ny; // number of tokens in y direction
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
uint32_t n_tokens() const { return nx * ny; }
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
mtmd_image_tokens clone() {
return mtmd_image_tokens{
nx,
ny,
use_mrope_pos,
batch_f32.clone(),
id
};
}
};
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
const struct llama_model * text_model,
const struct mtmd_context_params ctx_params) {
@@ -247,59 +288,63 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
auto vocab = llama_model_get_vocab(ctx->text_model);
std::string prompt_modified(text->text);
std::string marker_modified(ctx->image_marker);
std::string marker_modified(ctx->media_marker);
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
// for compatibility, we convert image marker to media marker
string_replace_all(prompt_modified, MTMD_DEFAULT_IMAGE_MARKER, ctx->media_marker);
// a bit hacky here, but works for now
// for some models, we need to add prefix and suffix to the image embeddings
if (clip_is_gemma3(ctx->ctx_clip)) {
// gemma 3
// <start_of_image> ... (image embeddings) ... <end_of_image>
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<start_of_image>" + ctx->media_marker + "<end_of_image>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
marker_modified = "<fake_token_around_image><global-img>" + ctx->image_marker + "<fake_token_around_image>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<fake_token_around_image><global-img>" + ctx->media_marker + "<fake_token_around_image>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
marker_modified = ctx->image_marker + "[IMG_END]";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = ctx->media_marker + "[IMG_END]";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
marker_modified = "<|vision_start|>" + ctx->image_marker + "<|vision_end|>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<|vision_start|>" + ctx->media_marker + "<|vision_end|>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_LLAMA4) {
// (more details in mtmd_context constructor)
marker_modified = "<|image_start|>" + ctx->image_marker + "<|image_end|>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<|image_start|>" + ctx->media_marker + "<|image_end|>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
// <img> ... (image embeddings) ... </img>
marker_modified = "<img>" + ctx->image_marker + "</img>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<img>" + ctx->media_marker + "</img>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
}
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
// for glm-edge, BOI and EOI token's embeddings are not present in the text model
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->media_marker);
output->entries.clear();
output->entries.reserve(parts.size());
size_t i_img = 0;
size_t i_bm = 0;
// utility for adding raw tokens
auto add_text_chunk = [&output](std::vector<llama_token> && tokens) {
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_TEXT,
std::move(tokens),
{},
nullptr, // image tokens
nullptr, // audio tokens
};
output->entries.emplace_back(std::move(chunk));
};
@@ -317,8 +362,9 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
{}, // text tokens
std::move(image_tokens),
nullptr, // audio tokens
};
chunks.emplace_back(std::move(chunk));
}
@@ -336,24 +382,36 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_TEXT,
std::move(tokens),
{},
nullptr, // image tokens
nullptr, // audio tokens
};
output->entries.emplace_back(std::move(chunk));
if (&parts.back() != &part) {
// add image token to middle of 2 parts
// only add image/audio tokens to middle of 2 parts
// therefore, we skip handling image/audio if this is the last part
if (&parts.back() == &part) {
continue;
}
if (i_img >= n_bitmaps) {
if (!bitmaps[i_bm]->is_audio) {
// handle image
if (i_bm >= n_bitmaps) {
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
return 1;
}
if (!ctx->has_vision) {
LOG_ERR("%s: error: model does not support vision input\n", __func__);
return 2;
}
// convert mtmd_bitmap to clip_image_u8
clip_image_u8_ptr img_u8(clip_image_u8_init());
img_u8->nx = bitmaps[i_img]->nx;
img_u8->ny = bitmaps[i_img]->ny;
img_u8->buf.resize(bitmaps[i_img]->data.size());
std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3);
img_u8->nx = bitmaps[i_bm]->nx;
img_u8->ny = bitmaps[i_bm]->ny;
img_u8->buf.resize(bitmaps[i_bm]->data.size());
std::memcpy(img_u8->buf.data(), bitmaps[i_bm]->data.data(), img_u8->nx * img_u8->ny * 3);
// preprocess image
clip_image_f32_batch batch_f32;
@@ -370,7 +428,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
) {
// split batch into chunks of single images
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id);
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_bm]->id);
GGML_ASSERT(chunks.size() > 0);
auto ov_chunk = std::move(chunks.front());
@@ -446,7 +504,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
image_tokens->ny = 1;
}
image_tokens->batch_f32 = std::move(batch_f32);
image_tokens->id = bitmaps[i_img]->id; // optional
image_tokens->id = bitmaps[i_bm]->id; // optional
LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
@@ -454,23 +512,101 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
{}, // text tokens
std::move(image_tokens),
nullptr, // audio tokens
};
output->entries.emplace_back(std::move(chunk));
}
i_img++; // move to next image
i_bm++; // move to next image
continue;
} else {
// handle audio
if (i_bm >= n_bitmaps) {
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
return 1;
}
if (!ctx->has_audio) {
LOG_ERR("%s: error: model does not support audio input\n", __func__);
return 2;
}
if (bitmaps[i_bm]->data.size() == 0) {
LOG_ERR("%s: error: empty audio data\n", __func__);
return 2;
}
// preprocess audio
GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded
std::vector<whisper_preprocessor::whisper_mel> mel_spec_chunks;
const float * samples = (const float *)bitmaps[i_bm]->data.data();
size_t n_samples = bitmaps[i_bm]->data.size() / sizeof(float);
bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks);
if (!ok) {
LOG_ERR("Unable to preprocess audio\n");
return 2;
}
// consider each mel_spec as a separate audio chunk
// TODO: maybe support batching, but this may come with memory cost
for (auto & mel_spec : mel_spec_chunks) {
clip_image_f32_ptr mel_f32(clip_image_f32_init());
mel_f32->nx = mel_spec.n_len;
mel_f32->ny = mel_spec.n_mel;
mel_f32->buf = std::move(mel_spec.data);
size_t n_tokens = clip_n_output_tokens(ctx->ctx_clip, mel_f32.get());
clip_image_f32_batch batch_f32;
batch_f32.is_audio = true;
batch_f32.entries.push_back(std::move(mel_f32));
mtmd_audio_tokens_ptr audio_tokens(new mtmd_audio_tokens);
audio_tokens->n_tokens = n_tokens;
audio_tokens->batch_f32 = std::move(batch_f32);
audio_tokens->id = bitmaps[i_bm]->id; // optional
LOG_DBG("audio_tokens->n_tokens = %d\n", audio_tokens->n_tokens);
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_AUDIO,
{}, // text tokens
nullptr, // image tokens
std::move(audio_tokens),
};
output->entries.emplace_back(std::move(chunk));
}
i_bm++;
continue;
}
}
return 0;
}
static void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
if (image_tokens) {
delete image_tokens;
int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
LOG_WRN("mtmd_encode_chunk has no effect for text chunks\n");
return 0;
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return mtmd_encode(ctx, chunk->tokens_image.get());
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
ctx->image_embd_v.resize(chunk->tokens_audio->n_tokens * n_mmproj_embd);
bool ok = clip_image_batch_encode(
ctx->ctx_clip,
ctx->n_threads,
&chunk->tokens_audio->batch_f32,
ctx->image_embd_v.data());
return ok ? 0 : 1;
}
LOG_ERR("mtmd_encode_chunk: unknown chunk type %d\n", (int)chunk->type);
return 1;
}
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
@@ -516,8 +652,12 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) {
return ctx->use_mrope;
}
void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
mtmd_image_tokens_free(val);
bool mtmd_support_vision(mtmd_context * ctx) {
return ctx->has_vision;
}
bool mtmd_support_audio(mtmd_context * ctx) {
return ctx->has_audio;
}
// these 2 helpers below use internal clip_image_u8_ptr,
@@ -526,6 +666,15 @@ void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
// whichever library they want, and then use mtmd_bitmap_init() to create bitmap
mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) {
if (audio_helpers::is_audio_file((const char *)buf, len)) {
std::vector<float> pcmf32;
if (!audio_helpers::decode_audio_from_buf(buf, len, COMMON_SAMPLE_RATE, pcmf32)) {
LOG_ERR("Unable to read WAV audio file from buffer\n");
return nullptr;
}
return mtmd_bitmap_init_from_audio(pcmf32.size(), pcmf32.data());
}
clip_image_u8_ptr img_u8(clip_image_u8_init());
bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
if (!ok) {
@@ -538,15 +687,26 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t
}
mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname) {
clip_image_u8_ptr img_u8(clip_image_u8_init());
bool ok = clip_image_load_from_file(fname, img_u8.get());
if (!ok) {
LOG_ERR("Unable to load image %s\n", fname);
std::vector<unsigned char> buf;
FILE * f = fopen(fname, "rb");
if (!f) {
LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno));
return nullptr;
}
uint32_t nx, ny;
unsigned char * data = clip_image_u8_get_data(img_u8.get(), &nx, &ny);
return mtmd_bitmap_init(nx, ny, data);
fseek(f, 0, SEEK_END);
long file_size = ftell(f);
fseek(f, 0, SEEK_SET);
buf.resize(file_size);
size_t n_read = fread(buf.data(), 1, file_size, f);
fclose(f);
if (n_read != (size_t)file_size) {
LOG_ERR("Failed to read entire file %s", fname);
return nullptr;
}
return mtmd_helper_bitmap_init_from_buf(buf.data(), buf.size());
}
//
@@ -567,6 +727,18 @@ mtmd_bitmap * mtmd_bitmap_init(uint32_t nx,
return bitmap;
}
mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples,
const float * data) {
mtmd_bitmap * bitmap = new mtmd_bitmap;
bitmap->nx = n_samples;
bitmap->ny = 1;
bitmap->is_audio = true;
size_t data_size = n_samples * sizeof(float);
bitmap->data.resize(data_size);
std::memcpy(bitmap->data.data(), data, data_size);
return bitmap;
}
uint32_t mtmd_bitmap_get_nx(const mtmd_bitmap * bitmap) {
return bitmap->nx;
}
@@ -579,6 +751,14 @@ const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
return bitmap->data.data();
}
size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap) {
return bitmap->data.size();
}
bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap) {
return bitmap->is_audio;
}
const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap) {
return bitmap->id.c_str();
}
@@ -642,17 +822,56 @@ const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chu
return nullptr;
}
size_t mtmd_input_chunk_get_n_tokens(const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
return chunk->tokens_text.size();
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return mtmd_image_tokens_get_n_tokens(chunk->tokens_image.get());
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
return chunk->tokens_audio->n_tokens;
} else {
GGML_ABORT("invalid chunk type");
}
}
llama_pos mtmd_input_chunk_get_n_pos(const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
return chunk->tokens_text.size();
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return mtmd_image_tokens_get_n_pos(chunk->tokens_image.get());
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
return chunk->tokens_audio->n_tokens;
} else {
GGML_ABORT("invalid chunk type");
}
}
const char * mtmd_input_chunk_get_id(const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return chunk->tokens_image->id.c_str();
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
return chunk->tokens_audio->id.c_str();
}
return nullptr;
}
mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk) {
mtmd_input_chunk * copy = new mtmd_input_chunk{
chunk->type,
chunk->tokens_text,
mtmd_image_tokens_ptr(),
nullptr,
nullptr,
};
if (chunk->tokens_image) {
// copy the image tokens
copy->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens());
*copy->tokens_image = chunk->tokens_image->clone();
}
if (chunk->tokens_audio) {
// copy the audio tokens
copy->tokens_audio = mtmd_audio_tokens_ptr(new mtmd_audio_tokens());
*copy->tokens_audio = chunk->tokens_audio->clone();
}
return copy;
}
@@ -700,7 +919,8 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
mtmd_input_chunk chunk_text{
MTMD_INPUT_CHUNK_TYPE_TEXT,
std::move(tokens_text),
{},
nullptr, // image tokens
nullptr, // audio tokens
};
chunks->entries.emplace_back(std::move(chunk_text));
@@ -712,8 +932,9 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
image_tokens->id = "image_1";
mtmd_input_chunk chunk_image{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
{}, // text tokens
std::move(image_tokens),
nullptr, // audio tokens
};
chunks->entries.emplace_back(std::move(chunk_image));

View File

@@ -39,6 +39,7 @@
# define MTMD_API
#endif
// deprecated marker, use mtmd_default_marker() instead
#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>"
#ifdef __cplusplus
@@ -48,6 +49,7 @@ extern "C" {
enum mtmd_input_chunk_type {
MTMD_INPUT_CHUNK_TYPE_TEXT,
MTMD_INPUT_CHUNK_TYPE_IMAGE,
MTMD_INPUT_CHUNK_TYPE_AUDIO,
};
// opaque types
@@ -79,9 +81,12 @@ struct mtmd_context_params {
bool print_timings;
int n_threads;
enum ggml_log_level verbosity;
const char * image_marker;
const char * image_marker; // deprecated, use media_marker instead
const char * media_marker;
};
MTMD_API const char * mtmd_default_marker(void);
MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
// initialize the mtmd context
@@ -98,18 +103,28 @@ MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
// whether the current model use M-RoPE for llama_decode
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
// whether the current model supports vision input
MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
// whether the current model supports audio input
MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
// mtmd_bitmap
//
// length of data must be nx * ny * 3
// the data is in RGBRGBRGB... format
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx,
uint32_t ny,
const unsigned char * data);
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap);
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
// if bitmap is image:
// length of data must be nx * ny * 3
// the data is in RGBRGBRGB... format
// if bitmap is audio:
// length of data must be n_samples * sizeof(float)
// the data is in float format (PCM F32)
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data);
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data);
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
MTMD_API const unsigned char * mtmd_bitmap_get_data (const mtmd_bitmap * bitmap);
MTMD_API size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap);
MTMD_API bool mtmd_bitmap_is_audio (const mtmd_bitmap * bitmap);
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
// bitmap ID is optional, but useful for KV cache tracking
// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap);
@@ -132,6 +147,11 @@ MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chu
MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk);
MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output);
MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk);
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
// returns nullptr for ID on text chunk
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
// you can move the chunk ownership to your own code by copying it
@@ -144,27 +164,28 @@ MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk);
//
// the instance will be constructed via mtmd_tokenize()
// it will be freed along with mtmd_input_chunk
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); // TODO: deprecate
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens);
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// tokenize an input text prompt and an image
// the prompt must have the input image marker (default: "<__image__>") in it
// the marker will be replaced with the image tokens
// tokenize an input text prompt and a list of bitmaps (images/audio)
// the prompt must have the input image marker (default: "<__media__>") in it
// the default marker is defined by mtmd_default_marker()
// the marker will be replaced with the image/audio chunk
// for example:
// "here is an image: <__image__>\ndescribe it in detail."
// "here is an image: <__media__>\ndescribe it in detail."
// this will gives 3 chunks:
// 1. "here is an image: <start_of_image>"
// 2. (image tokens)
// 2. (image/audio tokens)
// 3. "<end_of_image>\ndescribe it in detail."
// number of bitmaps must be equal to the number of image markers in the prompt
// number of bitmaps must be equal to the number of markers in the prompt
// this function is thread-safe (shared ctx)
// return values:
// 0 on success
// 1 on number of images not matching the number of markers
// 1 on number of bitmaps not matching the number of markers
// 2 on image preprocessing error
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunks * output,
@@ -173,9 +194,14 @@ MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
size_t n_bitmaps);
// returns 0 on success
// TODO: deprecate
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
const mtmd_image_tokens * image_tokens);
// returns 0 on success
MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
const mtmd_input_chunk * chunk);
// get output embeddings from the last encode pass
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
@@ -189,12 +215,16 @@ MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
//
// helper function to construct a mtmd_bitmap from a file
// it calls mtmd_helper_bitmap_init_from_buf() internally
// returns nullptr on failure
// this function is thread-safe
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname);
// helper function to construct a mtmd_bitmap from a buffer containing a file
// the file content must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.)
// supported formats:
// image: formats supported by stb_image: jpg, png, bmp, gif, etc.
// audio: formats supported by miniaudio: wav, mp3, flac
// note: audio files will be auto-detected based on magic bytes
// returns nullptr on failure
// this function is thread-safe
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len);
@@ -293,6 +323,7 @@ struct bitmap {
uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); }
uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); }
const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); }
size_t n_bytes() { return mtmd_bitmap_get_n_bytes(ptr.get()); }
std::string id() { return mtmd_bitmap_get_id(ptr.get()); }
void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); }
};

View File

@@ -936,7 +936,7 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
// Function to tokenize the prompt
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
const bool is_first = llama_kv_self_used_cells(llama_data.context.get()) == 0;
const bool is_first = llama_kv_self_seq_pos_max(llama_data.context.get(), 0) == 0;
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
prompt_tokens.resize(n_prompt_tokens);
@@ -952,7 +952,7 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
// Check if we have enough space in the context to evaluate this batch
static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
const int n_ctx = llama_n_ctx(ctx.get());
const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
const int n_ctx_used = llama_kv_self_seq_pos_max(ctx.get(), 0);
if (n_ctx_used + batch.n_tokens > n_ctx) {
printf(LOG_COL_DEFAULT "\n");
printe("context size exceeded\n");

Binary file not shown.

View File

@@ -1,3 +1,4 @@
#include "chat.h"
#include "utils.hpp"
#include "arg.h"
@@ -114,11 +115,11 @@ struct slot_params {
struct common_params_speculative speculative;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
json to_json() const {
std::vector<std::string> samplers;
@@ -176,7 +177,10 @@ struct slot_params {
{"grammar_lazy", sampling.grammar_lazy},
{"grammar_triggers", grammar_triggers},
{"preserved_tokens", sampling.preserved_tokens},
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
{"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
{"reasoning_format", (oaicompat_chat_syntax.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "deepseek" : "none")},
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
@@ -352,11 +356,14 @@ struct server_task {
{
auto it = data.find("chat_format");
if (it != data.end()) {
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format).c_str());
} else {
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
}
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
}
{
@@ -396,7 +403,14 @@ struct server_task {
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
}
} else {
params.sampling.grammar_triggers.push_back(std::move(ct.value));
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
} else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
} else {
throw std::runtime_error("Unknown grammar trigger type");
}
params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
}
}
}
@@ -639,11 +653,12 @@ struct server_task_result_cmpl_final : server_task_result {
slot_params generation_params;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override {
return index;
@@ -738,47 +753,20 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat() {
std::string finish_reason = "length";
common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
SRV_DBG("Parsing chat message: %s\n", content.c_str());
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
if (!oaicompat_msg.empty()) {
msg = oaicompat_msg;
} else {
msg.role = "assistant";
msg.content = content;
}
json message {
{"role", "assistant"},
};
if (!msg.reasoning_content.empty()) {
message["reasoning_content"] = msg.reasoning_content;
}
if (msg.content.empty() && !msg.tool_calls.empty()) {
message["content"] = json();
} else {
message["content"] = msg.content;
}
if (!msg.tool_calls.empty()) {
auto tool_calls = json::array();
for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
{"name", tc.name},
{"arguments", tc.arguments},
}},
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
// We only generate a random id for the ones that don't generate one by themselves
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
{"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
});
}
message["tool_calls"] = tool_calls;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
}
json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"message", message},
{"message", msg.to_json_oaicompat<json>()},
};
if (!stream && probs_output.size() > 0) {
@@ -818,17 +806,35 @@ struct server_task_result_cmpl_final : server_task_result {
std::time_t t = std::time(0);
std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
}
json choice = json {
{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}
};
json deltas = json::array();
for (const auto & diff : oaicompat_msg_diffs) {
deltas.push_back({
{"choices", json::array({
json {
{"finish_reason", nullptr},
{"index", 0},
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
},
})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
});
}
json ret = json {
{"choices", json::array({choice})},
deltas.push_back({
{"choices", json::array({
json {
{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()},
},
})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
@@ -839,18 +845,18 @@ struct server_task_result_cmpl_final : server_task_result {
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}},
};
});
if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
deltas.back().push_back({"timings", timings.to_json()});
}
// extra fields for debugging purposes
if (verbose) {
ret["__verbose"] = to_json_non_oaicompat();
if (verbose && !deltas.empty()) {
deltas.front()["__verbose"] = to_json_non_oaicompat();
}
return ret;
return deltas;
}
};
@@ -868,10 +874,11 @@ struct server_task_result_cmpl_partial : server_task_result {
result_timings timings;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override {
return index;
@@ -951,74 +958,54 @@ struct server_task_result_cmpl_partial : server_task_result {
}
json to_json_oaicompat_chat() {
bool first = n_decoded == 0;
bool first = n_decoded == 1;
std::time_t t = std::time(0);
json choices;
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json {
{"content", content}}}
}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};
return std::vector<json>({initial_ret, second_ret});
}
} else {
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json {
{"content", content},
}},
}});
}
GGML_ASSERT(choices.size() >= 1);
if (prob_output.probs.size() > 0) {
choices[0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
json ret = json {
{"choices", choices},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}
std::vector<json> deltas;
auto add_delta = [&](const json & delta) {
deltas.push_back({
{"choices", json::array({
json {
{"finish_reason", nullptr},
{"index", 0},
{"delta", delta},
},
})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
});
};
if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
// We have to send an initial update to conform to openai behavior
if (first) {
add_delta({
{"role", "assistant"},
{"content", nullptr},
});
}
return std::vector<json>({ret});
for (const auto & diff : oaicompat_msg_diffs) {
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
}
if (!deltas.empty()) {
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
if (prob_output.probs.size() > 0) {
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
if (timings.prompt_n >= 0) {
deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
}
}
return deltas;
}
};
@@ -1137,9 +1124,6 @@ struct server_task_result_metrics : server_task_result {
int n_tasks_deferred;
int64_t t_start;
int32_t kv_cache_tokens_count;
int32_t kv_cache_used_cells;
// TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
uint64_t n_prompt_tokens_processed_total = 0;
uint64_t t_prompt_processing_total = 0;
@@ -1179,9 +1163,6 @@ struct server_task_result_metrics : server_task_result {
{ "n_decode_total", n_decode_total },
{ "n_busy_slots_total", n_busy_slots_total },
{ "kv_cache_tokens_count", kv_cache_tokens_count },
{ "kv_cache_used_cells", kv_cache_used_cells },
{ "slots", slots_data },
};
}
@@ -1285,6 +1266,7 @@ struct server_slot {
std::string generated_text;
llama_tokens generated_tokens;
common_chat_msg chat_msg;
server_tokens cache_tokens;
@@ -1305,6 +1287,7 @@ struct server_slot {
llama_token sampled;
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::vector<std::string> generated_tool_call_ids;
// stats
size_t n_sent_text = 0; // number of sent text character
@@ -1334,9 +1317,13 @@ struct server_slot {
n_past = 0;
n_sent_text = 0;
task_type = SERVER_TASK_TYPE_COMPLETION;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
generated_tokens.clear();
generated_token_probs.clear();
chat_msg = {};
json_schema = json();
generated_tool_call_ids.clear();
// clear speculative decoding stats
n_draft_total = 0;
@@ -1416,6 +1403,21 @@ struct server_slot {
return timings;
}
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
auto previous_msg = chat_msg;
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
auto new_msg = common_chat_parse(
generated_text,
/* is_partial= */ stop != STOP_TYPE_EOS,
params.oaicompat_chat_syntax);
if (!new_msg.empty()) {
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
}
return chat_msg;
}
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
size_t stop_pos = std::string::npos;
@@ -1883,6 +1885,7 @@ struct server_context {
float slot_prompt_similarity = 0.0f;
common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;
~server_context() {
mtmd_free(mctx);
@@ -2078,6 +2081,15 @@ struct server_context {
}
metrics.init();
oai_parser_opt = {
/* use_jinja */ params_base.use_jinja,
/* prefill_assistant */ params_base.prefill_assistant,
/* reasoning_format */ params_base.reasoning_format,
/* common_chat_templates */ chat_templates.get(),
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
};
}
server_slot * get_slot_by_id(int id) {
@@ -2457,10 +2469,12 @@ struct server_context {
res->n_prompt_tokens = slot.n_prompt_tokens;
res->post_sampling_probs = slot.params.post_sampling_probs;
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@@ -2481,7 +2495,7 @@ struct server_context {
res->id_slot = slot.id;
res->index = slot.index;
res->content = std::move(slot.generated_text);
res->content = slot.generated_text;
res->tokens = std::move(slot.generated_tokens);
res->timings = slot.get_timings();
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
@@ -2501,7 +2515,8 @@ struct server_context {
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
@@ -2771,9 +2786,6 @@ struct server_context {
res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
res->t_start = metrics.t_start;
res->kv_cache_tokens_count = llama_kv_self_n_tokens(ctx);
res->kv_cache_used_cells = llama_kv_self_used_cells(ctx);
res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
res->t_prompt_processing_total = metrics.t_prompt_processing_total;
res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
@@ -3336,6 +3348,37 @@ struct server_context {
common_set_adapter_lora(ctx, slot_batched->lora);
}
const bool do_encode = (params_base.embedding || params_base.reranking);
// pad the batch so that batch.n_tokens >= n_slots
// TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
if (do_encode) {
const int n_slots = slots.size();
if (batch.n_tokens < n_slots) {
std::set<llama_seq_id> seq_ids;
for (int j = 0; j < batch.n_tokens; ++j) {
seq_ids.insert(batch.seq_id[j][0]);
}
// find unused sequence id
llama_seq_id seq_id = -1;
for (int i = 0; i < n_slots; ++i) {
if (seq_ids.find(i) == seq_ids.end()) {
seq_id = i;
}
}
const int n_add = n_slots - batch.n_tokens;
SRV_WRN("adding %d dummy tokens to the batch, seq_id = %d\n", n_add, seq_id);
for (int j = 0; j < n_add; ++j) {
common_batch_add(batch, 0, j, { seq_id }, false);
}
}
}
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@@ -3352,7 +3395,7 @@ struct server_context {
int ret = 0;
if (params_base.embedding || params_base.reranking) {
if (do_encode) {
ret = llama_encode(ctx, batch_view);
} else {
ret = llama_decode(ctx, batch_view);
@@ -3361,14 +3404,29 @@ struct server_context {
metrics.on_decoded(slots);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
for (auto & slot : slots) {
slot.release();
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
{
std::string err;
if (n_batch == 1 && ret == 1) {
err = "Context size has been exceeded.";
}
if (ret == -1) {
err = "Invalid input batch.";
}
if (ret < -1) {
err = "Compute error.";
}
if (!err.empty()) {
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
for (auto & slot : slots) {
slot.release();
send_error(slot, err);
}
break;
}
break; // break loop of n_batch
}
// retry with half the batch size to try to find a free slot in the KV cache
@@ -3702,6 +3760,7 @@ int main(int argc, char ** argv) {
"/health",
"/models",
"/v1/models",
"/api/tags"
};
// If API key is not set, skip validation
@@ -3740,7 +3799,7 @@ int main(int argc, char ** argv) {
if (req.path == "/" || tmp.back() == "html") {
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
res.status = 503;
} else if (req.path == "/models" || req.path == "/v1/models") {
} else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") {
// allow the models endpoint to be accessed during loading
return true;
} else {
@@ -3883,14 +3942,6 @@ int main(int argc, char ** argv) {
{"name", "predicted_tokens_seconds"},
{"help", "Average generation throughput in tokens/s."},
{"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.}
},{
{"name", "kv_cache_usage_ratio"},
{"help", "KV-cache usage. 1 means 100 percent usage."},
{"value", 1. * res_metrics->kv_cache_used_cells / params.n_ctx}
},{
{"name", "kv_cache_tokens"},
{"help", "KV-cache tokens."},
{"value", (uint64_t) res_metrics->kv_cache_tokens_count}
},{
{"name", "requests_processing"},
{"help", "Number of requests processing."},
@@ -4048,7 +4099,10 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model.path },
{ "modalities", json{{"vision", ctx_server.mctx != nullptr}} }, // TODO: add more in the future
{ "modalities", json{
{"vision", ctx_server.oai_parser_opt.allow_image},
{"audio", ctx_server.oai_parser_opt.allow_audio},
} },
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
@@ -4086,6 +4140,19 @@ int main(int argc, char ** argv) {
{ "llama.context_length", ctx_server.slots.back().n_ctx, },
}
},
{"modelfile", ""},
{"parameters", ""},
{"template", common_chat_templates_source(ctx_server.chat_templates.get())},
{"details", {
{"parent_model", ""},
{"format", "gguf"},
{"family", ""},
{"families", {""}},
{"parameter_size", ""},
{"quantization_level", ""}
}},
{"model_info", ""},
{"capabilities", {"completion"}}
};
res_ok(res, data);
@@ -4126,10 +4193,10 @@ int main(int argc, char ** argv) {
for (auto & file : files) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
if (!bmp.ptr) {
throw std::runtime_error("Failed to load image");
throw std::runtime_error("Failed to load image or audio file");
}
// calculate bitmap hash (for KV caching)
std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
bmp.set_id(hash.c_str());
bitmaps.entries.push_back(std::move(bmp));
}
@@ -4361,7 +4428,7 @@ int main(int argc, char ** argv) {
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
};
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
const auto handle_chat_completions = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
LOG_DBG("request: %s\n", req.body.c_str());
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
@@ -4370,13 +4437,9 @@ int main(int argc, char ** argv) {
auto body = json::parse(req.body);
std::vector<raw_buffer> files;
json data = oaicompat_completion_params_parse(
json data = oaicompat_chat_params_parse(
body,
params.use_jinja,
params.prefill_assistant,
params.reasoning_format,
ctx_server.chat_templates.get(),
ctx_server.mctx,
ctx_server.oai_parser_opt,
files);
handle_completions_impl(
@@ -4389,16 +4452,12 @@ int main(int argc, char ** argv) {
};
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
std::vector<raw_buffer> files; // dummy, unused
json data = oaicompat_completion_params_parse(
json data = oaicompat_chat_params_parse(
body,
params.use_jinja,
params.prefill_assistant,
params.reasoning_format,
ctx_server.chat_templates.get(),
ctx_server.mctx,
ctx_server.oai_parser_opt,
files);
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};
@@ -4411,6 +4470,28 @@ int main(int argc, char ** argv) {
}
json models = {
{"models", {
{
{"name", params.model_alias.empty() ? params.model.path : params.model_alias},
{"model", params.model_alias.empty() ? params.model.path : params.model_alias},
{"modified_at", ""},
{"size", ""},
{"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash
{"type", "model"},
{"description", ""},
{"tags", {""}},
{"capabilities", {"completion"}},
{"parameters", ""},
{"details", {
{"parent_model", ""},
{"format", "gguf"},
{"family", ""},
{"families", {""}},
{"parameter_size", ""},
{"quantization_level", ""}
}}
}
}},
{"object", "list"},
{"data", {
{
@@ -4420,7 +4501,7 @@ int main(int argc, char ** argv) {
{"owned_by", "llamacpp"},
{"meta", model_meta},
},
}}
}}
};
res_ok(res, models);
@@ -4748,11 +4829,13 @@ int main(int argc, char ** argv) {
svr->Post("/api/show", handle_api_show);
svr->Get ("/models", handle_models); // public endpoint (no API key check)
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
svr->Get ("/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check)
svr->Post("/completion", handle_completions); // legacy
svr->Post("/completions", handle_completions);
svr->Post("/v1/completions", handle_completions_oai);
svr->Post("/chat/completions", handle_chat_completions);
svr->Post("/v1/chat/completions", handle_chat_completions);
svr->Post("/api/chat", handle_chat_completions); // ollama specific endpoint
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);

View File

@@ -71,8 +71,14 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
})
content = ""
last_cmpl_id = None
for data in res:
for i, data in enumerate(res):
choice = data["choices"][0]
if i == 0:
# Check first role message for stream=True
assert choice["delta"]["content"] is None
assert choice["delta"]["role"] == "assistant"
else:
assert "role" not in choice["delta"]
assert data["system_fingerprint"].startswith("b")
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
if last_cmpl_id is None:
@@ -86,7 +92,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
assert choice["finish_reason"] == finish_reason
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"]
content += choice["delta"]["content"] or ''
def test_chat_completion_with_openai_library():
@@ -242,12 +248,19 @@ def test_chat_completion_with_timings_per_token():
"stream": True,
"timings_per_token": True,
})
for data in res:
assert "timings" in data
assert "prompt_per_second" in data["timings"]
assert "predicted_per_second" in data["timings"]
assert "predicted_n" in data["timings"]
assert data["timings"]["predicted_n"] <= 10
for i, data in enumerate(res):
if i == 0:
# Check first role message for stream=True
assert data["choices"][0]["delta"]["content"] is None
assert data["choices"][0]["delta"]["role"] == "assistant"
assert "timings" not in data, f'First event should not have timings: {data}'
else:
assert "role" not in data["choices"][0]["delta"]
assert "timings" in data
assert "prompt_per_second" in data["timings"]
assert "predicted_per_second" in data["timings"]
assert "predicted_n" in data["timings"]
assert data["timings"]["predicted_n"] <= 10
def test_logprobs():
@@ -295,17 +308,23 @@ def test_logprobs_stream():
)
output_text = ''
aggregated_text = ''
for data in res:
for i, data in enumerate(res):
choice = data.choices[0]
if choice.finish_reason is None:
if choice.delta.content:
output_text += choice.delta.content
assert choice.logprobs is not None
assert choice.logprobs.content is not None
for token in choice.logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
if i == 0:
# Check first role message for stream=True
assert choice.delta.content is None
assert choice.delta.role == "assistant"
else:
assert choice.delta.role is None
if choice.finish_reason is None:
if choice.delta.content:
output_text += choice.delta.content
assert choice.logprobs is not None
assert choice.logprobs.content is not None
for token in choice.logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text

View File

@@ -8,6 +8,7 @@ path = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(path))
from utils import *
from enum import Enum
server: ServerProcess
@@ -20,7 +21,11 @@ def create_server():
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2-tool-call"
server.server_port = 8081
server.n_slots = 1
class CompletionMode(Enum):
NORMAL = "normal"
STREAMED = "streamed"
TEST_TOOL = {
"type":"function",
@@ -73,9 +78,8 @@ WEATHER_TOOL = {
}
}
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
@@ -86,13 +90,13 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
"parallel_tool_calls": False,
**kwargs,
})
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
# assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
assert expected_function_name == tool_call["function"]["name"]
actual_arguments = tool_call["function"]["arguments"]
@@ -102,12 +106,16 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,tool,argument_key", [
("google-gemma-2-2b-it", TEST_TOOL, "success"),
("google-gemma-2-2b-it", TEST_TOOL, "success"),
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
])
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
global server
n_predict = 1024
# server = ServerPreset.stories15m_moe()
@@ -115,31 +123,43 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,tool,argument_key", [
("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
# Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
# ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True),
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
])
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
global server
n_predict = 512
# server = ServerPreset.stories15m_moe()
@@ -147,10 +167,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
@@ -184,9 +205,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
# (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
@@ -203,10 +224,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
n_predict = 512
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
@@ -219,7 +239,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
@@ -228,12 +248,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
"tool_choice": "required",
"tools": [tool],
"parallel_tool_calls": False,
"stream": stream == CompletionMode.STREAMED,
"temperature": 0.0,
"top_k": 1,
"top_p": 1.0,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
@@ -248,7 +268,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
@@ -258,26 +278,27 @@ def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int,
"tool_choice": tool_choice,
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
])
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
global server
server.jinja = True
server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
("meetkai-functionary-medium-v3.2", 256, [], None),
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
@@ -289,16 +310,17 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
])
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
global server
server.jinja = True
server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
@@ -321,11 +343,11 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
@@ -339,10 +361,9 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
])
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
n_predict = 512
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
@@ -355,11 +376,11 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_weather(server, max_tokens=n_predict)
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
def do_test_weather(server: ServerProcess, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"messages": [
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
{"role": "user", "content": "What is the weather in Istanbul?"},
@@ -367,14 +388,13 @@ def do_test_weather(server: ServerProcess, **kwargs):
"tools": [WEATHER_TOOL],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
actual_arguments = json.loads(tool_call["function"]["arguments"])
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
location = actual_arguments["location"]
@@ -383,6 +403,7 @@ def do_test_weather(server: ServerProcess, **kwargs):
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
@@ -400,9 +421,8 @@ def do_test_weather(server: ServerProcess, **kwargs):
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192 * 2
server.n_predict = n_predict
@@ -415,11 +435,11 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_calc_result(server, result_override, n_predict)
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
@@ -466,8 +486,7 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content")
@@ -480,18 +499,18 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
@pytest.mark.slow
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
(128, 'deepseek', "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(128, None, "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "I need to calculate the sum of 102 and 7[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'none', "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [
(128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>I need to calculate [\\s\\S]*?</think>To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>First, I [\\s\\S]*?</think>To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
# (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
])
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
server.n_slots = 1
server.reasoning_format = reasoning_format
server.jinja = True
server.n_ctx = 8192 * 2
@@ -505,14 +524,14 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "user", "content": "What's the sum of 102 and 7?"},
]
],
"stream": stream == CompletionMode.STREAMED,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content")
@@ -529,6 +548,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
@@ -562,10 +582,9 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
])
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
n_predict = 512 # High because of DeepSeek R1
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
@@ -579,11 +598,11 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_hello_world(server, max_tokens=n_predict)
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
def do_test_hello_world(server: ServerProcess, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"messages": [
{"role": "system", "content": "You are a tool-calling agent."},
{"role": "user", "content": "say hello world with python"},
@@ -591,16 +610,15 @@ def do_test_hello_world(server: ServerProcess, **kwargs):
"tools": [PYTHON_TOOL],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
actual_arguments = json.loads(tool_call["function"]["arguments"])
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
code = actual_arguments["code"]
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'

View File

@@ -30,6 +30,7 @@ def create_server():
("What is this:\n", "malformed", False, None),
("What is this:\n", "https://google.com/404", False, None), # non-existent image
("What is this:\n", "https://ggml.ai", False, None), # non-image data
# TODO @ngxson : test with multiple images, no images and with audio
]
)
def test_vision_chat_completion(prompt, image_url, success, re_content):

View File

@@ -294,6 +294,77 @@ class ServerProcess:
print("Partial response from server", json.dumps(data, indent=2))
yield data
def make_any_request(
self,
method: str,
path: str,
data: dict | None = None,
headers: dict | None = None,
timeout: float | None = None,
) -> dict:
stream = data.get('stream', False)
if stream:
content: list[str] = []
tool_calls: list[dict] = []
finish_reason: Optional[str] = None
content_parts = 0
tool_call_parts = 0
arguments_parts = 0
for chunk in self.make_stream_request(method, path, data, headers):
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
choice = chunk['choices'][0]
if choice['delta'].get('content') is not None:
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
content.append(choice['delta']['content'])
content_parts += 1
if choice['delta'].get('finish_reason') is not None:
finish_reason = choice['delta']['finish_reason']
for tc in choice['delta'].get('tool_calls', []):
if 'function' not in tc:
raise ValueError(f"Expected function type, got {tc['type']}")
if tc['index'] >= len(tool_calls):
tool_calls.append(dict(
id="",
type="function",
function=dict(
name="",
arguments="",
)
))
tool_call = tool_calls[tc['index']]
if tc.get('id') is not None:
tool_call['id'] = tc['id']
fct = tc['function']
if fct.get('name') is not None:
tool_call['function']['name'] = fct['name']
if fct.get('arguments') is not None:
assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
tool_call['function']['arguments'] += fct['arguments']
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
result = dict(
choices=[
dict(
index=0,
finish_reason=finish_reason,
message=dict(
role='assistant',
content=''.join(content) if content else None,
tool_calls=tool_calls if tool_calls else None,
),
)
],
)
print("Final response from server", json.dumps(result, indent=2))
return result
else:
response = self.make_request(method, path, data, headers, timeout=timeout)
assert response.status_code == 200, f"Server returned error: {response.status_code}"
return response.body
server_instances: Set[ServerProcess] = set()

View File

@@ -474,26 +474,6 @@ static std::string gen_tool_call_id() {
// other common utils
//
static bool ends_with(const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
if (!text.empty() && !stop.empty()) {
const char text_last_char = text.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const std::string current_partial = stop.substr(0, char_index + 1);
if (ends_with(text, current_partial)) {
return text.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
// TODO: reuse llama_detokenize
template <class Iter>
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
@@ -536,6 +516,7 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
// OAI utils
//
// used by /completions endpoint
static json oaicompat_completion_params_parse(const json & body) {
json llama_params;
@@ -580,31 +561,34 @@ static json oaicompat_completion_params_parse(const json & body) {
return llama_params;
}
static json oaicompat_completion_params_parse(
struct oaicompat_parser_options {
bool use_jinja;
bool prefill_assistant;
common_reasoning_format reasoning_format;
common_chat_templates * tmpls;
bool allow_image;
bool allow_audio;
};
// used by /chat/completions endpoint
static json oaicompat_chat_params_parse(
const json & body, /* openai api json semantics */
bool use_jinja,
bool prefill_assistant,
common_reasoning_format reasoning_format,
const struct common_chat_templates * tmpls,
bool allow_non_text,
const oaicompat_parser_options & opt,
std::vector<raw_buffer> & out_files)
{
json llama_params;
auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty();
auto stream = json_value(body, "stream", false);
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
if (tools.is_array() && !tools.empty()) {
if (stream) {
throw std::runtime_error("Cannot use tools with stream");
}
if (!use_jinja) {
if (!opt.use_jinja) {
if (has_tools) {
throw std::runtime_error("tools param requires --jinja flag");
}
}
if (!use_jinja) {
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
throw std::runtime_error("Unsupported param: tool_choice");
if (tool_choice != "auto") {
throw std::runtime_error("tool_choice param requires --jinja flag");
}
}
@@ -667,12 +651,12 @@ static json oaicompat_completion_params_parse(
for (auto & p : content) {
std::string type = json_value(p, "type", std::string());
json image_url = json_value(p, "image_url", json::object());
if (type == "image_url") {
if (!allow_non_text) {
throw std::runtime_error("image input is not supported by this server");
if (!opt.allow_image) {
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
json image_url = json_value(p, "image_url", json::object());
std::string url = json_value(image_url, "url", std::string());
if (string_starts_with(url, "http")) {
// download remote image
@@ -710,8 +694,31 @@ static json oaicompat_completion_params_parse(
// replace this chunk with a marker
p["type"] = "text";
p["text"] = MTMD_DEFAULT_IMAGE_MARKER;
p["text"] = mtmd_default_marker();
p.erase("image_url");
} else if (type == "input_audio") {
if (!opt.allow_audio) {
throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
json input_audio = json_value(p, "input_audio", json::object());
std::string data = json_value(input_audio, "data", std::string());
std::string format = json_value(input_audio, "format", std::string());
// while we also support flac, we don't allow it here so we matches the OAI spec
if (format != "wav" && format != "mp3") {
throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
}
auto decoded_data = base64_decode(data); // expected to be base64 encoded
out_files.push_back(decoded_data);
// replace this chunk with a marker
p["type"] = "text";
p["text"] = mtmd_default_marker();
p.erase("input_audio");
} else if (type != "text") {
throw std::runtime_error("unsupported content[].type");
}
}
}
@@ -719,21 +726,19 @@ static json oaicompat_completion_params_parse(
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = common_chat_tools_parse_oaicompat(tools);
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
inputs.grammar = grammar;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.use_jinja = use_jinja;
inputs.use_jinja = opt.use_jinja;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.reasoning_format = opt.reasoning_format;
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
// if the assistant message appears at the end of list, we do not add end-of-turn token
// for ex. this can be useful to modify the reasoning process in reasoning models
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && prefill_assistant;
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant;
common_chat_msg last_message;
if (prefill_assistant_message) {
last_message = inputs.messages.back();
@@ -744,12 +749,13 @@ static json oaicompat_completion_params_parse(
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
}
inputs.extract_reasoning = false;
/* TODO: test this properly */
inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
inputs.add_generation_prompt = true;
}
// Apply chat template to the list of messages
auto chat_params = common_chat_templates_apply(tmpls, inputs);
auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
/* Append assistant prefilled message */
if (prefill_assistant_message) {
@@ -769,6 +775,7 @@ static json oaicompat_completion_params_parse(
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
@@ -782,6 +789,9 @@ static json oaicompat_completion_params_parse(
// Handle "logprobs" field
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
if (json_value(body, "logprobs", false)) {
if (has_tools && stream) {
throw std::runtime_error("logprobs is not supported with tools + stream");
}
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
@@ -1040,7 +1050,7 @@ struct server_tokens {
private: // disallow accessing these members directly, risking out-of-sync
// map a **start** position in tokens to the image chunk
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_image;
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
// list of tokens
// it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
@@ -1051,7 +1061,7 @@ private: // disallow accessing these members directly, risking out-of-sync
// for ex. with input of 5 text tokens and 2 images:
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
// pos 0 1 2 3 4 5 6 7 8 9
// map_pos_to_image will contain: {5, img0}, {8, img1}
// map_pos_to_media will contain: {5, img0}, {8, img1}
public:
server_tokens() = default;
@@ -1090,15 +1100,15 @@ public:
}
oss << "\n";
oss << "image pos: ";
for (const auto & it : map_pos_to_image) {
for (const auto & it : map_pos_to_media) {
oss << it.first << ", ";
}
return oss.str();
}
const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
auto it = map_pos_to_image.find(pos);
if (it != map_pos_to_image.end()) {
auto it = map_pos_to_media.find(pos);
if (it != map_pos_to_media.end()) {
return it->second;
} else {
throw std::runtime_error("Chunk not found");
@@ -1115,16 +1125,15 @@ public:
// will create a copy of the chunk if it contains non-text data
void push_back(const mtmd_input_chunk * chunk) {
auto type = mtmd_input_chunk_get_type(chunk);
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
GGML_ASSERT(has_mtmd);
auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk);
const int n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
llama_pos start_pos = tokens.size();
for (int i = 0; i < n_pos; ++i) {
tokens.emplace_back(LLAMA_TOKEN_NULL);
}
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
map_pos_to_image[start_pos] = std::move(new_chunk);
map_pos_to_media[start_pos] = std::move(new_chunk);
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens;
auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
@@ -1169,6 +1178,9 @@ public:
void keep_first(size_t n) {
GGML_ASSERT(n <= tokens.size());
if (has_mtmd) {
if (n == tokens.size()) {
return; // nothing to do
}
// we throw an error if we try to remove a token in the middle of an image
// for ex. with input of 5 text tokens and 2 images:
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
@@ -1183,10 +1195,10 @@ public:
}
}
// remove all image chunks that are not used anymore
for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) {
for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) {
llama_pos pos = it->first;
if (pos >= (llama_pos)n) {
it = map_pos_to_image.erase(it);
it = map_pos_to_media.erase(it);
} else {
++it;
}
@@ -1217,14 +1229,12 @@ public:
const auto & a_chunk = find_chunk(i);
const auto & b_chunk = b.find_chunk(i);
GGML_ASSERT(a_chunk && b_chunk);
const auto * a_img = mtmd_input_chunk_get_tokens_image(a_chunk.get());
const auto * b_img = mtmd_input_chunk_get_tokens_image(b_chunk.get());
std::string ai_id = mtmd_image_tokens_get_id(a_img);
std::string bi_id = mtmd_image_tokens_get_id(b_img);
size_t a_pos = mtmd_image_tokens_get_n_pos(a_img);
size_t b_pos = mtmd_image_tokens_get_n_pos(b_img);
std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get());
std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get());
size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get());
size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get());
if (ai_id == bi_id && a_pos == b_pos) {
GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen
GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
i += a_pos - 1; // will be +1 by the for loop
continue;
} else {
@@ -1250,8 +1260,7 @@ public:
if (t == LLAMA_TOKEN_NULL) {
try {
const auto & chunk = find_chunk(i);
const auto * img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get());
size_t n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
i += n_pos - 1; // will be +1 by the for loop
} catch (const std::exception & e) {
return false;
@@ -1270,22 +1279,21 @@ public:
llama_pos n_past,
int32_t seq_id,
llama_pos & n_pos_out) {
auto it = map_pos_to_image.find(n_past);
if (it == map_pos_to_image.end()) {
throw std::runtime_error("Chunk not found");
}
SRV_INF("%s\n", "processing image...");
auto & chunk = find_chunk(n_past);
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
? "image" : "audio";
SRV_INF("processing %s...\n", name);
int32_t n_batch = llama_n_batch(ctx);
int64_t t0 = ggml_time_ms();
llama_pos new_n_past = n_past;
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
it->second.get(), // chunk
chunk.get(),
n_past,
seq_id,
n_batch,
true, // logits last
&new_n_past);
SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0);
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
if (result != 0) {
LOG_ERR("mtmd_helper_eval failed with status %d", result);
n_pos_out = n_past;

View File

@@ -1,4 +1,8 @@
import { DocumentTextIcon, XMarkIcon } from '@heroicons/react/24/outline';
import {
DocumentTextIcon,
SpeakerWaveIcon,
XMarkIcon,
} from '@heroicons/react/24/outline';
import { MessageExtra } from '../utils/types';
import { useState } from 'react';
import { classNames } from '../utils/misc';
@@ -66,7 +70,11 @@ export default function ChatInputExtraContextItem({
className="w-14 h-14 flex items-center justify-center"
aria-description="Document icon"
>
<DocumentTextIcon className="h-8 w-14 text-base-content/50" />
{item.type === 'audioFile' ? (
<SpeakerWaveIcon className="h-8 w-8 text-gray-500" />
) : (
<DocumentTextIcon className="h-8 w-8 text-gray-500" />
)}
</div>
<div className="text-xs pr-4">
@@ -98,6 +106,19 @@ export default function ChatInputExtraContextItem({
src={showingItem.base64Url}
alt={`Preview image for ${showingItem.name}`}
/>
) : showingItem.type === 'audioFile' ? (
<audio
controls
className="w-full"
aria-description={`Audio file ${showingItem.name}`}
>
<source
src={`data:${showingItem.mimeType};base64,${showingItem.base64Data}`}
type={showingItem.mimeType}
aria-description={`Audio file ${showingItem.name}`}
/>
Your browser does not support the audio element.
</audio>
) : (
<div className="overflow-x-auto">
<pre className="whitespace-pre-wrap break-words text-sm">

View File

@@ -278,6 +278,13 @@ export default function ChatScreen() {
function ServerInfo() {
const { serverProps } = useAppContext();
const modalities = [];
if (serverProps?.modalities?.audio) {
modalities.push('audio');
}
if (serverProps?.modalities?.vision) {
modalities.push('vision');
}
return (
<div
className="card card-sm shadow-sm border-1 border-base-content/20 text-base-content/70 mb-6"
@@ -291,6 +298,13 @@ function ServerInfo() {
<br />
<b>Build</b>: {serverProps?.build_info}
<br />
{modalities.length > 0 ? (
<>
<b>Supported modalities:</b> {modalities.join(', ')}
</>
) : (
''
)}
</p>
</div>
</div>

View File

@@ -11,6 +11,7 @@ pdfjs.GlobalWorkerOptions.workerSrc = pdfjsWorkerSrc;
// This file handles uploading extra context items (a.k.a files)
// It allows processing these kinds of files:
// - image files (converted to base64)
// - audio files (converted to base64)
// - text files (including code files)
// - pdf (converted to text)
@@ -41,96 +42,73 @@ export function useChatExtraContext(): ChatExtraContextApi {
const isSupportVision = serverProps?.modalities?.vision;
const onFileAdded = (files: File[]) => {
for (const file of files) {
const mimeType = file.type;
console.debug({ mimeType, file });
if (file.size > 10 * 1024 * 1024) {
toast.error('File is too large. Maximum size is 10MB.');
break;
}
if (mimeType.startsWith('image/')) {
if (!isSupportVision) {
toast.error('Multimodal is not supported by this server or model.');
const onFileAdded = async (files: File[]) => {
try {
for (const file of files) {
const mimeType = file.type;
if (file.size > 10 * 1024 * 1024) {
toast.error('File is too large. Maximum size is 10MB.');
break;
}
const reader = new FileReader();
reader.onload = async (event) => {
if (event.target?.result) {
let base64Url = event.target.result as string;
if (mimeType === 'image/svg+xml') {
// Convert SVG to PNG
base64Url = await svgBase64UrlToPngDataURL(base64Url);
}
if (mimeType.startsWith('image/')) {
if (!isSupportVision) {
toast.error('Multimodal is not supported by this server or model.');
break;
}
addItems([
{
let base64Url = await getFileAsBase64(file);
if (mimeType === 'image/svg+xml') {
// Convert SVG to PNG
base64Url = await svgBase64UrlToPngDataURL(base64Url);
}
addItems([
{
type: 'imageFile',
name: file.name,
base64Url,
},
]);
} else if (mimeType.startsWith('video/')) {
toast.error('Video files are not supported yet.');
break;
} else if (mimeType.startsWith('audio/')) {
if (!/mpeg|wav/.test(mimeType)) {
toast.error('Only mp3 and wav audio files are supported.');
break;
}
// plain base64, not a data URL
const base64Data = await getFileAsBase64(file, false);
addItems([
{
type: 'audioFile',
name: file.name,
mimeType,
base64Data,
},
]);
} else if (mimeType.startsWith('application/pdf')) {
if (config.pdfAsImage && !isSupportVision) {
toast(
'Multimodal is not supported, PDF will be converted to text instead of image.'
);
break;
}
if (config.pdfAsImage && isSupportVision) {
// Convert PDF to images
const base64Urls = await convertPDFToImage(file);
addItems(
base64Urls.map((base64Url) => ({
type: 'imageFile',
name: file.name,
base64Url,
},
]);
}
};
reader.readAsDataURL(file);
} else if (
mimeType.startsWith('video/') ||
mimeType.startsWith('audio/')
) {
toast.error('Video and audio files are not supported yet.');
break;
} else if (mimeType.startsWith('application/pdf')) {
if (config.pdfAsImage && !isSupportVision) {
toast(
'Multimodal is not supported, PDF will be converted to text instead of image.'
);
break;
}
const promise =
config.pdfAsImage && isSupportVision
? convertPDFToImage(file).then((base64Urls) => {
addItems(
base64Urls.map((base64Url) => ({
type: 'imageFile',
name: file.name,
base64Url,
}))
);
})
: convertPDFToText(file).then((content) => {
if (isSupportVision) {
toast.success(
'PDF file converted to text. You can also convert it to image, see in Settings.'
);
}
addItems([
{
type: 'textFile',
name: file.name,
content,
},
]);
});
promise.catch((error) => {
console.error(error);
toast.error('Failed to parse PDF file.');
});
break;
} else {
// Because there can be many text file types (like code file), we will not check the mime type
// and will just check if the file is not binary.
const reader = new FileReader();
reader.onload = (event) => {
if (event.target?.result) {
const content = event.target.result as string;
if (!isLikelyNotBinary(content)) {
toast.error('File is binary. Please upload a text file.');
return;
}
}))
);
} else {
// Convert PDF to text
const content = await convertPDFToText(file);
addItems([
{
type: 'textFile',
@@ -138,10 +116,40 @@ export function useChatExtraContext(): ChatExtraContextApi {
content,
},
]);
if (isSupportVision) {
toast.success(
'PDF file converted to text. You can also convert it to image, see in Settings.'
);
}
}
};
reader.readAsText(file);
break;
} else {
// Because there can be many text file types (like code file), we will not check the mime type
// and will just check if the file is not binary.
const reader = new FileReader();
reader.onload = (event) => {
if (event.target?.result) {
const content = event.target.result as string;
if (!isLikelyNotBinary(content)) {
toast.error('File is binary. Please upload a text file.');
return;
}
addItems([
{
type: 'textFile',
name: file.name,
content,
},
]);
}
};
reader.readAsText(file);
}
}
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
const errorMessage = `Error processing file: ${message}`;
toast.error(errorMessage);
}
};
@@ -154,6 +162,25 @@ export function useChatExtraContext(): ChatExtraContextApi {
};
}
async function getFileAsBase64(file: File, outputUrl = true): Promise<string> {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (event) => {
if (event.target?.result) {
let result = event.target.result as string;
if (!outputUrl) {
// remove base64 url prefix and correct characters
result = result.substring(result.indexOf(',') + 1);
}
resolve(result);
} else {
reject(new Error('Failed to read file.'));
}
};
reader.readAsDataURL(file);
});
}
async function getFileAsBuffer(file: File): Promise<ArrayBuffer> {
return new Promise((resolve, reject) => {
const reader = new FileReader();

View File

@@ -89,6 +89,14 @@ export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
type: 'image_url',
image_url: { url: extra.base64Url },
});
} else if (extra.type === 'audioFile') {
contentArr.push({
type: 'input_audio',
input_audio: {
data: extra.base64Data,
format: /wav/.test(extra.mimeType) ? 'wav' : 'mp3',
},
});
} else {
throw new Error('Unknown extra type');
}

Some files were not shown because too many files have changed in this diff Show More