Compare commits

...

27 Commits
b7248 ... b7275

Author SHA1 Message Date
Gabe Goodhart
bde188d60f metal: TRI, FILL, EXPM1, SOFTPLUS (#16623)
* feat(wip): Port initial TRI impl from pervious work

The kernel does not work and is not optimized, but the
code compiles and runs, so this will be the starting point
now that the core op has been merged.

Branch: ggml-cumsum-tri

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove argument for constant val override

This was added in the original draft, but later removed. With this, the
kernel now passes tests.

Branch: ggml-cumsum-tri

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Move the ttype conditional to templating to avoid conditional in kernel

Branch: ggml-cumsum-tri

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Type fixes

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* feat: Add softplus for metal

Branch: ggml-cumsum-tri

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add EXPM1 for metal

Branch: ggml-cumsum-tri

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add FILL for metal

Branch: ggml-cumsum-tri

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Branchless version of tri using _ggml_vec_tri_cmp as a mask

Branch: ggml-cumsum-tri

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove unused arguments

Branch: ggml-cumsum-tri

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Use select instead of branch for softplus non-vec

Branch: ggml-cumsum-tri

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-12-04 19:12:19 +02:00
Xuan-Son Nguyen
9d0229967a server: strip content-length header on proxy (#17734) 2025-12-04 16:32:57 +01:00
Xuan-Son Nguyen
c4c10bfb86 server: move msg diffs tracking to HTTP thread (#17740)
* server: move msg diffs tracking to HTTP thread

* wip

* tool call tests ok

* minor : style

* cont : fix

* move states to server_response_reader

* add safe-guard

* fix

* fix 2

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-12-04 15:46:08 +01:00
Daniel Bevenius
817d743cc1 examples : add missing code block end marker [no ci] (#17756)
This commit adds the missing code block end marker in simple-cmake-pkg
to correct the formatting.
2025-12-04 14:17:30 +01:00
Daniel Bevenius
bd4ef13476 common : skip model validation when --help is requested (#17755)
This commit skips the model validation check when the user specifies the
--help option.

The motivation for this is that currently and error is thrown before the
--help could be processed. Now skips validation if params.usage is set,
allowing help to display without requiring --model.

Resolves: https://github.com/ggml-org/llama.cpp/issues/17754
2025-12-04 13:36:50 +01:00
Alberto Cabrera Pérez
87a2084c45 ggml-cpu : remove asserts always evaluating to false (#17728) 2025-12-04 13:16:38 +01:00
SmartestWashingMachine
3659aa28e9 convert: use existing local chat_template if mistral-format model has one. (#17749)
* conversion: use existing local chat_template.jinja file if mistral-format model has one.

* fix --mistral-format mistakenly assuming some <=v7 chat template names are file paths and reading them.

* Update convert_hf_to_gguf.py - change from exists() to is_file()

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-12-04 12:12:45 +01:00
Adrien Gallouët
2a73f81f8a cmake : simplify build info detection using standard variables (#17423)
The current approach has several drawbacks. Mostly, when
cross-compiling, invoking the compiler binary directly to query the
machine hardware can behave unexpectedly depending on the toolchain
wrapper (using COMPILER_TARGET, CFLAGS, etc).

As CMake is the official tool to build llama.cpp, I propose to only rely
on it to get those variables (`CMAKE_SYSTEM_NAME` and
`CMAKE_SYSTEM_PROCESSOR`).

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-12-04 12:42:13 +02:00
Sigbjørn Skjæret
7dba049b07 ci : disable ggml-ci-x64-amd-* (#17753) 2025-12-04 11:25:08 +01:00
Adrien Gallouët
83c1171529 common: use native MultiByteToWideChar (#17738)
`std::codecvt_utf8<wchar_t>` is deprecated and produces warnings:

    common/common.cpp:792:31: warning: 'codecvt_utf8<wchar_t>' is deprecated [-Wdeprecated-declarations]
      792 |     std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
          |

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-12-04 12:06:49 +02:00
Georgi Gerganov
0d1324856f metal : use params per pipeline instance (#17739) 2025-12-04 10:34:11 +02:00
Georgi Gerganov
a67ef0f47f llama : fix sanity checks during quantization (#17721) 2025-12-04 10:33:42 +02:00
Adrien Gallouët
ef75a89fdb build : move _WIN32_WINNT definition to headers (#17736)
Previously, cmake was forcing `_WIN32_WINNT=0x0A00` for MinGW builds,
This caused "macro redefined" warnings with toolchains that define the version.

This also removes the `GGML_WIN_VER` variable as it is no longer needed.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-12-04 07:04:02 +01:00
Jeff Bolz
d8b5cdc4fe build: enable parallel builds in msbuild using MTT (#17708)
* build: enable parallel builds in msbuild using MTT

* check LLAMA_STANDALONE
2025-12-03 22:42:29 -06:00
Herman Semenoff
dea9ba27cb ggml-cpu: remove duplicate conditional check 'iid' (#17650) 2025-12-04 05:03:19 +08:00
Piotr Wilkin (ilintar)
c6d1a00aa7 Add a couple of file types to the text section (#17670)
* Add a couple of file types to the text section

* Format + regenerate index

* Rebuild after rebase
2025-12-03 21:45:06 +01:00
SmartestWashingMachine
424c579455 convert : support latest mistral-common (fix conversion with --mistral-format) (#17712)
* fix convert_hf_to_gguf.py failing with --mistral-format using later mistral-common versions.

* use get_one_valid_tokenizer_file from mistral-common if available and fallback to old logic otherwise.

* use file name instead of file path for get_one_valid_tokenizer_file.

* fix --mistral-format tokenizer file failing for tokenizers in subdirectories.

* move get_one_valid_tokenizer_file import to avoid nested try-except.
2025-12-03 21:15:04 +01:00
Aleksander Grygier
e9f9483464 Use OpenAI-compatible /v1/models endpoint by default (#17689)
* refactor: Data fetching via stores

* chore: update webui build output

* refactor: Use OpenAI compat `/v1/models` endpoint by default to list models

* chore: update webui build output

* chore: update webui build output
2025-12-03 20:49:09 +01:00
Andika Wasisto
41c5e02f42 webui: Fix zero pasteLongTextToFileLen to disable conversion being overridden (#17445)
* webui: Fix zero pasteLongTextToFileLen to disable conversion being overridden

Zero pasteLongTextToFileLen should disable the conversion, but it was
overwritten with 2500.

* Apply suggestions from code review

* Update webui build
2025-12-03 20:45:17 +01:00
Johannes Gäßler
2e1c9cd814 CUDA: generalized (mma) FA, add Volta support (#17505)
* CUDA: generalized (mma) FA, add Volta support

* use struct for MMA FA kernel config

---------

Co-authored-by: Aman Gupta <aman>
2025-12-03 16:57:05 +01:00
Georgi Gerganov
190c4838bd chat : reserve memory in compute_diffs and improve naming (#17729) 2025-12-03 17:22:10 +02:00
Pascal
e7c2cf1356 server: add router multi-model tests (#17704) (#17722)
* llama-server: add router multi-model tests (#17704)

Add 4 test cases for model router:
- test_router_unload_model: explicit model unloading
- test_router_models_max_evicts_lru: LRU eviction with --models-max
- test_router_no_models_autoload: --no-models-autoload flag behavior
- test_router_api_key_required: API key authentication

Tests use async model loading with polling and graceful skip when
insufficient models available for eviction testing.

utils.py changes:
- Add models_max, models_dir, no_models_autoload attributes to ServerProcess
- Handle JSONDecodeError for non-JSON error responses (fallback to text)

* llama-server: update test models to new HF repos

* add offline

* llama-server: fix router LRU eviction test and add preloading

Fix eviction test: load 2 models first, verify state, then load
3rd to trigger eviction. Previous logic loaded all 3 at once,
causing first model to be evicted before verification could occur.

Add module fixture to preload models via ServerPreset.load_all()
and mark test presets as offline to use cached models

* llama-server: fix split model download on Windows

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2025-12-03 15:10:37 +01:00
Adrien Gallouët
1257491047 server : fix bad fmt, size() is a size_type (#17735)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-12-03 15:47:22 +02:00
Adrien Gallouët
083e18b11c cmake: explicitly link against crypt32 on non-MSVC Windows builds (#17727)
Some toolchains do not support linking via pragmas such as:

    #pragma comment(lib, "crypt32.lib")

so we need to add the library explicitly.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-12-03 15:47:02 +02:00
Georgi Gerganov
3d94e967a1 metal : fix data race in pipeline library (#17731) 2025-12-03 14:03:40 +02:00
jiahao su
7feb0a1005 ci : remove the build of openeuler-cann in release (#17724)
* Remove the build of openeuler-cann in release

* Remove the relevant release files
2025-12-03 12:24:59 +01:00
Aldehir Rojas
0a8026e768 common : introduce composable PEG parser combinators for chat parsing (#17136)
* common : implement parser combinators to simplify chat parsing

* add virtual destructor to parser_base

* fix memory leak from circular references of rules

* implement gbnf grammar building

* remove unused private variable

* create a base visitor and implement id assignment as a visitor

* fix const ref for grammar builder

* clean up types, friend classes, and class declarations

* remove builder usage from until_parser

* Use a counter class to help assign rule ids

* cache everything

* add short description for each parser

* create a type for the root parser

* implement repetition parser

* Make optional, one_or_more, and zero_or_more subclasses of repetition

* improve context constructor

* improve until parsing and add benchmarks

* remove cached() pattern, cache in parser_base with specialized parsing functions for each parser

* improve json parsing performance to better match legacy parsing

* fix const auto * it for windows

* move id assignment to classes instead of using a visitor

* create named rules in the command r7b example

* use '.' for any in GBNF

* fix parens around choices in gbnf grammar

* add convenience operators to turn strings to literals

* add free-form operators for const char * to simplify defining literals

* simplify test case parser

* implement semantic actions

* remove groups in favor of actions and a scratchpad

* add built in actions for common operations

* add actions to command r7b example

* use std::default_searcher for platforms that don't have bm

* improve parser_type handling and add cast helper

* add partial result type to better control when to run actions

* fix bug in until()

* run actions on partial results by default

* use common_chat_msg for result

* add qwen3 example wip

* trash partial idea and simplify

* move action arguments to a struct

* implement aho-corasick matcher for until_parser and to build exclusion grammars

* use std::string for input, since std::string_view is incompatible with std::regex

* Refactor tests

* improve qwen3 example

* implement sax-style parsing and refactor

* fix json string in test

* rename classes to use common_chat_ prefix

* remove is_ suffix from functions

* rename from id_counter to just counter

* Final refactored tests

* Fix executable name and editorconfig-checker

* Third time's the charm...

* add trigger parser to begin lazy grammar rule generation

* working lazy grammar

* refactor json rules now that we check for reachability

* reduce pointer usage

* print out grammars in example

* rename to chat-peg-parser* and common_chat_peg_parser*

* Revert unrelated changes

* New macros for CMakeLists to enable multi-file compilations

* starting unicode support

* add unicode support to char_parser

* use unparsed args as additional sources

* Refactor tests to new harness

* Fix CMakeLists

* fix rate calculation

* add unicode tests

* fix trailing whitespace and line endings

skip-checks: true

* Helpers + rewrite qwen3 with helpers

* Fix whitespace

* extract unicode functions to separate file

* refactor parse unicode function

* fix compiler error

* improve construction of sequence/choice parsers

* be less clever

* add make_parser helper function

* expand usage of make_parser, alias common_chat_msg_peg_parser_builder to builder in source

* lower bench iterations

* add unicode support to until_parser

* add unicode support to json_string_parser

* clean up unicode tests

* reduce unicode details to match src/unicode.cpp

* simplify even further

* remove unused functions

* fix type

* reformat char class parsing

* clean up json string parser

* clean up + fix diagnostics

* reorder includes

* compact builder functions

* replace action_parser with capture_parser, rename env to semantics

* rename env to semantics

* clean up common_chat_parse_context

* move type() to below constant

* use default constructor for common_chat_peg_parser

* make all operators functions for consistency

* fix compilation errors in test-optional.cpp

* simplify result values

* rename json_string_unquoted to json_string_content

* Move helper to separate class, add separate explicit and helper classes

* Whitespace

* Change + to append()

* Reformat

* Add extra helpers, tests and Minimax example

* Add some extra optional debugging prints + real example of how to use them

* fix bug in repetitions when min_count = 0 reports failures

* dump rule in debug

* fix token accumulation and assert parsing never fails

* indent debug by depth

* use LOG_* in tests so logs sync up with test logs

* - Add selective testing
- Refactor all messaging to use LOG_ERR
- Fix lack of argument / tool name capturing
- Temporary fix for double event capture

* refactor rule() and introduce ref()

* clean up visitor

* clean up indirection in root parser w.r.t rules

* store shared ptr directly in parser classes

* replace aho-corasick automation with a simple trie

* Reset prev for qwen3 helper example variant

* refactor to use value semantics with std::variant/std::visit

* simplify trie_matcher result

* fix linting issues

* add annotations to rules

* revert test workaround

* implement serializing the parser

* remove redundant parsers

* remove tests

* gbnf generation fixes

* remove LOG_* use in tests

* update gbnf tests to test entire grammar

* clean up gbnf generation and fix a few bugs

* fix typo in test output

* remove implicit conversion rules

* improve test output

* rename trie_matcher to trie

* simplify trie to just know if a node is the end of a word

* remove common_chat_ prefix and ensure a common_peg_ prefix to all types

* rename chat-peg-parser -> peg-parser

* promote chat-peg-parser-helper to chat-peg-parser

* checkpoint

* use a static_assert to ensure we handle every branch

* inline trivial peg parser builders

* use json strings for now

* implement basic and native chat peg parser builders/extractors

* resolve refs to their rules

* remove packrat caching (for now)

* update tests

* compare parsers with incremental input

* benchmark both complete and incremental parsing

* add raw string generation from json schema

* add support for string schemas in gbnf generation

* fix qwen example to include \n

* tidy up example

* rename extractor to mapper

* rename ast_arena to ast

* place basic tests into one

* use gbnf_format_literal from json-schema-to-grammar

* integrate parser with common/chat and server

* clean up schema and serialization

* add json-schema raw string tests

* clean up json creation and remove capture parser

* trim spaces from reasoning and content

* clean up redundant rules and comments

* rename input_is_complete to is_partial to match rest of project

* simplify json rules

* remove extraneous file

* remove comment

* implement += and |= operators

* add comments to qwen3 implementation

* reorder arguments to common_chat_peg_parse

* remove commented outdated tests

* add explicit copy constructor

* fix operators and constness

* wip: update test-chat for qwen3-coder

* bring json parser closer to json-schema-to-grammar rules

* trim trailing space for most things

* fix qwen3 coder rules w.r.t. trailing spaces

* group rules

* do not trim trailing space from string args

* tweak spacing of qwen3 grammar

* update qwen3-coder tests

* qwen3-coder small fixes

* place parser in common_chat_syntax to simplify invocation

* use std::set to collect rules to keep order predictable for tests

* initialize parser to make certain platforms happy

* revert back to std::unordered_set, sort rule names at the end instead

* uncomment rest of chat tests

* define explicit default constructor

* improve arena init and server integration

* fix chat test

* add json_member()

* add a comprehensive native example

* clean up example qwen test and add response_format example to native test

* make build_peg_parser accept std::function instead of template

* change peg parser parameters into const ref

* push tool call on tool open for constructed parser

* add parsing documentation

* clean up some comments

* add json schema support to qwen3-coder

* add id initializer in tests

* remove grammar debug line from qwen3-coder

* refactor qwen3-coder to use sequence over operators

* only call common_chat_peg_parse if appropriate format

* simplify qwen3-coder space handling

* revert qwen3-coder implementation

* revert json-schema-to-grammar changes

* remove unnecessary forward declaration

* small adjustment to until_parser

* rename C/C++ files to use dashes

* codeowners : add aldehir to peg-parser and related files

---------

Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com>
2025-12-03 12:45:32 +02:00
84 changed files with 7596 additions and 1758 deletions

View File

@@ -1602,33 +1602,33 @@ jobs:
run: |
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-x64-amd-vulkan:
runs-on: [self-hosted, Linux, X64, AMD]
# ggml-ci-x64-amd-vulkan:
# runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
# - name: Test
# id: ggml-ci
# run: |
# vulkaninfo --summary
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-x64-amd-rocm:
runs-on: [self-hosted, Linux, X64, AMD]
# ggml-ci-x64-amd-rocm:
# runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
amd-smi static
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
# - name: Test
# id: ggml-ci
# run: |
# amd-smi static
# GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-mac-metal:
runs-on: [self-hosted, macOS, ARM64]

View File

@@ -728,58 +728,6 @@ jobs:
path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
name: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
openEuler-cann:
strategy:
matrix:
arch: [x86, aarch64]
chip_type: ['910b', '310p']
build: ['Release']
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Dependencies
run: |
yum update -y
yum install -y git gcc gcc-c++ make cmake libcurl-devel
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Build
run: |
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
-DGGML_CANN=on \
-DSOC_TYPE=ascend${{ matrix.chip_type }}
cmake --build build -j $(nproc)
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
- name: Pack artifacts
run: |
cp LICENSE ./build/bin/
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz -C ./build/bin .
- name: Upload artifacts (zip)
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
- name: Upload artifacts (tar)
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz
release:
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
@@ -801,7 +749,6 @@ jobs:
- macOS-arm64
- macOS-x64
- ios-xcode-build
- openEuler-cann
steps:
- name: Clone
@@ -893,12 +840,6 @@ jobs:
- [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip)
- [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip)
**openEuler:**
- [openEuler x86 (310p)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-310p-openEuler-x86.tar.gz)
- [openEuler x86 (910b)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-910b-openEuler-x86.tar.gz)
- [openEuler aarch64 (310p)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-310p-openEuler-aarch64.tar.gz)
- [openEuler aarch64 (910b)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-910b-openEuler-aarch64.tar.gz)
- name: Upload release
id: upload_release
uses: actions/github-script@v3

View File

@@ -72,6 +72,12 @@ if (MSVC)
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
endif()
if (LLAMA_STANDALONE)
# enable parallel builds for msbuild
list(APPEND CMAKE_VS_GLOBALS UseMultiToolTask=true)
list(APPEND CMAKE_VS_GLOBALS EnforceProcessCountAcrossBuilds=true)
endif()
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
set(LLAMA_TOOLS_INSTALL_DEFAULT OFF)
else()
@@ -193,11 +199,6 @@ if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML)
# ... otherwise assume ggml is added by a parent CMakeLists.txt
endif()
if (MINGW)
# Target Windows 8 for PrefetchVirtualMemory
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
endif()
#
# build the library
#

View File

@@ -10,13 +10,16 @@
/common/arg.* @ggerganov
/common/base64.hpp.* @ggerganov
/common/build-info.* @ggerganov
/common/chat-peg-parser.* @aldehir
/common/common.* @ggerganov
/common/console.* @ggerganov
/common/http.* @angt
/common/llguidance.* @ggerganov
/common/log.* @ggerganov
/common/peg-parser.* @aldehir
/common/sampling.* @ggerganov
/common/speculative.* @ggerganov
/common/unicode.* @aldehir
/convert_*.py @CISC
/examples/batched.swift/ @ggerganov
/examples/batched/ @ggerganov

View File

@@ -39,26 +39,10 @@ if(Git_FOUND)
endif()
endif()
if(MSVC)
set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
if (CMAKE_VS_PLATFORM_NAME)
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
else()
set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}")
endif()
else()
execute_process(
COMMAND ${CMAKE_C_COMPILER} --version
OUTPUT_VARIABLE OUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
string(REGEX REPLACE " *\n.*" "" OUT "${OUT}")
set(BUILD_COMPILER ${OUT})
set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
execute_process(
COMMAND ${CMAKE_C_COMPILER} -dumpmachine
OUTPUT_VARIABLE OUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(BUILD_TARGET ${OUT})
if(CMAKE_VS_PLATFORM_NAME)
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
else()
set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}")
endif()

View File

@@ -52,6 +52,8 @@ add_library(${TARGET} STATIC
chat-parser.h
chat-parser-xml-toolcall.h
chat-parser-xml-toolcall.cpp
chat-peg-parser.cpp
chat-peg-parser.h
chat.cpp
chat.h
common.cpp
@@ -69,12 +71,16 @@ add_library(${TARGET} STATIC
log.h
ngram-cache.cpp
ngram-cache.h
peg-parser.cpp
peg-parser.h
regex-partial.cpp
regex-partial.h
sampling.cpp
sampling.h
speculative.cpp
speculative.h
unicode.cpp
unicode.h
)
if (BUILD_SHARED_LIBS)

View File

@@ -427,7 +427,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER) {
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage) {
throw std::invalid_argument("error: --model is required\n");
}

View File

@@ -1,6 +1,8 @@
#include "chat-parser.h"
#include "chat-peg-parser.h"
#include "common.h"
#include "log.h"
#include "peg-parser.h"
#include "regex-partial.h"
#include <algorithm>
@@ -1483,6 +1485,11 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
}
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE ||
syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE ||
syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
return common_chat_peg_parse(syntax.parser, input, is_partial, syntax);
}
common_chat_msg_parser builder(input, is_partial, syntax);
try {
common_chat_parse(builder);
@@ -1500,3 +1507,36 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
}
return msg;
}
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
if (parser.empty()) {
throw std::runtime_error("Failed to parse due to missing parser definition.");
}
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(syntax.format), input.c_str());
common_peg_parse_context ctx(input, is_partial);
auto result = parser.parse(ctx);
if (result.fail()) {
throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end));
}
common_chat_msg msg;
msg.role = "assistant";
if (syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE) {
auto mapper = common_chat_peg_native_mapper(msg);
mapper.from_ast(ctx.ast, result);
} else if (syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
auto mapper = common_chat_peg_constructed_mapper(msg);
mapper.from_ast(ctx.ast, result);
} else {
// Generic mapper
auto mapper = common_chat_peg_mapper(msg);
mapper.from_ast(ctx.ast, result);
}
if (!is_partial) {
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
}
return msg;
}

114
common/chat-peg-parser.cpp Normal file
View File

@@ -0,0 +1,114 @@
#include "chat-peg-parser.h"
#include <nlohmann/json.hpp>
using json = nlohmann::json;
static std::string_view trim_trailing_space(std::string_view sv) {
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
sv.remove_suffix(1);
}
return sv;
}
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
arena.visit(result, [this](const common_peg_ast_node & node) {
map(node);
});
}
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
bool is_reasoning = node.tag == common_chat_peg_builder::REASONING;
bool is_content = node.tag == common_chat_peg_builder::CONTENT;
if (is_reasoning) {
result.reasoning_content = std::string(trim_trailing_space(node.text));
}
if (is_content) {
result.content = std::string(trim_trailing_space(node.text));
}
}
void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) {
common_chat_peg_mapper::map(node);
bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN;
bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME;
bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID;
bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS;
if (is_tool_open) {
result.tool_calls.emplace_back();
current_tool = &result.tool_calls.back();
}
if (is_tool_id && current_tool) {
current_tool->id = std::string(trim_trailing_space(node.text));
}
if (is_tool_name && current_tool) {
current_tool->name = std::string(trim_trailing_space(node.text));
}
if (is_tool_args && current_tool) {
current_tool->arguments = std::string(trim_trailing_space(node.text));
}
}
void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
common_chat_peg_mapper::map(node);
bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN;
bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME;
bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE;
bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN;
bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE;
bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME;
bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE;
bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE;
if (is_tool_open) {
result.tool_calls.emplace_back();
current_tool = &result.tool_calls.back();
arg_count = 0;
}
if (is_tool_name) {
current_tool->name = std::string(node.text);
current_tool->arguments = "{";
}
if (is_arg_open) {
needs_closing_quote = false;
}
if (is_arg_name && current_tool) {
if (arg_count > 0) {
current_tool->arguments += ",";
}
current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":";
++arg_count;
}
if (is_arg_string && current_tool) {
// Serialize to JSON, but exclude the end quote
std::string dumped = json(node.text).dump();
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
needs_closing_quote = true;
}
if (is_arg_close && current_tool) {
if (needs_closing_quote) {
current_tool->arguments += "\"";
}
}
if (is_arg_json && current_tool) {
current_tool->arguments += std::string(trim_trailing_space(node.text));
}
if (is_tool_close && current_tool) {
current_tool->arguments += "}";
}
}

105
common/chat-peg-parser.h Normal file
View File

@@ -0,0 +1,105 @@
#pragma once
#include "chat.h"
#include "peg-parser.h"
class common_chat_peg_builder : public common_peg_parser_builder {
public:
static constexpr const char * REASONING_BLOCK = "reasoning-block";
static constexpr const char * REASONING = "reasoning";
static constexpr const char * CONTENT = "content";
common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); }
common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); }
common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); }
};
inline common_peg_arena build_chat_peg_parser(const std::function<common_peg_parser(common_chat_peg_builder & builder)> & fn) {
common_chat_peg_builder builder;
builder.set_root(fn(builder));
return builder.build();
}
class common_chat_peg_mapper {
public:
common_chat_msg & result;
common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {}
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
virtual void map(const common_peg_ast_node & node);
};
class common_chat_peg_native_builder : public common_chat_peg_builder {
public:
static constexpr const char * TOOL = "tool";
static constexpr const char * TOOL_OPEN = "tool-open";
static constexpr const char * TOOL_CLOSE = "tool-close";
static constexpr const char * TOOL_ID = "tool-id";
static constexpr const char * TOOL_NAME = "tool-name";
static constexpr const char * TOOL_ARGS = "tool-args";
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); }
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); }
};
class common_chat_peg_native_mapper : public common_chat_peg_mapper {
common_chat_tool_call * current_tool;
public:
common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
void map(const common_peg_ast_node & node) override;
};
inline common_peg_arena build_chat_peg_native_parser(const std::function<common_peg_parser(common_chat_peg_native_builder & builder)> & fn) {
common_chat_peg_native_builder builder;
builder.set_root(fn(builder));
return builder.build();
}
class common_chat_peg_constructed_builder : public common_chat_peg_builder {
public:
static constexpr const char * TOOL = "tool";
static constexpr const char * TOOL_OPEN = "tool-open";
static constexpr const char * TOOL_CLOSE = "tool-close";
static constexpr const char * TOOL_NAME = "tool-name";
static constexpr const char * TOOL_ARG = "tool-arg";
static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open";
static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close";
static constexpr const char * TOOL_ARG_NAME = "tool-arg-name";
static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value";
static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value";
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); }
common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); }
common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); }
common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); }
common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); }
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); }
};
class common_chat_peg_constructed_mapper : public common_chat_peg_mapper {
common_chat_tool_call * current_tool;
int arg_count = 0;
bool needs_closing_quote = false;
public:
common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
void map(const common_peg_ast_node & node) override;
};
inline common_peg_arena build_chat_peg_constructed_parser(const std::function<common_peg_parser(common_chat_peg_constructed_builder & builder)> & fn) {
common_chat_peg_constructed_builder builder;
builder.set_root(fn(builder));
return builder.build();
}

View File

@@ -85,29 +85,36 @@ json common_chat_msg::to_json_oaicompat() const
return message;
}
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) {
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) {
std::vector<common_chat_msg_diff> diffs;
if (previous_msg.reasoning_content != new_msg.reasoning_content) {
auto & diff = diffs.emplace_back();
diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content);
}
if (previous_msg.content != new_msg.content) {
auto & diff = diffs.emplace_back();
diff.content_delta = string_diff(previous_msg.content, new_msg.content);
if (msg_new.tool_calls.size() > msg_prv.tool_calls.size()) {
diffs.reserve(msg_new.tool_calls.size() - msg_prv.tool_calls.size() + 3);
} else {
diffs.reserve(3);
}
if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) {
// TODO: these can become expensive for long messages - how to optimize?
if (msg_prv.reasoning_content != msg_new.reasoning_content) {
auto & diff = diffs.emplace_back();
diff.reasoning_content_delta = string_diff(msg_prv.reasoning_content, msg_new.reasoning_content);
}
if (msg_prv.content != msg_new.content) {
auto & diff = diffs.emplace_back();
diff.content_delta = string_diff(msg_prv.content, msg_new.content);
}
if (msg_new.tool_calls.size() < msg_prv.tool_calls.size()) {
throw std::runtime_error("Invalid diff: now finding less tool calls!");
}
if (!previous_msg.tool_calls.empty()) {
auto idx = previous_msg.tool_calls.size() - 1;
const auto & pref = previous_msg.tool_calls[idx];
const auto & newf = new_msg.tool_calls[idx];
if (!msg_prv.tool_calls.empty()) {
const auto idx = msg_prv.tool_calls.size() - 1;
const auto & pref = msg_prv.tool_calls[idx];
const auto & newf = msg_new.tool_calls[idx];
if (pref.name != newf.name) {
throw std::runtime_error("Invalid diff: tool call mismatch!");
}
auto args_diff = string_diff(pref.arguments, newf.arguments);
const auto args_diff = string_diff(pref.arguments, newf.arguments);
if (!args_diff.empty() || pref.id != newf.id) {
auto & diff = diffs.emplace_back();
diff.tool_call_index = idx;
@@ -118,11 +125,12 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
diff.tool_call_delta.arguments = args_diff;
}
}
for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) {
for (size_t idx = msg_prv.tool_calls.size(); idx < msg_new.tool_calls.size(); ++idx) {
auto & diff = diffs.emplace_back();
diff.tool_call_index = idx;
diff.tool_call_delta = new_msg.tool_calls[idx];
diff.tool_call_delta = msg_new.tool_calls[idx];
}
return diffs;
}
@@ -649,6 +657,9 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple";
case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native";
case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed";
default:
throw std::runtime_error("Unknown chat format");
}

View File

@@ -3,6 +3,7 @@
#pragma once
#include "common.h"
#include "peg-parser.h"
#include <functional>
#include <chrono>
#include <string>
@@ -76,7 +77,7 @@ struct common_chat_msg_diff {
size_t tool_call_index = std::string::npos;
common_chat_tool_call tool_call_delta;
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new);
bool operator==(const common_chat_msg_diff & other) const {
return content_delta == other.content_delta
@@ -124,6 +125,11 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_APRIEL_1_5,
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
// These are intended to be parsed by the PEG parser
COMMON_CHAT_FORMAT_PEG_SIMPLE,
COMMON_CHAT_FORMAT_PEG_NATIVE,
COMMON_CHAT_FORMAT_PEG_CONSTRUCTED,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
@@ -154,6 +160,7 @@ struct common_chat_params {
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
std::string parser;
};
struct common_chat_syntax {
@@ -163,6 +170,7 @@ struct common_chat_syntax {
bool reasoning_in_content = false;
bool thinking_forced_open = false;
bool parse_tool_calls = true;
common_peg_arena parser = {};
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
@@ -206,6 +214,7 @@ const char* common_chat_format_name(common_chat_format format);
const char* common_reasoning_format_name(common_reasoning_format format);
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);

View File

@@ -786,11 +786,29 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
#include <iostream>
#ifdef _WIN32
static std::wstring utf8_to_wstring(const std::string & str) {
if (str.empty()) {
return std::wstring();
}
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
if (size <= 0) {
return std::wstring();
}
std::wstring wstr(size, 0);
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
return wstr;
}
#endif
// returns true if successful, false otherwise
bool fs_create_directory_with_parents(const std::string & path) {
#ifdef _WIN32
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
std::wstring wpath = converter.from_bytes(path);
std::wstring wpath = utf8_to_wstring(path);
// if the path already exists, check whether it's a directory
const DWORD attributes = GetFileAttributesW(wpath.c_str());

View File

@@ -12,6 +12,10 @@
#include <vector>
#include <map>
#if defined(_WIN32) && !defined(_WIN32_WINNT)
#define _WIN32_WINNT 0x0A00
#endif
#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
#else

1712
common/peg-parser.cpp Normal file

File diff suppressed because it is too large Load Diff

459
common/peg-parser.h Normal file
View File

@@ -0,0 +1,459 @@
#pragma once
#include <nlohmann/json_fwd.hpp>
#include <memory>
#include <unordered_map>
#include <string>
#include <string_view>
#include <functional>
#include <vector>
#include <variant>
struct common_grammar_builder;
class common_peg_parser_builder;
using common_peg_parser_id = size_t;
constexpr common_peg_parser_id COMMON_PEG_INVALID_PARSER_ID = static_cast<common_peg_parser_id>(-1);
using common_peg_ast_id = size_t;
constexpr common_peg_ast_id COMMON_PEG_INVALID_AST_ID = static_cast<common_peg_ast_id>(-1);
// Lightweight wrapper around common_peg_parser_id for convenience
class common_peg_parser {
common_peg_parser_id id_;
common_peg_parser_builder & builder_;
public:
common_peg_parser(const common_peg_parser & other) : id_(other.id_), builder_(other.builder_) {}
common_peg_parser(common_peg_parser_id id, common_peg_parser_builder & builder) : id_(id), builder_(builder) {}
common_peg_parser & operator=(const common_peg_parser & other);
common_peg_parser & operator+=(const common_peg_parser & other);
common_peg_parser & operator|=(const common_peg_parser & other);
operator common_peg_parser_id() const { return id_; }
common_peg_parser_id id() const { return id_; }
common_peg_parser_builder & builder() const { return builder_; }
// Creates a sequence
common_peg_parser operator+(const common_peg_parser & other) const;
// Creates a sequence separated by spaces.
common_peg_parser operator<<(const common_peg_parser & other) const;
// Creates a choice
common_peg_parser operator|(const common_peg_parser & other) const;
common_peg_parser operator+(const char * str) const;
common_peg_parser operator+(const std::string & str) const;
common_peg_parser operator<<(const char * str) const;
common_peg_parser operator<<(const std::string & str) const;
common_peg_parser operator|(const char * str) const;
common_peg_parser operator|(const std::string & str) const;
};
common_peg_parser operator+(const char * str, const common_peg_parser & p);
common_peg_parser operator+(const std::string & str, const common_peg_parser & p);
common_peg_parser operator<<(const char * str, const common_peg_parser & p);
common_peg_parser operator<<(const std::string & str, const common_peg_parser & p);
common_peg_parser operator|(const char * str, const common_peg_parser & p);
common_peg_parser operator|(const std::string & str, const common_peg_parser & p);
enum common_peg_parse_result_type {
COMMON_PEG_PARSE_RESULT_FAIL = 0,
COMMON_PEG_PARSE_RESULT_SUCCESS = 1,
COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT = 2,
};
const char * common_peg_parse_result_type_name(common_peg_parse_result_type type);
struct common_peg_ast_node {
common_peg_ast_id id;
std::string rule;
std::string tag;
size_t start;
size_t end;
std::string_view text;
std::vector<common_peg_ast_id> children;
bool is_partial = false;
};
struct common_peg_parse_result;
using common_peg_ast_visitor = std::function<void(const common_peg_ast_node & node)>;
class common_peg_ast_arena {
std::vector<common_peg_ast_node> nodes_;
public:
common_peg_ast_id add_node(
const std::string & rule,
const std::string & tag,
size_t start,
size_t end,
std::string_view text,
std::vector<common_peg_ast_id> children,
bool is_partial = false
) {
common_peg_ast_id id = nodes_.size();
nodes_.push_back({id, rule, tag, start, end, text, std::move(children), is_partial});
return id;
}
const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); }
size_t size() const { return nodes_.size(); }
void clear() { nodes_.clear(); }
void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const;
void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const;
};
struct common_peg_parse_result {
common_peg_parse_result_type type = COMMON_PEG_PARSE_RESULT_FAIL;
size_t start = 0;
size_t end = 0;
std::vector<common_peg_ast_id> nodes;
common_peg_parse_result() = default;
common_peg_parse_result(common_peg_parse_result_type type, size_t start)
: type(type), start(start), end(start) {}
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end)
: type(type), start(start), end(end) {}
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end, std::vector<common_peg_ast_id> nodes)
: type(type), start(start), end(end), nodes(std::move(nodes)) {}
bool fail() const { return type == COMMON_PEG_PARSE_RESULT_FAIL; }
bool need_more_input() const { return type == COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT; }
bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; }
};
struct common_peg_parse_context {
std::string input;
bool is_partial;
common_peg_ast_arena ast;
int parse_depth;
common_peg_parse_context()
: is_partial(false), parse_depth(0) {}
common_peg_parse_context(const std::string & input)
: input(input), is_partial(false), parse_depth(0) {}
common_peg_parse_context(const std::string & input, bool is_partial)
: input(input), is_partial(is_partial), parse_depth(0) {}
};
class common_peg_arena;
// Parser variants
struct common_peg_epsilon_parser {};
struct common_peg_start_parser {};
struct common_peg_end_parser {};
struct common_peg_literal_parser {
std::string literal;
};
struct common_peg_sequence_parser {
std::vector<common_peg_parser_id> children;
};
struct common_peg_choice_parser {
std::vector<common_peg_parser_id> children;
};
struct common_peg_repetition_parser {
common_peg_parser_id child;
int min_count;
int max_count; // -1 for unbounded
};
struct common_peg_and_parser {
common_peg_parser_id child;
};
struct common_peg_not_parser {
common_peg_parser_id child;
};
struct common_peg_any_parser {};
struct common_peg_space_parser {};
struct common_peg_chars_parser {
struct char_range {
uint32_t start;
uint32_t end;
bool contains(uint32_t codepoint) const { return codepoint >= start && codepoint <= end; }
};
std::string pattern;
std::vector<char_range> ranges;
bool negated;
int min_count;
int max_count; // -1 for unbounded
};
struct common_peg_json_string_parser {};
struct common_peg_until_parser {
std::vector<std::string> delimiters;
};
struct common_peg_schema_parser {
common_peg_parser_id child;
std::string name;
std::shared_ptr<nlohmann::ordered_json> schema;
// Indicates if the GBNF should accept a raw string that matches the schema.
bool raw;
};
struct common_peg_rule_parser {
std::string name;
common_peg_parser_id child;
bool trigger;
};
struct common_peg_ref_parser {
std::string name;
};
struct common_peg_atomic_parser {
common_peg_parser_id child;
};
struct common_peg_tag_parser {
common_peg_parser_id child;
std::string tag;
};
// Variant holding all parser types
using common_peg_parser_variant = std::variant<
common_peg_epsilon_parser,
common_peg_start_parser,
common_peg_end_parser,
common_peg_literal_parser,
common_peg_sequence_parser,
common_peg_choice_parser,
common_peg_repetition_parser,
common_peg_and_parser,
common_peg_not_parser,
common_peg_any_parser,
common_peg_space_parser,
common_peg_chars_parser,
common_peg_json_string_parser,
common_peg_until_parser,
common_peg_schema_parser,
common_peg_rule_parser,
common_peg_ref_parser,
common_peg_atomic_parser,
common_peg_tag_parser
>;
class common_peg_arena {
std::vector<common_peg_parser_variant> parsers_;
std::unordered_map<std::string, common_peg_parser_id> rules_;
common_peg_parser_id root_ = COMMON_PEG_INVALID_PARSER_ID;
public:
const common_peg_parser_variant & get(common_peg_parser_id id) const { return parsers_.at(id); }
common_peg_parser_variant & get(common_peg_parser_id id) { return parsers_.at(id); }
size_t size() const { return parsers_.size(); }
bool empty() const { return parsers_.empty(); }
common_peg_parser_id get_rule(const std::string & name) const;
bool has_rule(const std::string & name) const { return rules_.find(name) != rules_.end(); }
common_peg_parser_id root() const { return root_; }
void set_root(common_peg_parser_id id) { root_ = id; }
common_peg_parse_result parse(common_peg_parse_context & ctx, size_t start = 0) const;
common_peg_parse_result parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const;
void resolve_refs();
void build_grammar(const common_grammar_builder & builder, bool lazy = false) const;
std::string dump(common_peg_parser_id id) const;
nlohmann::json to_json() const;
static common_peg_arena from_json(const nlohmann::json & j);
std::string save() const;
void load(const std::string & data);
friend class common_peg_parser_builder;
private:
common_peg_parser_id add_parser(common_peg_parser_variant parser);
void add_rule(const std::string & name, common_peg_parser_id id);
common_peg_parser_id resolve_ref(common_peg_parser_id id);
};
class common_peg_parser_builder {
common_peg_arena arena_;
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
public:
common_peg_parser_builder();
// Match nothing, always succeed.
// S -> ε
common_peg_parser eps() { return add(common_peg_epsilon_parser{}); }
// Matches the start of the input.
// S -> ^
common_peg_parser start() { return add(common_peg_start_parser{}); }
// Matches the end of the input.
// S -> $
common_peg_parser end() { return add(common_peg_end_parser{}); }
// Matches an exact literal string.
// S -> "hello"
common_peg_parser literal(const std::string & literal) { return add(common_peg_literal_parser{literal}); }
// Matches a sequence of parsers in order, all must succeed.
// S -> A B C
common_peg_parser sequence() { return add(common_peg_sequence_parser{}); }
common_peg_parser sequence(const std::vector<common_peg_parser_id> & parsers);
common_peg_parser sequence(const std::vector<common_peg_parser> & parsers);
common_peg_parser sequence(std::initializer_list<common_peg_parser> parsers);
// Matches the first parser that succeeds from a list of alternatives.
// S -> A | B | C
common_peg_parser choice() { return add(common_peg_choice_parser{}); }
common_peg_parser choice(const std::vector<common_peg_parser_id> & parsers);
common_peg_parser choice(const std::vector<common_peg_parser> & parsers);
common_peg_parser choice(std::initializer_list<common_peg_parser> parsers);
// Matches one or more repetitions of a parser.
// S -> A+
common_peg_parser one_or_more(const common_peg_parser & p) { return repeat(p, 1, -1); }
// Matches zero or more repetitions of a parser, always succeeds.
// S -> A*
common_peg_parser zero_or_more(const common_peg_parser & p) { return repeat(p, 0, -1); }
// Matches zero or one occurrence of a parser, always succeeds.
// S -> A?
common_peg_parser optional(const common_peg_parser & p) { return repeat(p, 0, 1); }
// Positive lookahead: succeeds if child parser succeeds, consumes no input.
// S -> &A
common_peg_parser peek(const common_peg_parser & p) { return add(common_peg_and_parser{p}); }
// Negative lookahead: succeeds if child parser fails, consumes no input.
// S -> !A
common_peg_parser negate(const common_peg_parser & p) { return add(common_peg_not_parser{p}); }
// Matches any single character.
// S -> .
common_peg_parser any() { return add(common_peg_any_parser{}); }
// Matches between min and max repetitions of characters from a character class.
// S -> [a-z]{m,n}
//
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
common_peg_parser chars(const std::string & classes, int min = 1, int max = -1);
// Creates a lightweight reference to a named rule (resolved during build()).
// Use this for forward references in recursive grammars.
// expr_ref -> expr
common_peg_parser ref(const std::string & name) { return add(common_peg_ref_parser{name}); }
// Matches zero or more whitespace characters (space, tab, newline).
// S -> [ \t\n]*
common_peg_parser space() { return add(common_peg_space_parser{}); }
// Matches all characters until a delimiter is found (delimiter not consumed).
// S -> (!delim .)*
common_peg_parser until(const std::string & delimiter) { return add(common_peg_until_parser{{delimiter}}); }
// Matches all characters until one of the delimiters in the list is found (delimiter not consumed).
// S -> (!delim .)*
common_peg_parser until_one_of(const std::vector<std::string> & delimiters) { return add(common_peg_until_parser{delimiters}); }
// Matches everything
// S -> .*
common_peg_parser rest() { return until_one_of({}); }
// Matches between min and max repetitions of a parser (inclusive).
// S -> A{m,n}
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
common_peg_parser repeat(const common_peg_parser & p, int min, int max) { return add(common_peg_repetition_parser{p, min,max}); }
// Matches exactly n repetitions of a parser.
// S -> A{n}
common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); }
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
// value -> object | array | string | number | true | false | null
common_peg_parser json();
common_peg_parser json_object();
common_peg_parser json_string();
common_peg_parser json_array();
common_peg_parser json_number();
common_peg_parser json_bool();
common_peg_parser json_null();
// Matches JSON string content without the surrounding quotes.
// Useful for extracting content within a JSON string.
common_peg_parser json_string_content();
// Matches a JSON object member with a key and associated parser as the
// value.
common_peg_parser json_member(const std::string & key, const common_peg_parser & p);
// Wraps a parser with JSON schema metadata for grammar generation.
// Used internally to convert JSON schemas to GBNF grammar rules.
common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false);
// Creates a named rule, stores it in the grammar, and returns a ref.
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
// auto json = p.rule("json", json_obj | json_arr | ...)
common_peg_parser rule(const std::string & name, const common_peg_parser & p, bool trigger = false);
// Creates a named rule using a builder function, and returns a ref.
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
// auto json = p.rule("json", [&]() { return json_object() | json_array() | ... })
common_peg_parser rule(const std::string & name, const std::function<common_peg_parser()> & builder, bool trigger = false);
// Creates a trigger rule. When generating a lazy grammar from the parser,
// only trigger rules and descendents are emitted.
common_peg_parser trigger_rule(const std::string & name, const common_peg_parser & p) { return rule(name, p, true); }
common_peg_parser trigger_rule(const std::string & name, const std::function<common_peg_parser()> & builder) { return rule(name, builder, true); }
// Creates an atomic parser. Atomic parsers do not create an AST node if
// the child results in a partial parse, i.e. NEEDS_MORE_INPUT. This is
// intended for situations where partial output is undesirable.
common_peg_parser atomic(const common_peg_parser & p) { return add(common_peg_atomic_parser{p}); }
// Tags create nodes in the generated AST for semantic purposes.
// Unlike rules, you can tag multiple nodes with the same tag.
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
void set_root(const common_peg_parser & p);
common_peg_arena build();
};
// Helper function for building parsers
common_peg_arena build_peg_parser(const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn);

64
common/unicode.cpp Normal file
View File

@@ -0,0 +1,64 @@
#include "unicode.h"
// implementation adopted from src/unicode.cpp
size_t utf8_sequence_length(unsigned char first_byte) {
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t highbits = static_cast<uint8_t>(first_byte) >> 4;
return lookup[highbits];
}
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) {
if (offset >= input.size()) {
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
}
// ASCII fast path
if (!(input[offset] & 0x80)) {
return utf8_parse_result(utf8_parse_result::SUCCESS, input[offset], 1);
}
// Invalid: continuation byte as first byte
if (!(input[offset] & 0x40)) {
return utf8_parse_result(utf8_parse_result::INVALID);
}
// 2-byte sequence
if (!(input[offset] & 0x20)) {
if (offset + 1 >= input.size()) {
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
}
if ((input[offset + 1] & 0xc0) != 0x80) {
return utf8_parse_result(utf8_parse_result::INVALID);
}
auto result = ((input[offset] & 0x1f) << 6) | (input[offset + 1] & 0x3f);
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 2);
}
// 3-byte sequence
if (!(input[offset] & 0x10)) {
if (offset + 2 >= input.size()) {
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
}
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80) {
return utf8_parse_result(utf8_parse_result::INVALID);
}
auto result = ((input[offset] & 0x0f) << 12) | ((input[offset + 1] & 0x3f) << 6) | (input[offset + 2] & 0x3f);
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 3);
}
// 4-byte sequence
if (!(input[offset] & 0x08)) {
if (offset + 3 >= input.size()) {
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
}
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80 || (input[offset + 3] & 0xc0) != 0x80) {
return utf8_parse_result(utf8_parse_result::INVALID);
}
auto result = ((input[offset] & 0x07) << 18) | ((input[offset + 1] & 0x3f) << 12) | ((input[offset + 2] & 0x3f) << 6) | (input[offset + 3] & 0x3f);
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 4);
}
// Invalid first byte
return utf8_parse_result(utf8_parse_result::INVALID);
}

22
common/unicode.h Normal file
View File

@@ -0,0 +1,22 @@
#pragma once
#include <cstdint>
#include <string_view>
// UTF-8 parsing utilities for streaming-aware unicode support
struct utf8_parse_result {
uint32_t codepoint; // Decoded codepoint (only valid if status == SUCCESS)
size_t bytes_consumed; // How many bytes this codepoint uses (1-4)
enum status { SUCCESS, INCOMPLETE, INVALID } status;
utf8_parse_result(enum status s, uint32_t cp = 0, size_t bytes = 0)
: codepoint(cp), bytes_consumed(bytes), status(s) {}
};
// Determine the expected length of a UTF-8 sequence from its first byte
// Returns 0 for invalid first bytes
size_t utf8_sequence_length(unsigned char first_byte);
// Parse a single UTF-8 codepoint from input
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset);

View File

@@ -2341,9 +2341,18 @@ class LlamaModel(TextModel):
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(False)
template_dir = Path(__file__).parent / "models/templates/"
local_template_file_path = self.dir_model / "chat_template.jinja"
if self.is_mistral_format and local_template_file_path.is_file():
# Ministral-3 and other new Mistral models come with chat templates.
# ref: https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512/tree/main
logger.info("Using an existing Mistral local chat template.")
with open(local_template_file_path, "r", encoding="utf-8") as f:
template = f.read()
elif not self.is_mistral_format or not self.disable_mistral_community_chat_template:
template_dir = Path(__file__).parent / "models/templates/"
if not self.is_mistral_format or not self.disable_mistral_community_chat_template:
# Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
if self.is_mistral_format:
logger.info(
@@ -2351,9 +2360,12 @@ class LlamaModel(TextModel):
"Mistral recommends to use `mistral-common` to perform tokenization and detokenization."
)
template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format)
self.gguf_writer.add_chat_template(template)
else:
logger.info("Not using a Mistral community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
logger.info("Not using a Mistral local or community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
template = None
if template is not None:
self.gguf_writer.add_chat_template(template)
def set_vocab(self):
if self.is_mistral_format:

288
docs/development/parsing.md Normal file
View File

@@ -0,0 +1,288 @@
# Parsing Model Output
The `common` library contains a PEG parser implementation suitable for parsing
model output.
Types with the prefix `common_peg_*` are intended for general use and may have
applications beyond parsing model output, such as parsing user-provided regex
patterns.
Types with the prefix `common_chat_peg_*` are specialized helpers for model
output.
The parser features:
- Partial parsing of streaming input
- Built-in JSON parsers
- AST generation with semantics via "tagged" nodes
## Example
Below is a contrived example demonstrating how to use the PEG parser to parse
output from a model that emits arguments as JSON.
```cpp
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
// Build a choice of all available tools
auto tool_choice = p.choice();
for (const auto & tool : tools) {
const auto & function = tool.at("function");
std::string name = function.at("name");
const auto & schema = function.at("parameters");
auto tool_name = p.json_member("name", "\"" + p.literal(name) + "\"");
auto tool_args = p.json_member("arguments", p.schema(p.json(), "tool-" + name + "-schema", schema));
tool_choice |= p.rule("tool-" + name, "{" << tool_name << "," << tool_args << "}");
}
// Define the tool call structure: <tool_call>[{tool}]</tool_call>
auto tool_call = p.trigger_rule("tool-call",
p.sequence({
p.literal("<tool_call>["),
tool_choice,
p.literal("]</tool_call>")
})
);
// Parser accepts content, optionally followed by a tool call
return p.sequence({
p.content(p.until("<tool_call>")),
p.optional(tool_call),
p.end()
});
});
```
For a more complete example, see `test_example_native()` in
[tests/test-chat-peg-parser.cpp](tests/test-chat-peg-parser.cpp).
## Parsers/Combinators
### Basic Matchers
- **`eps()`** - Matches nothing and always succeeds (epsilon/empty match)
- **`start()`** - Matches the start of input (anchor `^`)
- **`end()`** - Matches the end of input (anchor `$`)
- **`literal(string)`** - Matches an exact literal string
- **`any()`** - Matches any single character (`.`)
### Combinators
- **`sequence(...)`** - Matches parsers in order; all must succeed
- **`choice(...)`** - Matches the first parser that succeeds from alternatives (ordered choice)
- **`one_or_more(p)`** - Matches one or more repetitions (`+`)
- **`zero_or_more(p)`** - Matches zero or more repetitions (`*`)
- **`optional(p)`** - Matches zero or one occurrence (`?`)
- **`repeat(p, min, max)`** - Matches between min and max repetitions (use `-1` for unbounded)
- **`repeat(p, n)`** - Matches exactly n repetitions
### Lookahead
- **`peek(p)`** - Positive lookahead: succeeds if parser succeeds without consuming input (`&`)
- **`negate(p)`** - Negative lookahead: succeeds if parser fails without consuming input (`!`)
### Character Classes & Utilities
- **`chars(classes, min, max)`** - Matches repetitions of characters from a character class
- **`space()`** - Matches zero or more whitespace characters (space, tab, newline)
- **`until(delimiter)`** - Matches characters until delimiter is found (delimiter not consumed)
- **`until_one_of(delimiters)`** - Matches characters until any delimiter in the list is found
- **`rest()`** - Matches everything remaining (`.*`)
### JSON Parsers
- **`json()`** - Complete JSON parser (objects, arrays, strings, numbers, booleans, null)
- **`json_object()`** - JSON object parser
- **`json_array()`** - JSON array parser
- **`json_string()`** - JSON string parser
- **`json_number()`** - JSON number parser
- **`json_bool()`** - JSON boolean parser
- **`json_null()`** - JSON null parser
- **`json_string_content()`** - JSON string content without surrounding quotes
- **`json_member(key, p)`** - JSON object member with specific key and value parser
### Grammar Building
- **`ref(name)`** - Creates a lightweight reference to a named rule (for recursive grammars)
- **`rule(name, p, trigger)`** - Creates a named rule and returns a reference
- **`trigger_rule(name, p)`** - Creates a trigger rule (entry point for lazy grammar generation)
- **`schema(p, name, schema, raw)`** - Wraps parser with JSON schema metadata for grammar generation
### AST Control
- **`atomic(p)`** - Prevents AST node creation for partial parses
- **`tag(tag, p)`** - Creates AST nodes with semantic tags (multiple nodes can share tags)
## GBNF Grammar Generation
The PEG parser also acts as a convenient DSL for generating GBNF grammars, with
some exceptions.
```cpp
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(params.tools, [&](const json & fn) {
builder.resolve_refs(fn.at("parameters"));
});
parser.build_grammar(builder, data.grammar_lazy);
});
```
The notable exception is the `negate(p)` lookahead parser, which cannot be
defined as a CFG grammar and therefore does not produce a rule. Its usage
should be limited and preferably hidden behind a `schema()` parser. In many
cases, `until(delimiter)` or `until_one_of(delimiters)` is a better choice.
Another limitation is that the PEG parser requires an unambiguous grammar. In
contrast, the `llama-grammar` implementation can support ambiguous grammars,
though they are difficult to parse.
### Lazy Grammars
During lazy grammar generation, only rules reachable from a `trigger_rule(p)`
are emitted in the grammar. All trigger rules are added as alternations in the
root rule. It is still necessary to define trigger patterns, as the parser has
no interaction with the grammar sampling.
### JSON Schema
The `schema(p, name, schema, raw)` parser will use the `json-schema-to-grammar`
implementation to generate the grammar instead of the underlying parser.
The `raw` option emits a grammar suitable for a raw string instead of a JSON
string. In other words, it won't be wrapped in quotes or require escaping
quotes. It should only be used when `type == "string"`.
The downside is that it can potentially lead to ambiguous grammars. For
example, if a user provides the pattern `^.*$`, the following grammar may be
generated:
```
root ::= "<arg>" .* "</arg>"
```
This creates an ambiguous grammar that cannot be parsed by the PEG parser. To
help mitigate this, if `.*` is found in the pattern, the grammar from the
underlying parser will be emitted instead.
## Common AST Shapes for Chat Parsing
Most model output can be placed in one of the following categories:
- Content only
- Tool calling with arguments emitted as a single JSON object
- Tool calling with arguments emitted as separate entities, either XML
(Qwen3-Coder, MiniMax M2) or pseudo-function calls (LFM2)
To provide broad coverage,
[`common/chat-peg-parser.h`](common/chat-peg-parser.h) contains builders and
mappers that help create parsers and visitors/extractors for these types. They
require parsers to tag nodes to conform to an AST "shape". This normalization
makes it easy to extract information and generalize parsing.
### Simple
The `common_chat_peg_builder` builds a `simple` parser that supports
content-only models with optional reasoning.
- **`reasoning(p)`** - Tag node for extracting `reasoning_content`
- **`content(p)`** - Tag node for extracting `content`
```cpp
build_chat_peg_parser([&](common_chat_peg_parser & p) {
return p.sequence({
p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>"),
p.content(p.until("<tool_call>")),
p.end()
});
});
```
Use `common_chat_peg_mapper` to extract the content. Note that this is already
done for you in `common_chat_peg_parser` when
`chat_format == COMMON_CHAT_FORMAT_PEG_SIMPLE`.
```cpp
auto result = parser.parse(ctx);
common_chat_msg msg;
auto mapper = common_chat_peg_mapper(msg);
mapper.from_ast(ctx.ast, result);
```
### Native
The `common_chat_peg_native_builder` builds a `native` parser suitable for
models that emit tool arguments as a direct JSON object.
- **`reasoning(p)`** - Tag node for `reasoning_content`
- **`content(p)`** - Tag node for `content`
- **`tool(p)`** - Tag entirety of a single tool call
- **`tool_open(p)`** - Tag start of a tool call
- **`tool_close(p)`** - Tag end of a tool call
- **`tool_id(p)`** - Tag the tool call ID (optional)
- **`tool_name(p)`** - Tag the tool name
- **`tool_args(p)`** - Tag the tool arguments
```cpp
build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) {
auto get_weather_tool = p.tool(p.sequence({
p.tool_open(p.literal("{")),
p.json_member("name", "\"" + p.tool_name(p.literal("get_weather")) + "\""),
p.literal(","),
p.json_member("arguments", p.tool_args(p.json())),
p.tool_close(p.literal("}"))
}));
return p.sequence({
p.content(p.until("<tool_call>")),
p.literal("<tool_call>"),
get_weather_tool,
p.literal("</tool_call>"),
p.end()
});
});
```
### Constructed
The `common_chat_peg_constructed_builder` builds a `constructed` parser
suitable for models that emit tool arguments as separate entities, such as XML
tags.
- **`reasoning(p)`** - Tag node for `reasoning_content`
- **`content(p)`** - Tag node for `content`
- **`tool(p)`** - Tag entirety of a single tool call
- **`tool_open(p)`** - Tag start of a tool call
- **`tool_close(p)`** - Tag end of a tool call
- **`tool_name(p)`** - Tag the tool name
- **`tool_arg(p)`** - Tag a complete tool argument (name + value)
- **`tool_arg_open(p)`** - Tag start of a tool argument
- **`tool_arg_close(p)`** - Tag end of a tool argument
- **`tool_arg_name(p)`** - Tag the argument name
- **`tool_arg_string_value(p)`** - Tag string value for the argument
- **`tool_arg_json_value(p)`** - Tag JSON value for the argument
```cpp
build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) {
auto location_arg = p.tool_arg(
p.tool_arg_open("<parameter name=\"" + p.tool_arg_name(p.literal("location")) + "\">"),
p.tool_arg_string_value(p.until("</parameter>")),
p.tool_arg_close(p.literal("</parameter>"))
);
auto get_weather_tool = p.tool(p.sequence({
p.tool_open("<function name=\"" + p.tool_name(p.literal("get_weather")) + "\">"),
location_arg,
p.tool_close(p.literal("</function>"))
}));
return p.sequence({
p.content(p.until("<tool_call>")),
p.literal("<tool_call>"),
get_weather_tool,
p.literal("</tool_call>"),
p.end()
});
});
```

View File

@@ -18,6 +18,7 @@ cd llama.cpp
cmake -S . -B build
cmake --build build
cmake --install build --prefix inst
```
### Build simple-cmake-pkg

View File

@@ -175,11 +175,6 @@ option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requi
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC")
if (MINGW)
set(GGML_WIN_VER "0xA00" CACHE STRING "ggml: Windows version")
endif()
# ggml core
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
option(GGML_CPU "ggml: enable CPU backend" ON)

View File

@@ -204,6 +204,10 @@
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#if defined(_WIN32) && !defined(_WIN32_WINNT)
# define _WIN32_WINNT 0x0A00
#endif
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
@@ -2279,7 +2283,7 @@ extern "C" {
float stop,
float step);
#define GGML_KQ_MASK_PAD 64
#define GGML_KQ_MASK_PAD 1
// q: [n_embd_k, n_batch, n_head, ne3 ]
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]

View File

@@ -127,10 +127,6 @@ if (NOT MSVC)
endif()
endif()
if (MINGW)
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
endif()
#
# POSIX conformance
#

View File

@@ -505,7 +505,6 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
@@ -645,7 +644,6 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);

View File

@@ -6383,7 +6383,7 @@ static void ggml_compute_forward_im2col_3d_f16(
const int64_t iih = ioh*s1 + ikh*d1 - p1;
const int64_t iid = iod*s2 + ikd*d2 - p2;
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
} else {
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]

View File

@@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)(
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max(
template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
const int nbatch_fa) {
constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x;
@@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup(
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
@@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results(
template <int DV, int ncols1, int ncols2>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
) {
constexpr int ncols = ncols1 * ncols2;
@@ -790,8 +791,6 @@ void launch_fattn(
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
@@ -915,7 +914,7 @@ void launch_fattn(
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
} else {
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
// parallel_blocks must not be larger than what the tensor size allows:
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -970,6 +969,9 @@ void launch_fattn(
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// TODO other tensor dimensions after removal of WMMA kernel:
const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
GGML_ASSERT(block_dim.x % warp_size == 0);
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
(const char *) Q->data,
@@ -980,7 +982,7 @@ void launch_fattn(
KV_max.ptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
nb21, nb22, nb23,
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
@@ -995,7 +997,7 @@ void launch_fattn(
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
}
} else if (parallel_blocks > 1) {
const dim3 block_dim_combine(DV, 1, 1);

File diff suppressed because it is too large Load Diff

View File

@@ -501,6 +501,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
const half * const __restrict__ mask,
const uint3 ne01,
const float logit_softcap,
const float slope,
T_KQ * const KQ,
@@ -512,7 +513,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
float * const KQ_sum,
T_acc * const VKQ,
const int k_VKQ_0,
const int k_VKQ_max) {
const int k_VKQ_max,
const int col_Q_0) {
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
@@ -556,7 +558,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
// Apply logit softcap + mask, update KQ_max:
#pragma unroll
for (int jc0 = 0; jc0 < cpw; ++jc0) {
const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2;
const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01);
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
@@ -736,7 +738,7 @@ static __global__ void flash_attn_tile(
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -781,11 +783,11 @@ static __global__ void flash_attn_tile(
const int sequence = blockIdx.z / (ne02/ncols2);
const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0);
const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr;
const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr;
const int stride_K2 = nb11 / sizeof(half2);
const int stride_V2 = nb21 / sizeof(half2);
@@ -842,11 +844,9 @@ static __global__ void flash_attn_tile(
for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
float tmp_f[cpy_ne_D] = {0.0f};
if (ncols1 == 1 || col_Q_0 + j < ne01) {
ggml_cuda_memcpy_1<sizeof(tmp_f)>
(tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float))
+ i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
}
ggml_cuda_memcpy_1<sizeof(tmp_f)>
(tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
+ i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@@ -881,23 +881,23 @@ static __global__ void flash_attn_tile(
while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
constexpr bool oob_check = false;
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
(Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
k_VKQ_0 += gridDim.y*nbatch_fa;
}
if (k_VKQ_0 < k_VKQ_max) {
constexpr bool oob_check = true;
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
(Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
}
} else {
// Branch without out-of-bounds checks.
for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
constexpr bool oob_check = false;
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
(Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
}
}
@@ -1010,13 +1010,13 @@ static __global__ void flash_attn_tile(
const int j = jc / ncols2;
const int c = jc % ncols2;
if (ncols1 > 1 && col_Q_0 + j >= ne01) {
if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) {
return;
}
const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
#ifdef FAST_FP16_AVAILABLE
constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;

View File

@@ -33,7 +33,7 @@ static __global__ void flash_attn_ext_vec(
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -150,7 +150,7 @@ static __global__ void flash_attn_ext_vec(
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
// Set memory to zero if out of bounds:
if (ncols > 1 && ic0 + j >= ne01) {
if (ncols > 1 && ic0 + j >= int(ne01.z)) {
#pragma unroll
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
@@ -201,7 +201,7 @@ static __global__ void flash_attn_ext_vec(
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
if (ncols == 1 || ic0 + j < ne01) {
if (ncols == 1 || ic0 + j < int(ne01.z)) {
ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
}
@@ -222,7 +222,7 @@ static __global__ void flash_attn_ext_vec(
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
if (ncols == 1 || ic0 + j < ne01) {
if (ncols == 1 || ic0 + j < int(ne01.z)) {
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
}
@@ -266,7 +266,7 @@ static __global__ void flash_attn_ext_vec(
sum = logit_softcap*tanhf(sum);
}
if (mask) {
if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) {
sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
}
@@ -412,7 +412,7 @@ static __global__ void flash_attn_ext_vec(
#pragma unroll
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
if (ncols > 1 && ic0 + j_VKQ >= ne01) {
if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) {
break;
}
@@ -479,7 +479,7 @@ static __global__ void flash_attn_ext_vec(
if (gridDim.y == 1) {
dst_val /= KQ_sum[j_VKQ];
}
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
}
}
@@ -489,8 +489,8 @@ static __global__ void flash_attn_ext_vec(
}
if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) {
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) {
dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,

View File

@@ -38,14 +38,14 @@ static __global__ void flash_attn_ext_f16(
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
@@ -149,7 +149,7 @@ static __global__ void flash_attn_ext_f16(
if (i0 + warp_size > D && i >= D) {
break;
}
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f;
}
}
@@ -218,7 +218,8 @@ static __global__ void flash_attn_ext_f16(
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
__half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
}
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
@@ -270,7 +271,7 @@ static __global__ void flash_attn_ext_f16(
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
}
KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
@@ -431,7 +432,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j_VKQ = j0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
if (ic0 + j_VKQ >= int(ne01.z)) {
return;
}
@@ -442,7 +443,7 @@ static __global__ void flash_attn_ext_f16(
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
}
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size) {
@@ -481,7 +482,7 @@ static __global__ void flash_attn_ext_f16(
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
}
constexpr int get_max_power_of_2(int x) {

View File

@@ -2,9 +2,9 @@
#include "common.cuh"
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
#if defined(GGML_USE_MUSA)
#define GGML_USE_WMMA_FATTN
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
#endif // defined(GGML_USE_MUSA)
#if defined(GGML_HIP_ROCWMMA_FATTN)
#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)

View File

@@ -12,13 +12,13 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
const ggml_tensor * Q = dst->src[0];
if constexpr (ncols2 <= 8) {
if (Q->ne[1] <= 8/ncols2) {
if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
return;
}
}
if (Q->ne[1] <= 16/ncols2) {
if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
return;
}
@@ -41,7 +41,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
const bool use_gqa_opt = mask && max_bias == 0.0f;
const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
@@ -275,8 +275,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
// If Turing tensor cores available, use them:
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) {
// If Turing tensor cores are available, use them:
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
@@ -297,7 +297,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_VEC;
}
}
return BEST_FATTN_KERNEL_MMA_F16;
}
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
int gqa_ratio_eff = 1;
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
gqa_ratio_eff *= 2;
}
if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
if (Q->ne[1] * gqa_ratio_eff <= 16) {
return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.
}
return BEST_FATTN_KERNEL_MMA_F16;
}

View File

@@ -68,10 +68,31 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
namespace ggml_cuda_mma {
// Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
// effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
// In those cases the data can be split in different ways across the warp.
enum data_layout {
// By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
// For the A/C matrices this means I major == row major, J major == column major.
// For the B matrix this means I major == column major, J major == row major.
// MIRRORED == Each data value is held exactly once per thread subgroup.
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
DATA_LAYOUT_I_MAJOR_MIRRORED = 10,
DATA_LAYOUT_J_MAJOR_MIRRORED = 20,
};
// Implemented mma combinations are:
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
struct tile {};
template <int I_, int J_, typename T>
struct tile {
static constexpr int I = I_;
static constexpr int J = J_;
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = I * J / 64;
@@ -131,9 +152,9 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 32 && J == 8) {
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
#else
return (l & 2) | (threadIdx.x & ~2);
return (l & 2) + (threadIdx.x & ~2);
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
} else {
NO_DEVICE_CODE;
@@ -143,7 +164,7 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 32 && J == 8) {
return (threadIdx.x & 2) | (l & (4 + 1));
return (threadIdx.x & 2) + (l & (4 + 1));
} else {
NO_DEVICE_CODE;
return -1;
@@ -196,9 +217,9 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return ((l / 2) * 8) | (threadIdx.x / 4);
return ((l / 2) * 8) + (threadIdx.x / 4);
} else if constexpr (I == 16 && J == 16) {
return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
} else if constexpr (I == 32 && J == 8) {
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
} else {
@@ -211,11 +232,11 @@ namespace ggml_cuda_mma {
if constexpr (I == 8 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 8 && J == 8) {
return (l * 4) | (threadIdx.x % 4);
return (l * 4) + (threadIdx.x % 4);
} else if constexpr (I == 16 && J == 8) {
return ((threadIdx.x % 4) * 2) | (l % 2);
return ((threadIdx.x % 4) * 2) + (l % 2);
} else if constexpr (I == 16 && J == 16) {
return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
} else if constexpr (I == 32 && J == 8) {
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
} else {
@@ -227,26 +248,24 @@ namespace ggml_cuda_mma {
};
template <int I_, int J_>
struct tile<I_, J_, half2> {
static constexpr int I = I_;
static constexpr int J = J_;
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 8 && J == 8) return true;
if (I == 32 && J == 8) return true;
if (I == 32 && J == 4) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
} else if constexpr (I == 32 && J == 8) {
if constexpr (I == 32 && J == 4) {
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
#else
return threadIdx.x;
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
@@ -257,7 +276,7 @@ namespace ggml_cuda_mma {
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr ((I == 8 || I == 32) && J == 8) {
if constexpr (I == 32 && J == 4) {
return l;
} else {
NO_DEVICE_CODE;
@@ -307,11 +326,11 @@ namespace ggml_cuda_mma {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
return (l * 8) | (threadIdx.x / 4);
return (l * 8) + (threadIdx.x / 4);
} else if constexpr (I == 16 && J == 8) {
return ((l % 2) * 8) | (threadIdx.x / 4);
return ((l % 2) * 8) + (threadIdx.x / 4);
} else if constexpr (I == 32 && J == 8) {
return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
} else {
NO_DEVICE_CODE;
return -1;
@@ -320,13 +339,13 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
return (l * 4) | (threadIdx.x % 4);
return (l * 4) + (threadIdx.x % 4);
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return ((l / 2) * 4) | (threadIdx.x % 4);
return ((l / 2) * 4) + (threadIdx.x % 4);
} else if constexpr (I == 32 && J == 8) {
return ((l & 2) * 2) | (threadIdx.x % 4);
return ((l & 2) * 2) + (threadIdx.x % 4);
} else {
NO_DEVICE_CODE;
return -1;
@@ -336,14 +355,15 @@ namespace ggml_cuda_mma {
};
template <int I_, int J_>
struct tile<I_, J_, nv_bfloat162> {
static constexpr int I = I_;
static constexpr int J = J_;
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
static constexpr int ne = I * J / WARP_SIZE;
#if defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
#if defined(AMD_WMMA_AVAILABLE)
static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
@@ -367,9 +387,6 @@ namespace ggml_cuda_mma {
}
}
#else
static constexpr int ne = I * J / WARP_SIZE;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 8 && J == 8) return true;
if (I == 16 && J == 4) return true;
@@ -381,9 +398,9 @@ namespace ggml_cuda_mma {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
return (l * 8) | (threadIdx.x / 4);
return (l * 8) + (threadIdx.x / 4);
} else if constexpr (I == 16 && J == 8) {
return ((l % 2) * 8) | (threadIdx.x / 4);
return ((l % 2) * 8) + (threadIdx.x / 4);
} else {
NO_DEVICE_CODE;
return -1;
@@ -392,11 +409,11 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
return (l * 4) | (threadIdx.x % 4);
return (l * 4) + (threadIdx.x % 4);
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return ((l / 2) * 4) | (threadIdx.x % 4);
return ((l / 2) * 4) + (threadIdx.x % 4);
} else {
NO_DEVICE_CODE;
return -1;
@@ -405,6 +422,73 @@ namespace ggml_cuda_mma {
#endif // defined(AMD_WMMA_AVAILABLE)
};
template <int I_, int J_>
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
static constexpr int ne = I * J / (WARP_SIZE/4);
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 8 && J == 4) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int /*l*/) {
if constexpr (I == 8 && J == 4) {
return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 4) {
return l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
};
template <int I_, int J_>
struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
static constexpr int ne = I * J / (WARP_SIZE/4);
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 8 && J == 4) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 4) {
return ((l / 2) * 4) + (threadIdx.x % 4);
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 4) {
return ((threadIdx.x / 16) * 2) + (l % 2);
} else {
NO_DEVICE_CODE;
return -1;
}
}
};
#if defined(TURING_MMA_AVAILABLE)
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
tile<I, J/2, half2> ret;
@@ -422,9 +506,26 @@ namespace ggml_cuda_mma {
return ret;
}
#else // Volta
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
tile<I, J/2, half2> ret;
#pragma unroll
for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
template <int I, int J, typename T>
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
// On Volta FP16 and FP32 tiles have a different memory layout,
// for the conversion threads with an offset of 2 need to exchange half their values:
ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
}
return ret;
}
#endif // defined(TURING_MMA_AVAILABLE)
template <int I, int J, typename T, data_layout dl>
static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
#if defined(AMD_MFMA_AVAILABLE)
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
#pragma unroll
@@ -511,18 +612,6 @@ namespace ggml_cuda_mma {
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
: "l"(xs));
#else
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
GGML_UNUSED_VARS(t, xs0, stride);
NO_DEVICE_CODE;
#else
load_generic(t, xs0, stride);
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#endif // TURING_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#if 1
// TODO: more generic handling
@@ -533,9 +622,31 @@ namespace ggml_cuda_mma {
load_generic(t, xs0, stride);
#endif // 1
#else
tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
load_ldmatrix(t16[0], xs0 + 0*stride, stride);
load_ldmatrix(t16[1], xs0 + 16*stride, stride);
load_generic(t, xs0, stride);
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void load_ldmatrix(
tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
}
static __device__ __forceinline__ void load_ldmatrix(
tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
#pragma unroll
for (int l0 = 0; l0 < t.ne; l0 += 2) {
ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
}
}
static __device__ __forceinline__ void load_ldmatrix(
tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
#else
GGML_UNUSED_VARS(t, xs0, stride);
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
}
@@ -860,14 +971,14 @@ namespace ggml_cuda_mma {
template <typename T1, typename T2, int J, int K>
static __device__ __forceinline__ void mma(
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
mma(D16[0], A16[0], B);
mma(D16[1], A16[1], B);
}
static __device__ __forceinline__ void mma(
tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
@@ -880,20 +991,30 @@ namespace ggml_cuda_mma {
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
#else
tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
mma(D16[0], A16[0], B);
mma(D16[1], A16[1], B);
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
}
static __device__ __forceinline__ void mma(
tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
}
static __device__ __forceinline__ void mma(

View File

@@ -37,23 +37,19 @@ static __global__ void mul_mat_f(
typedef tile<16, 8, T> tile_A;
typedef tile<tile_B_I, 8, T> tile_B;
typedef tile<16, tile_C_J, float> tile_C;
constexpr bool a_supported = tile_A::supported();
constexpr bool b_supported = tile_B::supported();
constexpr bool c_supported = tile_C::supported();
constexpr bool supported = a_supported && b_supported && c_supported;
#else
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
constexpr bool supported = I_16_supported || I_32_supported;
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
typedef tile<I_preferred, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<I_preferred, 8, float> tile_C;
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
#endif // VOLTA_MMA_AVAILABLE
#endif // defined(AMD_WMMA_AVAILABLE)
if constexpr (!supported) {
if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
NO_DEVICE_CODE;
return;
}
@@ -248,6 +244,9 @@ static __global__ void mul_mat_f(
}
}
}
#ifdef VOLTA_MMA_AVAILABLE
}
#endif //VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
@@ -278,27 +277,24 @@ static __global__ void mul_mat_f_ids(
typedef tile<16, 8, T> tile_A;
typedef tile<tile_B_I, 8, T> tile_B;
typedef tile<16, tile_C_J, float> tile_C;
constexpr bool a_supported = tile_A::supported();
constexpr bool b_supported = tile_B::supported();
constexpr bool c_supported = tile_C::supported();
constexpr bool supported = a_supported && b_supported && c_supported;
#else
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
constexpr bool supported = I_16_supported || I_32_supported;
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
typedef tile<I_preferred, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<I_preferred, 8, float> tile_C;
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
#endif // VOLTA_MMA_AVAILABLE
#endif // defined(AMD_WMMA_AVAILABLE)
if constexpr (!supported) {
if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
NO_DEVICE_CODE;
return;
}
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
@@ -517,6 +513,9 @@ static __global__ void mul_mat_f_ids(
}
}
}
#ifdef VOLTA_MMA_AVAILABLE
}
#endif // VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,

File diff suppressed because it is too large Load Diff

View File

@@ -35,20 +35,6 @@ typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
ggml_metal_pipeline_t ggml_metal_pipeline_init(void);
void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);
void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg);
int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline);
void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0);
int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline);
void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1);
int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline);
void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem);
size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline);
int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline);
// a collection of pipelines
typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;
@@ -58,6 +44,19 @@ void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);
struct ggml_metal_pipeline_with_params {
ggml_metal_pipeline_t pipeline;
int nsg;
int nr0;
int nr1;
size_t smem;
};
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
//
// MTLCommandBuffer wrapper
//
@@ -76,7 +75,7 @@ void ggml_metal_encoder_free(ggml_metal_encoder_t encoder);
void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name);
void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder);
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline);
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline);
void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx);
void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx);
@@ -100,66 +99,67 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
void ggml_metal_library_free(ggml_metal_library_t lib);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
int32_t ncpsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
int32_t nqptg,
int32_t ncpsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
@@ -169,7 +169,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
bool has_kvpad,
int32_t nsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
@@ -180,7 +180,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
int32_t nsg,
int32_t nwg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
int32_t dv,

View File

@@ -75,14 +75,6 @@ void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) {
struct ggml_metal_pipeline {
id<MTLComputePipelineState> obj;
// suggested dispatch sizes
int nsg;
int nr0;
int nr1;
size_t smem;
};
ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
@@ -90,10 +82,6 @@ ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
*res = (struct ggml_metal_pipeline) {
/*.obj =*/ nil,
/*.nsg =*/ 0,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.smem =*/ 0,
};
return res;
@@ -105,40 +93,8 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) {
free(pipeline);
}
void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) {
pipeline->nsg = nsg;
}
int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) {
return pipeline->nsg;
}
void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) {
pipeline->nr0 = nr0;
}
int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) {
return pipeline->nr0;
}
void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) {
pipeline->nr1 = nr1;
}
int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) {
return pipeline->nr1;
}
void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) {
pipeline->smem = smem;
}
size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) {
return pipeline->smem;
}
int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) {
return pipeline->obj.maxTotalThreadsPerThreadgroup;
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) {
return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
}
struct ggml_metal_library {
@@ -146,6 +102,8 @@ struct ggml_metal_library {
id<MTLDevice> device;
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
NSLock * lock;
};
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
@@ -296,9 +254,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
res->obj = library;
res->device = device;
res->obj = library;
res->device = device;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
return res;
}
@@ -365,6 +324,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
res->obj = library;
res->device = device;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
return res;
}
@@ -380,26 +340,47 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
ggml_metal_pipelines_free(lib->pipelines);
[lib->lock release];
free(lib);
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
return ggml_metal_pipelines_get(lib->pipelines, name);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
[lib->lock lock];
struct ggml_metal_pipeline_with_params res = {
/*.pipeline =*/ nil,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.nsg =*/ 0,
/*.smem =*/ 0,
};
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
[lib->lock unlock];
return res;
}
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
// note: the pipelines are cached in the library per device, so they are shared across all metal contexts
ggml_critical_section_start();
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
struct ggml_metal_pipeline_with_params res = {
/*.pipeline =*/ nil,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.nsg =*/ 0,
/*.smem =*/ 0,
};
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
ggml_critical_section_end();
[lib->lock lock];
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
if (res.pipeline) {
[lib->lock unlock];
return res;
}
res = ggml_metal_pipeline_init();
@autoreleasepool {
NSError * error = nil;
@@ -414,36 +395,53 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
}
if (!mtl_function) {
ggml_critical_section_end();
[lib->lock unlock];
GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
if (error) {
GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
}
return nil;
return res;
}
res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
[mtl_function release];
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj,
(int) res->obj.maxTotalThreadsPerThreadgroup,
(int) res->obj.threadExecutionWidth);
if (!obj) {
[lib->lock unlock];
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
ggml_critical_section_end();
GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name);
if (error) {
GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
}
return res;
}
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name,
(void *) obj,
(int) obj.maxTotalThreadsPerThreadgroup,
(int) obj.threadExecutionWidth);
if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) {
[obj release];
[lib->lock unlock];
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
return nil;
return res;
}
ggml_metal_pipelines_add(lib->pipelines, name, res);
res.pipeline = ggml_metal_pipeline_init();
res.pipeline->obj = obj;
ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline);
}
ggml_critical_section_end();
[lib->lock unlock];
return res;
}
@@ -485,8 +483,8 @@ void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) {
[encoder->obj popDebugGroup];
}
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) {
[encoder->obj setComputePipelineState:pipeline->obj];
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) {
[encoder->obj setComputePipelineState:pipeline.pipeline->obj];
}
void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) {
@@ -611,8 +609,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
dev->props.has_tensor = false;
} else {
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
if (!ppl) {
struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
if (!ppl.pipeline) {
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
dev->props.has_tensor = false;
}
@@ -661,8 +659,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
dev->props.has_bfloat = false;
} else {
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
if (!ppl) {
struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
if (!ppl.pipeline) {
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
dev->props.has_bfloat = false;
}
@@ -820,6 +818,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
@@ -852,6 +852,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_ACC:
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
case GGML_OP_FILL:
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
case GGML_OP_CONV_TRANSPOSE_2D:
@@ -869,6 +870,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_TRI:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_SUM_ROWS:
case GGML_OP_CUMSUM:
case GGML_OP_MEAN:

View File

@@ -182,6 +182,10 @@ typedef struct {
float bias;
} ggml_metal_kargs_scale;
typedef struct {
float val;
} ggml_metal_kargs_fill;
typedef struct {
float min;
float max;
@@ -831,6 +835,25 @@ typedef struct {
float slope;
} ggml_metal_kargs_leaky_relu;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_tri;
typedef struct {
int32_t ne00;
int32_t ne01;

View File

@@ -286,6 +286,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_scale(ctx, idx);
} break;
case GGML_OP_FILL:
{
n_fuse = ggml_metal_op_fill(ctx, idx);
} break;
case GGML_OP_CLAMP:
{
n_fuse = ggml_metal_op_clamp(ctx, idx);
@@ -414,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
} break;
case GGML_OP_TRI:
{
n_fuse = ggml_metal_op_tri(ctx, idx);
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
@@ -524,7 +532,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
/*.dim =*/ dim,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -550,7 +558,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
ggml_metal_kargs_repeat args = {
/*.ne00 =*/ ne00,
@@ -616,7 +624,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
// TODO: make a simpler cpy_bytes kernel
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00,
@@ -679,7 +687,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
/*.o1 =*/ { 0 },
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -721,7 +729,42 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
n /= 4;
}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const float val = ggml_get_op_params_f32(op, 0);
ggml_metal_kargs_fill args = {
/*.val =*/ val
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -760,7 +803,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
n /= 4;
}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -789,7 +832,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
n /= 4;
}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
@@ -817,7 +860,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
const int32_t swp = ggml_get_op_params_i32(op, 1);
const float alpha = ggml_get_op_params_f32(op, 2);
@@ -870,7 +913,7 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
/*.np =*/ n,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
int nth = 32; // SIMD width
@@ -925,7 +968,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
/*.nb3 =*/ nb3,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
int nth = 32; // SIMD width
@@ -936,7 +979,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -963,7 +1006,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
int nth = 1;
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
@@ -1060,7 +1103,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx);
{
ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
ggml_metal_kargs_cumsum_add args = {
/*.ne00 =*/ ne00,
@@ -1106,7 +1149,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
ggml_metal_kargs_get_rows args = {
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
@@ -1151,7 +1194,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
const int32_t nk0 = ne0/ggml_blck_size(op->type);
@@ -1252,7 +1295,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
/*.n_head_log2 =*/ n_head_log2,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
int nth = 32; // SIMD width
@@ -1266,7 +1309,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
}
}
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
@@ -1322,7 +1365,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
/*.nb2 =*/ nb2,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
@@ -1409,11 +1452,11 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
/*.nb0 =*/ nb0,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -1426,7 +1469,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
@@ -1449,7 +1492,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
const int64_t C = op->ne[0];
const int64_t H = op->src[0]->ne[1];
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
int ida = 0;
@@ -1485,7 +1528,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
@@ -1592,7 +1635,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
/* .np = */ np
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
const int ntg = (np + nth - 1) / nth;
@@ -1701,7 +1744,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
GGML_ABORT("unsupported ne11");
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
ggml_metal_kargs_mul_mv_ext args = {
/*.ne00 =*/ ne00,
@@ -1748,7 +1791,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
// default: break;
//}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
ggml_metal_kargs_mul_mm args = {
/*.ne00 =*/ ne00,
@@ -1773,18 +1816,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
} else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const int nr0 = pipeline.nr0;
const int nr1 = pipeline.nr1;
const int nsg = pipeline.nsg;
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
ggml_metal_kargs_mul_mv args = {
/*.ne00 =*/ ne00,
@@ -1915,9 +1958,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
nb21,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
@@ -1938,7 +1981,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx);
{
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
ggml_metal_kargs_mul_mm_id args = {
/*.ne00 =*/ ne00,
@@ -1967,20 +2010,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
}
} else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const int nr0 = pipeline.nr0;
const int nr1 = pipeline.nr1;
const int nsg = pipeline.nsg;
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
ggml_metal_kargs_mul_mv_id args = {
/*.nei0 =*/ ne20,
@@ -2064,7 +2107,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
/*.nb21 =*/ nb21,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2308,7 +2351,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb33 =*/nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@@ -2339,7 +2382,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb33 =*/ nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@@ -2424,7 +2467,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2476,7 +2519,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb33 =*/nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@@ -2578,7 +2621,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
@@ -2630,7 +2673,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
nrows,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@@ -2762,7 +2805,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
// the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
bid_src1.offs = 0;
ggml_metal_pipeline_t pipeline = nullptr;
struct ggml_metal_pipeline_with_params pipeline;
if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@@ -2835,7 +2878,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
/*.eps =*/ eps,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
@@ -2844,7 +2887,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
const int64_t nrows = ggml_nrows(op->src[0]);
@@ -2887,7 +2930,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
/*.eps =*/ eps,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
int nth = 32; // SIMD width
//while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
@@ -2897,7 +2940,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
//nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
//nth = std::min(nth, ne00/4);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3022,7 +3065,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
}
}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
int nth = 32; // SIMD width
@@ -3033,7 +3076,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, args.ne00_t);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3127,7 +3170,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
/* src2 =*/ op->src[2] != nullptr,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3199,7 +3242,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
/*.KHW =*/ KH * KW,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
@@ -3270,7 +3313,7 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
/*.d1 =*/ d1,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
nth = std::min(nth, 256);
@@ -3325,7 +3368,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
/*.nb1 =*/ nb1,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3377,7 +3420,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
/*.nb2 =*/ nb2,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3433,7 +3476,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
/*.sf3 =*/ sf3
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
@@ -3477,7 +3520,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
/*.nb3 =*/ nb3
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
const int nth = std::min(1024, ne0);
@@ -3523,7 +3566,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
/*.p1 =*/ ((const int32_t *)(op->op_params))[1]
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
const int nth = std::min(1024, ne0);
@@ -3560,7 +3603,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
const int nth = std::min(1024, ne0);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3591,7 +3634,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
/*.max_period =*/ max_period,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
const int nth = std::max(1, std::min(1024, dim/2));
@@ -3621,7 +3664,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
/*.nb01 = */ nb01,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
const int64_t nrows = ggml_nrows(op->src[0]);
@@ -3630,7 +3673,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
nth *= 2;
}
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3657,7 +3700,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
// bitonic sort requires the number of elements to be power of 2
int nth = 1;
@@ -3706,7 +3749,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
int len = nth;
@@ -3764,7 +3807,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
// bitonic sort requires the number of elements to be power of 2
int nth = 1;
@@ -3818,7 +3861,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
int len = args.top_k;
@@ -3881,7 +3924,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
/*.slope =*/ slope
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
int64_t n = ggml_nelements(op);
@@ -3899,6 +3942,57 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_tri args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
return 1;
}
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -3910,7 +4004,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_adamw args = {
@@ -3946,7 +4040,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_sgd args = {

View File

@@ -47,6 +47,7 @@ int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
@@ -83,6 +84,7 @@ int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);

View File

@@ -1249,6 +1249,22 @@ kernel void kernel_scale_f32_4(
dst[tpig] = src0[tpig] * args.scale + args.bias;
}
kernel void kernel_fill_f32(
constant ggml_metal_kargs_fill & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
kernel void kernel_fill_f32_4(
constant ggml_metal_kargs_fill & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
kernel void kernel_clamp_f32(
constant ggml_metal_kargs_clamp & args,
device const float * src0,
@@ -1595,6 +1611,36 @@ kernel void kernel_exp_f32_4(
dst[tpig] = exp(src0[tpig]);
}
kernel void kernel_softplus_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
}
kernel void kernel_softplus_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
}
kernel void kernel_expm1_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]) - 1.0f;
}
kernel void kernel_expm1_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]) - 1.0f;
}
kernel void kernel_reglu_f32(
constant ggml_metal_kargs_glu & args,
device const char * src0,
@@ -1943,6 +1989,75 @@ typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
template<uint32_t ttype>
bool _ggml_vec_tri_cmp(const int i, const int r);
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
return i < r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
return i <= r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
return i > r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
return i >= r;
}
template<typename T, int ttype>
kernel void kernel_tri(
constant ggml_metal_kargs_tri & args,
device const char * src0,
device const char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
// Each thread is a single element of the row if ne00 < max threads per
// threadgroup, so this will loop once for each index that this thread is
// responsible for
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
// Use the comparison as a mask for branchless
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
}
}
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
#endif
template<typename T>
kernel void kernel_soft_max(
constant ggml_metal_kargs_soft_max & args,

View File

@@ -31,6 +31,14 @@ except ImportError:
else:
_mistral_common_installed = True
try:
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
get_one_valid_tokenizer_file,
)
except ImportError:
# We still want the conversion to work with older mistral-common versions.
get_one_valid_tokenizer_file = None
import gguf
@@ -673,24 +681,30 @@ class MistralVocab(Vocab):
# Find the tokenizer files
all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
if len(valid_tokenizer_files) == 0:
raise ValueError(f"No tokenizer file found in the directory: {base_path}")
# If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
if len(valid_tokenizer_files) > 1:
if "tekken.json" in valid_tokenizer_files:
tokenizer_file = "tekken.json"
else:
tokenizer_file = sorted(valid_tokenizer_files)[-1]
logger.warning(
f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
)
if get_one_valid_tokenizer_file is not None:
tokenizer_file_path = get_one_valid_tokenizer_file(all_files)
else:
tokenizer_file = valid_tokenizer_files[0]
valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
if len(valid_tokenizer_files) == 0:
raise ValueError(f"No tokenizer file found in the directory: {base_path}")
# If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
if len(valid_tokenizer_files) > 1:
if "tekken.json" in valid_tokenizer_files:
tokenizer_file = "tekken.json"
else:
tokenizer_file = sorted(valid_tokenizer_files)[-1]
logger.warning(
f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
)
else:
tokenizer_file = valid_tokenizer_files[0]
tokenizer_file_path = base_path / tokenizer_file
self.tokenizer = MistralTokenizer.from_file(
base_path / tokenizer_file
tokenizer_file_path
).instruct_tokenizer.tokenizer
self.tokenizer_type = (
MistralTokenizerType.tekken
@@ -698,7 +712,7 @@ class MistralVocab(Vocab):
else MistralTokenizerType.spm
)
self.vocab_size = self.tokenizer.n_words
self.fname_tokenizer = base_path / tokenizer_file
self.fname_tokenizer = tokenizer_file_path
self._name = (
"mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
)

View File

@@ -726,21 +726,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// sanity checks for models that have attention layers
if (qs.n_attention_wv != 0 && !is_clip_model)
{
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
// attention layers have a non-zero number of kv heads
int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
int32_t n_layer_all = model.hparams.n_layer;
if (llama_model_has_encoder(&model)) {
// now n_layer_attn is the number of attention layers in the encoder
// now n_layer_all is the number of attention layers in the encoder
// for each decoder block, there are 2 attention layers
n_layer_attn += 2 * model.hparams.dec_n_layer;
n_layer_all += 2 * model.hparams.dec_n_layer;
}
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w);
LLAMA_LOG_INFO("%s: n_layer_all = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_all, n_layer_recr, pruned_attention_w);
GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
GGML_ASSERT((qs.n_attention_wv == n_layer_all - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
}
size_t total_size_org = 0;

1
tests/.gitignore vendored
View File

@@ -3,3 +3,4 @@
*.o
ggml-common.h
**/*.swp
!peg-parser

View File

@@ -1,13 +1,15 @@
llama_add_compile_flags()
function(llama_build source)
set(TEST_SOURCES ${source} ${ARGN})
if (DEFINED LLAMA_TEST_NAME)
set(TEST_TARGET ${LLAMA_TEST_NAME})
else()
get_filename_component(TEST_TARGET ${source} NAME_WE)
endif()
add_executable(${TEST_TARGET} ${source})
add_executable(${TEST_TARGET} ${TEST_SOURCES})
target_link_libraries(${TEST_TARGET} PRIVATE common)
install(TARGETS ${TEST_TARGET} RUNTIME)
endfunction()
@@ -83,6 +85,8 @@ function(llama_build_and_test source)
set(multiValueArgs ARGS)
cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(TEST_SOURCES ${source} ${LLAMA_TEST_UNPARSED_ARGUMENTS} get-model.cpp)
if (NOT DEFINED LLAMA_TEST_LABEL)
set(LLAMA_TEST_LABEL "main")
endif()
@@ -95,7 +99,7 @@ function(llama_build_and_test source)
get_filename_component(TEST_TARGET ${source} NAME_WE)
endif()
add_executable(${TEST_TARGET} ${source} get-model.cpp)
add_executable(${TEST_TARGET} ${TEST_SOURCES})
install(TARGETS ${TEST_TARGET} RUNTIME)
target_link_libraries(${TEST_TARGET} PRIVATE common)
@@ -180,9 +184,21 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
endif()
llama_build_and_test(test-chat-parser.cpp)
llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp)
llama_build_and_test(test-chat-template.cpp)
llama_build_and_test(test-json-partial.cpp)
llama_build_and_test(test-log.cpp)
llama_build_and_test(
test-peg-parser.cpp
peg-parser/simple-tokenize.cpp
peg-parser/test-basic.cpp
peg-parser/test-gbnf-generation.cpp
peg-parser/test-json-parser.cpp
peg-parser/test-json-serialization.cpp
peg-parser/test-unicode.cpp
peg-parser/testing.h
peg-parser/tests.h
)
llama_build_and_test(test-regex-partial.cpp)
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")

View File

@@ -0,0 +1,37 @@
#include "simple-tokenize.h"
std::vector<std::string> simple_tokenize(const std::string & input) {
std::vector<std::string> result;
std::string current;
for (size_t i = 0; i < input.size(); i++) {
switch (input[i]) {
case ' ':
case '\n':
case '\t':
case '{':
case '}':
case ',':
case '[':
case '"':
case ']':
case '.':
case '<':
case '>':
case '=':
case '/':
if (!current.empty()) {
result.push_back(current);
current.clear();
}
default:;
}
current += input[i];
}
if (!current.empty()) {
result.push_back(current);
}
return result;
}

View File

@@ -0,0 +1,6 @@
#pragma once
#include <string>
#include <vector>
std::vector<std::string> simple_tokenize(const std::string &);

View File

@@ -0,0 +1,454 @@
#include "tests.h"
void test_basic(testing & t) {
t.test("chars", [](testing & t) {
// Test common escape sequences - newline
t.test("escape_sequence_newline", [](testing &t) {
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("\n");
result = common_chat_combinator_parser.parse(ctx);
t.assert_equal("escape_sequence_newline", true, result.success());
});
// Test common escape sequences - tab
t.test("escape_sequence_tab", [](testing &t) {
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("\t");
result = common_chat_combinator_parser.parse(ctx);
t.assert_equal("escape_sequence_tab", true, result.success());
});
// Test common escape sequences - backslash
t.test("escape_sequence_backslash", [](testing &t) {
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("\\");
result = common_chat_combinator_parser.parse(ctx);
t.assert_equal("escape_sequence_backslash", true, result.success());
});
// Test common escape sequences - space (should ())
t.test("escape_sequence_space_fail", [](testing &t) {
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context(" ");
result = common_chat_combinator_parser.parse(ctx);
t.assert_equal("escape_sequence_space_fail", true, result.fail());
});
// Test escaped dash - 'a' should succeed
t.test("escaped_dash_a", [](testing &t) {
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("a");
result = common_chat_combinator_parser.parse(ctx);
t.assert_equal("escaped_dash_a", true, result.success());
});
// Test escaped dash - '-' should succeed (literal dash)
t.test("escaped_dash_literal", [](testing &t) {
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("-");
result = common_chat_combinator_parser.parse(ctx);
t.assert_equal("escaped_dash_literal", true, result.success());
});
// Test escaped dash - 'z' should succeed
t.test("escaped_dash_z", [](testing &t) {
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("z");
result = common_chat_combinator_parser.parse(ctx);
t.assert_equal("escaped_dash_z", true, result.success());
});
// Test escaped dash - 'b' should NOT match (since \- is literal dash, not range)
t.test("escaped_dash_b_fail", [](testing &t) {
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("b");
result = common_chat_combinator_parser.parse(ctx);
t.assert_equal("escaped_dash_b_fail", true, result.fail());
});
});
t.test("optional", [](testing & t) {
// Full match with optional part present
t.test("optional_present", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("hello") + p.optional(p.literal(" world"));
});
auto ctx = common_peg_parse_context("hello world");
auto result = parser.parse(ctx);
t.assert_equal("optional_present", true, result.success());
t.assert_equal("optional_present_end", 11u, result.end);
});
// Full match with optional part absent
t.test("optional_absent", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("hello") + p.optional(p.literal(" world"));
});
auto ctx = common_peg_parse_context("hello", false);
auto result = parser.parse(ctx);
t.assert_equal("optional_absent", true, result.success());
t.assert_equal("optional_absent_end", 5u, result.end);
});
// Partial match - waiting for more input to determine if optional matches
t.test("partial_match_need_more", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("hello") + p.optional(p.literal(" world"));
});
auto ctx = common_peg_parse_context("hello ", true);
auto result = parser.parse(ctx);
t.assert_equal("partial_match_need_more", true, result.need_more_input());
});
});
t.test("partial parsing", [](testing & t) {
// Literals - Basic Success
t.test("literal_success", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("hello"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("hello");
result = parser.parse(ctx);
t.assert_equal("literal_success", true, result.success());
});
// Char Classes - Basic Lowercase Success
t.test("char_class_lowercase_success", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("a");
result = parser.parse(ctx);
t.assert_equal("char_class_lowercase_success", true, result.success());
});
// Char Classes - Uppercase Fail
t.test("char_class_uppercase_fail", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("A");
result = parser.parse(ctx);
t.assert_equal("char_class_uppercase_fail", true, result.fail());
});
// Char Classes with Dash - Lowercase Success
t.test("char_class_with_dash_lowercase", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("f");
result = parser.parse(ctx);
t.assert_equal("char_class_with_dash_lowercase", true, result.success());
});
// Char Classes with Dash - Literal Dash Success
t.test("char_class_with_dash_literal_dash", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("-");
result = parser.parse(ctx);
t.assert_equal("char_class_with_dash_literal_dash", true, result.success());
});
// Char Classes with Dash - Uppercase Fail
t.test("char_class_with_dash_uppercase_fail", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
common_peg_parse_context ctx;
common_peg_parse_result result;
ctx = common_peg_parse_context("A");
result = parser.parse(ctx);
t.assert_equal("char_class_with_dash_uppercase_fail", true, result.fail());
});
// Sequences - Partial Match 1
t.test("sequence_partial_match_1", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
auto ctx = common_peg_parse_context("<thi", true);
auto result = parser.parse(ctx);
t.assert_equal("sequence_partial_match_1", true, result.need_more_input());
});
// Sequences - Partial Match 2
t.test("sequence_partial_match_2", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("begin") + p.literal("end"); });
auto ctx = common_peg_parse_context("begin", true);
auto result = parser.parse(ctx);
t.assert_equal("sequence_partial_match_2", true, result.need_more_input());
});
// Sequences - Partial Match 3
t.test("sequence_partial_match_3", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
auto ctx = common_peg_parse_context("<think></", true);
auto result = parser.parse(ctx);
t.assert_equal("sequence_partial_match_3", true, result.need_more_input());
});
// Sequences - Full Match
t.test("sequence_full_match", [&](testing & t) {
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("hello") + p.literal("world"); });
auto ctx = common_peg_parse_context("helloworld", false);
auto result = common_chat_combinator_parser.parse(ctx);
t.assert_equal("sequence_full_match", true, result.success());
});
// Sequences - No Match
t.test("sequence_no_match", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
auto ctx = common_peg_parse_context("<think>I am common_chat_combinator_parser", true);
auto result = parser.parse(ctx);
t.assert_equal("sequence_no_match", true, result.fail());
});
// Choices - Partial Match 1
t.test("choices_partial_match_1", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("option1") | p.literal("option2"); });
auto ctx = common_peg_parse_context("opt", true);
auto result = parser.parse(ctx);
t.assert_equal("choices_partial_match_1", true, result.need_more_input());
});
// Choices - Partial Match 2
t.test("choices_partial_match_2", [&](testing & t) {
auto parser =
build_peg_parser([](common_peg_parser_builder & p) { return p.literal("choice_a") | p.literal("choice_b"); });
auto ctx = common_peg_parse_context("choice", true);
auto result = parser.parse(ctx);
t.assert_equal("choices_partial_match_2", true, result.need_more_input());
});
// Choices - Full Match 1
t.test("choices_full_match_1", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("first") | p.literal("second"); });
auto ctx = common_peg_parse_context("first", false);
auto result = parser.parse(ctx);
t.assert_equal("choices_full_match_1", true, result.success());
});
// Choices - Full Match 2
t.test("choices_full_match_2", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("alpha") | p.literal("beta"); });
auto ctx = common_peg_parse_context("beta", false);
auto result = parser.parse(ctx);
t.assert_equal("choices_full_match_2", true, result.success());
});
// Choices - No Match
t.test("choices_no_match", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("good") | p.literal("better"); });
auto ctx = common_peg_parse_context("best", false);
auto result = parser.parse(ctx);
t.assert_equal("choices_no_match", true, result.fail());
});
// Zero or More - Partial Match 1
t.test("zero_or_more_partial_match_1", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("ab")); });
auto ctx = common_peg_parse_context("a", true);
auto result = parser.parse(ctx);
t.assert_equal("zero_or_more_partial_match_1", true, result.need_more_input());
});
// Zero or More - Partial Match 2
t.test("zero_or_more_partial_match_2", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("xy")); });
auto ctx = common_peg_parse_context("xyx", true);
auto result = parser.parse(ctx);
t.assert_equal("zero_or_more_partial_match_2", true, result.need_more_input());
});
// Zero or More - Full Match
t.test("zero_or_more_full_match", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("test")); });
auto ctx = common_peg_parse_context("test", false);
auto result = parser.parse(ctx);
t.assert_equal("zero_or_more_full_match", true, result.success());
});
// One or More - Partial Match 1
t.test("one_or_more_partial_match_1", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("repeat")); });
auto ctx = common_peg_parse_context("rep", true);
auto result = parser.parse(ctx);
t.assert_equal("one_or_more_partial_match_1", true, result.need_more_input());
});
// One or More - Partial Match 2
t.test("one_or_more_partial_match_2", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("ab")); });
auto ctx = common_peg_parse_context("aba", true);
auto result = parser.parse(ctx);
t.assert_equal("one_or_more_partial_match_2", true, result.need_more_input());
});
// One or More - Full Match
t.test("one_or_more_full_match", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("single")); });
auto ctx = common_peg_parse_context("single", false);
auto result = parser.parse(ctx);
t.assert_equal("one_or_more_full_match", true, result.success());
});
// One or More - No Match
t.test("one_or_more_no_match", [&](testing & t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("()")); });
auto ctx = common_peg_parse_context("success", false);
auto result = parser.parse(ctx);
t.assert_equal("one_or_more_no_match", true, result.fail());
});
});
t.test("recursive rules", [](testing &t) {
// Test simple number
t.test("simple_number", [](testing &t) {
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
p.rule("number", p.chars("0-9"));
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
return p.rule("value", p.ref("number") | p.ref("list"));
});
common_peg_parse_context ctx("1", false);
auto result = value_parser.parse(ctx);
t.assert_equal("result_is_success", true, result.success());
});
// Test simple list
t.test("simple_list", [](testing &t) {
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
p.rule("number", p.chars("0-9"));
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
return p.rule("value", p.ref("number") | p.ref("list"));
});
common_peg_parse_context ctx("[1]", false);
auto result = value_parser.parse(ctx);
t.assert_equal("result_is_success", true, result.success());
});
// Test nested list
t.test("nested_list", [](testing &t) {
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
p.rule("number", p.chars("0-9"));
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
return p.rule("value", p.ref("number") | p.ref("list"));
});
common_peg_parse_context ctx("[[2]]", false);
auto result = value_parser.parse(ctx);
t.assert_equal("result_is_success", true, result.success());
});
// Test deeply nested list
t.test("deeply_nested_list", [](testing &t) {
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
p.rule("number", p.chars("0-9"));
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
return p.rule("value", p.ref("number") | p.ref("list"));
});
common_peg_parse_context ctx("[[[3]]]", false);
auto result = value_parser.parse(ctx);
t.assert_equal("result_is_success", true, result.success());
});
// Test need_more_input match
t.test("need_more_input_match", [](testing &t) {
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
p.rule("number", p.chars("0-9"));
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
return p.rule("value", p.ref("number") | p.ref("list"));
});
common_peg_parse_context ctx("[[", true);
auto result = value_parser.parse(ctx);
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
});
// Test no match
t.test("no_match", [](testing &t) {
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
p.rule("number", p.chars("0-9"));
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
return p.rule("value", p.ref("number") | p.ref("list"));
});
common_peg_parse_context ctx("[a]", false);
auto result = value_parser.parse(ctx);
t.assert_equal("result_is_fail", true, result.fail());
});
});
}

View File

@@ -0,0 +1,250 @@
#include "tests.h"
#include "json-schema-to-grammar.h"
#include <regex>
static std::string trim_leading_space(const std::string & s) {
static const std::regex leading_ws_re = std::regex(R"((^|\n)\s+)");
return std::regex_replace(s, leading_ws_re, "$1");
}
static void assert_gbnf_equal(testing & t, const std::string & expected, const std::string & actual) {
t.assert_equal("gbnf are equal", trim_leading_space(expected), trim_leading_space(actual));
}
void test_gbnf_generation(testing &t) {
t.test("literal grammar generation", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("hello");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "hello"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("char class grammar", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.chars("[a-z]", 1, 1);
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= [a-z]
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("sequence grammar", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("hello") + p.literal(" ") + p.literal("world");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "hello" " " "world"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("choice grammar", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("cat") | p.literal("dog");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "cat" | "dog"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("one_or_more grammar", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.one_or_more(p.literal("a"));
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "a"+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("zero_or_more grammar", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.zero_or_more(p.literal("a"));
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "a"*
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("optional grammar", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("hello") + p.optional(p.literal(" world"));
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "hello" " world"?
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("until grammar", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.until("</tag>");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= ([^<] | "<" [^/] | "</" [^t] | "</t" [^a] | "</ta" [^g] | "</tag" [^>])*
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("complex expressions with parentheses", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.one_or_more(p.literal("a") | p.literal("b"));
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= ("a" | "b")+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("rule references", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
auto digit = p.rule("digit", p.chars("[0-9]", 1, 1));
return p.one_or_more(digit);
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
digit ::= [0-9]
root ::= digit+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("escaping in literals", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("hello\nworld\n!");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "hello\nworld\n!"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("operator<< (whitespace insertion)", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.literal("hello") << p.literal("world");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= "hello" space "world"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("emit only reachable rules", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
p.rule("orphan", p.literal("orphan"));
return p.literal("hello") + p.rule("child", p.literal(" world"));
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
child ::= " world"
root ::= "hello" child
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("emit only trigger rules (and references)", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
auto rule1 = p.rule("rule-1", p.literal("a") + p.ref("rule-2"));
p.rule("rule-2", p.literal("b") + p.ref("rule-3"), true);
p.rule("rule-3", p.literal("c") + p.ref("rule-4"));
p.rule("rule-4", p.literal("d"), true);
return rule1;
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= rule-1
rule-1 ::= "a" rule-2
rule-2 ::= "b" rule-3
rule-3 ::= "c" rule-4
rule-4 ::= "d"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
auto gbnf_lazy = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder, true);
});
assert_gbnf_equal(t, R"""(
root ::= rule-2 | rule-4
rule-2 ::= "b" rule-3
rule-3 ::= "c" rule-4
rule-4 ::= "d"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf_lazy);
});
}

View File

@@ -0,0 +1,109 @@
#include "tests.h"
void test_json_parser(testing &t) {
// Test parsing a simple JSON object
t.test("simple JSON object parsing", [](testing &t) {
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
std::string input = R"({"name": "test", "value": 42, "flag": true})";
common_peg_parse_context ctx(input);
auto result = json.parse(ctx);
t.assert_equal("result_is_success", true, result.success());
t.assert_equal("result_end", input.size(), result.end);
});
// Test parsing a JSON array with mixed types
t.test("JSON array with mixed types", [](testing &t) {
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
std::string input = R"([1, "hello", true, null, 3.14])";
common_peg_parse_context ctx(input);
auto result = json.parse(ctx);
t.assert_equal("result_is_success", true, result.success());
t.assert_equal("result_end", input.size(), result.end);
});
// Test parsing nested JSON with objects and arrays
t.test("nested JSON with objects and arrays", [](testing &t) {
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
std::string input =
R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})";
common_peg_parse_context ctx(input);
auto result = json.parse(ctx);
t.assert_equal("result_is_success", true, result.success());
t.assert_equal("result_end", input.size(), result.end);
});
// Test need_more_input() parsing - incomplete object
t.test("need_more_input() parsing - incomplete object", [](testing &t) {
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
std::string input = R"({"name": "test", "value": )";
common_peg_parse_context ctx(input, true);
auto result = json.parse(ctx);
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
});
// Test need_more_input() parsing - incomplete array
t.test("need_more_input() parsing - incomplete array", [](testing &t) {
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
std::string input = R"([1, 2, 3, )";
common_peg_parse_context ctx(input, true);
auto result = json.parse(ctx);
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
});
// Test need_more_input() parsing - incomplete nested structure
t.test("need_more_input() parsing - incomplete nested structure", [](testing &t) {
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
std::string input = R"({"data": {"nested": )";
common_peg_parse_context ctx(input, true);
auto result = json.parse(ctx);
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
});
t.test("object member", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.json_member("name", "\"" + p.chars("[a-z]") + "\"");
});
t.test("success", [&](testing &t) {
std::string input = R"("name": "bob")";
common_peg_parse_context ctx(input, false);
auto result = parser.parse(ctx);
t.assert_true("success", result.success());
});
t.test("partial", [&](testing &t) {
std::string input = R"("name": "bo)";
common_peg_parse_context ctx(input, true);
auto result = parser.parse(ctx);
t.assert_true("need more input", result.need_more_input());
});
t.test("failed", [&](testing &t) {
std::string input = R"([])";
common_peg_parse_context ctx(input, false);
auto result = parser.parse(ctx);
t.assert_true("fail", result.fail());
});
});
}

View File

@@ -0,0 +1,28 @@
#include "tests.h"
void test_json_serialization(testing &t) {
auto original = build_peg_parser([](common_peg_parser_builder & p) {
return "<tool_call>" + p.json() + "</tool_call>";
});
auto json_serialized = original.to_json().dump();
t.test("compare before/after", [&](testing &t) {
auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized));
// Test complex JSON
std::string input = R"({"name": "test", "values": [1, 2, 3], "nested": {"a": true}})";
common_peg_parse_context ctx1(input);
common_peg_parse_context ctx2(input);
auto result1 = original.parse(ctx1);
auto result2 = deserialized.parse(ctx2);
t.assert_equal("both_succeed", result1.success(), result2.success());
t.assert_equal("same_end_pos", result1.end, result2.end);
});
t.bench("deserialize", [&]() {
auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized));
}, 100);
}

View File

@@ -0,0 +1,449 @@
#include "tests.h"
#include "peg-parser.h"
#include <string>
#include <sstream>
#include <iomanip>
#include <cctype>
static void assert_result_equal(testing & t, common_peg_parse_result_type expected, common_peg_parse_result_type actual) {
t.assert_equal(common_peg_parse_result_type_name(expected), common_peg_parse_result_type_name(actual));
}
static std::string hex_dump(const std::string& str) {
std::ostringstream oss;
for (unsigned char c : str) {
if (std::isprint(c)) {
oss << c;
} else {
oss << "\\x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(c);
}
}
return oss.str();
}
void test_unicode(testing &t) {
struct test_case {
std::string input;
std::string expected_text;
common_peg_parse_result_type expected_result;
};
t.test("any", [](testing &t) {
std::vector<test_case> test_cases {
// Valid UTF-8 sequences
{"Hello", "Hello", COMMON_PEG_PARSE_RESULT_SUCCESS},
{std::string("Caf\xC3\xA9"), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS},
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
{std::string("\xF0\x9F\x9A\x80"), std::string("\xF0\x9F\x9A\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
// Incomplete UTF-8 sequences (partial bytes at end)
{std::string("Caf\xC3"), "Caf", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
{std::string("\xE4\xBD"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
{std::string("\xF0\x9F\x9A"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
// Invalid/malformed UTF-8 sequences
{std::string("\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
{std::string("Hello\x80World"), "Hello", COMMON_PEG_PARSE_RESULT_FAIL},
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
};
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
return p.sequence({p.one_or_more(p.any()), p.end()});
});
for (size_t i = 0; i < test_cases.size(); i++) {
const auto & tc = test_cases[i];
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
t.test(test_name, [&](testing &t) {
common_peg_parse_context ctx(tc.input, true);
auto result = parser.parse(ctx);
// Assert result type matches
assert_result_equal(t, tc.expected_result, result.type);
// Assert matched text if success or need_more_input
if (result.success() || result.need_more_input()) {
std::string matched = tc.input.substr(result.start, result.end - result.start);
t.assert_equal(tc.expected_text, matched);
}
});
}
});
t.test("char classes", [](testing &t) {
t.test("unicode range U+4E00-U+9FFF (CJK)", [](testing &t) {
std::vector<test_case> test_cases {
// Within range - CJK Unified Ideographs
{std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00
{std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60
{std::string("\xE5\xA5\xBD"), std::string("\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+597D
{std::string("\xE9\xBF\xBF"), std::string("\xE9\xBF\xBF"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+9FFF
// Outside range - should fail
{"a", "", COMMON_PEG_PARSE_RESULT_FAIL}, // ASCII
{std::string("\xE4\xB7\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+4DFF (before range)
{std::string("\xEA\x80\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+A000 (after range)
// Incomplete sequences in range
{std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+4E00
{std::string("\xE5\xA5"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+597D
};
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
return p.sequence({p.chars(R"([\u4E00-\u9FFF])"), p.end()});
});
for (size_t i = 0; i < test_cases.size(); i++) {
const auto & tc = test_cases[i];
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
t.test(test_name, [&](testing &t) {
common_peg_parse_context ctx(tc.input, true);
auto result = parser.parse(ctx);
// Assert result type matches
assert_result_equal(t, tc.expected_result, result.type);
// Assert matched text if success or need_more_input
if (result.success() || result.need_more_input()) {
std::string matched = tc.input.substr(result.start, result.end - result.start);
t.assert_equal(tc.expected_text, matched);
}
});
}
});
t.test("unicode range U+1F600-U+1F64F (emoticons)", [](testing &t) {
std::vector<test_case> test_cases {
// Within range - Emoticons (all 4-byte UTF-8)
{std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600
{std::string("\xF0\x9F\x98\x81"), std::string("\xF0\x9F\x98\x81"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F601
{std::string("\xF0\x9F\x99\x8F"), std::string("\xF0\x9F\x99\x8F"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F64F
// Outside range
{std::string("\xF0\x9F\x97\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F5FF (before range)
{std::string("\xF0\x9F\x99\x90"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F650 (after range)
{std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680 (outside range)
// Incomplete sequences
{std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete emoji
{std::string("\xF0\x9F"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Very incomplete
};
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
return p.sequence({p.chars(R"([\U0001F600-\U0001F64F])"), p.end()});
});
for (size_t i = 0; i < test_cases.size(); i++) {
const auto & tc = test_cases[i];
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
t.test(test_name, [&](testing &t) {
common_peg_parse_context ctx(tc.input, true);
auto result = parser.parse(ctx);
// Assert result type matches
assert_result_equal(t, tc.expected_result, result.type);
// Assert matched text if success or need_more_input
if (result.success() || result.need_more_input()) {
std::string matched = tc.input.substr(result.start, result.end - result.start);
t.assert_equal(tc.expected_text, matched);
}
});
}
});
t.test("mixed unicode ranges", [](testing &t) {
std::vector<test_case> test_cases {
// Match CJK
{std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00
{std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60
// Match emoticons
{std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600
// Match ASCII digits
{"5", "5", COMMON_PEG_PARSE_RESULT_SUCCESS},
// Don't match outside any range
{"a", "", COMMON_PEG_PARSE_RESULT_FAIL},
{std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680
// Incomplete
{std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
{std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
};
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
return p.sequence({p.chars(R"([\u4E00-\u9FFF\U0001F600-\U0001F64F0-9])"), p.end()});
});
for (size_t i = 0; i < test_cases.size(); i++) {
const auto & tc = test_cases[i];
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
t.test(test_name, [&](testing &t) {
common_peg_parse_context ctx(tc.input, true);
auto result = parser.parse(ctx);
// Assert result type matches
assert_result_equal(t, tc.expected_result, result.type);
// Assert matched text if success or need_more_input
if (result.success() || result.need_more_input()) {
std::string matched = tc.input.substr(result.start, result.end - result.start);
t.assert_equal(tc.expected_text, matched);
}
});
}
});
});
t.test("until parser", [](testing &t) {
t.test("ASCII delimiter with Unicode content", [](testing &t) {
std::vector<test_case> test_cases {
// CJK characters before delimiter
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD</tag>"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
// Emoji before delimiter
{std::string("\xF0\x9F\x98\x80</tag>"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
// Mixed content
{std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!</tag>"), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS},
};
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
return p.until("</tag>");
});
for (size_t i = 0; i < test_cases.size(); i++) {
const auto & tc = test_cases[i];
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
t.test(test_name, [&](testing &t) {
common_peg_parse_context ctx(tc.input, false);
auto result = parser.parse(ctx);
assert_result_equal(t, tc.expected_result, result.type);
if (result.success()) {
std::string matched = tc.input.substr(result.start, result.end - result.start);
t.assert_equal(tc.expected_text, matched);
}
});
}
});
t.test("incomplete UTF-8 at end", [](testing &t) {
std::vector<test_case> test_cases {
// Incomplete emoji at end, no delimiter
{std::string("content\xF0\x9F\x98"), std::string("content"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
// Incomplete CJK at end, no delimiter
{std::string("hello\xE4\xB8"), std::string("hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
// Complete content, no delimiter (should consume all valid UTF-8)
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
};
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
return p.until("</tag>");
});
for (size_t i = 0; i < test_cases.size(); i++) {
const auto & tc = test_cases[i];
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
t.test(test_name, [&](testing &t) {
common_peg_parse_context ctx(tc.input, true);
auto result = parser.parse(ctx);
assert_result_equal(t, tc.expected_result, result.type);
if (result.success() || result.need_more_input()) {
std::string matched = tc.input.substr(result.start, result.end - result.start);
t.assert_equal(tc.expected_text, matched);
}
});
}
});
t.test("malformed UTF-8", [](testing &t) {
std::vector<test_case> test_cases {
// Invalid UTF-8 bytes
{std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
// Continuation byte without lead byte
{std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL},
// Invalid continuation byte
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
};
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
return p.until("</tag>");
});
for (size_t i = 0; i < test_cases.size(); i++) {
const auto & tc = test_cases[i];
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
t.test(test_name, [&](testing &t) {
common_peg_parse_context ctx(tc.input, false);
auto result = parser.parse(ctx);
assert_result_equal(t, tc.expected_result, result.type);
});
}
});
});
t.test("json_string parser", [](testing &t) {
t.test("valid UTF-8 characters", [](testing &t) {
std::vector<test_case> test_cases {
// ASCII only
{"Hello World\"", "Hello World", COMMON_PEG_PARSE_RESULT_SUCCESS},
// 2-byte UTF-8 (accented characters)
{std::string("Caf\xC3\xA9\""), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS},
// 3-byte UTF-8 (CJK)
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
// 4-byte UTF-8 (emoji)
{std::string("\xF0\x9F\x98\x80\""), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
// Mixed content
{std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!\""), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS},
};
for (size_t i = 0; i < test_cases.size(); i++) {
const auto & tc = test_cases[i];
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
t.test(test_name, [&](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
return p.sequence({p.json_string_content(), p.literal("\"")});
});
common_peg_parse_context ctx(tc.input, false);
auto result = parser.parse(ctx);
assert_result_equal(t, tc.expected_result, result.type);
if (result.success()) {
std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote
t.assert_equal(tc.expected_text, matched);
}
});
}
});
t.test("incomplete UTF-8", [](testing &t) {
std::vector<test_case> test_cases {
// Incomplete 2-byte sequence
{std::string("Caf\xC3"), std::string("Caf"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
// Incomplete 3-byte sequence
{std::string("Hello\xE4\xB8"), std::string("Hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
// Incomplete 4-byte sequence
{std::string("Text\xF0\x9F\x98"), std::string("Text"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
// Incomplete at very start
{std::string("\xE4\xBD"), std::string(""), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
};
for (size_t i = 0; i < test_cases.size(); i++) {
const auto & tc = test_cases[i];
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
t.test(test_name, [&](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
return p.json_string_content();
});
common_peg_parse_context ctx(tc.input, true);
auto result = parser.parse(ctx);
assert_result_equal(t, tc.expected_result, result.type);
if (result.need_more_input()) {
std::string matched = tc.input.substr(result.start, result.end - result.start);
t.assert_equal(tc.expected_text, matched);
}
});
}
});
t.test("malformed UTF-8", [](testing &t) {
std::vector<test_case> test_cases {
// Invalid UTF-8 bytes
{std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
// Continuation byte without lead byte
{std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL},
// Invalid continuation byte
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
// Overlong encoding (security issue)
{std::string("\xC0\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL},
};
for (size_t i = 0; i < test_cases.size(); i++) {
const auto & tc = test_cases[i];
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
t.test(test_name, [&](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
return p.json_string_content();
});
common_peg_parse_context ctx(tc.input, false);
auto result = parser.parse(ctx);
assert_result_equal(t, tc.expected_result, result.type);
});
}
});
t.test("escape sequences with UTF-8", [](testing &t) {
std::vector<test_case> test_cases {
// Unicode escape sequence
{"Hello\\u0041\"", "Hello\\u0041", COMMON_PEG_PARSE_RESULT_SUCCESS},
// Mix of UTF-8 and escape sequences
{std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
// Escaped quote in UTF-8 string
{std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
};
for (size_t i = 0; i < test_cases.size(); i++) {
const auto & tc = test_cases[i];
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
t.test(test_name, [&](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
return p.sequence({p.json_string_content(), p.literal("\"")});
});
common_peg_parse_context ctx(tc.input, false);
auto result = parser.parse(ctx);
assert_result_equal(t, tc.expected_result, result.type);
if (result.success()) {
std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote
t.assert_equal(tc.expected_text, matched);
}
});
}
});
});
}

243
tests/peg-parser/testing.h Normal file
View File

@@ -0,0 +1,243 @@
#pragma once
#include "common.h"
#include <chrono>
#include <exception>
#include <iostream>
#include <string>
#include <regex>
#include <vector>
struct testing {
std::ostream &out;
std::vector<std::string> stack;
std::regex filter;
bool filter_tests = false;
bool throw_exception = false;
bool verbose = false;
int tests = 0;
int assertions = 0;
int failures = 0;
int unnamed = 0;
int exceptions = 0;
static constexpr std::size_t status_column = 80;
explicit testing(std::ostream &os = std::cout) : out(os) {}
std::string indent() const {
if (stack.empty()) {
return "";
}
return std::string((stack.size() - 1) * 2, ' ');
}
std::string full_name() const {
return string_join(stack, ".");
}
void log(const std::string & msg) {
if (verbose) {
out << indent() << " " << msg << "\n";
}
}
void set_filter(const std::string & re) {
filter = std::regex(re);
filter_tests = true;
}
bool should_run() const {
if (filter_tests) {
if (!std::regex_match(full_name(), filter)) {
return false;
}
}
return true;
}
template <typename F>
void run_with_exceptions(F &&f, const char *ctx) {
try {
f();
} catch (const std::exception &e) {
++failures;
++exceptions;
out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): " << e.what() << "\n";
if (throw_exception) {
throw;
}
} catch (...) {
++failures;
++exceptions;
out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): unknown\n";
if (throw_exception) {
throw;
}
}
}
void print_result(const std::string &label, int new_failures, int new_assertions, const std::string &extra = "") const {
std::string line = indent() + label;
std::string details;
if (new_assertions > 0) {
if (new_failures == 0) {
details = std::to_string(new_assertions) + " assertion(s)";
} else {
details = std::to_string(new_failures) + " of " +
std::to_string(new_assertions) + " assertion(s) failed";
}
}
if (!extra.empty()) {
if (!details.empty()) {
details += ", ";
}
details += extra;
}
if (!details.empty()) {
line += " (" + details + ")";
}
std::string status = (new_failures == 0) ? "[PASS]" : "[FAIL]";
if (line.size() + 1 < status_column) {
line.append(status_column - line.size(), ' ');
} else {
line.push_back(' ');
}
out << line << status << "\n";
}
template <typename F>
void test(const std::string &name, F f) {
stack.push_back(name);
if (!should_run()) {
stack.pop_back();
return;
}
++tests;
out << indent() << name << "\n";
int before_failures = failures;
int before_assertions = assertions;
run_with_exceptions([&] { f(*this); }, "test");
int new_failures = failures - before_failures;
int new_assertions = assertions - before_assertions;
print_result(name, new_failures, new_assertions);
stack.pop_back();
}
template <typename F>
void test(F f) {
test("test #" + std::to_string(++unnamed), f);
}
template <typename F>
void bench(const std::string &name, F f, int iterations = 100) {
stack.push_back(name);
if (!should_run()) {
stack.pop_back();
return;
}
++tests;
out << indent() << "[bench] " << name << "\n";
int before_failures = failures;
int before_assertions = assertions;
using clock = std::chrono::high_resolution_clock;
std::chrono::microseconds duration(0);
run_with_exceptions([&] {
for (auto i = 0; i < iterations; i++) {
auto start = clock::now();
f();
duration += std::chrono::duration_cast<std::chrono::microseconds>(clock::now() - start);
}
}, "bench");
auto avg_elapsed = duration.count() / iterations;
auto avg_elapsed_s = std::chrono::duration_cast<std::chrono::duration<double>>(duration).count() / iterations;
auto rate = (avg_elapsed_s > 0.0) ? (1.0 / avg_elapsed_s) : 0.0;
int new_failures = failures - before_failures;
int new_assertions = assertions - before_assertions;
std::string extra =
"n=" + std::to_string(iterations) +
" avg=" + std::to_string(avg_elapsed) + "us" +
" rate=" + std::to_string(int(rate)) + "/s";
print_result("[bench] " + name, new_failures, new_assertions, extra);
stack.pop_back();
}
template <typename F>
void bench(F f, int iterations = 100) {
bench("bench #" + std::to_string(++unnamed), f, iterations);
}
// Assertions
bool assert_true(bool cond) {
return assert_true("", cond);
}
bool assert_true(const std::string &msg, bool cond) {
++assertions;
if (!cond) {
++failures;
out << indent() << "ASSERT TRUE FAILED";
if (!msg.empty()) {
out << " : " << msg;
}
out << "\n";
return false;
}
return true;
}
template <typename A, typename B>
bool assert_equal(const A &expected, const B &actual) {
return assert_equal("", expected, actual);
}
template <typename A, typename B>
bool assert_equal(const std::string &msg, const A &expected, const B &actual) {
++assertions;
if (!(actual == expected)) {
++failures;
out << indent() << "ASSERT EQUAL FAILED";
if (!msg.empty()) {
out << " : " << msg;
}
out << "\n";
out << indent() << " expected: " << expected << "\n";
out << indent() << " actual : " << actual << "\n";
return false;
}
return true;
}
// Print summary and return an exit code
int summary() const {
out << "\n";
out << "tests : " << tests << "\n";
out << "assertions : " << assertions << "\n";
out << "failures : " << failures << "\n";
out << "exceptions : " << exceptions << "\n";
return failures == 0 ? 0 : 1;
}
};

24
tests/peg-parser/tests.h Normal file
View File

@@ -0,0 +1,24 @@
#pragma once
// Common includes for all test files
#include <nlohmann/json.hpp>
#include <string>
#include <vector>
#include "testing.h"
#include "peg-parser.h"
#include "chat-peg-parser.h"
#include "simple-tokenize.h"
struct bench_tool_call {
std::string id;
std::string name;
nlohmann::ordered_json args;
};
// Test function declarations
void test_basic(testing &t);
void test_json_parser(testing &t);
void test_gbnf_generation(testing &t);
void test_unicode(testing &t);
void test_json_serialization(testing &t);

View File

@@ -0,0 +1,768 @@
#include <string>
#include <iostream>
#include <numeric>
#include "chat-parser.h"
#include "chat-peg-parser.h"
#include "chat.h"
#include "common.h"
#include "json-schema-to-grammar.h"
#include "peg-parser.h"
#include "peg-parser/testing.h"
#include "peg-parser/simple-tokenize.h"
#include "nlohmann/json.hpp"
using json = nlohmann::ordered_json;
static json create_tools();
static void test_example_native(testing & t);
static void test_example_qwen3_coder(testing & t);
static void test_command7_parser_compare(testing & t);
int main(int argc, char *argv[]) {
testing t(std::cout);
if (argc >= 2) {
t.set_filter(argv[1]);
}
const char * verbose = getenv("LLAMA_TEST_VERBOSE");
if (verbose) {
t.verbose = std::string(verbose) == "1";
}
t.test("native", test_example_native);
t.test("qwen3 coder", test_example_qwen3_coder);
t.test("comparison", test_command7_parser_compare);
return t.summary();
}
static json create_tools() {
json tools = json::array();
json tool_weather = {
{"type", "function"},
{"function", {
{"name", "get_current_weather"},
{"description", "Get the current weather in a given location"},
{"parameters", {
{"type", "object"},
{"properties", {
{"location", {
{"type", "string"},
{"description", "The city and state, e.g. San Francisco, CA"}
}},
{"unit", {
{"type", "string"},
{"enum", {"celsius", "fahrenheit"}},
{"description", "The temperature unit to use. Infer this from the users location."}
}}
}},
{"required", {"location", "unit"}},
}},
}}
};
tools.push_back(tool_weather);
json tool_forecast = {
{"type", "function"},
{"function", {
{"name", "get_forecast"},
{"description", "Get the weather forecast for a given location"},
{"parameters", {
{"type", "object"},
{"properties", {
{"location", {
{"type", "string"},
{"description", "The city and state, e.g. San Francisco, CA"}
}},
{"unit", {
{"type", "string"},
{"enum", {"celsius", "fahrenheit"}},
{"description", "The temperature unit to use. Infer this from the users location."}
}},
{"days", {
{"type", "integer"},
{"description", "Number of days to forecast (1-10)"},
{"minimum", 1},
{"maximum", 10}
}}
}},
{"required", {"location", "unit"}},
}},
}}
};
tools.push_back(tool_forecast);
json tool_search = {
{"type", "function"},
{"function", {
{"name", "search_knowledge_base"},
{"description", "Search the internal technical documentation knowledge base."},
{"parameters", {
{"type", "object"},
{"properties", {
{"query", {
{"type", "string"},
{"description", "The search query string."}
}},
{"max_results", {
{"type", "integer"},
{"description", "The maximum number of results to return."},
{"default", 5}
}},
{"category", {
{"type", "string"},
{"enum", {"api", "troubleshooting", "billing", "general"}},
{"description", "Filter search by specific category."}
}}
}},
{"required", {"query", "category"}},
{"additionalProperties", false}
}},
{"strict", true}
}}
};
tools.push_back(tool_search);
return tools;
}
struct tool_argument {
std::string name;
std::string type;
bool is_required;
json schema;
};
struct tool_definition {
std::string name;
std::vector<tool_argument> arguments;
json schema;
};
// Test fictitious model output that emits arguments as JSON.
static void test_example_native(testing & t) {
struct test_case {
// Parameters
std::string name;
json tools;
common_chat_tool_choice tool_choice;
common_reasoning_format reasoning_format;
json json_schema;
bool parallel_tool_calls;
bool thinking_forced_open;
std::string input;
// Expect
std::string expect_reasoning;
std::string expect_content;
std::vector<common_chat_tool_call> expect_tool_calls;
};
auto build_parser = [](const test_case & tc) {
return build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE);
auto reasoning = p.eps();
if (tc.thinking_forced_open) {
// If thinking is forced open, expect a closing tag
reasoning = p.reasoning(p.until("</think>")) + "</think>" + p.space();
} else {
// Otherwise, optionally accept thinking wrapped in tags
reasoning = p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>" + p.space());
}
// tool calling parser
if (tc.tools.is_array() && !tc.tools.empty()) {
auto tools = p.choice();
for (const auto & tool : tc.tools) {
const auto & function = tool.at("function");
std::string name = function.at("name");
const auto & schema = function.at("parameters");
auto tool_name = p.json_member("name", "\"" + p.tool_name(p.literal(name)) + "\"");
auto tool_args = p.json_member("arguments", p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
tools |= p.rule("tool-" + name, p.tool_open(p.literal("{")) << tool_name << "," << tool_args << "}");
};
auto parallel_calls = p.eps();
if (tc.parallel_tool_calls) {
parallel_calls = p.zero_or_more("," << tools);
}
auto tool_call = p.trigger_rule("tool-call",
p.sequence({
p.literal("<tool_call>["),
tools,
parallel_calls,
p.literal("]</tool_call>")
})
);
return p.sequence({
(reasoning_in_content ? p.eps() : reasoning),
p.content(p.until("<tool_call>")),
p.optional(p.space() + tool_call),
p.space(),
p.end()
});
}
// response_format parser
if (tc.json_schema.is_object() && !tc.json_schema.empty()) {
return p.sequence({
(reasoning_in_content ? p.eps() : reasoning),
p.content(p.schema(p.json(), "response-output", tc.json_schema)),
p.space(),
p.end()
});
}
// Content-only parser
return p.sequence({
(reasoning_in_content ? p.eps() : reasoning),
p.content(p.rest()),
p.end()
});
});
};
std::vector<test_case> test_cases = std::vector<test_case>{
{
/* .name = */ "content with thinking_forced_open = false",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ false,
/* .input = */ (
"<think>The user said hello, I must say hello back</think>\nHello"
),
/* .expect_reasoning = */ "The user said hello, I must say hello back",
/* .expect_content = */ "Hello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "content with thinking_forced_open = false and no reasoning",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ false,
/* .input = */ (
"Hello"
),
/* .expect_reasoning = */ "",
/* .expect_content = */ "Hello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "content with thinking_forced_open = false and reasoning_format = none",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true,
/* .input = */ (
"<think>The user said hello, I must say hello back</think>\nHello"
),
/* .expect_reasoning = */ "",
/* .expect_content = */ "<think>The user said hello, I must say hello back</think>\nHello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "content with thinking_forced_open = true",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true,
/* .input = */ (
"The user said hello, I must say hello back</think>\nHello"
),
/* .expect_reasoning = */ "The user said hello, I must say hello back",
/* .expect_content = */ "Hello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "content with thinking_forced_open = true and reasoning_format = none",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true,
/* .input = */ (
"The user said hello, I must say hello back</think>\nHello"
),
/* .expect_reasoning = */ "",
/* .expect_content = */ "The user said hello, I must say hello back</think>\nHello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "tools with tool_choice = auto and no parallel_tool_calls",
/* .tools = */ create_tools(),
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true,
/* .input = */ (
"I must get the weather in New York</think>\n"
"<tool_call>["
R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})"
"]</tool_call>"
),
/* .expect_reasoning = */ "I must get the weather in New York",
/* .expect_content = */ "",
/* .expect_tool_calls = */ {{
/* .name = */ "get_current_weather",
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})",
/* .id = */ "",
}},
},
{
/* .name = */ "tools with tool_choice = auto and parallel_tool_calls",
/* .tools = */ create_tools(),
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ true,
/* .thinking_forced_open = */ true,
/* .input = */ (
"I must get the weather in New York and San Francisco and a 3 day forecast of each.</think>\nLet me search that for you."
"<tool_call>["
R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})"
", "
R"({"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}})"
", "
R"({"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}})"
", "
R"({"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}})"
"]</tool_call>"
),
/* .expect_reasoning = */ "I must get the weather in New York and San Francisco and a 3 day forecast of each.",
/* .expect_content = */ "Let me search that for you.",
/* .expect_tool_calls = */ {{
/* .name = */ "get_current_weather",
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})",
/* .id = */ "",
}, {
/* .name = */ "get_current_weather",
/* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit"})",
/* .id = */ "",
}, {
/* .name = */ "get_forecast",
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit", "days": 3})",
/* .id = */ "",
}, {
/* .name = */ "get_forecast",
/* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3})",
/* .id = */ "",
}},
},
{
/* .name = */ "response_format with thinking_forced_open = true",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {
{"type", "object"},
{"properties", {
{"invoice_number", {{"type", "string"}}},
{"amount", {{"type", "number"}}},
{"due_date", {{"type", "string"}}}
}},
{"required", {"invoice_number", "amount", "due_date"}}
},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true,
/* .input = */ (
"I must produce the invoice in the requested format</think>\n"
R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})"
),
/* .expect_reasoning = */ "I must produce the invoice in the requested format",
/* .expect_content = */ R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})",
/* .expect_tool_calls = */ {},
},
};
for (const auto & tc : test_cases) {
t.test(tc.name, [&](testing & t) {
auto parser = build_parser(tc);
auto lazy = !tc.tools.empty() && tc.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
auto grammar = build_grammar([&](const common_grammar_builder & builder) {
for (auto const & def : tc.tools) {
auto function = def.at("function");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
};
parser.build_grammar(builder, lazy);
});
t.log("Grammar:");
for (auto const & line : string_split(grammar, "\n")) {
t.log(line);
}
common_peg_parse_context ctx(tc.input, false);
auto result = parser.parse(ctx);
t.assert_true("success", result.success());
common_chat_msg msg;
auto mapper = common_chat_peg_native_mapper(msg);
mapper.from_ast(ctx.ast, result);
t.assert_equal("content equal", tc.expect_content, msg.content);
t.assert_equal("reasoning equal", tc.expect_reasoning, msg.reasoning_content);
t.assert_equal("number of tool calls", tc.expect_tool_calls.size(), msg.tool_calls.size());
for (auto i = 0u; i < std::min(tc.expect_tool_calls.size(), msg.tool_calls.size()); i++) {
t.assert_equal("tool name", tc.expect_tool_calls[i].name, msg.tool_calls[i].name);
t.assert_equal("tool args", tc.expect_tool_calls[i].arguments, msg.tool_calls[i].arguments);
}
});
}
}
static void test_example_qwen3_coder(testing & t) {
auto tools = create_tools();
auto parser = build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) {
auto content = p.rule("content", p.content(p.until("<tool_call>")));
std::vector<common_peg_parser> tool_parsers;
for (auto const & def : tools) {
auto function = def.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
auto properties = parameters.at("properties");
std::set<std::string> required_properties;
if (function.contains("required")) {
function.at("required").get_to(required_properties);
}
std::vector<common_peg_parser> arg_parsers;
for (const auto & [param_name, param_schema] : properties.items()) {
bool is_required = required_properties.find(param_name) != required_properties.end();
auto type = param_schema.value("type", "object");
auto arg = p.tool_arg(p.sequence({
p.tool_arg_open("<parameter=" + p.tool_arg_name(p.literal(param_name)) + ">"),
(type == "string" ?
p.tool_arg_string_value(
p.schema(
p.until_one_of({
"</parameter>\n<parameter=",
"</parameter>\n</function>"
}),
"tool-" + name + "-arg-" + param_name + "-schema",
param_schema,
true
)
) : p.tool_arg_json_value(
p.schema(
p.json(),
"tool-" + name + "-arg-" + param_name + "-schema",
param_schema
)
)
),
p.tool_arg_close(
"</parameter>\n" +
p.peek(p.literal("<parameter=") | p.literal("</function>"))
)
}));
arg_parsers.push_back(is_required ?
p.rule("tool-" + name + "-arg-" + param_name, arg) :
p.optional(p.rule("tool-" + name + "-arg-" + param_name, arg)));
}
tool_parsers.push_back(p.rule("tool-" + name,
p.tool_open("<function=" + p.tool_name(p.literal(name)) + ">")
<< p.sequence(arg_parsers)
<< p.tool_close(p.literal("</function>"))
));
};
auto tool_call = p.trigger_rule("tool-call",
"<tool_call>"
<< p.choice(tool_parsers)
<< "</tool_call>"
);
return content + p.zero_or_more(p.space() + tool_call) + p.end();
});
auto grammar = build_grammar([&](const common_grammar_builder & builder) {
for (auto const & def : tools) {
auto function = def.at("function");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
};
parser.build_grammar(builder);
});
t.log("Grammar:");
for (auto const & line : string_split(grammar, "\n")) {
t.log(line);
}
t.test("incremental parsing", [&](testing &t) {
std::string input =
"Let me search the knowledge base for cat pictures."
"<tool_call>\n"
"<function=search_knowledge_base>\n"
"<parameter=query>cat pictures</parameter>\n"
"<parameter=category>general</parameter>\n"
"</function>\n"
"</tool_call>";
std::vector<std::string> tokens = simple_tokenize(input);
common_chat_msg prev;
for (auto it = tokens.begin(); it != tokens.end(); it++) {
std::string in = std::accumulate(tokens.begin(), it + 1, std::string());
common_peg_parse_context ctx(in, it + 1 < tokens.end());
auto result = parser.parse(ctx);
if (!t.assert_equal("not fail", false, result.fail())) {
t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end));
}
common_chat_msg msg;
auto mapper = common_chat_peg_constructed_mapper(msg);
mapper.from_ast(ctx.ast, result);
//t.log("Input: " + input);
t.log("===========================================");
t.log("Iteration " + std::to_string(in.size()));
t.log("Reasoning: " + msg.reasoning_content);
t.log("Content : " + msg.content);
for (const auto & tc : msg.tool_calls) {
t.log("Tool name: " + tc.name);
t.log("Tool args: " + tc.arguments);
}
try {
// This shouldn't emit any runtime errors
auto diffs = common_chat_msg_diff::compute_diffs(prev, msg);
} catch(const std::exception & e) {
t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end));
t.assert_true(std::string("failed with ") + e.what(), false);
}
prev = msg;
}
});
}
void test_command7_parser_compare(testing & t) {
auto parser = build_chat_peg_native_parser([](common_chat_peg_native_builder & p) {
auto thinking = p.reasoning_block(
"<|START_THINKING|>" << p.reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>");
auto response = "<|START_RESPONSE|>" << p.content(p.until("<|END_RESPONSE|>")) << "<|END_RESPONSE|>";
auto tool_call_id = p.atomic("\"tool_call_id\"" << (":" << ("\"" + p.tool_id(p.json_string_content()) + "\"")));
auto tool_call_name = p.atomic("\"tool_name\"" << (":" << ("\"" + p.tool_name(p.json_string_content()) + "\"")));
auto tool_call_args = "\"parameters\"" << (":" << p.tool_args(p.json()));
auto tool_call_fields = p.rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args);
auto tool_call = p.rule("tool-call", p.tool(
p.tool_open(p.literal("{"))
<< tool_call_fields
<< p.zero_or_more( p.literal(",") << tool_call_fields)
<< p.tool_close(p.literal("}"))
));
auto tool_calls = p.rule("tool-calls",
"<|START_ACTION|>"
<< ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]")
<< "<|END_ACTION|>");
return p.optional(thinking) << (tool_calls | response) + p.end();
});
auto test_current = [&](const common_peg_arena & p, const std::string & input, bool is_partial, bool print_results) {
common_peg_parse_context ctx(input, is_partial);
auto result = p.parse(ctx);
common_chat_msg msg;
auto mapper = common_chat_peg_native_mapper(msg);
mapper.from_ast(ctx.ast, result);
if (print_results) {
std::cout << "== Parsed (new) ==\n";
std::cout << "=== Reasoning ===\n";
std::cout << msg.reasoning_content << "\n";
std::cout << "\n\n=== Content ===\n";
std::cout << msg.content << "\n";
std::cout << "\n\n=== Tool Calls ===\n";
for (const auto & tc : msg.tool_calls) {
std::cout << "id: " << tc.id << "\n";
std::cout << "name: " << tc.name << "\n";
std::cout << "args: " << tc.arguments << "\n";
}
}
};
auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) {
// Original common_chat_combinator_parser taken from chat.cpp
common_chat_msg_parser builder(
input,
/* .is_partial = */ need_more_input,
{
/* .format = */ COMMON_CHAT_FORMAT_GENERIC,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
}
);
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
if (auto res = builder.try_find_regex(start_action_regex)) {
// If we didn't extract thoughts, prelude includes them.
auto tool_calls = builder.consume_json_with_dumped_args({ { "parameters" } });
for (const auto & tool_call : tool_calls.value) {
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
if (tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(end_action_regex);
} else if (auto res = builder.try_find_regex(start_response_regex)) {
if (!builder.try_find_regex(end_response_regex)) {
builder.add_content(builder.consume_rest());
throw common_chat_msg_partial_exception(end_response_regex.str());
}
} else {
builder.add_content(builder.consume_rest());
}
if (print_results) {
std::cout << "== Parsed (legacy) ==\n";
std::cout << "=== Reasoning ===\n";
std::cout << builder.result().reasoning_content << "\n";
std::cout << "\n\n=== Content ===\n";
std::cout << builder.result().content << "\n";
std::cout << "\n\n=== Tool Calls ===\n";
for (const auto & tc : builder.result().tool_calls) {
std::cout << "id: " << tc.id << "\n";
std::cout << "name: " << tc.name << "\n";
std::cout << "args: " << tc.arguments << "\n";
}
}
};
std::string reasoning = "To plan an effective trip to Japan that includes both historical sites and modern attractions within a "
"budget of $4000 for a two-week stay, we need to:\n\n"
"1. Identify key historical sites and modern attractions in Japan.\n"
"2. Find affordable accommodation options that provide a balance between comfort and cost.\n"
"3. Determine the best modes of transportation for getting around Japan.\n"
"4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without "
"overspending.\n"
"5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees "
"to attractions.";
std::vector<std::tuple<std::string, std::string, nlohmann::json>> tool_calls = {{
"call_0",
"plan_trip",
nlohmann::json::parse(R"({
"destination": "Japan",
"duration": 14,
"budget": 4000,
"interests": ["historical sites", "modern attractions"],
"accommodation_preferences": "affordable",
"transportation_preferences": "efficient",
"meal_preferences": "local cuisine"
})")
}};
std::vector<std::string> tokens;
// Build tokens
if (!reasoning.empty()) {
auto tokenized = simple_tokenize(reasoning);
tokens.emplace_back("<|START_THINKING|>");
tokens.insert(tokens.end(), tokenized.begin(), tokenized.end());
tokens.emplace_back("<|END_THINKING|>");
}
if (!tool_calls.empty()) {
tokens.emplace_back("<|START_ACTION|>");
auto json = nlohmann::json::array();
for (const auto & tc : tool_calls) {
auto tc_json = nlohmann::json::object();
tc_json["tool_call_id"] = std::get<0>(tc);
tc_json["tool_name"] = std::get<1>(tc);
tc_json["parameters"] = std::get<2>(tc);
json.push_back(tc_json);
}
auto tokenized = simple_tokenize(json.dump(-1, ' ', true));
tokens.insert(tokens.end(), tokenized.begin(), tokenized.end());
tokens.emplace_back("<|END_ACTION|>");
}
std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string());
// Run tests
t.test("legacy_parse", [&](testing & /* t */) {
test_legacy(input, false, false);
});
t.test("current_parse", [&](testing & /* t */) {
test_current(parser, input, false, false);
});
// Run benchmarks
t.bench("legacy_parse_benchmark complete", [&]() {
test_legacy(input, false, false);
});
t.bench("legacy_parse_benchmark incremental", [&]() {
std::string in;
for (auto i = 0u; i < tokens.size(); i++) {
in += tokens[i];
try {
test_legacy(in, i + 1 < tokens.size(), false);
} catch (common_chat_msg_partial_exception & /* e */) {
// Do nothing, this is expected
}
}
}, 20);
t.bench("current_parse_benchmark complete", [&]() {
test_current(parser, input, false, false);
}, 100);
t.bench("current_parse_benchmark incremental", [&]() {
std::string in;
for (auto i = 0u; i < tokens.size(); i++) {
in += tokens[i];
test_current(parser, in, i + 1 < tokens.size(), false);
}
}, 20);
}

25
tests/test-peg-parser.cpp Normal file
View File

@@ -0,0 +1,25 @@
#include <cstdlib>
#include <string>
#include <iostream>
#include "peg-parser/tests.h"
int main(int argc, char *argv[]) {
testing t(std::cout);
if (argc >= 2) {
t.set_filter(argv[1]);
}
const char * verbose = getenv("LLAMA_TEST_VERBOSE");
if (verbose) {
t.verbose = std::string(verbose) == "1";
}
t.test("basic", test_basic);
t.test("unicode", test_unicode);
t.test("json", test_json_parser);
t.test("gbnf", test_gbnf_generation);
t.test("serialization", test_json_serialization);
return t.summary();
}

View File

@@ -2,11 +2,6 @@ set(TARGET llama-server)
include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR})
if (MINGW)
# fix: https://github.com/ggml-org/llama.cpp/actions/runs/9651004652/job/26617901362?pr=8006
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
endif()
if (NOT LLAMA_HTTPLIB)
message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF")
endif()

Binary file not shown.

View File

@@ -791,7 +791,7 @@ static void handle_media(
SRV_INF("downloading image from '%s'\n", url.c_str());
auto res = common_remote_get_content(url, params);
if (200 <= res.first && res.first < 300) {
SRV_INF("downloaded %ld bytes\n", res.second.size());
SRV_INF("downloaded %zu bytes\n", res.second.size());
raw_buffer data;
data.insert(data.end(), res.second.begin(), res.second.end());
out_files.push_back(data);
@@ -1045,6 +1045,9 @@ json oaicompat_chat_params_parse(
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
if (!chat_params.parser.empty()) {
llama_params["chat_parser"] = chat_params.parser;
}
// Handle "n" field
int n_choices = json_value(body, "n", 1);

View File

@@ -101,8 +101,6 @@ struct server_slot {
std::string generated_text;
llama_tokens generated_tokens;
common_chat_msg chat_msg;
std::vector<completion_token_output> generated_token_probs;
bool has_next_token = true;
@@ -153,9 +151,6 @@ struct server_slot {
llama_token sampled;
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::vector<std::string> generated_tool_call_ids;
// stats
size_t n_sent_text = 0; // number of sent text character
@@ -183,13 +178,10 @@ struct server_slot {
stop = STOP_TYPE_NONE;
stopping_word = "";
n_sent_text = 0;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
generated_tokens.clear();
generated_token_probs.clear();
chat_msg = {};
json_schema = json();
generated_tool_call_ids.clear();
// clear speculative decoding stats
n_draft_total = 0;
@@ -302,23 +294,6 @@ struct server_slot {
return timings;
}
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
GGML_ASSERT(task);
auto previous_msg = chat_msg;
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
auto new_msg = common_chat_parse(
generated_text,
/* is_partial= */ stop != STOP_TYPE_EOS,
task->params.oaicompat_chat_syntax);
if (!new_msg.empty()) {
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
}
return chat_msg;
}
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
GGML_ASSERT(task);
@@ -1284,8 +1259,6 @@ struct server_context_impl {
} else {
res->content = tkn.text_to_send;
res->tokens = { tkn.tok };
slot.update_chat_msg(res->oaicompat_msg_diffs);
}
res->n_decoded = slot.n_decoded;
@@ -1317,8 +1290,14 @@ struct server_context_impl {
res->id_slot = slot.id;
res->index = slot.task->index;
res->content = slot.generated_text;
res->tokens = std::move(slot.generated_tokens);
// in stream mode, content and tokens are already in last partial chunk
if (slot.task->params.stream) {
res->content = "";
res->tokens = llama_tokens{};
} else {
res->content = std::move(slot.generated_text);
res->tokens = std::move(slot.generated_tokens);
}
res->timings = slot.get_timings();
res->prompt = slot.task->tokens.detokenize(ctx, true);
res->response_fields = std::move(slot.task->params.response_fields);
@@ -1338,7 +1317,6 @@ struct server_context_impl {
res->res_type = slot.task->params.res_type;
res->oaicompat_model = slot.task->params.oaicompat_model;
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
if (slot.task->params.sampling.n_probs > 0) {
@@ -2596,6 +2574,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
try {
std::vector<server_task> tasks;
// tracking generation state and partial tool calls
std::vector<task_result_state> states;
const auto & prompt = data.at("prompt");
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
@@ -2611,6 +2592,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
tasks.reserve(inputs.size());
states.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
@@ -2628,10 +2610,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = ctx_server.model_name;
states.push_back(task.params.oaicompat_chat_syntax);
tasks.push_back(std::move(task));
}
rd.set_states(std::move(states));
rd.post_tasks(std::move(tasks));
} catch (const std::exception & e) {
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
@@ -2657,7 +2641,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
// if single request, return single object instead of array
res->ok(arr.size() == 1 ? arr[0] : arr);
}
} else {
// in streaming mode, the first error must be treated as non-stream response
// this is to match the OAI API behavior
@@ -2676,76 +2659,92 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
}
// next responses are streamed
// to be sent immediately
json first_result_json = first_result->to_json();
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
res->data = format_anthropic_sse(first_result->to_json());
res->data = format_anthropic_sse(first_result_json);
} else {
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
res->data = format_oai_sse(first_result_json);
}
res->status = 200;
res->content_type = "text/event-stream";
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
if (should_stop()) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
if (!res_this->data.empty()) {
// flush the first chunk
output = std::move(res_this->data);
res_this->data.clear();
return true;
}
server_response_reader & rd = res_this->rd;
// check if there is more data
if (!rd.has_next()) {
static auto format_error = [](task_response_type res_type, const json & res_json) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
// Anthropic doesn't send [DONE], message_stop was already sent
output = "";
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
output = "data: [DONE]\n\n";
} else {
output = "";
}
SRV_DBG("%s", "all results received, terminating stream\n");
return false; // no more data, terminate
}
// receive subsequent results
auto result = rd.next(should_stop);
if (result == nullptr) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
// send the results
json res_json = result->to_json();
if (result->is_error()) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse({
return format_anthropic_sse({
{"event", "error"},
{"data", res_json},
});
} else {
output = format_oai_sse(json {{ "error", res_json }});
return format_oai_sse(json {{ "error", res_json }});
}
SRV_DBG("%s", "error received during streaming, terminating stream\n");
return false; // terminate on error
} else {
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse(res_json);
} else {
output = format_oai_sse(res_json);
}
}
};
// has next data, continue
return true;
try {
if (should_stop()) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
if (!res_this->data.empty()) {
// flush the first chunk
output = std::move(res_this->data);
res_this->data.clear();
return true;
}
server_response_reader & rd = res_this->rd;
// check if there is more data
if (!rd.has_next()) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
// Anthropic doesn't send [DONE], message_stop was already sent
output = "";
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
output = "data: [DONE]\n\n";
} else {
output = "";
}
SRV_DBG("%s", "all results received, terminating stream\n");
return false; // no more data, terminate
}
// receive subsequent results
auto result = rd.next(should_stop);
if (result == nullptr) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
// send the results
if (result->is_error()) {
json res_json = result->to_json();
output = format_error(res_type, res_json);
SRV_DBG("%s", "error received during streaming, terminating stream\n");
return false; // terminate on error
} else {
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
json res_json = result->to_json();
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse(res_json);
} else {
output = format_oai_sse(res_json);
}
}
// has next data, continue
return true;
} catch (const std::exception & e) {
json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
output = format_error(res_type, error_json);
// terminate on exception
return false;
}
};
}

View File

@@ -900,6 +900,7 @@ static bool should_strip_proxy_header(const std::string & header_name) {
// Headers that get duplicated when router forwards child responses
if (header_name == "server" ||
header_name == "transfer-encoding" ||
header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
header_name == "keep-alive") {
return true;
}

View File

@@ -271,6 +271,10 @@ void server_response::terminate() {
// server_response_reader
//
void server_response_reader::set_states(std::vector<task_result_state> && states) {
this->states = std::move(states);
}
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
id_tasks = server_task::get_list_id(tasks);
queue_results.add_waiting_tasks(tasks);
@@ -298,6 +302,12 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
SRV_DBG("%s", "received error result, stopping further processing\n");
return result;
}
if (!states.empty()) {
// update the generation state if needed
size_t idx = result->get_index();
GGML_ASSERT(idx < states.size());
result->update(states[idx]);
}
if (result->is_stop()) {
received_count++;
}

View File

@@ -7,6 +7,8 @@
#include <mutex>
#include <unordered_set>
// struct for managing server tasks
// in most cases, use server_response_reader to post new tasks and retrieve results
struct server_queue {
private:
int id = 0;
@@ -67,6 +69,8 @@ private:
void cleanup_pending_task(int id_target);
};
// struct for managing server responses
// in most cases, use server_response_reader to retrieve results
struct server_response {
private:
bool running = true;
@@ -120,6 +124,10 @@ struct server_response_reader {
bool cancelled = false;
int polling_interval_seconds;
// tracking generation state and partial tool calls
// only used by streaming completions
std::vector<task_result_state> states;
// should_stop function will be called each polling_interval_seconds
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
@@ -127,6 +135,7 @@ struct server_response_reader {
stop();
}
void set_states(std::vector<task_result_state> && states);
void post_tasks(std::vector<server_task> && tasks);
bool has_next() const;

View File

@@ -297,6 +297,9 @@ task_params server_task::params_from_json_cmpl(
params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
if (data.contains("chat_parser")) {
params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get<std::string>());
}
}
{
@@ -562,6 +565,7 @@ std::vector<unsigned char> completion_token_output::str_to_bytes(const std::stri
// server_task_result_cmpl_final
//
json server_task_result_cmpl_final::to_json() {
GGML_ASSERT(is_updated && "update() must be called before to_json()");
switch (res_type) {
case TASK_RESPONSE_TYPE_NONE:
return to_json_non_oaicompat();
@@ -579,8 +583,8 @@ json server_task_result_cmpl_final::to_json() {
json server_task_result_cmpl_final::to_json_non_oaicompat() {
json res = json {
{"index", index},
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"tokens", stream ? llama_tokens {} : tokens},
{"content", content},
{"tokens", tokens},
{"id_slot", id_slot},
{"stop", true},
{"model", oaicompat_model},
@@ -616,7 +620,7 @@ json server_task_result_cmpl_final::to_json_oaicompat() {
json res = json {
{"choices", json::array({
json{
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"text", content},
{"index", index},
{"logprobs", logprobs},
{"finish_reason", finish_reason},
@@ -697,6 +701,25 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
return res;
}
common_chat_msg task_result_state::update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs) {
generated_text += text_added;
auto msg_prv_copy = chat_msg;
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
auto new_msg = common_chat_parse(
generated_text,
is_partial,
oaicompat_chat_syntax);
if (!new_msg.empty()) {
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg);
}
return chat_msg;
}
json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
std::time_t t = std::time(0);
std::string finish_reason = "length";
@@ -953,6 +976,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
// server_task_result_cmpl_partial
//
json server_task_result_cmpl_partial::to_json() {
GGML_ASSERT(is_updated && "update() must be called before to_json()");
switch (res_type) {
case TASK_RESPONSE_TYPE_NONE:
return to_json_non_oaicompat();

View File

@@ -161,6 +161,25 @@ struct result_prompt_progress {
json to_json() const;
};
// struct for tracking the state of a task (e.g., for streaming)
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_syntax oaicompat_chat_syntax;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs);
};
struct server_task_result {
int id = -1;
int id_slot = -1;
@@ -175,6 +194,9 @@ struct server_task_result {
virtual int get_index() {
return -1;
}
virtual void update(task_result_state &) {
// only used by server_task_result_cmpl_*
}
virtual json to_json() = 0;
virtual ~server_task_result() = default;
};
@@ -233,9 +255,10 @@ struct server_task_result_cmpl_final : server_task_result {
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg;
common_chat_msg oaicompat_msg; // to be populated by update()
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;
virtual int get_index() override {
return index;
@@ -247,6 +270,11 @@ struct server_task_result_cmpl_final : server_task_result {
virtual json to_json() override;
virtual void update(task_result_state & state) override {
is_updated = true;
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
}
json to_json_non_oaicompat();
json to_json_oaicompat();
@@ -280,7 +308,8 @@ struct server_task_result_cmpl_partial : server_task_result {
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;
virtual int get_index() override {
return index;
@@ -292,6 +321,11 @@ struct server_task_result_cmpl_partial : server_task_result {
virtual json to_json() override;
virtual void update(task_result_state & state) override {
is_updated = true;
state.update_chat_msg(content, true, oaicompat_msg_diffs);
}
json to_json_non_oaicompat();
json to_json_oaicompat();

View File

@@ -65,6 +65,7 @@ def test_server_slots():
def test_load_split_model():
global server
server.offline = False
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
server.model_alias = "tinyllama-split"

View File

@@ -17,7 +17,6 @@ def create_server():
]
)
def test_router_chat_completion_stream(model: str, success: bool):
# TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server
global server
server.start()
content = ""
@@ -48,3 +47,148 @@ def test_router_chat_completion_stream(model: str, success: bool):
else:
assert ex is not None
assert content == ""
def _get_model_status(model_id: str) -> str:
res = server.make_request("GET", "/models")
assert res.status_code == 200
for item in res.body.get("data", []):
if item.get("id") == model_id or item.get("model") == model_id:
return item["status"]["value"]
raise AssertionError(f"Model {model_id} not found in /models response")
def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
deadline = time.time() + timeout
last_status = None
while time.time() < deadline:
last_status = _get_model_status(model_id)
if last_status in desired:
return last_status
time.sleep(1)
raise AssertionError(
f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
)
def _load_model_and_wait(
model_id: str, timeout: int = 60, headers: dict | None = None
) -> None:
load_res = server.make_request(
"POST", "/models/load", data={"model": model_id}, headers=headers
)
assert load_res.status_code == 200
assert isinstance(load_res.body, dict)
assert load_res.body.get("success") is True
_wait_for_model_status(model_id, {"loaded"}, timeout=timeout)
def test_router_unload_model():
global server
server.start()
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
_load_model_and_wait(model_id)
unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
assert unload_res.status_code == 200
assert unload_res.body.get("success") is True
_wait_for_model_status(model_id, {"unloaded"})
def test_router_models_max_evicts_lru():
global server
server.models_max = 2
server.start()
candidate_models = [
"ggml-org/tinygemma3-GGUF:Q8_0",
"ggml-org/test-model-stories260K",
"ggml-org/test-model-stories260K-infill",
]
# Load only the first 2 models to fill the cache
first, second, third = candidate_models[:3]
_load_model_and_wait(first, timeout=120)
_load_model_and_wait(second, timeout=120)
# Verify both models are loaded
assert _get_model_status(first) == "loaded"
assert _get_model_status(second) == "loaded"
# Load the third model - this should trigger LRU eviction of the first model
_load_model_and_wait(third, timeout=120)
# Verify eviction: third is loaded, first was evicted
assert _get_model_status(third) == "loaded"
assert _get_model_status(first) == "unloaded"
def test_router_no_models_autoload():
global server
server.no_models_autoload = True
server.start()
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
res = server.make_request(
"POST",
"/v1/chat/completions",
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert res.status_code == 400
assert "error" in res.body
_load_model_and_wait(model_id)
success_res = server.make_request(
"POST",
"/v1/chat/completions",
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert success_res.status_code == 200
assert "error" not in success_res.body
def test_router_api_key_required():
global server
server.api_key = "sk-router-secret"
server.start()
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
auth_headers = {"Authorization": f"Bearer {server.api_key}"}
res = server.make_request(
"POST",
"/v1/chat/completions",
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert res.status_code == 401
assert res.body.get("error", {}).get("type") == "authentication_error"
_load_model_and_wait(model_id, headers=auth_headers)
authed = server.make_request(
"POST",
"/v1/chat/completions",
headers=auth_headers,
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert authed.status_code == 200
assert "error" not in authed.body

View File

@@ -7,6 +7,7 @@ import subprocess
import os
import re
import json
from json import JSONDecodeError
import sys
import requests
import time
@@ -83,6 +84,9 @@ class ServerProcess:
pooling: str | None = None
draft: int | None = None
api_key: str | None = None
models_dir: str | None = None
models_max: int | None = None
no_models_autoload: bool | None = None
lora_files: List[str] | None = None
enable_ctx_shift: int | None = False
draft_min: int | None = None
@@ -143,6 +147,10 @@ class ServerProcess:
server_args.extend(["--hf-repo", self.model_hf_repo])
if self.model_hf_file:
server_args.extend(["--hf-file", self.model_hf_file])
if self.models_dir:
server_args.extend(["--models-dir", self.models_dir])
if self.models_max is not None:
server_args.extend(["--models-max", self.models_max])
if self.n_batch:
server_args.extend(["--batch-size", self.n_batch])
if self.n_ubatch:
@@ -204,6 +212,8 @@ class ServerProcess:
server_args.extend(["--draft-min", self.draft_min])
if self.no_webui:
server_args.append("--no-webui")
if self.no_models_autoload:
server_args.append("--no-models-autoload")
if self.jinja:
server_args.append("--jinja")
else:
@@ -295,7 +305,13 @@ class ServerProcess:
result = ServerResponse()
result.headers = dict(response.headers)
result.status_code = response.status_code
result.body = response.json() if parse_body else None
if parse_body:
try:
result.body = response.json()
except JSONDecodeError:
result.body = response.text
else:
result.body = None
print("Response from server", json.dumps(result.body, indent=2))
return result
@@ -434,8 +450,9 @@ class ServerPreset:
@staticmethod
def tinyllama2() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K.gguf"
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/test-model-stories260K"
server.model_hf_file = None
server.model_alias = "tinyllama-2"
server.n_ctx = 512
server.n_batch = 32
@@ -479,8 +496,8 @@ class ServerPreset:
def tinyllama_infill() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
server.model_hf_file = None
server.model_alias = "tinyllama-infill"
server.n_ctx = 2048
server.n_batch = 1024
@@ -537,6 +554,7 @@ class ServerPreset:
@staticmethod
def router() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
# router server has no models
server.model_file = None
server.model_alias = None

View File

@@ -15,7 +15,7 @@ sequenceDiagram
Stores->>DB: load conversations
Stores->>API: GET /props
API-->>Stores: {role: "router"}
Stores->>API: GET /models
Stores->>API: GET /v1/models
API-->>Stores: models[] with status (loaded/available)
loop each loaded model
Stores->>API: GET /props?model=X
@@ -28,7 +28,7 @@ sequenceDiagram
alt model not loaded
Stores->>API: POST /models/load
loop poll status
Stores->>API: GET /models
Stores->>API: GET /v1/models
API-->>Stores: check if loaded
end
Stores->>API: GET /props?model=X

View File

@@ -56,7 +56,7 @@ sequenceDiagram
UI->>modelsStore: fetchRouterModels()
activate modelsStore
modelsStore->>ModelsSvc: listRouter()
ModelsSvc->>API: GET /models
ModelsSvc->>API: GET /v1/models
API-->>ModelsSvc: ApiRouterModelsListResponse
Note right of API: {data: [{id, status, path, in_cache}]}
modelsStore->>modelsStore: routerModels = $state(data)
@@ -132,7 +132,7 @@ sequenceDiagram
loop poll every 500ms (max 60 attempts)
modelsStore->>modelsStore: fetchRouterModels()
modelsStore->>ModelsSvc: listRouter()
ModelsSvc->>API: GET /models
ModelsSvc->>API: GET /v1/models
API-->>ModelsSvc: models[]
modelsStore->>modelsStore: getModelStatus(modelId)
alt status === LOADED
@@ -165,7 +165,7 @@ sequenceDiagram
modelsStore->>modelsStore: pollForModelStatus(modelId, UNLOADED)
loop poll until unloaded
modelsStore->>ModelsSvc: listRouter()
ModelsSvc->>API: GET /models
ModelsSvc->>API: GET /v1/models
end
modelsStore->>modelsStore: modelLoadingStates.set(modelId, false)

View File

@@ -64,7 +64,10 @@
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
let isRecording = $state(false);
let message = $state('');
let pasteLongTextToFileLength = $derived(Number(currentConfig.pasteLongTextToFileLen) || 2500);
let pasteLongTextToFileLength = $derived.by(() => {
const n = Number(currentConfig.pasteLongTextToFileLen);
return Number.isNaN(n) ? 2500 : n;
});
let previousIsLoading = $state(isLoading);
let recordingSupported = $state(false);
let textareaRef: ChatFormTextarea | undefined = $state(undefined);

View File

@@ -1,6 +1,5 @@
<script lang="ts">
import { ChatMessage } from '$lib/components/app';
import { DatabaseService } from '$lib/services/database';
import { chatStore } from '$lib/stores/chat.svelte';
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
import { getMessageSiblings } from '$lib/utils';
@@ -19,7 +18,7 @@
const conversation = activeConversation();
if (conversation) {
DatabaseService.getConversationMessages(conversation.id).then((messages) => {
conversationsStore.getConversationMessages(conversation.id).then((messages) => {
allConversationMessages = messages;
});
} else {

View File

@@ -7,7 +7,6 @@
import { Textarea } from '$lib/components/ui/textarea';
import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config';
import { settingsStore } from '$lib/stores/settings.svelte';
import { ParameterSyncService } from '$lib/services/parameter-sync';
import { ChatSettingsParameterSourceIndicator } from '$lib/components/app';
import type { Component } from 'svelte';
@@ -22,7 +21,7 @@
// Helper function to get parameter source info for syncable parameters
function getParameterSourceInfo(key: string) {
if (!ParameterSyncService.canSyncParameter(key)) {
if (!settingsStore.canSyncParameter(key)) {
return null;
}

View File

@@ -2,9 +2,8 @@
import { Download, Upload } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { DialogConversationSelection } from '$lib/components/app';
import { DatabaseService } from '$lib/services/database';
import { createMessageCountMap } from '$lib/utils';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
let exportedConversations = $state<DatabaseConversation[]>([]);
let importedConversations = $state<DatabaseConversation[]>([]);
@@ -21,15 +20,15 @@
async function handleExportClick() {
try {
const allConversations = await DatabaseService.getAllConversations();
const allConversations = conversations();
if (allConversations.length === 0) {
alert('No conversations to export');
return;
}
const conversationsWithMessages = await Promise.all(
allConversations.map(async (conv) => {
const messages = await DatabaseService.getConversationMessages(conv.id);
allConversations.map(async (conv: DatabaseConversation) => {
const messages = await conversationsStore.getConversationMessages(conv.id);
return { conv, messages };
})
);
@@ -47,7 +46,7 @@
try {
const allData: ExportedConversations = await Promise.all(
selectedConversations.map(async (conv) => {
const messages = await DatabaseService.getConversationMessages(conv.id);
const messages = await conversationsStore.getConversationMessages(conv.id);
return { conv: $state.snapshot(conv), messages: $state.snapshot(messages) };
})
);
@@ -135,9 +134,7 @@
.snapshot(fullImportData)
.filter((item) => selectedIds.has(item.conv.id));
await DatabaseService.importConversations(selectedData);
await conversationsStore.loadConversations();
await conversationsStore.importConversationsData(selectedData);
importedConversations = selectedConversations;
showImportSummary = true;

View File

@@ -3,8 +3,7 @@
import * as Table from '$lib/components/ui/table';
import { BadgeModality, CopyToClipboardIcon } from '$lib/components/app';
import { serverStore } from '$lib/stores/server.svelte';
import { modelsStore } from '$lib/stores/models.svelte';
import { ChatService } from '$lib/services/chat';
import { modelsStore, modelOptions, modelsLoading } from '$lib/stores/models.svelte';
import { formatFileSize, formatParameters, formatNumber } from '$lib/utils';
interface Props {
@@ -16,38 +15,24 @@
let serverProps = $derived(serverStore.props);
let modelName = $derived(modelsStore.singleModelName);
let models = $derived(modelOptions());
let isLoadingModels = $derived(modelsLoading());
// Get the first model for single-model mode display
let firstModel = $derived(models[0] ?? null);
// Get modalities from modelStore using the model ID from the first model
// For now it supports only for single-model mode, will be extended with further improvements for multi-model functioanlities
let modalities = $derived.by(() => {
if (!modelsData?.data?.[0]?.id) return [];
return modelsStore.getModelModalitiesArray(modelsData.data[0].id);
if (!firstModel?.id) return [];
return modelsStore.getModelModalitiesArray(firstModel.id);
});
let modelsData = $state<ApiModelListResponse | null>(null);
let isLoadingModels = $state(false);
// Fetch models data when dialog opens
// Ensure models are fetched when dialog opens
$effect(() => {
if (open && !modelsData) {
loadModelsData();
if (open && models.length === 0) {
modelsStore.fetch();
}
});
async function loadModelsData() {
isLoadingModels = true;
try {
modelsData = await ChatService.getModels();
} catch (error) {
console.error('Failed to load models data:', error);
// Set empty data to prevent infinite loading
modelsData = { object: 'list', data: [] };
} finally {
isLoadingModels = false;
}
}
</script>
<Dialog.Root bind:open {onOpenChange}>
@@ -70,8 +55,8 @@
<div class="flex items-center justify-center py-8">
<div class="text-sm text-muted-foreground">Loading model information...</div>
</div>
{:else if modelsData && modelsData.data.length > 0}
{@const modelMeta = modelsData.data[0].meta}
{:else if firstModel}
{@const modelMeta = firstModel.meta}
{#if serverProps}
<Table.Root>

View File

@@ -126,8 +126,13 @@ export const TEXT_FILE_TYPES = {
mimeTypes: [MimeTypeText.JAVA]
},
[FileTypeText.CPP]: {
extensions: [FileExtensionText.CPP, FileExtensionText.C, FileExtensionText.H],
mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.C_SRC, MimeTypeText.C_HDR]
extensions: [
FileExtensionText.CPP,
FileExtensionText.C,
FileExtensionText.H,
FileExtensionText.HPP
],
mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.CPP_HDR, MimeTypeText.C_SRC, MimeTypeText.C_HDR]
},
[FileTypeText.PHP]: {
extensions: [FileExtensionText.PHP],
@@ -183,10 +188,30 @@ export const TEXT_FILE_TYPES = {
},
[FileTypeText.LATEX]: {
extensions: [FileExtensionText.TEX],
mimeTypes: [MimeTypeText.LATEX]
mimeTypes: [MimeTypeText.LATEX, MimeTypeText.TEX, MimeTypeText.TEX_APP]
},
[FileTypeText.BIBTEX]: {
extensions: [FileExtensionText.BIB],
mimeTypes: [MimeTypeText.BIBTEX]
},
[FileTypeText.CUDA]: {
extensions: [FileExtensionText.CU, FileExtensionText.CUH],
mimeTypes: [MimeTypeText.CUDA]
},
[FileTypeText.VULKAN]: {
extensions: [FileExtensionText.COMP],
mimeTypes: [MimeTypeText.PLAIN]
},
[FileTypeText.HASKELL]: {
extensions: [FileExtensionText.HS],
mimeTypes: [MimeTypeText.HASKELL]
},
[FileTypeText.CSHARP]: {
extensions: [FileExtensionText.CS],
mimeTypes: [MimeTypeText.CSHARP]
},
[FileTypeText.PROPERTIES]: {
extensions: [FileExtensionText.PROPERTIES],
mimeTypes: [MimeTypeText.PROPERTIES]
}
} as const;

View File

@@ -62,7 +62,12 @@ export enum FileTypeText {
VUE = 'vue',
SVELTE = 'svelte',
LATEX = 'latex',
BIBTEX = 'bibtex'
BIBTEX = 'bibtex',
CUDA = 'cuda',
VULKAN = 'vulkan',
HASKELL = 'haskell',
CSHARP = 'csharp',
PROPERTIES = 'properties'
}
// File extension enums
@@ -121,7 +126,14 @@ export enum FileExtensionText {
VUE = '.vue',
SVELTE = '.svelte',
TEX = '.tex',
BIB = '.bib'
BIB = '.bib',
CU = '.cu',
CUH = '.cuh',
COMP = '.comp',
HPP = '.hpp',
HS = '.hs',
PROPERTIES = '.properties',
CS = '.cs'
}
// MIME type enums
@@ -165,7 +177,10 @@ export enum MimeTypeText {
CSV = 'text/csv',
PYTHON = 'text/x-python',
JAVA = 'text/x-java-source',
CPP_HDR = 'text/x-c++hdr',
CPP_SRC = 'text/x-c++src',
CSHARP = 'text/x-csharp',
HASKELL = 'text/x-haskell',
C_SRC = 'text/x-csrc',
C_HDR = 'text/x-chdr',
PHP = 'text/x-php',
@@ -182,6 +197,10 @@ export enum MimeTypeText {
DART = 'text/x-dart',
VUE = 'text/x-vue',
SVELTE = 'text/x-svelte',
LATEX = 'text/x-tex',
BIBTEX = 'text/x-bibtex'
TEX = 'text/x-tex',
TEX_APP = 'application/x-tex',
LATEX = 'application/x-latex',
BIBTEX = 'text/x-bibtex',
CUDA = 'text/x-cuda',
PROPERTIES = 'text/properties'
}

View File

@@ -677,48 +677,6 @@ export class ChatService {
// Utilities
// ─────────────────────────────────────────────────────────────────────────────
/**
* Get server properties - static method for API compatibility (to be refactored)
*/
static async getServerProps(): Promise<ApiLlamaCppServerProps> {
try {
const response = await fetch(`./props`, {
headers: getJsonHeaders()
});
if (!response.ok) {
throw new Error(`Failed to fetch server props: ${response.status}`);
}
const data = await response.json();
return data;
} catch (error) {
console.error('Error fetching server props:', error);
throw error;
}
}
/**
* Get model information from /models endpoint (to be refactored)
*/
static async getModels(): Promise<ApiModelListResponse> {
try {
const response = await fetch(`./models`, {
headers: getJsonHeaders()
});
if (!response.ok) {
throw new Error(`Failed to fetch models: ${response.status} ${response.statusText}`);
}
const data = await response.json();
return data;
} catch (error) {
console.error('Error fetching models:', error);
throw error;
}
}
/**
* Injects a system message at the beginning of the conversation if provided.
* Checks for existing system messages to avoid duplication.

View File

@@ -7,7 +7,7 @@ import { getJsonHeaders } from '$lib/utils';
*
* This service handles communication with model-related endpoints:
* - `/v1/models` - OpenAI-compatible model list (MODEL + ROUTER mode)
* - `/models` - Router-specific model management (ROUTER mode only)
* - `/models/load`, `/models/unload` - Router-specific model management (ROUTER mode only)
*
* **Responsibilities:**
* - List available models
@@ -43,7 +43,7 @@ export class ModelsService {
* Returns models with load status, paths, and other metadata
*/
static async listRouter(): Promise<ApiRouterModelsListResponse> {
const response = await fetch(`${base}/models`, {
const response = await fetch(`${base}/v1/models`, {
headers: getJsonHeaders()
});

View File

@@ -519,6 +519,19 @@ class ConversationsStore {
return await DatabaseService.getConversationMessages(convId);
}
/**
* Imports conversations from provided data (without file picker)
* @param data - Array of conversation data with messages
* @returns Import result with counts
*/
async importConversationsData(
data: ExportedConversations
): Promise<{ imported: number; skipped: number }> {
const result = await DatabaseService.importConversations(data);
await this.loadConversations();
return result;
}
/**
* Adds a message to the active messages array
* Used by chatStore when creating new messages

View File

@@ -370,6 +370,10 @@ class SettingsStore {
return { ...this.config };
}
canSyncParameter(key: string): boolean {
return ParameterSyncService.canSyncParameter(key);
}
/**
* Get parameter information including source for a specific parameter
*/

View File

@@ -77,6 +77,13 @@ export function getFileTypeCategory(mimeType: string): FileTypeCategory | null {
case MimeTypeText.SVELTE:
case MimeTypeText.LATEX:
case MimeTypeText.BIBTEX:
case MimeTypeText.CUDA:
case MimeTypeText.CPP_HDR:
case MimeTypeText.CSHARP:
case MimeTypeText.HASKELL:
case MimeTypeText.PROPERTIES:
case MimeTypeText.TEX:
case MimeTypeText.TEX_APP:
return FileTypeCategory.TEXT;
default:
@@ -144,6 +151,12 @@ export function getFileTypeCategoryByExtension(filename: string): FileTypeCatego
case FileExtensionText.SVELTE:
case FileExtensionText.TEX:
case FileExtensionText.BIB:
case FileExtensionText.COMP:
case FileExtensionText.CU:
case FileExtensionText.CUH:
case FileExtensionText.HPP:
case FileExtensionText.HS:
case FileExtensionText.PROPERTIES:
return FileTypeCategory.TEXT;
default:

View File

@@ -144,4 +144,7 @@ if (CPPHTTPLIB_OPENSSL_SUPPORT)
find_library(SECURITY_FRAMEWORK Security REQUIRED)
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
endif()
if (WIN32 AND NOT MSVC)
target_link_libraries(${TARGET} PUBLIC crypt32)
endif()
endif()